File size: 6,346 Bytes
5601c60
 
 
 
 
 
58966a1
5601c60
89183a0
d00f229
 
 
 
 
5601c60
 
d00f229
5601c60
 
 
 
 
 
d00f229
 
5601c60
d00f229
c1073c4
d00f229
 
 
 
 
 
5601c60
 
d00f229
5601c60
d00f229
 
 
5601c60
 
d00f229
c1073c4
 
5601c60
 
c1073c4
d00f229
 
c1073c4
d00f229
 
 
34826da
d00f229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89183a0
d00f229
 
 
89183a0
 
 
5601c60
d00f229
5601c60
 
 
 
 
 
 
 
 
a05ac69
d00f229
 
 
a05ac69
 
 
d00f229
 
 
 
a05ac69
d00f229
20960a5
d00f229
20960a5
d00f229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a05ac69
 
 
 
d00f229
 
a05ac69
d00f229
34826da
a05ac69
d00f229
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
from fastapi import FastAPI, HTTPException
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from pydantic import BaseModel
import traceback
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationChain
from langchain.prompts import PromptTemplate
from starlette.responses import StreamingResponse
import asyncio
import json
from langchain_community.llls import HuggingFacePipeline
import uvicorn

app = FastAPI()

# Get the Hugging Face API token from environment variables (BEST PRACTICE)
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")

if HUGGINGFACEHUB_API_TOKEN is None:
    raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")

# --- UPDATED: Use Llama 3.1 8B Instruct model ---
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16, # torch.bfloat16 is generally good for Llama, can try torch.float16 if issues
    trust_remote_code=True,
    token=HUGGINGFACEHUB_API_TOKEN
)

if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

model.to(device)

# k=5 means it will keep the last 5 human-AI interaction pairs (10 messages total)
memory = ConversationBufferWindowMemory(k=5)

# Initialize Langchain HuggingFacePipeline
llm = HuggingFacePipeline(pipeline=pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,  # Allows for longer, detailed answers when required
    # --- IMPORTANT FIX: Set return_full_text to True and handle slicing manually ---
    return_full_text=True,
    temperature=0.2,      # Controls randomness (0.0 for deterministic, 1.0 for very creative)
    do_sample=True,        # Enable sampling for more varied outputs
    # --- IMPORTANT FIX: REMOVED stop_sequence from pipeline initialization ---
    # This prevents the TypeError and we handle stopping manually below.
))

# --- UPDATED PROMPT TEMPLATE ---
# Using the recommended chat format for Llama models and explicit instructions.
template = """<|im_start|>system
You are a concise and direct AI assistant named Siddhi.
You strictly avoid asking any follow-up questions.
You do not generate any additional conversational turns (e.g., "Human: ...").
If asked for your name, you respond with "I am Siddhi."
If you do not know the answer to a question, you truthfully state that it does not know.
<|im_end|>
<|im_start|>user
{history}
{input}<|im_end|>
<|im_start|>assistant
"""

PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)

# Initialize Langchain ConversationChain
conversation = ConversationChain(llm=llm, memory=memory, prompt=PROMPT, verbose=True)

class QuestionRequest(BaseModel):
    question: str

class ChatResponse(BaseModel):
    response: str

@app.post("/api/generate")
async def generate_text(request: QuestionRequest):
    async def generate_stream():
        # Flag to indicate when we've started streaming the AI's actual response
        started_streaming_ai_response = False
        
        try:
            response_stream = conversation.stream({"input": request.question})

            # Define stop sequences for manual checking
            stop_sequences_to_check = ["Human:", "AI:", "\nHuman:", "\nAI:", "<|im_end|>"]
            assistant_start_marker = "<|im_start|>assistant\n" # Marker from the prompt template

            for chunk in response_stream:
                full_text_chunk = ""
                if 'response' in chunk:
                    full_text_chunk = chunk['response']
                else:
                    full_text_chunk = str(chunk) # Fallback for unexpected chunk format

                # Logic to extract only the AI's response from the full text chunk
                if not started_streaming_ai_response:
                    if assistant_start_marker in full_text_chunk:
                        # Split the chunk at the assistant's start marker and take the part after it
                        token_content = full_text_chunk.split(assistant_start_marker, 1)[1]
                        started_streaming_ai_response = True
                    else:
                        # If the marker is not yet in the chunk, this chunk is still part of the prompt.
                        # We don't yield anything yet.
                        token_content = ""
                else:
                    # Once we've started, all subsequent chunks are AI's response
                    token_content = full_text_chunk

                # --- Manual stopping logic ---
                # Check if the generated content contains a stop sequence.
                # If it does, truncate the content and break the loop.
                for stop_seq in stop_sequences_to_check:
                    if stop_seq in token_content:
                        token_content = token_content.split(stop_seq, 1)[0] # Truncate at the stop sequence
                        if token_content: # Yield any content before stop sequence
                            yield json.dumps({"content": token_content}) + "\n"
                            await asyncio.sleep(0.01)
                        yield json.dumps({"status": "completed"}) + "\n" # Signal completion
                        return # Exit the generator function

                # Only yield if there's actual content to send after processing
                if token_content:
                    yield json.dumps({"content": token_content}) + "\n"
                    await asyncio.sleep(0.01)

            # Send a final completion message if the stream finishes naturally
            yield json.dumps({"status": "completed"}) + "\n"

        except Exception as e:
            print("Error during streaming generation:")
            traceback.print_exc()
            # Yield error message in JSON format
            yield json.dumps({"error": str(e)}) + "\n"

    # Return a StreamingResponse with application/json media type
    return StreamingResponse(generate_stream(), media_type="application/json")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))