wishwakankanamg commited on
Commit
42c164c
·
1 Parent(s): 54868b2
Files changed (1) hide show
  1. agent.py +17 -14
agent.py CHANGED
@@ -156,16 +156,20 @@ if not hf_token:
156
 
157
  # Build graph function
158
  def build_graph(provider: str = "huggingface"):
159
- """Build the graph"""
160
-
161
- repo_id = "togethercomputer/evo-1-131k-base"
162
-
163
- if not hf_token:
164
- raise ValueError("HF_TOKEN environment variable not set. It's required for Hugging Face provider.")
165
 
166
- try:
167
- # Initialize the HuggingFaceEndpoint
168
- endpoint_llm = HuggingFaceEndpoint(
 
 
 
 
 
 
 
 
 
 
169
  repo_id=repo_id,
170
  temperature=0,
171
  huggingfacehub_api_token=hf_token,
@@ -174,13 +178,12 @@ def build_graph(provider: str = "huggingface"):
174
  # max_new_tokens=512,
175
  # model_kwargs={"top_k": 10}
176
  )
177
- # Wrap it with ChatHuggingFace for a chat interface
178
  llm = ChatHuggingFace(llm=endpoint_llm)
179
- print(f"Successfully initialized Hugging Face LLM: {repo_id}")
 
 
 
180
 
181
- except Exception as e:
182
- # This error is specific to Hugging Face LLM instantiation failing
183
- raise RuntimeError(f"Failed to initialize Hugging Face model ('{repo_id}'): {e}") from e
184
 
185
  llm_with_tools = llm.bind_tools(tools)
186
 
 
156
 
157
  # Build graph function
158
  def build_graph(provider: str = "huggingface"):
 
 
 
 
 
 
159
 
160
+ """Build the graph"""
161
+ # Load environment variables from .env file
162
+ if provider == "google":
163
+ # Google Gemini
164
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
165
+ elif provider == "groq":
166
+ # Groq https://console.groq.com/docs/models
167
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
168
+ elif provider == "huggingface":
169
+ repo_id = "togethercomputer/evo-1-131k-base"
170
+ if not hf_token:
171
+ raise ValueError("HF_TOKEN environment variable not set. It's required for Hugging Face provider.")
172
+ llm = HuggingFaceEndpoint(
173
  repo_id=repo_id,
174
  temperature=0,
175
  huggingfacehub_api_token=hf_token,
 
178
  # max_new_tokens=512,
179
  # model_kwargs={"top_k": 10}
180
  )
 
181
  llm = ChatHuggingFace(llm=endpoint_llm)
182
+ else:
183
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
184
+ # Bind tools to LLM
185
+ """Build the graph"""
186
 
 
 
 
187
 
188
  llm_with_tools = llm.bind_tools(tools)
189