|
from lmdeploy.serve.gradio.turbomind_coupled import * |
|
from lmdeploy.messages import TurbomindEngineConfig |
|
|
|
backend_config = TurbomindEngineConfig(max_batch_size=1, cache_max_entry_count=0.05, model_format='awq') |
|
model_path = 'internlm/internlm2-chat-20b-4bits' |
|
|
|
InterFace.async_engine = AsyncEngine( |
|
model_path=model_path, |
|
backend='turbomind', |
|
backend_config=backend_config, |
|
tp=1) |
|
|
|
with gr.Blocks(css=CSS, theme=THEME) as demo: |
|
state_chatbot = gr.State([]) |
|
state_session_id = gr.State(0) |
|
|
|
with gr.Column(elem_id='container'): |
|
gr.Markdown('## LMDeploy Playground') |
|
|
|
chatbot = gr.Chatbot( |
|
elem_id='chatbot', |
|
label=InterFace.async_engine.engine.model_name) |
|
instruction_txtbox = gr.Textbox( |
|
placeholder='Please input the instruction', |
|
label='Instruction') |
|
with gr.Row(): |
|
cancel_btn = gr.Button(value='Cancel', interactive=False) |
|
reset_btn = gr.Button(value='Reset') |
|
with gr.Row(): |
|
request_output_len = gr.Slider(1, |
|
2048, |
|
value=512, |
|
step=1, |
|
label='Maximum new tokens') |
|
top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p') |
|
temperature = gr.Slider(0.01, |
|
1.5, |
|
value=0.7, |
|
step=0.01, |
|
label='Temperature') |
|
|
|
send_event = instruction_txtbox.submit(chat_stream_local, [ |
|
instruction_txtbox, state_chatbot, cancel_btn, reset_btn, |
|
state_session_id, top_p, temperature, request_output_len |
|
], [state_chatbot, chatbot, cancel_btn, reset_btn]) |
|
instruction_txtbox.submit( |
|
lambda: gr.Textbox.update(value=''), |
|
[], |
|
[instruction_txtbox], |
|
) |
|
cancel_btn.click( |
|
cancel_local_func, |
|
[state_chatbot, cancel_btn, reset_btn, state_session_id], |
|
[state_chatbot, cancel_btn, reset_btn], |
|
cancels=[send_event]) |
|
|
|
reset_btn.click(reset_local_func, |
|
[instruction_txtbox, state_chatbot, state_session_id], |
|
[state_chatbot, chatbot, instruction_txtbox], |
|
cancels=[send_event]) |
|
|
|
def init(): |
|
with InterFace.lock: |
|
InterFace.global_session_id += 1 |
|
new_session_id = InterFace.global_session_id |
|
return new_session_id |
|
|
|
demo.load(init, inputs=None, outputs=[state_session_id]) |
|
|
|
demo.queue(concurrency_count=InterFace.async_engine.instance_num, |
|
max_size=100).launch() |
|
|