Spaces:
Paused
Paused
File size: 5,397 Bytes
5601c60 58966a1 5601c60 44f89b9 5601c60 c1073c4 5601c60 c1073c4 5601c60 c1073c4 5601c60 c1073c4 5601c60 a05ac69 48d0a68 a05ac69 20960a5 48d0a68 20960a5 48d0a68 20960a5 a05ac69 48d0a68 a05ac69 48d0a68 a05ac69 5601c60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
from fastapi import FastAPI, HTTPException
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from pydantic import BaseModel
import traceback
#from langchain.memory import ConversationBufferMemory
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationChain
from starlette.responses import StreamingResponse # <-- NEW IMPORT
import asyncio
from langchain_community.llms import HuggingFacePipeline
app = FastAPI()
# Get the Hugging Face API token from environment variables (BEST PRACTICE)
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if HUGGINGFACEHUB_API_TOKEN is None:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-1.5B-Instruct",
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
token=HUGGINGFACEHUB_API_TOKEN)
#print(f"Tokenizer attributes: {dir(tokenizer)}")
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device= "cuda"
else :
device = "cpu"
model.to(device)
# k=5 means it will keep the last 5 human-AI interaction pairs (10 messages total)
memory = ConversationBufferWindowMemory(k=5)
# Initialize Langchain HuggingFacePipeline
llm = HuggingFacePipeline(pipeline=pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512, # Adjust as needed for desired response length
return_full_text=False, # Crucial for getting only the AI's response, esp when ans is small
temperature=0.7, # Controls randomness (0.0 for deterministic, 1.0 for very creative)
do_sample=True # Enable sampling for more varied outputs
))
# Initialize Langchain ConversationChain
# verbose=True for debugging LangChain's pro
conversation = ConversationChain(llm=llm, memory=memory,verbose=True)
class QuestionRequest(BaseModel):
question: str
class ChatResponse(BaseModel):
response: str
@app.post("/api/generate")
async def generate_text(request: QuestionRequest):
async def generate_stream():
try:
# Use LangChain's .stream() method for token-by-token generation
# This will yield chunks of the response as they are produced
response_stream = conversation.stream({"input": request.question})
for chunk in response_stream:
token_content = ""
# Each chunk is typically a dictionary with a 'content' key
# We want to send just the new token/text back.
# Ensure the chunk is stringified and followed by a newline for client parsing.
# For more robust streaming, consider Server-Sent Events (SSE) format:
# yield f"data: {json.dumps({'token': chunk.content})}\n\n"
# For simplicity, we'll just yield the content directly for now.
if 'response' in chunk:
token_content= chunk['response']
else:
token_content= str(chunk)
yield json.dumps({"content":token_content}) +"\n"
await asyncio.sleep(0.01) # Small delay to allow client to process chunks
#optionally send final end msg
yield json.dumps({"status":"completed"}) +"\n"
except Exception as e:
print("Error during streaming generation:")
traceback.print_exc()
# You might want to yield an error message to the client here
yield f"ERROR: {str(e)}\n"
# Return a StreamingResponse, which will send chunks as they are yielded by generate_stream()
# media_type can be "text/event-stream" for SSE, or "text/plain" for simple newline-delimited text.
# For simplicity, we'll start with "text/plain" for easier initial client parsing.
return StreamingResponse(generate_stream(), media_type="text/plain")
# below when not using langchain fully
# try:
# # Retrieve history
# history = memory.load_memory_variables({})['history']
# # Create prompt with history and current question
# prompt = f"History:\n{history}\nQuestion: {request.question}\nAnswer:"
# inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)
# with torch.no_grad():
# outputs = model.generate(
# inputs=inputs['input_ids'], # Pass the 'input_ids' tensor
# attention_mask=inputs['attention_mask'],
# max_length=300,
# num_beams=5,
# no_repeat_ngram_size=2,
# temperature=0.7,
# top_k=50,
# top_p=0.95,
# do_sample=True,
# eos_token_id=tokenizer.convert_tokens_to_ids("<|endoftext|>"),
# pad_token_id=tokenizer.convert_tokens_to_ids("<|endoftext|>")
# )
# response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# return {"response": response}
# except Exception as e:
# print("Error during generation:")
# traceback.print_exc()
# raise HTTPException(status_code=500, detail=str(e))
|