dharmendra commited on
Commit
dca8b66
·
1 Parent(s): 81d2ef5

quantisation added

Browse files
Files changed (2) hide show
  1. app.py +74 -38
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
 
 
2
  from fastapi import FastAPI, HTTPException
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  import torch
5
  from pydantic import BaseModel
6
  import traceback
@@ -28,44 +30,37 @@ try:
28
  print("Successfully logged into Hugging Face Hub.")
29
  except Exception as e:
30
  print(f"Failed to log into Hugging Face Hub: {e}")
31
- # The app will likely fail to load the model if login fails, so this print is for debugging.
32
 
33
- # --- Use Mistral 7B Instruct v0.3 model ---
34
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACEHUB_API_TOKEN)
37
  model = AutoModelForCausalLM.from_pretrained(
38
  model_id,
39
- device_map="auto", # 'auto' handles device placement, including offloading
40
- torch_dtype=torch.bfloat16,
 
41
  trust_remote_code=True,
42
  token=HUGGINGFACEHUB_API_TOKEN
43
  )
44
 
45
- # --- REMOVED: model.to(device) ---
46
- # When device_map="auto" is used, accelerate handles device placement.
47
- # Manually moving the model can cause conflicts and RuntimeErrors.
48
- # if torch.backends.mps.is_available():
49
- # device = "mps"
50
- # elif torch.cuda.is_available():
51
- # device = "cuda"
52
- # else:
53
- # device = "cpu"
54
- # model.to(device) # This line is removed
55
-
56
- # k=5 means it will keep the last 5 human-AI interaction pairs (10 messages total)
57
- memory = ConversationBufferWindowMemory(k=5)
58
-
59
- # Initialize Langchain HuggingFacePipeline
60
- llm = HuggingFacePipeline(pipeline=pipeline(
61
- "text-generation",
62
- model=model,
63
- tokenizer=tokenizer,
64
- max_new_tokens=512,
65
- return_full_text=True,
66
- temperature=0.2,
67
- do_sample=True,
68
- ))
69
 
70
  # --- UPDATED PROMPT TEMPLATE ---
71
  template = """<|im_start|>system
@@ -83,21 +78,61 @@ If you do not know the answer to a question, you truthfully state that it does n
83
 
84
  PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)
85
 
86
- # Initialize Langchain ConversationChain
87
- conversation = ConversationChain(llm=llm, memory=memory, prompt=PROMPT, verbose=True)
88
-
89
  class QuestionRequest(BaseModel):
90
  question: str
 
91
 
92
  class ChatResponse(BaseModel):
93
  response: str
 
94
 
95
  @app.post("/api/generate")
96
  async def generate_text(request: QuestionRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  async def generate_stream():
 
 
 
 
 
98
  started_streaming_ai_response = False
99
 
100
  try:
 
 
 
 
101
  response_stream = conversation.stream({"input": request.question})
102
 
103
  stop_sequences_to_check = ["Human:", "AI:", "\nHuman:", "\nAI:", "<|im_end|>"]
@@ -123,22 +158,23 @@ async def generate_text(request: QuestionRequest):
123
  if stop_seq in token_content:
124
  token_content = token_content.split(stop_seq, 1)[0]
125
  if token_content:
126
- yield json.dumps({"content": token_content}) + "\n"
127
  await asyncio.sleep(0.01)
128
- yield json.dumps({"status": "completed"}) + "\n"
129
  return
130
 
131
  if token_content:
132
- yield json.dumps({"content": token_content}) + "\n"
133
  await asyncio.sleep(0.01)
134
 
135
- yield json.dumps({"status": "completed"}) + "\n"
136
 
137
  except Exception as e:
138
- print("Error during streaming generation:")
139
  traceback.print_exc()
140
- yield json.dumps({"error": str(e)}) + "\n"
141
 
 
142
  return StreamingResponse(generate_stream(), media_type="application/json")
143
 
144
  if __name__ == "__main__":
 
1
  import os
2
+ import uuid
3
+ from typing import Dict, Optional
4
  from fastapi import FastAPI, HTTPException
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig # Import BitsAndBytesConfig
6
  import torch
7
  from pydantic import BaseModel
8
  import traceback
 
30
  print("Successfully logged into Hugging Face Hub.")
31
  except Exception as e:
32
  print(f"Failed to log into Hugging Face Hub: {e}")
 
33
 
34
+ # --- Initialize tokenizer and model globally (heavy to load, shared across sessions) ---
35
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
36
 
37
+ # --- NEW: Quantization configuration for 4-bit loading, optimized for T4 ---
38
+ # This configuration tells Hugging Face Transformers to load the model weights
39
+ # in 4-bit precision using the bitsandbytes library.
40
+ bnb_config = BitsAndBytesConfig(
41
+ load_in_4bit=True, # Enable 4-bit quantization
42
+ bnb_4bit_quant_type="nf4", # Specify the quantization type: "nf4" (NormalFloat 4-bit) is recommended for transformers
43
+ # --- IMPORTANT CHANGE: Use float16 for compute dtype for T4 compatibility ---
44
+ # T4 GPUs (Turing architecture) do not have native bfloat16 support.
45
+ # Using float16 for computations is more efficient and prevents CPU offloading.
46
+ bnb_4bit_compute_dtype=torch.float16,
47
+ bnb_4bit_use_double_quant=True, # Use double quantization for slightly better quality
48
+ )
49
+
50
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACEHUB_API_TOKEN)
51
  model = AutoModelForCausalLM.from_pretrained(
52
  model_id,
53
+ device_map="auto", # 'auto' handles device placement, including offloading to CPU if necessary (but quantization aims to prevent this)
54
+ quantization_config=bnb_config, # Pass the quantization configuration here
55
+ # torch_dtype=torch.bfloat16, # REMOVED: This is now handled by bnb_4bit_compute_dtype
56
  trust_remote_code=True,
57
  token=HUGGINGFACEHUB_API_TOKEN
58
  )
59
 
60
+ # Global dictionary to store active conversation chains, keyed by session_id.
61
+ # IMPORTANT: In a production environment, this in-memory dictionary will reset
62
+ # if the server restarts. For true persistence, you would use a database (e.g., Redis, Firestore).
63
+ active_conversations: Dict[str, ConversationChain] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # --- UPDATED PROMPT TEMPLATE ---
66
  template = """<|im_start|>system
 
78
 
79
  PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)
80
 
 
 
 
81
  class QuestionRequest(BaseModel):
82
  question: str
83
+ session_id: Optional[str] = None # Optional session ID for continuing conversations
84
 
85
  class ChatResponse(BaseModel):
86
  response: str
87
+ session_id: str # Include session_id in the response for client to track
88
 
89
  @app.post("/api/generate")
90
  async def generate_text(request: QuestionRequest):
91
+ """
92
+ Handles text generation requests, maintaining conversation history per session.
93
+ """
94
+ session_id = request.session_id
95
+
96
+ # If no session_id is provided, generate a new one.
97
+ # This signifies the start of a new conversation.
98
+ if session_id is None:
99
+ session_id = str(uuid.uuid4())
100
+ print(f"Starting new conversation with session_id: {session_id}")
101
+
102
+ # Retrieve or create a ConversationChain for this session_id
103
+ if session_id not in active_conversations:
104
+ print(f"Creating new ConversationChain for session_id: {session_id}")
105
+ # Initialize Langchain HuggingFacePipeline for this session
106
+ llm = HuggingFacePipeline(pipeline=pipeline(
107
+ "text-generation",
108
+ model=model, # Use the globally loaded model
109
+ tokenizer=tokenizer, # Use the globally loaded tokenizer
110
+ max_new_tokens=512,
111
+ return_full_text=True,
112
+ temperature=0.2,
113
+ do_sample=True,
114
+ ))
115
+ # Initialize memory for this specific session
116
+ memory = ConversationBufferWindowMemory(k=5) # Remembers the last 5 human-AI interaction pairs
117
+ conversation = ConversationChain(llm=llm, memory=memory, prompt=PROMPT, verbose=True)
118
+ active_conversations[session_id] = conversation
119
+ else:
120
+ print(f"Continuing conversation for session_id: {session_id}")
121
+ conversation = active_conversations[session_id]
122
+
123
  async def generate_stream():
124
+ """
125
+ An asynchronous generator function to stream text responses token-by-token.
126
+ Each yielded item will be a JSON string representing a part of the stream.
127
+ """
128
+ # Flag to indicate when we've started streaming the AI's actual response
129
  started_streaming_ai_response = False
130
 
131
  try:
132
+ # First, send a JSON object containing the session_id.
133
+ # This allows the client to immediately get the session ID.
134
+ yield json.dumps({"type": "session_info", "session_id": session_id}) + "\n"
135
+
136
  response_stream = conversation.stream({"input": request.question})
137
 
138
  stop_sequences_to_check = ["Human:", "AI:", "\nHuman:", "\nAI:", "<|im_end|>"]
 
158
  if stop_seq in token_content:
159
  token_content = token_content.split(stop_seq, 1)[0]
160
  if token_content:
161
+ yield json.dumps({"type": "token", "content": token_content}) + "\n"
162
  await asyncio.sleep(0.01)
163
+ yield json.dumps({"type": "end", "status": "completed", "session_id": session_id}) + "\n"
164
  return
165
 
166
  if token_content:
167
+ yield json.dumps({"type": "token", "content": token_content}) + "\n"
168
  await asyncio.sleep(0.01)
169
 
170
+ yield json.dumps({"type": "end", "status": "completed", "session_id": session_id}) + "\n"
171
 
172
  except Exception as e:
173
+ print(f"Error during streaming generation for session {session_id}:")
174
  traceback.print_exc()
175
+ yield json.dumps({"type": "error", "message": str(e), "session_id": session_id}) + "\n"
176
 
177
+ # Return a StreamingResponse with application/json media type
178
  return StreamingResponse(generate_stream(), media_type="application/json")
179
 
180
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -69,3 +69,4 @@ uvicorn==0.34.0
69
  yarl==1.19.0
70
  zstandard==0.23.0
71
  protobuf
 
 
69
  yarl==1.19.0
70
  zstandard==0.23.0
71
  protobuf
72
+ bitsandbytes==0.43.0