DragonProgrammer commited on
Commit
451d4d7
·
verified ·
1 Parent(s): d771153

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -12
app.py CHANGED
@@ -66,10 +66,8 @@ class LangChainAgentWrapper:
66
  def __init__(self):
67
  print("Initializing LangChainAgentWrapper...")
68
 
69
- # Using a newer, more capable instruction-tuned model.
70
- # This model is generally better at following the ReAct prompt format used by LangChain agents.
71
  model_id = "google/gemma-2b-it"
72
- # model_id = "bigcode/starcoderbase-1b" # You can still use starcoder if you prefer
73
 
74
  try:
75
  hf_auth_token = os.getenv("HF_TOKEN")
@@ -78,28 +76,41 @@ class LangChainAgentWrapper:
78
  else:
79
  print("HF_TOKEN secret found.")
80
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # Create the Hugging Face pipeline
82
- print(f"Loading model pipeline for: {model_id}")
83
  llm_pipeline = transformers.pipeline(
84
  "text-generation",
85
  model=model_id,
86
- model_kwargs={"torch_dtype": "auto"}, # Use "auto" for dtype
87
- device_map="auto", # Requires accelerate
88
  token=hf_auth_token,
 
89
  )
90
  print("Model pipeline loaded successfully.")
91
 
92
  # Wrap the pipeline in a LangChain LLM object
93
  self.llm = HuggingFacePipeline(pipeline=llm_pipeline)
94
 
95
- # Define the list of LangChain tools
96
  self.tools = [
97
  Tool(
98
  name="get_current_time_in_timezone",
99
  func=get_current_time_in_timezone_func,
100
  description=get_current_time_in_timezone_func.__doc__
101
  ),
102
- search_tool, # This is already a LangChain Tool instance
103
  Tool(
104
  name="safe_calculator",
105
  func=safe_calculator_func,
@@ -108,8 +119,7 @@ class LangChainAgentWrapper:
108
  ]
109
  print(f"Tools prepared for agent: {[tool.name for tool in self.tools]}")
110
 
111
- # Create the ReAct agent prompt from a template
112
- # The prompt is crucial for teaching the agent how to think and use tools.
113
  react_prompt = PromptTemplate.from_template(
114
  """
115
  You are a helpful assistant. Answer the following questions as best you can.
@@ -135,10 +145,10 @@ class LangChainAgentWrapper:
135
  """
136
  )
137
 
138
- # Create the agent
139
  agent = create_react_agent(self.llm, self.tools, react_prompt)
140
 
141
- # Create the agent executor, which runs the agent loop
142
  self.agent_executor = AgentExecutor(agent=agent, tools=self.tools, verbose=True, handle_parsing_errors=True)
143
  print("LangChain agent created successfully.")
144
 
 
66
  def __init__(self):
67
  print("Initializing LangChainAgentWrapper...")
68
 
69
+ # We will keep using the gemma-2b-it model, but load it in 4-bit
 
70
  model_id = "google/gemma-2b-it"
 
71
 
72
  try:
73
  hf_auth_token = os.getenv("HF_TOKEN")
 
76
  else:
77
  print("HF_TOKEN secret found.")
78
 
79
+ # --- NEW: 4-Bit Quantization Configuration ---
80
+ # Create a configuration for loading the model in 4-bit precision.
81
+ # This makes the model faster and use less memory.
82
+ print("Creating 4-bit quantization config...")
83
+ quantization_config = transformers.BitsAndBytesConfig(
84
+ load_in_4bit=True,
85
+ bnb_4bit_quant_type="nf4",
86
+ bnb_4bit_compute_dtype="bfloat16" # Use bfloat16 for faster computation
87
+ )
88
+ print("Quantization config created.")
89
+ # --- END NEW ---
90
+
91
  # Create the Hugging Face pipeline
92
+ print(f"Loading model pipeline for: {model_id} with quantization")
93
  llm_pipeline = transformers.pipeline(
94
  "text-generation",
95
  model=model_id,
96
+ model_kwargs={"torch_dtype": "auto"},
97
+ device_map="auto",
98
  token=hf_auth_token,
99
+ quantization_config=quantization_config # <<< --- PASS THE NEW CONFIG HERE
100
  )
101
  print("Model pipeline loaded successfully.")
102
 
103
  # Wrap the pipeline in a LangChain LLM object
104
  self.llm = HuggingFacePipeline(pipeline=llm_pipeline)
105
 
106
+ # Define the list of LangChain tools (this part is unchanged)
107
  self.tools = [
108
  Tool(
109
  name="get_current_time_in_timezone",
110
  func=get_current_time_in_timezone_func,
111
  description=get_current_time_in_timezone_func.__doc__
112
  ),
113
+ search_tool,
114
  Tool(
115
  name="safe_calculator",
116
  func=safe_calculator_func,
 
119
  ]
120
  print(f"Tools prepared for agent: {[tool.name for tool in self.tools]}")
121
 
122
+ # Create the ReAct agent prompt from a template (this part is unchanged)
 
123
  react_prompt = PromptTemplate.from_template(
124
  """
125
  You are a helpful assistant. Answer the following questions as best you can.
 
145
  """
146
  )
147
 
148
+ # Create the agent (this part is unchanged)
149
  agent = create_react_agent(self.llm, self.tools, react_prompt)
150
 
151
+ # Create the agent executor (this part is unchanged)
152
  self.agent_executor = AgentExecutor(agent=agent, tools=self.tools, verbose=True, handle_parsing_errors=True)
153
  print("LangChain agent created successfully.")
154