dharmendra commited on
Commit
d00f229
·
1 Parent(s): 0b5b6d7

using Llama 3.1 8B instruct

Browse files
Files changed (1) hide show
  1. app.py +93 -62
app.py CHANGED
@@ -4,72 +4,80 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  import torch
5
  from pydantic import BaseModel
6
  import traceback
7
- #from langchain.memory import ConversationBufferMemory
8
  from langchain.memory import ConversationBufferWindowMemory
9
  from langchain.chains import ConversationChain
10
- from starlette.responses import StreamingResponse # <-- NEW IMPORT
11
- import asyncio
12
- from langchain_community.llms import HuggingFacePipeline
13
- import json
14
  from langchain.prompts import PromptTemplate
 
 
 
 
 
15
 
16
  app = FastAPI()
 
17
  # Get the Hugging Face API token from environment variables (BEST PRACTICE)
18
  HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
19
 
20
  if HUGGINGFACEHUB_API_TOKEN is None:
21
  raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
22
 
 
 
23
 
24
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
25
  model = AutoModelForCausalLM.from_pretrained(
26
- "Qwen/Qwen2.5-1.5B-Instruct",
27
- device_map="auto",
28
- torch_dtype=torch.bfloat16,
29
- trust_remote_code=True,
30
- token=HUGGINGFACEHUB_API_TOKEN)
31
- #print(f"Tokenizer attributes: {dir(tokenizer)}")
32
 
33
  if torch.backends.mps.is_available():
34
- device = "mps"
35
  elif torch.cuda.is_available():
36
- device= "cuda"
37
- else :
38
- device = "cpu"
39
 
40
  model.to(device)
 
41
  # k=5 means it will keep the last 5 human-AI interaction pairs (10 messages total)
42
  memory = ConversationBufferWindowMemory(k=5)
43
 
44
  # Initialize Langchain HuggingFacePipeline
45
  llm = HuggingFacePipeline(pipeline=pipeline(
46
- "text-generation",
47
- model=model,
48
  tokenizer=tokenizer,
49
- max_new_tokens=512, # Adjust as needed for desired response length
50
- return_full_text=False, # Crucial for getting only the AI's response, esp when ans is small
 
51
  temperature=0.2, # Controls randomness (0.0 for deterministic, 1.0 for very creative)
52
- do_sample=True # Enable sampling for more varied outputs
53
- #stop_sequence=["Human:", "AI:", "\nHuman:", "\nAI:"] #to prevent gen unwanted conversations
54
- ))
55
-
56
- template = """The following is a concise and direct conversation between a human and an AI.
57
- The AI should provide a direct answer to the human's question and strictly avoid asking any follow-up questions.
58
- The AI should not generate any additional conversational turns (e.g., "Human: ...").
59
- If the AI is asked for its name, it should respond with "I am Siddhi."
60
- If the AI does not know the answer to a question, it should truthfully state that it does not know.
61
-
62
- Current conversation:
 
 
 
 
63
  {history}
64
- Human: {input}
65
- AI:"""
 
66
 
67
  PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)
68
 
69
-
70
  # Initialize Langchain ConversationChain
71
- # verbose=True for debugging LangChain's pro
72
- conversation = ConversationChain(llm=llm, memory=memory, prompt = PROMPT, verbose=True)
73
 
74
  class QuestionRequest(BaseModel):
75
  question: str
@@ -80,42 +88,65 @@ class ChatResponse(BaseModel):
80
  @app.post("/api/generate")
81
  async def generate_text(request: QuestionRequest):
82
  async def generate_stream():
 
 
 
83
  try:
84
- # Use LangChain's .stream() method for token-by-token generation
85
- # This will yield chunks of the response as they are produced
86
  response_stream = conversation.stream({"input": request.question})
87
 
 
 
 
 
88
  for chunk in response_stream:
89
- token_content = ""
90
- # Each chunk is typically a dictionary with a 'content' key
91
- # We want to send just the new token/text back.
92
- # Ensure the chunk is stringified and followed by a newline for client parsing.
93
- # For more robust streaming, consider Server-Sent Events (SSE) format:
94
- # yield f"data: {json.dumps({'token': chunk.content})}\n\n"
95
- # For simplicity, we'll just yield the content directly for now.
96
  if 'response' in chunk:
97
- token_content= chunk['response']
98
  else:
99
- token_content= str(chunk)
100
-
101
- yield json.dumps({"content":token_content}) +"\n"
102
-
103
- await asyncio.sleep(0.01) # Small delay to allow client to process chunks
104
- #optionally send final end msg
105
- yield json.dumps({"status":"completed"}) +"\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
-
108
  except Exception as e:
109
  print("Error during streaming generation:")
110
  traceback.print_exc()
111
- # You might want to yield an error message to the client here
112
- yield f"ERROR: {str(e)}\n"
113
 
114
- # Return a StreamingResponse, which will send chunks as they are yielded by generate_stream()
115
- # media_type can be "text/event-stream" for SSE, or "text/plain" for simple newline-delimited text.
116
- # For simplicity, we'll start with "text/plain" for easier initial client parsing.
117
  return StreamingResponse(generate_stream(), media_type="application/json")
118
 
119
-
120
-
121
-
 
4
  import torch
5
  from pydantic import BaseModel
6
  import traceback
 
7
  from langchain.memory import ConversationBufferWindowMemory
8
  from langchain.chains import ConversationChain
 
 
 
 
9
  from langchain.prompts import PromptTemplate
10
+ from starlette.responses import StreamingResponse
11
+ import asyncio
12
+ import json
13
+ from langchain_community.llls import HuggingFacePipeline
14
+ import uvicorn
15
 
16
  app = FastAPI()
17
+
18
  # Get the Hugging Face API token from environment variables (BEST PRACTICE)
19
  HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
20
 
21
  if HUGGINGFACEHUB_API_TOKEN is None:
22
  raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
23
 
24
+ # --- UPDATED: Use Llama 3.1 8B Instruct model ---
25
+ model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
26
 
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
28
  model = AutoModelForCausalLM.from_pretrained(
29
+ model_id,
30
+ device_map="auto",
31
+ torch_dtype=torch.bfloat16, # torch.bfloat16 is generally good for Llama, can try torch.float16 if issues
32
+ trust_remote_code=True,
33
+ token=HUGGINGFACEHUB_API_TOKEN
34
+ )
35
 
36
  if torch.backends.mps.is_available():
37
+ device = "mps"
38
  elif torch.cuda.is_available():
39
+ device = "cuda"
40
+ else:
41
+ device = "cpu"
42
 
43
  model.to(device)
44
+
45
  # k=5 means it will keep the last 5 human-AI interaction pairs (10 messages total)
46
  memory = ConversationBufferWindowMemory(k=5)
47
 
48
  # Initialize Langchain HuggingFacePipeline
49
  llm = HuggingFacePipeline(pipeline=pipeline(
50
+ "text-generation",
51
+ model=model,
52
  tokenizer=tokenizer,
53
+ max_new_tokens=512, # Allows for longer, detailed answers when required
54
+ # --- IMPORTANT FIX: Set return_full_text to True and handle slicing manually ---
55
+ return_full_text=True,
56
  temperature=0.2, # Controls randomness (0.0 for deterministic, 1.0 for very creative)
57
+ do_sample=True, # Enable sampling for more varied outputs
58
+ # --- IMPORTANT FIX: REMOVED stop_sequence from pipeline initialization ---
59
+ # This prevents the TypeError and we handle stopping manually below.
60
+ ))
61
+
62
+ # --- UPDATED PROMPT TEMPLATE ---
63
+ # Using the recommended chat format for Llama models and explicit instructions.
64
+ template = """<|im_start|>system
65
+ You are a concise and direct AI assistant named Siddhi.
66
+ You strictly avoid asking any follow-up questions.
67
+ You do not generate any additional conversational turns (e.g., "Human: ...").
68
+ If asked for your name, you respond with "I am Siddhi."
69
+ If you do not know the answer to a question, you truthfully state that it does not know.
70
+ <|im_end|>
71
+ <|im_start|>user
72
  {history}
73
+ {input}<|im_end|>
74
+ <|im_start|>assistant
75
+ """
76
 
77
  PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)
78
 
 
79
  # Initialize Langchain ConversationChain
80
+ conversation = ConversationChain(llm=llm, memory=memory, prompt=PROMPT, verbose=True)
 
81
 
82
  class QuestionRequest(BaseModel):
83
  question: str
 
88
  @app.post("/api/generate")
89
  async def generate_text(request: QuestionRequest):
90
  async def generate_stream():
91
+ # Flag to indicate when we've started streaming the AI's actual response
92
+ started_streaming_ai_response = False
93
+
94
  try:
 
 
95
  response_stream = conversation.stream({"input": request.question})
96
 
97
+ # Define stop sequences for manual checking
98
+ stop_sequences_to_check = ["Human:", "AI:", "\nHuman:", "\nAI:", "<|im_end|>"]
99
+ assistant_start_marker = "<|im_start|>assistant\n" # Marker from the prompt template
100
+
101
  for chunk in response_stream:
102
+ full_text_chunk = ""
 
 
 
 
 
 
103
  if 'response' in chunk:
104
+ full_text_chunk = chunk['response']
105
  else:
106
+ full_text_chunk = str(chunk) # Fallback for unexpected chunk format
107
+
108
+ # Logic to extract only the AI's response from the full text chunk
109
+ if not started_streaming_ai_response:
110
+ if assistant_start_marker in full_text_chunk:
111
+ # Split the chunk at the assistant's start marker and take the part after it
112
+ token_content = full_text_chunk.split(assistant_start_marker, 1)[1]
113
+ started_streaming_ai_response = True
114
+ else:
115
+ # If the marker is not yet in the chunk, this chunk is still part of the prompt.
116
+ # We don't yield anything yet.
117
+ token_content = ""
118
+ else:
119
+ # Once we've started, all subsequent chunks are AI's response
120
+ token_content = full_text_chunk
121
+
122
+ # --- Manual stopping logic ---
123
+ # Check if the generated content contains a stop sequence.
124
+ # If it does, truncate the content and break the loop.
125
+ for stop_seq in stop_sequences_to_check:
126
+ if stop_seq in token_content:
127
+ token_content = token_content.split(stop_seq, 1)[0] # Truncate at the stop sequence
128
+ if token_content: # Yield any content before stop sequence
129
+ yield json.dumps({"content": token_content}) + "\n"
130
+ await asyncio.sleep(0.01)
131
+ yield json.dumps({"status": "completed"}) + "\n" # Signal completion
132
+ return # Exit the generator function
133
+
134
+ # Only yield if there's actual content to send after processing
135
+ if token_content:
136
+ yield json.dumps({"content": token_content}) + "\n"
137
+ await asyncio.sleep(0.01)
138
+
139
+ # Send a final completion message if the stream finishes naturally
140
+ yield json.dumps({"status": "completed"}) + "\n"
141
 
 
142
  except Exception as e:
143
  print("Error during streaming generation:")
144
  traceback.print_exc()
145
+ # Yield error message in JSON format
146
+ yield json.dumps({"error": str(e)}) + "\n"
147
 
148
+ # Return a StreamingResponse with application/json media type
 
 
149
  return StreamingResponse(generate_stream(), media_type="application/json")
150
 
151
+ if __name__ == "__main__":
152
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))