ShenghaoYummy commited on
Commit
4287d7f
·
verified ·
1 Parent(s): d8c1d71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -24,15 +24,10 @@ def generate(message, history):
24
  """
25
  # rebuild a single prompt string from history + current message
26
  prompt = ""
27
-
28
- # Add conversation history
29
  for user_msg, assistant_msg in history:
30
  prompt += f"User: {user_msg}\n"
31
  prompt += f"Assistant: {assistant_msg}\n"
32
-
33
- # Add current user message
34
- prompt += f"User: {message}\n"
35
- prompt += "Assistant:"
36
 
37
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
38
  outputs = model.generate(
@@ -42,17 +37,18 @@ def generate(message, history):
42
  temperature=0.7,
43
  )
44
  text = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
- # strip everything before the last "Assistant:"
46
  reply = text.split("Assistant:")[-1].strip()
47
-
48
  return reply
49
 
50
- # 3) build Gradio ChatInterface
51
- demo = gr.ChatInterface(
52
- fn=generate,
53
- title="TinyLlama-1.1B Chat API",
54
- description="Chat with TinyLlama-1.1B and call via /api/predict",
55
- type="messages",
 
 
 
56
  )
57
 
58
  # 4) launch
 
24
  """
25
  # rebuild a single prompt string from history + current message
26
  prompt = ""
 
 
27
  for user_msg, assistant_msg in history:
28
  prompt += f"User: {user_msg}\n"
29
  prompt += f"Assistant: {assistant_msg}\n"
30
+ prompt += f"User: {message}\nAssistant:"
 
 
 
31
 
32
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
33
  outputs = model.generate(
 
37
  temperature=0.7,
38
  )
39
  text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
40
  reply = text.split("Assistant:")[-1].strip()
 
41
  return reply
42
 
43
+ # 3) build Gradio ChatInterface *with open_routes enabled*
44
+ demo = (
45
+ gr.ChatInterface(
46
+ fn=generate,
47
+ title="TinyLlama-1.1B Chat API",
48
+ description="Chat with TinyLlama-1.1B and call via /api/predict",
49
+ type="messages",
50
+ )
51
+ .queue(open_routes=True) # ← allow direct HTTP POST to /api/predict
52
  )
53
 
54
  # 4) launch