File size: 5,222 Bytes
5601c60
 
 
 
 
 
58966a1
5601c60
89183a0
d00f229
 
 
5343cd4
d00f229
7d7d860
5601c60
 
d00f229
5601c60
 
 
 
 
 
81d2ef5
7d7d860
 
 
 
 
81d2ef5
7d7d860
81d2ef5
73ab258
5601c60
a23c36a
c1073c4
d00f229
81d2ef5
0242952
d00f229
81d2ef5
d00f229
5601c60
81d2ef5
 
 
 
 
 
 
 
 
 
d00f229
c1073c4
 
5601c60
 
c1073c4
d00f229
 
c1073c4
81d2ef5
0242952
 
 
d00f229
 
 
 
 
 
 
 
 
 
 
89183a0
d00f229
 
 
89183a0
 
 
5601c60
d00f229
5601c60
 
 
 
 
 
 
 
 
a05ac69
d00f229
 
a05ac69
 
 
d00f229
0242952
d00f229
a05ac69
d00f229
20960a5
d00f229
20960a5
0242952
d00f229
 
 
 
 
 
 
 
 
 
 
 
0242952
 
d00f229
 
0242952
 
d00f229
 
 
 
 
 
a05ac69
 
 
 
d00f229
a05ac69
34826da
a05ac69
d00f229
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
from fastapi import FastAPI, HTTPException
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from pydantic import BaseModel
import traceback
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationChain
from langchain.prompts import PromptTemplate
from starlette.responses import StreamingResponse
import asyncio
import json
from langchain_community.llms import HuggingFacePipeline 
import uvicorn
from huggingface_hub import login

app = FastAPI()

# Get the Hugging Face API token from environment variables (BEST PRACTICE)
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")

if HUGGINGFACEHUB_API_TOKEN is None:
    raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")

# --- Explicitly log in to Hugging Face Hub ---
try:
    login(token=HUGGINGFACEHUB_API_TOKEN)
    print("Successfully logged into Hugging Face Hub.")
except Exception as e:
    print(f"Failed to log into Hugging Face Hub: {e}")
    # The app will likely fail to load the model if login fails, so this print is for debugging.

# --- Use Mistral 7B Instruct v0.3 model ---
model_id = "mistralai/Mistral-7B-Instruct-v0.3"

tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACEHUB_API_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto", # 'auto' handles device placement, including offloading
    torch_dtype=torch.bfloat16, 
    trust_remote_code=True,
    token=HUGGINGFACEHUB_API_TOKEN 
)

# --- REMOVED: model.to(device) ---
# When device_map="auto" is used, accelerate handles device placement.
# Manually moving the model can cause conflicts and RuntimeErrors.
# if torch.backends.mps.is_available():
#     device = "mps"
# elif torch.cuda.is_available():
#     device = "cuda"
# else:
#     device = "cpu"
# model.to(device) # This line is removed

# k=5 means it will keep the last 5 human-AI interaction pairs (10 messages total)
memory = ConversationBufferWindowMemory(k=5)

# Initialize Langchain HuggingFacePipeline
llm = HuggingFacePipeline(pipeline=pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,  
    return_full_text=True, 
    temperature=0.2,      
    do_sample=True,        
))

# --- UPDATED PROMPT TEMPLATE ---
template = """<|im_start|>system
You are a concise and direct AI assistant named Siddhi.
You strictly avoid asking any follow-up questions.
You do not generate any additional conversational turns (e.g., "Human: ...").
If asked for your name, you respond with "I am Siddhi."
If you do not know the answer to a question, you truthfully state that it does not know.
<|im_end|>
<|im_start|>user
{history}
{input}<|im_end|>
<|im_start|>assistant
"""

PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)

# Initialize Langchain ConversationChain
conversation = ConversationChain(llm=llm, memory=memory, prompt=PROMPT, verbose=True)

class QuestionRequest(BaseModel):
    question: str

class ChatResponse(BaseModel):
    response: str

@app.post("/api/generate")
async def generate_text(request: QuestionRequest):
    async def generate_stream():
        started_streaming_ai_response = False
        
        try:
            response_stream = conversation.stream({"input": request.question})

            stop_sequences_to_check = ["Human:", "AI:", "\nHuman:", "\nAI:", "<|im_end|>"]
            assistant_start_marker = "<|im_start|>assistant\n" 

            for chunk in response_stream:
                full_text_chunk = ""
                if 'response' in chunk:
                    full_text_chunk = chunk['response']
                else:
                    full_text_chunk = str(chunk) 

                if not started_streaming_ai_response:
                    if assistant_start_marker in full_text_chunk:
                        token_content = full_text_chunk.split(assistant_start_marker, 1)[1]
                        started_streaming_ai_response = True
                    else:
                        token_content = ""
                else:
                    token_content = full_text_chunk

                for stop_seq in stop_sequences_to_check:
                    if stop_seq in token_content:
                        token_content = token_content.split(stop_seq, 1)[0] 
                        if token_content: 
                            yield json.dumps({"content": token_content}) + "\n"
                            await asyncio.sleep(0.01)
                        yield json.dumps({"status": "completed"}) + "\n" 
                        return 

                if token_content:
                    yield json.dumps({"content": token_content}) + "\n"
                    await asyncio.sleep(0.01)

            yield json.dumps({"status": "completed"}) + "\n"

        except Exception as e:
            print("Error during streaming generation:")
            traceback.print_exc()
            yield json.dumps({"error": str(e)}) + "\n"

    return StreamingResponse(generate_stream(), media_type="application/json")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))