chatbot-demo / app.py
tr3n1ttty's picture
some fixes in generating kwargs
dc8887b
raw
history blame
1.33 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from functools import partial
tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/saiga_llama3_8b")
model = AutoModelForCausalLM.from_pretrained("IlyaGusev/saiga_llama3_8b", torch_dtype=torch.bfloat16)
model = model
def predict(message, history):
print(history)
history_transformer_format = history + [{"role": "user", "content": message},
{"role": "assistant", "content": ""}]
model_inputs = tokenizer.apply_chat_template(history_transformer_format, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=1.0,
num_beams=1,
)
generating_func = partial(model.generate, model_inputs)
t = Thread(target=generating_func, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
if new_token != '<':
partial_message += new_token
yield partial_message
gr.ChatInterface(predict).launch(share=True)