kgrammar-2 / app.py
davidkim205's picture
Update app.py
1a2430b verified
raw
history blame
1.91 kB
import gradio as gr
from huggingface_hub import InferenceClient
import re
import time
# Function to create an Inference Client based on selected model
def create_inference_client(model_name):
return InferenceClient(model_name)
# Function to generate a response
def respond(
message,
system_message,
model,
max_tokens,
temperature,
top_p,
):
# Create InferenceClient based on model selection
client = create_inference_client(model)
messages = [{"role": "system", "content": system_message}]
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
print(response)
yield response
# Gradio interface setup
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(label="User Message"),
gr.Textbox(value="한국어 문맥상 부자연스러운 부분을 찾으시오. 오류 문장과 개수는 <incorrect grammar> </incorrect grammar> tag, 즉 <incorrect grammar> - 오류 문장과 설명 </incorrect grammar> 안에 담겨 있으며, <wrong count> </wrong count> tag, 즉 <wrong count> 오류 개수 </wrong count> 이다.", label="System message"),
gr.Dropdown(choices=["davidkim205/kgrammar-2-1b", "davidkim205/kgrammar-2-3b"], value="davidkim205/kgrammar-2-1b", label="Model Selection"),
gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
outputs="textbox"
)
if __name__ == "__main__":
demo.launch()