Spaces:
Running
Running
Update eb_agent_module.py
Browse files- eb_agent_module.py +191 -218
eb_agent_module.py
CHANGED
@@ -11,24 +11,68 @@ import textwrap
|
|
11 |
try:
|
12 |
from google import genai
|
13 |
from google.genai import types as genai_types
|
14 |
-
# from google.api_core import retry_async # For async retries if needed
|
15 |
except ImportError:
|
16 |
print("Google Generative AI library not found. Please install it: pip install google-generativeai")
|
17 |
# Define dummy classes/functions if the import fails, to allow the rest of the script to be parsed
|
18 |
class genai: # type: ignore
|
19 |
@staticmethod
|
20 |
def configure(api_key): pass
|
|
|
|
|
|
|
21 |
@staticmethod
|
22 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
@staticmethod
|
24 |
-
def embed_content(model, content, task_type, title=None):
|
|
|
|
|
25 |
|
26 |
class genai_types: # type: ignore
|
27 |
@staticmethod
|
28 |
-
def GenerateContentConfig(**kwargs): return
|
29 |
-
class BlockReason:
|
30 |
SAFETY = "SAFETY"
|
31 |
-
class HarmCategory:
|
32 |
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"
|
33 |
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
|
34 |
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
|
@@ -39,62 +83,55 @@ except ImportError:
|
|
39 |
|
40 |
|
41 |
# --- Configuration ---
|
42 |
-
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
|
43 |
-
LLM_MODEL_NAME = "gemini-2.0-flash"
|
44 |
-
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07" #
|
45 |
|
46 |
# Generation configuration for the LLM
|
47 |
GENERATION_CONFIG_PARAMS = {
|
48 |
"temperature": 0.2,
|
49 |
"top_p": 1.0,
|
50 |
"top_k": 32,
|
51 |
-
"max_output_tokens": 4096,
|
52 |
}
|
53 |
|
54 |
# Safety settings for Gemini
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
# Logging setup
|
64 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
|
65 |
|
66 |
-
#
|
67 |
if GEMINI_API_KEY:
|
68 |
try:
|
69 |
genai.configure(api_key=GEMINI_API_KEY)
|
70 |
-
logging.info(f"Gemini API key configured. Target model for generation: '{LLM_MODEL_NAME}', Embedding model: '{GEMINI_EMBEDDING_MODEL_NAME}'")
|
71 |
except Exception as e:
|
72 |
-
logging.error(f"Failed to configure Gemini API: {e}", exc_info=True)
|
73 |
else:
|
74 |
logging.warning("GEMINI_API_KEY environment variable not set. LLM and Embedding functionalities will be limited.")
|
75 |
|
76 |
|
77 |
-
# --- RAG Documents Definition
|
78 |
-
# This will be used by the AdvancedRAGSystem.
|
79 |
-
# You can replace this with more relevant documents for your LinkedIn dashboard context if needed.
|
80 |
rag_documents_data = {
|
81 |
'Title': [
|
82 |
-
"Employer Branding Best Practices 2024",
|
83 |
-
"
|
84 |
-
"Understanding Company Culture for Talent Acquisition",
|
85 |
-
"Diversity and Inclusion in Modern Hiring Processes",
|
86 |
-
"Leveraging LinkedIn Data for Recruitment Insights",
|
87 |
-
"Analyzing Employee Engagement Metrics",
|
88 |
-
"Content Strategies for LinkedIn Company Pages"
|
89 |
],
|
90 |
'Text': [
|
91 |
-
"Focus on authentic employee stories
|
92 |
-
"
|
93 |
-
"Company culture is defined by shared values, beliefs, and behaviors. It's crucial for attracting and retaining talent that aligns with the organization. Assess culture through employee surveys, feedback sessions, and by observing daily interactions. Promote a positive culture actively.",
|
94 |
-
"Promote diversity and inclusion by using inclusive language in job descriptions, ensuring diverse interview panels, and highlighting D&I initiatives. Track diversity metrics and be transparent about your goals and progress. An inclusive culture boosts innovation.",
|
95 |
-
"LinkedIn data provides rich insights into talent pools, competitor strategies, and industry trends. Analyze follower demographics, content engagement, and employee advocacy to refine your employer branding and recruitment efforts. Use LinkedIn Analytics effectively.",
|
96 |
-
"High employee engagement correlates with better retention and productivity. Key metrics include employee Net Promoter Score (eNPS), satisfaction surveys, and participation in company initiatives. Address feedback promptly to foster a positive work environment.",
|
97 |
-
"Develop a content calendar for your LinkedIn Company Page that includes a mix of thought leadership, company news, employee spotlights, job postings, and industry insights. Use visuals and videos to increase engagement. Analyze post performance to optimize your strategy."
|
98 |
]
|
99 |
}
|
100 |
df_rag_documents = pd.DataFrame(rag_documents_data)
|
@@ -102,31 +139,25 @@ df_rag_documents = pd.DataFrame(rag_documents_data)
|
|
102 |
|
103 |
# --- Schema Representation ---
|
104 |
def get_schema_representation(df_name: str, df: pd.DataFrame) -> str:
|
105 |
-
"""Generates a string representation of a DataFrame's schema and a sample of its data."""
|
106 |
if df.empty:
|
107 |
return f"Schema for DataFrame '{df_name}':\n - DataFrame is empty.\n"
|
108 |
-
|
109 |
cols = df.columns.tolist()
|
110 |
dtypes = df.dtypes.to_dict()
|
111 |
-
schema_str = f"Schema for DataFrame 'df_{df_name}':\n"
|
112 |
for col in cols:
|
113 |
schema_str += f" - Column '{col}': {dtypes[col]}\n"
|
114 |
-
|
115 |
-
# Add notes for complex data types or common pitfalls
|
116 |
for col in cols:
|
117 |
if 'date' in col.lower() or 'time' in col.lower():
|
118 |
-
schema_str += f" - Note: Column '{col}' seems to be date/time related
|
119 |
if df[col].apply(type).eq(list).any() or df[col].apply(type).eq(dict).any():
|
120 |
-
schema_str += f" - Note: Column '{col}' may contain list-like or dict-like data
|
121 |
-
if df[col].dtype == 'object' and df[col].nunique() < 20 and df.shape[0] > 20:
|
122 |
-
schema_str += f" - Note: Column '{col}'
|
123 |
-
|
124 |
schema_str += f"Sample of first 2 rows of 'df_{df_name}':\n{df.head(2).to_string()}\n"
|
125 |
return schema_str
|
126 |
|
127 |
def get_all_schemas_representation(dataframes_dict: dict) -> str:
|
128 |
-
|
129 |
-
full_schema_str = "You have access to the following Pandas DataFrames. In your Python code, refer to them with the 'df_' prefix (e.g., df_follower_stats, df_posts).\n\n"
|
130 |
for name, df_instance in dataframes_dict.items():
|
131 |
full_schema_str += get_schema_representation(name, df_instance) + "\n"
|
132 |
return full_schema_str
|
@@ -135,36 +166,42 @@ def get_all_schemas_representation(dataframes_dict: dict) -> str:
|
|
135 |
# --- Advanced RAG System ---
|
136 |
class AdvancedRAGSystem:
|
137 |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
|
|
|
138 |
if not GEMINI_API_KEY:
|
139 |
logging.warning("RAG System: GEMINI_API_KEY not set. Embeddings will not be generated.")
|
140 |
-
self.documents_df = documents_df.copy()
|
141 |
if 'Embeddings' not in self.documents_df.columns:
|
142 |
self.documents_df['Embeddings'] = pd.Series(dtype='object')
|
143 |
-
self.embedding_model_name = embedding_model_name
|
144 |
self.embeddings_generated = False
|
145 |
return
|
146 |
|
147 |
self.documents_df = documents_df.copy()
|
148 |
-
self.embedding_model_name = embedding_model_name
|
149 |
self.embeddings_generated = False
|
150 |
try:
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
except Exception as e:
|
155 |
logging.error(f"Error during RAG embedding precomputation: {e}", exc_info=True)
|
156 |
if 'Embeddings' not in self.documents_df.columns:
|
157 |
self.documents_df['Embeddings'] = pd.Series(dtype='object')
|
158 |
|
159 |
-
|
160 |
def _embed_fn(self, title: str, text: str) -> list[float]:
|
161 |
try:
|
162 |
-
if
|
163 |
-
|
164 |
-
|
|
|
165 |
|
166 |
embedding_result = genai.embed_content(
|
167 |
-
model=self.embedding_model_name,
|
168 |
content=text,
|
169 |
task_type="retrieval_document",
|
170 |
title=title
|
@@ -177,145 +214,143 @@ class AdvancedRAGSystem:
|
|
177 |
def _precompute_embeddings(self):
|
178 |
if 'Embeddings' not in self.documents_df.columns:
|
179 |
self.documents_df['Embeddings'] = pd.Series(dtype='object')
|
180 |
-
|
181 |
-
# Only compute for rows where 'Embeddings' is None, not a list, or a zero vector
|
182 |
for index, row in self.documents_df.iterrows():
|
183 |
current_embedding = row['Embeddings']
|
184 |
-
is_valid_embedding = isinstance(current_embedding, list) and len(current_embedding) > 0 and sum(abs(x) for x in current_embedding) > 1e-6
|
185 |
-
|
186 |
if not is_valid_embedding:
|
187 |
self.documents_df.at[index, 'Embeddings'] = self._embed_fn(row['Title'], row['Text'])
|
188 |
-
logging.info("Embeddings precomputation finished.")
|
189 |
-
|
190 |
|
191 |
def retrieve_relevant_info(self, query_text: str, top_k: int = 2) -> str:
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
194 |
return "\n[RAG Context]\nNo specific pre-defined context found (RAG system inactive or no embeddings).\n"
|
195 |
-
|
196 |
try:
|
197 |
query_embedding_result = genai.embed_content(
|
198 |
-
model=self.embedding_model_name,
|
199 |
content=query_text,
|
200 |
task_type="retrieval_query"
|
201 |
)
|
202 |
query_embedding = np.array(query_embedding_result["embedding"])
|
203 |
-
|
204 |
valid_embeddings_df = self.documents_df.dropna(subset=['Embeddings'])
|
205 |
valid_embeddings_df = valid_embeddings_df[valid_embeddings_df['Embeddings'].apply(lambda x: isinstance(x, list) and len(x) > 0 and sum(abs(val) for val in x) > 1e-6)]
|
206 |
-
|
207 |
-
|
208 |
if valid_embeddings_df.empty:
|
209 |
-
logging.warning("No valid document embeddings found for RAG.")
|
210 |
return "\n[RAG Context]\nNo valid document embeddings available for retrieval.\n"
|
211 |
-
|
212 |
document_embeddings = np.stack(valid_embeddings_df['Embeddings'].apply(np.array).values)
|
213 |
-
|
214 |
if query_embedding.shape[0] != document_embeddings.shape[1]:
|
215 |
-
logging.error(f"Query embedding dim ({query_embedding.shape[0]}) != Document embedding dim ({document_embeddings.shape[1]})")
|
216 |
return "\n[RAG Context]\nEmbedding dimension mismatch.\n"
|
217 |
-
|
218 |
dot_products = np.dot(document_embeddings, query_embedding)
|
219 |
-
|
220 |
-
# Get top_k indices, ensure top_k is not greater than available docs
|
221 |
num_available_docs = len(valid_embeddings_df)
|
222 |
actual_top_k = min(top_k, num_available_docs)
|
223 |
-
|
224 |
-
if actual_top_k == 0:
|
225 |
-
return "\n[RAG Context]\nNo documents to retrieve from.\n"
|
226 |
-
|
227 |
-
if actual_top_k == 1 and num_available_docs > 0:
|
228 |
-
idx = [np.argmax(dot_products)]
|
229 |
-
elif num_available_docs > 0 :
|
230 |
-
idx = np.argsort(dot_products)[-actual_top_k:][::-1]
|
231 |
-
else: # Should not happen if actual_top_k > 0
|
232 |
-
idx = []
|
233 |
-
|
234 |
-
|
235 |
relevant_passages = ""
|
236 |
-
for
|
237 |
-
passage_title = valid_embeddings_df.iloc[
|
238 |
-
passage_text = valid_embeddings_df.iloc[
|
239 |
relevant_passages += f"\n[RAG Context from: '{passage_title}']\n{passage_text}\n"
|
240 |
-
|
241 |
-
logging.info(f"RAG System retrieved: {relevant_passages[:200]}...")
|
242 |
return relevant_passages if relevant_passages else "\n[RAG Context]\nNo highly relevant passages found.\n"
|
243 |
-
|
244 |
except Exception as e:
|
245 |
logging.error(f"Error retrieving relevant info from RAG: {e}", exc_info=True)
|
246 |
return f"\n[RAG Context]\nError during RAG retrieval: {str(e)}\n"
|
247 |
|
248 |
-
|
249 |
# --- PandasLLM Class (Gemini-Powered) ---
|
250 |
class PandasLLM:
|
251 |
-
def __init__(self, llm_model_name: str, generation_config_params: dict,
|
|
|
|
|
252 |
self.llm_model_name = llm_model_name
|
253 |
self.generation_config_params = generation_config_params
|
254 |
-
self.safety_settings = safety_settings
|
255 |
self.data_privacy = data_privacy
|
256 |
-
self.force_sandbox = force_sandbox
|
|
|
|
|
257 |
|
258 |
if not GEMINI_API_KEY:
|
259 |
logging.warning("PandasLLM: GEMINI_API_KEY not set. LLM functionalities will be limited.")
|
260 |
-
self.model = None
|
261 |
else:
|
262 |
try:
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
except Exception as e:
|
269 |
-
logging.error(f"Failed to initialize
|
270 |
-
|
271 |
|
272 |
async def _call_gemini_api_async(self, prompt_text: str, history: list = None) -> str:
|
273 |
-
if not self.
|
274 |
-
logging.error("PandasLLM:
|
275 |
-
return "# Error: Gemini
|
276 |
|
277 |
-
contents_for_api = [
|
278 |
-
if history:
|
279 |
-
formatted_history = []
|
280 |
for entry in history:
|
281 |
-
role = entry.get("role", "user")
|
282 |
-
if role == "assistant": role = "model"
|
283 |
-
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
|
287 |
-
|
288 |
-
gen_config_obj = genai_types.GenerateContentConfig(**self.generation_config_params)
|
289 |
-
except Exception as e:
|
290 |
-
logging.error(f"Error creating GenerateContentConfig: {e}. Using dict directly.")
|
291 |
-
gen_config_obj = self.generation_config_params
|
292 |
-
|
293 |
-
logging.info(f"\n--- Calling Gemini API with prompt (first 500 chars of last message): ---\n{contents_for_api[-1]['parts'][0]['text'][:500]}...\n-------------------------------------------------------\n")
|
294 |
|
295 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
response = await asyncio.to_thread(
|
297 |
-
self.
|
298 |
-
|
299 |
-
|
|
|
|
|
300 |
)
|
301 |
|
302 |
-
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
303 |
reason = response.prompt_feedback.block_reason
|
304 |
reason_name = getattr(reason, 'name', str(reason))
|
305 |
logging.warning(f"Gemini API call blocked by prompt feedback: {reason_name}")
|
306 |
return f"# Error: Prompt blocked due to content policy: {reason_name}."
|
307 |
|
308 |
llm_output = ""
|
309 |
-
if hasattr(response, 'text') and response.text:
|
310 |
llm_output = response.text
|
311 |
-
elif response.candidates:
|
312 |
candidate = response.candidates[0]
|
313 |
-
if candidate.content and candidate.content.parts:
|
314 |
llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text'))
|
315 |
|
316 |
-
if not llm_output:
|
317 |
finish_reason_val = candidate.finish_reason
|
318 |
-
finish_reason = getattr(finish_reason_val, 'name', str(finish_reason_val))
|
319 |
logging.warning(f"No text content in response candidate. Finish reason: {finish_reason}")
|
320 |
if finish_reason == "SAFETY":
|
321 |
return f"# Error: Response generation stopped due to safety reasons ({finish_reason})."
|
@@ -323,31 +358,29 @@ class PandasLLM:
|
|
323 |
return f"# Error: Response generation stopped due to recitation policy ({finish_reason})."
|
324 |
return f"# Error: The AI model returned an empty response. Finish reason: {finish_reason}."
|
325 |
else:
|
326 |
-
logging.warning("Gemini API response structure not recognized or empty.")
|
327 |
return "# Error: The AI model returned an unexpected or empty response structure."
|
328 |
|
329 |
logging.info(f"--- Gemini API Response (first 300 chars): ---\n{llm_output[:300]}...\n--------------------------------------------------\n")
|
330 |
return llm_output
|
331 |
|
332 |
except AttributeError as ae:
|
333 |
-
logging.error(f"AttributeError during Gemini call
|
334 |
-
return f"# Error (Attribute): {type(ae).__name__} - {ae}. Check
|
335 |
except Exception as e:
|
336 |
-
logging.error(f"Error calling Gemini API: {e}", exc_info=True)
|
337 |
if "API_KEY_INVALID" in str(e) or "API key not valid" in str(e):
|
338 |
-
return "# Error: Gemini API key is not valid.
|
339 |
-
if "400" in str(e) and "model" in str(e).lower() and ("not found" in str(e).lower() or "does not exist" in str(e).lower()):
|
340 |
-
return f"# Error: Gemini Model '{self.llm_model_name}' not found or not accessible with your API key. Check model name and permissions."
|
341 |
-
if "DeadlineExceeded" in str(e) or "504" in str(e):
|
342 |
-
return "# Error: The request to Gemini API timed out. Please try again later."
|
343 |
if "PermissionDenied" in str(e) or "403" in str(e):
|
344 |
-
return "# Error: Permission denied
|
345 |
-
|
|
|
|
|
|
|
346 |
|
347 |
|
348 |
async def query(self, prompt_with_query_and_context: str, dataframes_dict: dict, history: list = None) -> str:
|
349 |
llm_response_text = await self._call_gemini_api_async(prompt_with_query_and_context, history)
|
350 |
-
|
351 |
if self.force_sandbox:
|
352 |
code_to_execute = ""
|
353 |
if "```python" in llm_response_text:
|
@@ -361,54 +394,39 @@ class PandasLLM:
|
|
361 |
except IndexError:
|
362 |
code_to_execute = ""
|
363 |
logging.warning("Could not extract Python code using primary or secondary split method.")
|
364 |
-
|
365 |
-
llm_response_text_for_sandbox_error = "" # Initialize this variable
|
366 |
if llm_response_text.startswith("# Error:") or not code_to_execute:
|
367 |
-
error_prefix = "LLM did not return
|
368 |
if llm_response_text.startswith("# Error:"): error_prefix = "An error occurred during LLM call."
|
369 |
elif not code_to_execute: error_prefix = "Could not extract Python code from LLM response."
|
370 |
-
|
371 |
safe_llm_response = str(llm_response_text).replace("'''", "'").replace('"""', '"')
|
372 |
llm_response_text_for_sandbox_error = f"print(f'''{error_prefix}\\nRaw LLM Response (may be truncated):\\n{safe_llm_response[:1000]}''')"
|
373 |
logging.warning(f"Problem with LLM response for sandbox: {error_prefix}")
|
374 |
-
|
375 |
logging.info(f"\n--- Code to Execute (from LLM, if sandbox): ---\n{code_to_execute}\n------------------------------------------------\n")
|
376 |
-
|
377 |
-
# --- THIS IS THE CORRECTED SECTION ---
|
378 |
-
# In the exec environment, __builtins__ is a dict.
|
379 |
-
# We iterate over its items directly.
|
380 |
safe_builtins = {}
|
381 |
if isinstance(__builtins__, dict):
|
382 |
safe_builtins = {name: obj for name, obj in __builtins__.items() if not name.startswith('_')}
|
383 |
-
else:
|
384 |
safe_builtins = {name: obj for name, obj in __builtins__.__dict__.items() if not name.startswith('_')}
|
385 |
-
# --- END OF CORRECTION ---
|
386 |
-
|
387 |
unsafe_builtins = ['eval', 'exec', 'open', 'compile', 'input', 'memoryview', 'vars', 'globals', 'locals', '__import__']
|
388 |
for ub in unsafe_builtins:
|
389 |
safe_builtins.pop(ub, None)
|
390 |
-
|
391 |
exec_globals = {'pd': pd, 'np': np, '__builtins__': safe_builtins}
|
392 |
for name, df_instance in dataframes_dict.items():
|
393 |
exec_globals[f"df_{name}"] = df_instance
|
394 |
-
|
395 |
from io import StringIO
|
396 |
import sys
|
397 |
old_stdout = sys.stdout
|
398 |
sys.stdout = captured_output = StringIO()
|
399 |
-
|
400 |
final_output_str = ""
|
401 |
try:
|
402 |
if code_to_execute:
|
403 |
exec(code_to_execute, exec_globals, {})
|
404 |
output_val = captured_output.getvalue()
|
405 |
final_output_str = output_val if output_val else "# Code executed successfully, but no explicit print() output was generated by the code."
|
406 |
-
logging.info(f"--- Sandbox Execution Output: ---\n{final_output_str}\n-------------------------\n")
|
407 |
else:
|
408 |
exec(llm_response_text_for_sandbox_error, exec_globals, {})
|
409 |
final_output_str = captured_output.getvalue()
|
410 |
-
logging.warning(f"--- Sandbox Fallback Output (No Code Executed): ---\n{final_output_str}\n-------------------------\n")
|
411 |
-
|
412 |
except Exception as e:
|
413 |
error_msg = f"# Error executing LLM-generated code:\n# {type(e).__name__}: {str(e)}\n# --- Code that caused error: ---\n{textwrap.indent(code_to_execute, '# ')}"
|
414 |
final_output_str = error_msg
|
@@ -419,7 +437,6 @@ class PandasLLM:
|
|
419 |
else:
|
420 |
return llm_response_text
|
421 |
|
422 |
-
|
423 |
# --- Employer Branding Agent ---
|
424 |
class EmployerBrandingAgent:
|
425 |
def __init__(self, llm_model_name: str, generation_config_params: dict, safety_settings: dict,
|
@@ -433,77 +450,33 @@ class EmployerBrandingAgent:
|
|
433 |
logging.info("EmployerBrandingAgent Initialized.")
|
434 |
|
435 |
def _build_prompt(self, user_query: str, role="Employer Branding Analyst", task_decomposition_hint=None, cot_hint=True) -> str:
|
436 |
-
prompt = f"You are a helpful and expert '{role}'
|
437 |
-
|
438 |
-
|
439 |
-
if self.pandas_llm.data_privacy:
|
440 |
-
prompt += "IMPORTANT: Be mindful of data privacy. Do not output raw Personally Identifiable Information (PII) like names or specific user details unless explicitly asked and absolutely necessary for the query. Summarize or aggregate data where possible.\n"
|
441 |
-
|
442 |
-
if self.pandas_llm.force_sandbox:
|
443 |
-
prompt += "Your main task is to GENERATE PYTHON CODE using the Pandas library to answer the user query based on the provided DataFrames. Output ONLY the Python code block.\n"
|
444 |
-
prompt += "The available DataFrames are already loaded and can be accessed by their dictionary keys prefixed with 'df_' (e.g., df_follower_stats, df_posts) within the execution environment.\n"
|
445 |
-
prompt += "Example of accessing a DataFrame: `df_follower_stats['country']`.\n"
|
446 |
-
prompt += "Your Python code MUST include `print()` statements for any results, DataFrames, or values you want to display. The output of these print statements will be the final answer.\n"
|
447 |
-
prompt += "If a column contains lists (e.g., 'skills' in a hypothetical 'df_employees'), you might need to use methods like `.explode()` or `.apply(pd.Series)` or `.apply(lambda x: ...)` for analysis.\n"
|
448 |
-
prompt += "If the query is ambiguous or requires clarification, ask for it instead of making assumptions. If the query cannot be answered with the given data, state that clearly.\n"
|
449 |
-
prompt += "If the query is not about data analysis or code generation (e.g. 'hello', 'how are you?'), respond politely and briefly, do not attempt to generate code.\n"
|
450 |
-
prompt += "Structure your code clearly. Add comments (#) to explain each step of your logic.\n"
|
451 |
-
else:
|
452 |
-
prompt += "Your task is to analyze the data and provide a comprehensive textual answer to the user query. You can explain your reasoning step-by-step.\n"
|
453 |
-
|
454 |
prompt += "\n--- AVAILABLE DATA AND SCHEMAS ---\n"
|
455 |
prompt += self.schemas_representation
|
456 |
-
|
457 |
rag_context = self.rag_system.retrieve_relevant_info(user_query)
|
458 |
if rag_context and "[RAG Context]" in rag_context and "No specific pre-defined context found" not in rag_context and "No highly relevant passages found" not in rag_context:
|
459 |
prompt += f"\n--- ADDITIONAL CONTEXT (from internal knowledge base, consider this information) ---\n{rag_context}\n"
|
460 |
-
|
461 |
prompt += f"\n--- USER QUERY ---\n{user_query}\n"
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
if cot_hint:
|
467 |
-
if self.pandas_llm.force_sandbox:
|
468 |
-
prompt += "\n--- INSTRUCTIONS FOR PYTHON CODE GENERATION (Chain of Thought) ---\n"
|
469 |
-
prompt += "1. Understand the query: What specific information is requested?\n"
|
470 |
-
prompt += "2. Identify relevant DataFrame(s) and column(s) from the schemas provided.\n"
|
471 |
-
prompt += "3. Plan the steps: Outline the Pandas operations needed (filtering, grouping, aggregation, merging, etc.) as comments in your code.\n"
|
472 |
-
prompt += "4. Write the code: Implement the steps using Pandas. Remember to use `df_name_of_dataframe` (e.g. `df_follower_stats`).\n"
|
473 |
-
prompt += "5. Ensure output: Use `print()` for all results that should be displayed. For DataFrames, you can print the DataFrame directly, or `df.to_string()` if it's large.\n"
|
474 |
-
prompt += "6. Review: Check for correctness, efficiency, and adherence to the prompt (especially the `print()` requirement).\n"
|
475 |
-
prompt += "7. Generate ONLY the Python code block starting with ```python and ending with ```. No explanations outside the code block's comments.\n"
|
476 |
-
else:
|
477 |
-
prompt += "\n--- INSTRUCTIONS FOR RESPONSE (Chain of Thought) ---\n"
|
478 |
-
prompt += "1. Understand the query fully.\n"
|
479 |
-
prompt += "2. Identify the relevant data sources (DataFrames and columns).\n"
|
480 |
-
prompt += "3. Explain your analytical approach step-by-step.\n"
|
481 |
-
prompt += "4. Perform the analysis (mentally or by outlining the steps).\n"
|
482 |
-
prompt += "5. Present the findings clearly and concisely. If you performed calculations, show or describe them.\n"
|
483 |
-
prompt += "6. If applicable, incorporate insights from the 'ADDITIONAL CONTEXT' (RAG system).\n"
|
484 |
-
|
485 |
return prompt
|
486 |
|
487 |
async def process_query(self, user_query: str, role="Employer Branding Analyst", task_decomposition_hint=None, cot_hint=True) -> str:
|
488 |
-
logging.info(f"\n=== Processing Query for Role: {role} ===")
|
489 |
-
logging.info(f"User Query: {user_query}")
|
490 |
-
|
491 |
self.chat_history.append({"role": "user", "content": user_query})
|
492 |
-
|
493 |
full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint)
|
494 |
-
|
495 |
response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=self.chat_history[:-1])
|
496 |
-
|
497 |
self.chat_history.append({"role": "assistant", "content": response_text})
|
498 |
-
|
499 |
MAX_HISTORY_TURNS = 5
|
500 |
if len(self.chat_history) > MAX_HISTORY_TURNS * 2:
|
501 |
self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):]
|
502 |
-
|
503 |
return response_text
|
504 |
|
505 |
def update_dataframes(self, new_dataframes: dict):
|
506 |
-
"""Updates the agent's DataFrames and their schema representation."""
|
507 |
self.all_dataframes = new_dataframes
|
508 |
self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
|
509 |
logging.info("EmployerBrandingAgent DataFrames updated.")
|
|
|
11 |
try:
|
12 |
from google import genai
|
13 |
from google.genai import types as genai_types
|
|
|
14 |
except ImportError:
|
15 |
print("Google Generative AI library not found. Please install it: pip install google-generativeai")
|
16 |
# Define dummy classes/functions if the import fails, to allow the rest of the script to be parsed
|
17 |
class genai: # type: ignore
|
18 |
@staticmethod
|
19 |
def configure(api_key): pass
|
20 |
+
|
21 |
+
# Making dummy Client return a dummy client object that has a dummy 'models' attribute
|
22 |
+
# which in turn has a dummy 'generate_content' method.
|
23 |
@staticmethod
|
24 |
+
def Client(api_key=None): # api_key can be optional if configure is used
|
25 |
+
class DummyModels:
|
26 |
+
@staticmethod
|
27 |
+
def generate_content(model=None, contents=None, generation_config=None, safety_settings=None):
|
28 |
+
print(f"Dummy genai.Client.models.generate_content called for model: {model}")
|
29 |
+
# Simulate a minimal valid-looking response structure
|
30 |
+
class DummyPart:
|
31 |
+
def __init__(self, text):
|
32 |
+
self.text = text
|
33 |
+
class DummyContent:
|
34 |
+
def __init__(self):
|
35 |
+
self.parts = [DummyPart("# Dummy response from dummy client")]
|
36 |
+
class DummyCandidate:
|
37 |
+
def __init__(self):
|
38 |
+
self.content = DummyContent()
|
39 |
+
self.finish_reason = "DUMMY"
|
40 |
+
self.safety_ratings = []
|
41 |
+
class DummyResponse:
|
42 |
+
def __init__(self):
|
43 |
+
self.candidates = [DummyCandidate()]
|
44 |
+
self.prompt_feedback = None
|
45 |
+
@property
|
46 |
+
def text(self): # Add a text property for compatibility
|
47 |
+
if self.candidates and self.candidates[0].content and self.candidates[0].content.parts:
|
48 |
+
return "".join(p.text for p in self.candidates[0].content.parts)
|
49 |
+
return ""
|
50 |
+
return DummyResponse()
|
51 |
+
|
52 |
+
class DummyClient:
|
53 |
+
def __init__(self):
|
54 |
+
self.models = DummyModels()
|
55 |
+
|
56 |
+
if api_key: # Only return a DummyClient if api_key is provided, mimicking real client
|
57 |
+
return DummyClient()
|
58 |
+
return None # If no API key, client init might fail or return None
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def GenerativeModel(model_name): # Keep dummy GenerativeModel for other parts if any
|
62 |
+
print(f"Dummy genai.GenerativeModel called for model: {model_name}")
|
63 |
+
return None
|
64 |
+
|
65 |
@staticmethod
|
66 |
+
def embed_content(model, content, task_type, title=None):
|
67 |
+
print(f"Dummy genai.embed_content called for model: {model}")
|
68 |
+
return {"embedding": [0.1] * 768}
|
69 |
|
70 |
class genai_types: # type: ignore
|
71 |
@staticmethod
|
72 |
+
def GenerateContentConfig(**kwargs): return kwargs # Return the dict itself for dummy
|
73 |
+
class BlockReason:
|
74 |
SAFETY = "SAFETY"
|
75 |
+
class HarmCategory:
|
76 |
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"
|
77 |
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
|
78 |
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
|
|
|
83 |
|
84 |
|
85 |
# --- Configuration ---
|
86 |
+
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
|
87 |
+
LLM_MODEL_NAME = "gemini-2.0-flash" # Updated model name
|
88 |
+
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07" # Updated embedding model name
|
89 |
|
90 |
# Generation configuration for the LLM
|
91 |
GENERATION_CONFIG_PARAMS = {
|
92 |
"temperature": 0.2,
|
93 |
"top_p": 1.0,
|
94 |
"top_k": 32,
|
95 |
+
"max_output_tokens": 4096,
|
96 |
}
|
97 |
|
98 |
# Safety settings for Gemini
|
99 |
+
# Ensure genai_types is the real one or the dummy has these attributes
|
100 |
+
try:
|
101 |
+
DEFAULT_SAFETY_SETTINGS = {
|
102 |
+
genai_types.HarmCategory.HARM_CATEGORY_HARASSMENT: genai_types.HarmBlockThreshold.BLOCK_NONE,
|
103 |
+
genai_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: genai_types.HarmBlockThreshold.BLOCK_NONE,
|
104 |
+
genai_types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai_types.HarmBlockThreshold.BLOCK_NONE,
|
105 |
+
genai_types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai_types.HarmBlockThreshold.BLOCK_NONE,
|
106 |
+
}
|
107 |
+
except AttributeError: # If genai_types is the dummy and doesn't have these, create placeholder
|
108 |
+
logging.warning("Could not define DEFAULT_SAFETY_SETTINGS using genai_types. Using placeholder.")
|
109 |
+
DEFAULT_SAFETY_SETTINGS = {}
|
110 |
|
111 |
|
112 |
# Logging setup
|
113 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
|
114 |
|
115 |
+
# Configure Gemini API key globally if available
|
116 |
if GEMINI_API_KEY:
|
117 |
try:
|
118 |
genai.configure(api_key=GEMINI_API_KEY)
|
119 |
+
logging.info(f"Gemini API key configured globally. Target model for generation: '{LLM_MODEL_NAME}', Embedding model: '{GEMINI_EMBEDDING_MODEL_NAME}'")
|
120 |
except Exception as e:
|
121 |
+
logging.error(f"Failed to configure Gemini API globally: {e}", exc_info=True)
|
122 |
else:
|
123 |
logging.warning("GEMINI_API_KEY environment variable not set. LLM and Embedding functionalities will be limited.")
|
124 |
|
125 |
|
126 |
+
# --- RAG Documents Definition ---
|
|
|
|
|
127 |
rag_documents_data = {
|
128 |
'Title': [
|
129 |
+
"Employer Branding Best Practices 2024", "Attracting Tech Talent",
|
130 |
+
"Understanding Company Culture", "Diversity and Inclusion in Hiring"
|
|
|
|
|
|
|
|
|
|
|
131 |
],
|
132 |
'Text': [
|
133 |
+
"Focus on authentic employee stories...", "Tech candidates value challenging projects...",
|
134 |
+
"Company culture is defined by shared values...", "Promote diversity and inclusion by using inclusive language..."
|
|
|
|
|
|
|
|
|
|
|
135 |
]
|
136 |
}
|
137 |
df_rag_documents = pd.DataFrame(rag_documents_data)
|
|
|
139 |
|
140 |
# --- Schema Representation ---
|
141 |
def get_schema_representation(df_name: str, df: pd.DataFrame) -> str:
|
|
|
142 |
if df.empty:
|
143 |
return f"Schema for DataFrame '{df_name}':\n - DataFrame is empty.\n"
|
|
|
144 |
cols = df.columns.tolist()
|
145 |
dtypes = df.dtypes.to_dict()
|
146 |
+
schema_str = f"Schema for DataFrame 'df_{df_name}':\n"
|
147 |
for col in cols:
|
148 |
schema_str += f" - Column '{col}': {dtypes[col]}\n"
|
|
|
|
|
149 |
for col in cols:
|
150 |
if 'date' in col.lower() or 'time' in col.lower():
|
151 |
+
schema_str += f" - Note: Column '{col}' seems to be date/time related...\n"
|
152 |
if df[col].apply(type).eq(list).any() or df[col].apply(type).eq(dict).any():
|
153 |
+
schema_str += f" - Note: Column '{col}' may contain list-like or dict-like data...\n"
|
154 |
+
if df[col].dtype == 'object' and df[col].nunique() < 20 and df.shape[0] > 20:
|
155 |
+
schema_str += f" - Note: Column '{col}' might be categorical...\n"
|
|
|
156 |
schema_str += f"Sample of first 2 rows of 'df_{df_name}':\n{df.head(2).to_string()}\n"
|
157 |
return schema_str
|
158 |
|
159 |
def get_all_schemas_representation(dataframes_dict: dict) -> str:
|
160 |
+
full_schema_str = "You have access to the following Pandas DataFrames...\n\n"
|
|
|
161 |
for name, df_instance in dataframes_dict.items():
|
162 |
full_schema_str += get_schema_representation(name, df_instance) + "\n"
|
163 |
return full_schema_str
|
|
|
166 |
# --- Advanced RAG System ---
|
167 |
class AdvancedRAGSystem:
|
168 |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
|
169 |
+
self.embedding_model_name = embedding_model_name # Store the model name
|
170 |
if not GEMINI_API_KEY:
|
171 |
logging.warning("RAG System: GEMINI_API_KEY not set. Embeddings will not be generated.")
|
172 |
+
self.documents_df = documents_df.copy()
|
173 |
if 'Embeddings' not in self.documents_df.columns:
|
174 |
self.documents_df['Embeddings'] = pd.Series(dtype='object')
|
|
|
175 |
self.embeddings_generated = False
|
176 |
return
|
177 |
|
178 |
self.documents_df = documents_df.copy()
|
|
|
179 |
self.embeddings_generated = False
|
180 |
try:
|
181 |
+
# Check if genai.embed_content is available (not the dummy one)
|
182 |
+
if hasattr(genai, 'embed_content') and not (hasattr(genai.embed_content, '__func__') and genai.embed_content.__func__.__qualname__.startswith('genai.embed_content')): # Basic check if it's not the dummy's staticmethod
|
183 |
+
self._precompute_embeddings()
|
184 |
+
self.embeddings_generated = True
|
185 |
+
logging.info("AdvancedRAGSystem Initialized and embeddings precomputed.")
|
186 |
+
else:
|
187 |
+
logging.warning("AdvancedRAGSystem: Real genai.embed_content not available. Skipping embedding precomputation.")
|
188 |
+
if 'Embeddings' not in self.documents_df.columns:
|
189 |
+
self.documents_df['Embeddings'] = pd.Series(dtype='object')
|
190 |
+
|
191 |
except Exception as e:
|
192 |
logging.error(f"Error during RAG embedding precomputation: {e}", exc_info=True)
|
193 |
if 'Embeddings' not in self.documents_df.columns:
|
194 |
self.documents_df['Embeddings'] = pd.Series(dtype='object')
|
195 |
|
|
|
196 |
def _embed_fn(self, title: str, text: str) -> list[float]:
|
197 |
try:
|
198 |
+
# Check if genai.embed_content is available and not the dummy's
|
199 |
+
if not self.embeddings_generated or not hasattr(genai, 'embed_content') or (hasattr(genai.embed_content, '__func__') and genai.embed_content.__func__.__qualname__.startswith('genai.embed_content')):
|
200 |
+
logging.warning(f"genai.embed_content not available or using dummy. Returning zero vector for title: {title}")
|
201 |
+
return [0.0] * 768 # Default embedding size
|
202 |
|
203 |
embedding_result = genai.embed_content(
|
204 |
+
model=self.embedding_model_name, # Use the stored model name
|
205 |
content=text,
|
206 |
task_type="retrieval_document",
|
207 |
title=title
|
|
|
214 |
def _precompute_embeddings(self):
|
215 |
if 'Embeddings' not in self.documents_df.columns:
|
216 |
self.documents_df['Embeddings'] = pd.Series(dtype='object')
|
|
|
|
|
217 |
for index, row in self.documents_df.iterrows():
|
218 |
current_embedding = row['Embeddings']
|
219 |
+
is_valid_embedding = isinstance(current_embedding, list) and len(current_embedding) > 0 and sum(abs(x) for x in current_embedding) > 1e-6
|
|
|
220 |
if not is_valid_embedding:
|
221 |
self.documents_df.at[index, 'Embeddings'] = self._embed_fn(row['Title'], row['Text'])
|
222 |
+
logging.info("Embeddings precomputation finished (or skipped if dummy).")
|
|
|
223 |
|
224 |
def retrieve_relevant_info(self, query_text: str, top_k: int = 2) -> str:
|
225 |
+
# Check if embeddings were actually generated and if the real embed_content is available
|
226 |
+
if not self.embeddings_generated or not hasattr(genai, 'embed_content') or \
|
227 |
+
(hasattr(genai.embed_content, '__func__') and genai.embed_content.__func__.__qualname__.startswith('genai.embed_content')) or \
|
228 |
+
'Embeddings' not in self.documents_df.columns or self.documents_df['Embeddings'].isnull().all():
|
229 |
+
logging.warning("RAG System: Cannot retrieve info. Conditions not met (API key, embeddings, or real genai functions).")
|
230 |
return "\n[RAG Context]\nNo specific pre-defined context found (RAG system inactive or no embeddings).\n"
|
|
|
231 |
try:
|
232 |
query_embedding_result = genai.embed_content(
|
233 |
+
model=self.embedding_model_name, # Use the stored model name
|
234 |
content=query_text,
|
235 |
task_type="retrieval_query"
|
236 |
)
|
237 |
query_embedding = np.array(query_embedding_result["embedding"])
|
|
|
238 |
valid_embeddings_df = self.documents_df.dropna(subset=['Embeddings'])
|
239 |
valid_embeddings_df = valid_embeddings_df[valid_embeddings_df['Embeddings'].apply(lambda x: isinstance(x, list) and len(x) > 0 and sum(abs(val) for val in x) > 1e-6)]
|
|
|
|
|
240 |
if valid_embeddings_df.empty:
|
|
|
241 |
return "\n[RAG Context]\nNo valid document embeddings available for retrieval.\n"
|
|
|
242 |
document_embeddings = np.stack(valid_embeddings_df['Embeddings'].apply(np.array).values)
|
|
|
243 |
if query_embedding.shape[0] != document_embeddings.shape[1]:
|
|
|
244 |
return "\n[RAG Context]\nEmbedding dimension mismatch.\n"
|
|
|
245 |
dot_products = np.dot(document_embeddings, query_embedding)
|
|
|
|
|
246 |
num_available_docs = len(valid_embeddings_df)
|
247 |
actual_top_k = min(top_k, num_available_docs)
|
248 |
+
if actual_top_k == 0: return "\n[RAG Context]\nNo documents to retrieve from.\n"
|
249 |
+
idx = [np.argmax(dot_products)] if actual_top_k == 1 and num_available_docs > 0 else (np.argsort(dot_products)[-actual_top_k:][::-1] if num_available_docs > 0 else [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
relevant_passages = ""
|
251 |
+
for i_val in idx:
|
252 |
+
passage_title = valid_embeddings_df.iloc[i_val]['Title']
|
253 |
+
passage_text = valid_embeddings_df.iloc[i_val]['Text']
|
254 |
relevant_passages += f"\n[RAG Context from: '{passage_title}']\n{passage_text}\n"
|
|
|
|
|
255 |
return relevant_passages if relevant_passages else "\n[RAG Context]\nNo highly relevant passages found.\n"
|
|
|
256 |
except Exception as e:
|
257 |
logging.error(f"Error retrieving relevant info from RAG: {e}", exc_info=True)
|
258 |
return f"\n[RAG Context]\nError during RAG retrieval: {str(e)}\n"
|
259 |
|
|
|
260 |
# --- PandasLLM Class (Gemini-Powered) ---
|
261 |
class PandasLLM:
|
262 |
+
def __init__(self, llm_model_name: str, generation_config_params: dict,
|
263 |
+
safety_settings: dict, # safety_settings might not be used by client.models.generate_content
|
264 |
+
data_privacy=True, force_sandbox=True):
|
265 |
self.llm_model_name = llm_model_name
|
266 |
self.generation_config_params = generation_config_params
|
267 |
+
self.safety_settings = safety_settings # Store it, might be usable
|
268 |
self.data_privacy = data_privacy
|
269 |
+
self.force_sandbox = force_sandbox
|
270 |
+
self.client = None
|
271 |
+
self.generative_model_service = None # To store client.models
|
272 |
|
273 |
if not GEMINI_API_KEY:
|
274 |
logging.warning("PandasLLM: GEMINI_API_KEY not set. LLM functionalities will be limited.")
|
|
|
275 |
else:
|
276 |
try:
|
277 |
+
# Global genai.configure should have been called already
|
278 |
+
# User's suggestion: client = genai.Client(api_key="GEMINI_API_KEY")
|
279 |
+
# If genai.configure was called, api_key might not be needed for genai.Client()
|
280 |
+
# However, to be safe and follow user's hint structure:
|
281 |
+
self.client = genai.Client(api_key=GEMINI_API_KEY)
|
282 |
+
|
283 |
+
if self.client and hasattr(self.client, 'models') and hasattr(self.client.models, 'generate_content'):
|
284 |
+
self.generative_model_service = self.client.models
|
285 |
+
logging.info(f"PandasLLM Initialized with genai.Client. Using client.models for '{self.llm_model_name}'.")
|
286 |
+
elif self.client and hasattr(self.client, 'generate_content'): # Fallback: client itself has generate_content
|
287 |
+
self.generative_model_service = self.client # Use client directly
|
288 |
+
logging.info(f"PandasLLM Initialized with genai.Client. Using client.generate_content for '{self.llm_model_name}'.")
|
289 |
+
else:
|
290 |
+
logging.warning(f"PandasLLM: genai.Client initialized, but suitable 'generate_content' method not found on client or client.models. LLM calls may fail.")
|
291 |
+
except AttributeError as ae: # Catch if genai.Client itself is missing (e.g. very old dummy or lib issue)
|
292 |
+
logging.error(f"Failed to initialize genai.Client: {ae}. The 'genai' module might be a dummy or library is missing/old.", exc_info=True)
|
293 |
except Exception as e:
|
294 |
+
logging.error(f"Failed to initialize PandasLLM with genai.Client: {e}", exc_info=True)
|
295 |
+
|
296 |
|
297 |
async def _call_gemini_api_async(self, prompt_text: str, history: list = None) -> str:
|
298 |
+
if not self.generative_model_service:
|
299 |
+
logging.error("PandasLLM: Generative model service (e.g., client.models or client) not initialized. Cannot call API.")
|
300 |
+
return "# Error: Gemini client or service not available. Check API key and library installation."
|
301 |
|
302 |
+
contents_for_api = []
|
303 |
+
if history:
|
|
|
304 |
for entry in history:
|
305 |
+
role = entry.get("role", "user")
|
306 |
+
if role == "assistant": role = "model"
|
307 |
+
contents_for_api.append({"role": role, "parts": [{"text": entry.get("content", "")}]})
|
308 |
+
contents_for_api.append({"role": "user", "parts": [{"text": prompt_text}]})
|
309 |
+
|
310 |
+
generation_config_to_pass = self.generation_config_params
|
311 |
+
# For client.models.generate_content or client.generate_content, safety_settings might be a direct param
|
312 |
+
# or part of generation_config. This depends on the specific client API.
|
313 |
+
# Assuming it might be a direct parameter based on some Google API styles.
|
314 |
+
safety_settings_to_pass = self.safety_settings
|
315 |
|
316 |
|
317 |
+
logging.info(f"\n--- Calling Gemini API via Client with prompt (first 500 chars of last message): ---\n{contents_for_api[-1]['parts'][0]['text'][:500]}...\n-------------------------------------------------------\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
|
319 |
try:
|
320 |
+
# Construct the model name string, usually 'models/model-name'
|
321 |
+
# self.llm_model_name is "gemini-2.0-flash", so "models/gemini-2.0-flash"
|
322 |
+
model_id_for_api = self.llm_model_name
|
323 |
+
if not model_id_for_api.startswith("models/"):
|
324 |
+
model_id_for_api = f"models/{model_id_for_api}"
|
325 |
+
|
326 |
+
|
327 |
+
# Try to call self.generative_model_service.generate_content
|
328 |
+
# This service could be client.models or client itself.
|
329 |
response = await asyncio.to_thread(
|
330 |
+
self.generative_model_service.generate_content,
|
331 |
+
model=model_id_for_api,
|
332 |
+
contents=contents_for_api,
|
333 |
+
generation_config=generation_config_to_pass,
|
334 |
+
safety_settings=safety_settings_to_pass
|
335 |
)
|
336 |
|
337 |
+
if hasattr(response, 'prompt_feedback') and response.prompt_feedback and response.prompt_feedback.block_reason:
|
338 |
reason = response.prompt_feedback.block_reason
|
339 |
reason_name = getattr(reason, 'name', str(reason))
|
340 |
logging.warning(f"Gemini API call blocked by prompt feedback: {reason_name}")
|
341 |
return f"# Error: Prompt blocked due to content policy: {reason_name}."
|
342 |
|
343 |
llm_output = ""
|
344 |
+
if hasattr(response, 'text') and response.text: # Common for newer SDK responses
|
345 |
llm_output = response.text
|
346 |
+
elif hasattr(response, 'candidates') and response.candidates:
|
347 |
candidate = response.candidates[0]
|
348 |
+
if hasattr(candidate, 'content') and candidate.content and hasattr(candidate.content, 'parts') and candidate.content.parts:
|
349 |
llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text'))
|
350 |
|
351 |
+
if not llm_output and hasattr(candidate, 'finish_reason'):
|
352 |
finish_reason_val = candidate.finish_reason
|
353 |
+
finish_reason = getattr(finish_reason_val, 'name', str(finish_reason_val))
|
354 |
logging.warning(f"No text content in response candidate. Finish reason: {finish_reason}")
|
355 |
if finish_reason == "SAFETY":
|
356 |
return f"# Error: Response generation stopped due to safety reasons ({finish_reason})."
|
|
|
358 |
return f"# Error: Response generation stopped due to recitation policy ({finish_reason})."
|
359 |
return f"# Error: The AI model returned an empty response. Finish reason: {finish_reason}."
|
360 |
else:
|
361 |
+
logging.warning(f"Gemini API response structure not recognized or empty. Response: {response}")
|
362 |
return "# Error: The AI model returned an unexpected or empty response structure."
|
363 |
|
364 |
logging.info(f"--- Gemini API Response (first 300 chars): ---\n{llm_output[:300]}...\n--------------------------------------------------\n")
|
365 |
return llm_output
|
366 |
|
367 |
except AttributeError as ae:
|
368 |
+
logging.error(f"AttributeError during Gemini client call: {ae}. This might indicate the client object or 'models' attribute doesn't have 'generate_content' or is None.", exc_info=True)
|
369 |
+
return f"# Error (Attribute): {type(ae).__name__} - {ae}. Check client structure."
|
370 |
except Exception as e:
|
371 |
+
logging.error(f"Error calling Gemini API via Client: {e}", exc_info=True)
|
372 |
if "API_KEY_INVALID" in str(e) or "API key not valid" in str(e):
|
373 |
+
return "# Error: Gemini API key is not valid."
|
|
|
|
|
|
|
|
|
374 |
if "PermissionDenied" in str(e) or "403" in str(e):
|
375 |
+
return f"# Error: Permission denied for model '{model_id_for_api}' or service."
|
376 |
+
# Check for model not found specifically
|
377 |
+
if ("not found" in str(e).lower() or "does not exist" in str(e).lower()) and model_id_for_api in str(e):
|
378 |
+
return f"# Error: Model '{model_id_for_api}' not found or not accessible with your API key via client."
|
379 |
+
return f"# Error: An unexpected error occurred while contacting the AI model via Client: {type(e).__name__}."
|
380 |
|
381 |
|
382 |
async def query(self, prompt_with_query_and_context: str, dataframes_dict: dict, history: list = None) -> str:
|
383 |
llm_response_text = await self._call_gemini_api_async(prompt_with_query_and_context, history)
|
|
|
384 |
if self.force_sandbox:
|
385 |
code_to_execute = ""
|
386 |
if "```python" in llm_response_text:
|
|
|
394 |
except IndexError:
|
395 |
code_to_execute = ""
|
396 |
logging.warning("Could not extract Python code using primary or secondary split method.")
|
397 |
+
llm_response_text_for_sandbox_error = ""
|
|
|
398 |
if llm_response_text.startswith("# Error:") or not code_to_execute:
|
399 |
+
error_prefix = "LLM did not return valid Python code or an error occurred."
|
400 |
if llm_response_text.startswith("# Error:"): error_prefix = "An error occurred during LLM call."
|
401 |
elif not code_to_execute: error_prefix = "Could not extract Python code from LLM response."
|
|
|
402 |
safe_llm_response = str(llm_response_text).replace("'''", "'").replace('"""', '"')
|
403 |
llm_response_text_for_sandbox_error = f"print(f'''{error_prefix}\\nRaw LLM Response (may be truncated):\\n{safe_llm_response[:1000]}''')"
|
404 |
logging.warning(f"Problem with LLM response for sandbox: {error_prefix}")
|
|
|
405 |
logging.info(f"\n--- Code to Execute (from LLM, if sandbox): ---\n{code_to_execute}\n------------------------------------------------\n")
|
|
|
|
|
|
|
|
|
406 |
safe_builtins = {}
|
407 |
if isinstance(__builtins__, dict):
|
408 |
safe_builtins = {name: obj for name, obj in __builtins__.items() if not name.startswith('_')}
|
409 |
+
else:
|
410 |
safe_builtins = {name: obj for name, obj in __builtins__.__dict__.items() if not name.startswith('_')}
|
|
|
|
|
411 |
unsafe_builtins = ['eval', 'exec', 'open', 'compile', 'input', 'memoryview', 'vars', 'globals', 'locals', '__import__']
|
412 |
for ub in unsafe_builtins:
|
413 |
safe_builtins.pop(ub, None)
|
|
|
414 |
exec_globals = {'pd': pd, 'np': np, '__builtins__': safe_builtins}
|
415 |
for name, df_instance in dataframes_dict.items():
|
416 |
exec_globals[f"df_{name}"] = df_instance
|
|
|
417 |
from io import StringIO
|
418 |
import sys
|
419 |
old_stdout = sys.stdout
|
420 |
sys.stdout = captured_output = StringIO()
|
|
|
421 |
final_output_str = ""
|
422 |
try:
|
423 |
if code_to_execute:
|
424 |
exec(code_to_execute, exec_globals, {})
|
425 |
output_val = captured_output.getvalue()
|
426 |
final_output_str = output_val if output_val else "# Code executed successfully, but no explicit print() output was generated by the code."
|
|
|
427 |
else:
|
428 |
exec(llm_response_text_for_sandbox_error, exec_globals, {})
|
429 |
final_output_str = captured_output.getvalue()
|
|
|
|
|
430 |
except Exception as e:
|
431 |
error_msg = f"# Error executing LLM-generated code:\n# {type(e).__name__}: {str(e)}\n# --- Code that caused error: ---\n{textwrap.indent(code_to_execute, '# ')}"
|
432 |
final_output_str = error_msg
|
|
|
437 |
else:
|
438 |
return llm_response_text
|
439 |
|
|
|
440 |
# --- Employer Branding Agent ---
|
441 |
class EmployerBrandingAgent:
|
442 |
def __init__(self, llm_model_name: str, generation_config_params: dict, safety_settings: dict,
|
|
|
450 |
logging.info("EmployerBrandingAgent Initialized.")
|
451 |
|
452 |
def _build_prompt(self, user_query: str, role="Employer Branding Analyst", task_decomposition_hint=None, cot_hint=True) -> str:
|
453 |
+
prompt = f"You are a helpful and expert '{role}'...\n" # Truncated for brevity
|
454 |
+
# ... (rest of the prompt building logic remains the same)
|
455 |
+
prompt += "Your main task is to GENERATE PYTHON CODE using the Pandas library...\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
prompt += "\n--- AVAILABLE DATA AND SCHEMAS ---\n"
|
457 |
prompt += self.schemas_representation
|
|
|
458 |
rag_context = self.rag_system.retrieve_relevant_info(user_query)
|
459 |
if rag_context and "[RAG Context]" in rag_context and "No specific pre-defined context found" not in rag_context and "No highly relevant passages found" not in rag_context:
|
460 |
prompt += f"\n--- ADDITIONAL CONTEXT (from internal knowledge base, consider this information) ---\n{rag_context}\n"
|
|
|
461 |
prompt += f"\n--- USER QUERY ---\n{user_query}\n"
|
462 |
+
if self.pandas_llm.force_sandbox:
|
463 |
+
prompt += "\n--- INSTRUCTIONS FOR PYTHON CODE GENERATION (Chain of Thought) ---\n"
|
464 |
+
prompt += "1. Understand the query...\n"
|
465 |
+
prompt += "7. Generate ONLY the Python code block starting with ```python and ending with ```...\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
return prompt
|
467 |
|
468 |
async def process_query(self, user_query: str, role="Employer Branding Analyst", task_decomposition_hint=None, cot_hint=True) -> str:
|
469 |
+
logging.info(f"\n=== Processing Query for Role: {role}, Query: {user_query} ===")
|
|
|
|
|
470 |
self.chat_history.append({"role": "user", "content": user_query})
|
|
|
471 |
full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint)
|
|
|
472 |
response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=self.chat_history[:-1])
|
|
|
473 |
self.chat_history.append({"role": "assistant", "content": response_text})
|
|
|
474 |
MAX_HISTORY_TURNS = 5
|
475 |
if len(self.chat_history) > MAX_HISTORY_TURNS * 2:
|
476 |
self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):]
|
|
|
477 |
return response_text
|
478 |
|
479 |
def update_dataframes(self, new_dataframes: dict):
|
|
|
480 |
self.all_dataframes = new_dataframes
|
481 |
self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
|
482 |
logging.info("EmployerBrandingAgent DataFrames updated.")
|