Spaces:
Paused
Paused
File size: 5,222 Bytes
5601c60 58966a1 5601c60 89183a0 d00f229 5343cd4 d00f229 7d7d860 5601c60 d00f229 5601c60 81d2ef5 7d7d860 81d2ef5 7d7d860 81d2ef5 73ab258 5601c60 a23c36a c1073c4 d00f229 81d2ef5 0242952 d00f229 81d2ef5 d00f229 5601c60 81d2ef5 d00f229 c1073c4 5601c60 c1073c4 d00f229 c1073c4 81d2ef5 0242952 d00f229 89183a0 d00f229 89183a0 5601c60 d00f229 5601c60 a05ac69 d00f229 a05ac69 d00f229 0242952 d00f229 a05ac69 d00f229 20960a5 d00f229 20960a5 0242952 d00f229 0242952 d00f229 0242952 d00f229 a05ac69 d00f229 a05ac69 34826da a05ac69 d00f229 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
from fastapi import FastAPI, HTTPException
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from pydantic import BaseModel
import traceback
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationChain
from langchain.prompts import PromptTemplate
from starlette.responses import StreamingResponse
import asyncio
import json
from langchain_community.llms import HuggingFacePipeline
import uvicorn
from huggingface_hub import login
app = FastAPI()
# Get the Hugging Face API token from environment variables (BEST PRACTICE)
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if HUGGINGFACEHUB_API_TOKEN is None:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
# --- Explicitly log in to Hugging Face Hub ---
try:
login(token=HUGGINGFACEHUB_API_TOKEN)
print("Successfully logged into Hugging Face Hub.")
except Exception as e:
print(f"Failed to log into Hugging Face Hub: {e}")
# The app will likely fail to load the model if login fails, so this print is for debugging.
# --- Use Mistral 7B Instruct v0.3 model ---
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACEHUB_API_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", # 'auto' handles device placement, including offloading
torch_dtype=torch.bfloat16,
trust_remote_code=True,
token=HUGGINGFACEHUB_API_TOKEN
)
# --- REMOVED: model.to(device) ---
# When device_map="auto" is used, accelerate handles device placement.
# Manually moving the model can cause conflicts and RuntimeErrors.
# if torch.backends.mps.is_available():
# device = "mps"
# elif torch.cuda.is_available():
# device = "cuda"
# else:
# device = "cpu"
# model.to(device) # This line is removed
# k=5 means it will keep the last 5 human-AI interaction pairs (10 messages total)
memory = ConversationBufferWindowMemory(k=5)
# Initialize Langchain HuggingFacePipeline
llm = HuggingFacePipeline(pipeline=pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
return_full_text=True,
temperature=0.2,
do_sample=True,
))
# --- UPDATED PROMPT TEMPLATE ---
template = """<|im_start|>system
You are a concise and direct AI assistant named Siddhi.
You strictly avoid asking any follow-up questions.
You do not generate any additional conversational turns (e.g., "Human: ...").
If asked for your name, you respond with "I am Siddhi."
If you do not know the answer to a question, you truthfully state that it does not know.
<|im_end|>
<|im_start|>user
{history}
{input}<|im_end|>
<|im_start|>assistant
"""
PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)
# Initialize Langchain ConversationChain
conversation = ConversationChain(llm=llm, memory=memory, prompt=PROMPT, verbose=True)
class QuestionRequest(BaseModel):
question: str
class ChatResponse(BaseModel):
response: str
@app.post("/api/generate")
async def generate_text(request: QuestionRequest):
async def generate_stream():
started_streaming_ai_response = False
try:
response_stream = conversation.stream({"input": request.question})
stop_sequences_to_check = ["Human:", "AI:", "\nHuman:", "\nAI:", "<|im_end|>"]
assistant_start_marker = "<|im_start|>assistant\n"
for chunk in response_stream:
full_text_chunk = ""
if 'response' in chunk:
full_text_chunk = chunk['response']
else:
full_text_chunk = str(chunk)
if not started_streaming_ai_response:
if assistant_start_marker in full_text_chunk:
token_content = full_text_chunk.split(assistant_start_marker, 1)[1]
started_streaming_ai_response = True
else:
token_content = ""
else:
token_content = full_text_chunk
for stop_seq in stop_sequences_to_check:
if stop_seq in token_content:
token_content = token_content.split(stop_seq, 1)[0]
if token_content:
yield json.dumps({"content": token_content}) + "\n"
await asyncio.sleep(0.01)
yield json.dumps({"status": "completed"}) + "\n"
return
if token_content:
yield json.dumps({"content": token_content}) + "\n"
await asyncio.sleep(0.01)
yield json.dumps({"status": "completed"}) + "\n"
except Exception as e:
print("Error during streaming generation:")
traceback.print_exc()
yield json.dumps({"error": str(e)}) + "\n"
return StreamingResponse(generate_stream(), media_type="application/json")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
|