AI-chatbot / app.py
ShenghaoYummy's picture
fix type bug
8512311
raw
history blame
1.62 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import os
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# 1) load model & tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
load_in_4bit=True, # comment out to use full precision
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
# 2) define inference function
def generate(messages):
"""
messages: List of alternating [user, assistant, user, ...]
returns: [user, assistant, user, assistant, ...] with model's reply appended
"""
# rebuild a single prompt string
prompt = ""
for i in range(0, len(messages), 2):
prompt += f"User: {messages[i]}\n"
if i+1 < len(messages):
prompt += f"Assistant: {messages[i+1]}\n"
prompt += "Assistant:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=128,
do_sample=True,
temperature=0.7,
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# strip everything before the last "Assistant:"
reply = text.split("Assistant:")[-1].strip()
messages.append(reply)
return messages
# 3) build Gradio ChatInterface
demo = gr.ChatInterface(
fn=generate,
title="TinyLlama-1.1B Chat API",
description="Chat with TinyLlama-1.1B and call via /api/predict",
type="messages",
)
# 4) launch
if __name__ == "__main__":
demo.launch()