Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -66,10 +66,8 @@ class LangChainAgentWrapper:
|
|
66 |
def __init__(self):
|
67 |
print("Initializing LangChainAgentWrapper...")
|
68 |
|
69 |
-
#
|
70 |
-
# This model is generally better at following the ReAct prompt format used by LangChain agents.
|
71 |
model_id = "google/gemma-2b-it"
|
72 |
-
# model_id = "bigcode/starcoderbase-1b" # You can still use starcoder if you prefer
|
73 |
|
74 |
try:
|
75 |
hf_auth_token = os.getenv("HF_TOKEN")
|
@@ -78,28 +76,41 @@ class LangChainAgentWrapper:
|
|
78 |
else:
|
79 |
print("HF_TOKEN secret found.")
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
# Create the Hugging Face pipeline
|
82 |
-
print(f"Loading model pipeline for: {model_id}")
|
83 |
llm_pipeline = transformers.pipeline(
|
84 |
"text-generation",
|
85 |
model=model_id,
|
86 |
-
model_kwargs={"torch_dtype": "auto"},
|
87 |
-
device_map="auto",
|
88 |
token=hf_auth_token,
|
|
|
89 |
)
|
90 |
print("Model pipeline loaded successfully.")
|
91 |
|
92 |
# Wrap the pipeline in a LangChain LLM object
|
93 |
self.llm = HuggingFacePipeline(pipeline=llm_pipeline)
|
94 |
|
95 |
-
# Define the list of LangChain tools
|
96 |
self.tools = [
|
97 |
Tool(
|
98 |
name="get_current_time_in_timezone",
|
99 |
func=get_current_time_in_timezone_func,
|
100 |
description=get_current_time_in_timezone_func.__doc__
|
101 |
),
|
102 |
-
search_tool,
|
103 |
Tool(
|
104 |
name="safe_calculator",
|
105 |
func=safe_calculator_func,
|
@@ -108,8 +119,7 @@ class LangChainAgentWrapper:
|
|
108 |
]
|
109 |
print(f"Tools prepared for agent: {[tool.name for tool in self.tools]}")
|
110 |
|
111 |
-
# Create the ReAct agent prompt from a template
|
112 |
-
# The prompt is crucial for teaching the agent how to think and use tools.
|
113 |
react_prompt = PromptTemplate.from_template(
|
114 |
"""
|
115 |
You are a helpful assistant. Answer the following questions as best you can.
|
@@ -135,10 +145,10 @@ class LangChainAgentWrapper:
|
|
135 |
"""
|
136 |
)
|
137 |
|
138 |
-
# Create the agent
|
139 |
agent = create_react_agent(self.llm, self.tools, react_prompt)
|
140 |
|
141 |
-
# Create the agent executor
|
142 |
self.agent_executor = AgentExecutor(agent=agent, tools=self.tools, verbose=True, handle_parsing_errors=True)
|
143 |
print("LangChain agent created successfully.")
|
144 |
|
|
|
66 |
def __init__(self):
|
67 |
print("Initializing LangChainAgentWrapper...")
|
68 |
|
69 |
+
# We will keep using the gemma-2b-it model, but load it in 4-bit
|
|
|
70 |
model_id = "google/gemma-2b-it"
|
|
|
71 |
|
72 |
try:
|
73 |
hf_auth_token = os.getenv("HF_TOKEN")
|
|
|
76 |
else:
|
77 |
print("HF_TOKEN secret found.")
|
78 |
|
79 |
+
# --- NEW: 4-Bit Quantization Configuration ---
|
80 |
+
# Create a configuration for loading the model in 4-bit precision.
|
81 |
+
# This makes the model faster and use less memory.
|
82 |
+
print("Creating 4-bit quantization config...")
|
83 |
+
quantization_config = transformers.BitsAndBytesConfig(
|
84 |
+
load_in_4bit=True,
|
85 |
+
bnb_4bit_quant_type="nf4",
|
86 |
+
bnb_4bit_compute_dtype="bfloat16" # Use bfloat16 for faster computation
|
87 |
+
)
|
88 |
+
print("Quantization config created.")
|
89 |
+
# --- END NEW ---
|
90 |
+
|
91 |
# Create the Hugging Face pipeline
|
92 |
+
print(f"Loading model pipeline for: {model_id} with quantization")
|
93 |
llm_pipeline = transformers.pipeline(
|
94 |
"text-generation",
|
95 |
model=model_id,
|
96 |
+
model_kwargs={"torch_dtype": "auto"},
|
97 |
+
device_map="auto",
|
98 |
token=hf_auth_token,
|
99 |
+
quantization_config=quantization_config # <<< --- PASS THE NEW CONFIG HERE
|
100 |
)
|
101 |
print("Model pipeline loaded successfully.")
|
102 |
|
103 |
# Wrap the pipeline in a LangChain LLM object
|
104 |
self.llm = HuggingFacePipeline(pipeline=llm_pipeline)
|
105 |
|
106 |
+
# Define the list of LangChain tools (this part is unchanged)
|
107 |
self.tools = [
|
108 |
Tool(
|
109 |
name="get_current_time_in_timezone",
|
110 |
func=get_current_time_in_timezone_func,
|
111 |
description=get_current_time_in_timezone_func.__doc__
|
112 |
),
|
113 |
+
search_tool,
|
114 |
Tool(
|
115 |
name="safe_calculator",
|
116 |
func=safe_calculator_func,
|
|
|
119 |
]
|
120 |
print(f"Tools prepared for agent: {[tool.name for tool in self.tools]}")
|
121 |
|
122 |
+
# Create the ReAct agent prompt from a template (this part is unchanged)
|
|
|
123 |
react_prompt = PromptTemplate.from_template(
|
124 |
"""
|
125 |
You are a helpful assistant. Answer the following questions as best you can.
|
|
|
145 |
"""
|
146 |
)
|
147 |
|
148 |
+
# Create the agent (this part is unchanged)
|
149 |
agent = create_react_agent(self.llm, self.tools, react_prompt)
|
150 |
|
151 |
+
# Create the agent executor (this part is unchanged)
|
152 |
self.agent_executor = AgentExecutor(agent=agent, tools=self.tools, verbose=True, handle_parsing_errors=True)
|
153 |
print("LangChain agent created successfully.")
|
154 |
|