WillHeld commited on
Commit
a891312
·
verified ·
1 Parent(s): 403c2fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -13
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import spaces
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import gradio as gr
 
4
 
5
  checkpoint = "WillHeld/soft-raccoon"
6
  device = "cuda"
@@ -13,20 +14,28 @@ def predict(message, history, temperature, top_p):
13
  input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
14
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
15
 
16
- streamer = gr.TelegramStreamer() # Use Gradio's built-in streamer
 
17
 
18
- # Generate with streaming
19
- model.generate(
20
- inputs,
21
- max_new_tokens=1024,
22
- temperature=float(temperature),
23
- top_p=float(top_p),
24
- do_sample=True,
25
- streamer=streamer
26
- )
 
 
 
 
27
 
28
- # The streamer will handle returning the tokens
29
- return streamer
 
 
 
30
 
31
  with gr.Blocks() as demo:
32
  chatbot = gr.ChatInterface(
 
1
  import spaces
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import gradio as gr
4
+ from threading import Thread
5
 
6
  checkpoint = "WillHeld/soft-raccoon"
7
  device = "cuda"
 
14
  input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
15
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
16
 
17
+ # Create a streamer
18
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
19
 
20
+ # Set up generation parameters
21
+ generation_kwargs = {
22
+ "input_ids": inputs,
23
+ "max_new_tokens": 1024,
24
+ "temperature": float(temperature),
25
+ "top_p": float(top_p),
26
+ "do_sample": True,
27
+ "streamer": streamer,
28
+ }
29
+
30
+ # Run generation in a separate thread
31
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
32
+ thread.start()
33
 
34
+ # Yield from the streamer as tokens are generated
35
+ partial_text = ""
36
+ for new_text in streamer:
37
+ partial_text += new_text
38
+ yield partial_text
39
 
40
  with gr.Blocks() as demo:
41
  chatbot = gr.ChatInterface(