Spaces:
Running
Running
""" | |
Utility functions for LLM interactions within the auto_causal module. | |
""" | |
from typing import Dict, Any, Optional, List | |
# Assume pandas is available for get_columns_info and sample_data | |
import pandas as pd | |
import logging | |
import json # Ensure json is imported | |
# Added import for type hint | |
from langchain.chat_models.base import BaseChatModel | |
from langchain_core.messages import AIMessage # For type hinting llm.invoke response | |
logger = logging.getLogger(__name__) | |
# Placeholder for actual LLM calling logic | |
def call_llm_with_json_output(llm: Optional[BaseChatModel], prompt: str) -> Optional[Dict[str, Any]]: | |
""" | |
Calls the provided LLM with a prompt, expecting a JSON object in the response. | |
It parses the JSON string (after attempting to remove markdown fences) | |
and returns it as a Python dictionary. | |
Args: | |
llm: An instance of BaseChatModel (e.g., from Langchain). If None, | |
the function will log a warning and return None. | |
prompt: The prompt string to send to the LLM. | |
Returns: | |
A dictionary parsed from the LLM's JSON response, or None if: | |
- llm is None. | |
- The LLM call fails. | |
- The LLM response content cannot be extracted as a string. | |
- The response content is empty after stripping markdown. | |
- The response is not valid JSON. | |
- The parsed JSON is not a dictionary. | |
""" | |
if not llm: | |
logger.warning("LLM client (BaseChatModel) not provided to call_llm_with_json_output. Cannot make LLM call.") | |
return None | |
logger.info(f"Attempting LLM call with {type(llm).__name__} for JSON output.") | |
# Full prompt logging can be verbose, using DEBUG level. | |
logger.debug(f"LLM Prompt for JSON output:\\n{prompt}") | |
raw_response_content = "" # For logging in case of errors before parsing | |
processed_content_for_json = "" # For logging in case of JSON parsing error | |
try: | |
llm_response_obj = llm.invoke(prompt) | |
# Extract string content from LLM response object | |
if hasattr(llm_response_obj, 'content') and isinstance(llm_response_obj.content, str): | |
raw_response_content = llm_response_obj.content | |
elif isinstance(llm_response_obj, str): | |
raw_response_content = llm_response_obj | |
else: | |
# Fallback for other potential response structures | |
logger.warning( | |
f"LLM response is not a string and has no '.content' attribute of type string. " | |
f"Type: {type(llm_response_obj)}. Trying '.text' attribute." | |
) | |
if hasattr(llm_response_obj, 'text') and isinstance(llm_response_obj.text, str): | |
raw_response_content = llm_response_obj.text | |
if not raw_response_content: | |
logger.warning(f"LLM invocation returned no extractable string content. Response object type: {type(llm_response_obj)}") | |
return None | |
# Prepare content for JSON parsing: strip whitespace and markdown fences. | |
# Using the same stripping logic as in llm_identify_temporal_and_unit_vars for consistency. | |
processed_content_for_json = raw_response_content.strip() | |
if processed_content_for_json.startswith("```json"): | |
# Removes "```json" prefix and "```" suffix, then strips whitespace. | |
# Assumes the format is "```json\\nCONTENT\\n```" or similar. | |
processed_content_for_json = processed_content_for_json[7:-3].strip() | |
elif processed_content_for_json.startswith("```"): | |
# Removes generic "```" prefix and "```" suffix, then strips. | |
processed_content_for_json = processed_content_for_json[3:-3].strip() | |
if not processed_content_for_json: # Check if empty after stripping | |
logger.warning( | |
"LLM response content became empty after attempting to strip markdown. " | |
f"Original raw content snippet: '{raw_response_content[:200]}...'" | |
) | |
return None | |
parsed_json = json.loads(processed_content_for_json) | |
if not isinstance(parsed_json, dict): | |
logger.warning( | |
"LLM response was successfully parsed as JSON, but it is not a dictionary. " | |
f"Type: {type(parsed_json)}. Parsed content snippet: '{str(parsed_json)[:200]}...'" | |
) | |
return None | |
logger.info(f"Successfully received and parsed JSON response from {type(llm).__name__}.") | |
return parsed_json | |
except json.JSONDecodeError as e: | |
logger.error( | |
f"Failed to decode JSON from LLM response. Error: {e}. " | |
f"Content processed for parsing (snippet): '{processed_content_for_json[:500]}...'" | |
) | |
return None | |
except Exception as e: | |
# This catches errors from llm.invoke() or other unexpected issues. | |
logger.error(f"An unexpected error occurred during LLM call or JSON processing: {e}", exc_info=True) | |
# Log raw content if available and different from processed, for better debugging | |
if raw_response_content and raw_response_content[:500] != processed_content_for_json[:500]: | |
logger.debug(f"Original raw LLM response content (snippet): '{raw_response_content[:500]}...'") | |
return None | |
# Placeholder for processing LLM response | |
def process_llm_response(response: Dict[str, Any], method: str) -> Dict[str, Any]: | |
# Validate and structure the LLM response based on the method | |
# For now, just return the response | |
return response | |
# Placeholder for getting column info | |
def get_columns_info(df: pd.DataFrame) -> Dict[str, str]: | |
return {col: str(dtype) for col, dtype in df.dtypes.items()} | |
def analyze_dataset_for_method(df: pd.DataFrame, query: str, method: str) -> Dict[str, Any]: | |
"""Use LLM to analyze dataset for appropriate method parameters. | |
Args: | |
df: Input DataFrame | |
query: User's causal query | |
method: The causal method being considered | |
Returns: | |
Dictionary with suggested parameters and validation checks from LLM. | |
""" | |
# Prepare prompt with dataset information | |
columns_info = get_columns_info(df) | |
try: | |
# Attempt to get sample data safely | |
sample_data = df.head(5).to_dict(orient='records') | |
except Exception: | |
sample_data = "Error retrieving sample data." | |
# --- Revised Prompt --- | |
prompt = f""" | |
Given the dataset with columns {columns_info} and the causal query "{query}", | |
suggest SENSIBLE INITIAL DEFAULT parameters for applying the {method} method. | |
Do NOT attempt complex optimization; provide common starting points. | |
The first 5 rows of data look like: | |
{sample_data} | |
Specifically for {method}: | |
- If PS.Matching: | |
- For 'caliper': Suggest a common heuristic value like 0.01, 0.02, or 0.05 (this is relative to std dev of logit score, but just suggest the number). If unsure, suggest 0.02. | |
- For 'n_neighbors': Suggest 1. | |
- For 'propensity_model_type': Suggest 'logistic' unless the context strongly implies a more complex model is needed. | |
- If PS.Weighting: | |
- For 'weight_type': Suggest 'ATE' unless the query specifically asks for ATT or ATC. | |
- For 'trim_threshold': Suggest a small value like 0.01 or 0.05 if the data seems noisy or has extreme propensity scores, otherwise suggest null (no trimming). Default to null if unsure. | |
- Add other parameters if relevant for the specific method. | |
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text): | |
{{ | |
"parameters": {{ | |
// method-specific parameters based on the guidelines above | |
}}, | |
"validation": {{ | |
// validation checks typically needed (e.g., check_balance: true for PSM) | |
}} | |
}} | |
""" | |
# --- End Revised Prompt --- | |
# Call LLM with prompt - Assuming analyze_dataset_for_method provides the llm object | |
# For now, this internal call still uses the placeholder without passing llm | |
# This needs to be updated if analyze_dataset_for_method is intended to use a passed llm | |
response = call_llm_with_json_output(None, prompt) # Passing None for llm temporarily | |
# Process and validate response | |
# This step might involve ensuring the structure is correct, | |
# parameters are valid types, etc. | |
processed_response = process_llm_response(response, method) | |
return processed_response | |
def llm_identify_temporal_and_unit_vars( | |
column_names: List[str], | |
column_dtypes: Dict[str, str], | |
dataset_description: str, | |
dataset_summary: str, | |
heuristic_time_candidates: Optional[List[str]] = None, # These are no longer used in the revised prompt | |
heuristic_id_candidates: Optional[List[str]] = None, # These are no longer used in the revised prompt | |
query: str = "No query provided.", | |
llm: Optional[BaseChatModel] = None | |
) -> Dict[str, Optional[str]]: | |
"""Uses LLM to identify the primary time: | |
Args: | |
column_names: List of all column names. | |
column_dtypes: Dictionary mapping column names to string representation of data types. | |
dataset_description: Textual description of the dataset. | |
dataset_summary: Summary of the dataset | |
heuristic_time_candidates: Optional list of columns identified as time vars by heuristics (currently unused by prompt). | |
heuristic_id_candidates: Optional list of columns identified as unit ID vars by heuristics (currently unused by prompt). | |
llm: The language model client instance. | |
Returns: | |
A dictionary with keys 'time_variable' and 'unit_variable', | |
whose values are the identified column names or None. | |
""" | |
if not llm: | |
logger.warning("LLM client not provided for temporal/unit identification. Returning None.") | |
return {"time_variable": None, "unit_variable": None} | |
logger.info("Attempting LLM identification of time and unit variables...") | |
# Construct the prompt (revised based on user feedback in conversation) | |
prompt = f""" | |
You are a data analysis expert tasked with determining whether a dataset supports a Difference-in-Differences (DiD) or Two-Way Fixed Effects (TWFE) design to answer the following query: | |
{query} | |
You are given the following information: | |
Dataset Description: | |
{dataset_description} | |
Columns and Data Types: | |
{column_dtypes} | |
First, based on the above information, check if any columns represent information about the time/periods associated directly with intervention application. It could be either: | |
1. A variable that represents **time periods associated with the intervention**. This must satisfy one of the following: | |
- A binary indicator showing pre/post-intervention status, | |
- A discrete or continuous variable that records **when units were observed**, which can be aligned with treatment application periods. | |
Do **not** select generic time-related variables that merely describe time as a feature, such as **'date of birth'**, **'year of graduation'**, 'week of sign-up', **'years of schooling'** unless they directly represent **observation times relevant to treatment**. | |
2. A variable that represents the **unit of observation** (e.g., individual, region, school) — the entity over which we compare treated vs. untreated groups across time. | |
Return ONLY a valid JSON object with this structure and no surrounding explanation: | |
{{ | |
"time_variable": "<column_name_or_null>", | |
"unit_variable": "<column_name_or_null>" | |
}} | |
""" | |
parsed_response = None | |
try: | |
llm_response_obj = llm.invoke(prompt) | |
response_content = "" | |
if hasattr(llm_response_obj, 'content'): | |
response_content = llm_response_obj.content | |
elif isinstance(llm_response_obj, str): # Some LLMs might return str directly | |
response_content = llm_response_obj | |
else: | |
logger.warning(f"LLM response object type not recognized for content extraction: {type(llm_response_obj)}") | |
if response_content: | |
# Attempt to strip markdown ```json ... ``` if present | |
if response_content.strip().startswith("```json"): | |
response_content = response_content.strip()[7:-3].strip() | |
elif response_content.strip().startswith("```"): | |
response_content = response_content.strip()[3:-3].strip() | |
parsed_response = json.loads(response_content) | |
else: | |
logger.warning("LLM invocation returned no content.") | |
except json.JSONDecodeError as e: | |
logger.error(f"Failed to decode JSON from LLM response for time/unit vars: {e}. Response content: '{response_content[:500]}...'") # Log snippet | |
except Exception as e: | |
logger.error(f"Error during LLM invocation or processing for time/unit vars: {e}", exc_info=True) | |
# Process the response | |
if parsed_response and isinstance(parsed_response, dict): | |
time_var = parsed_response.get("time_variable") | |
unit_var = parsed_response.get("unit_variable") | |
# Basic validation: ensure returned names are actual columns or None | |
if time_var is not None and time_var not in column_names: | |
logger.warning(f"LLM identified time variable '{time_var}' not found in columns. Setting to None.") | |
time_var = None | |
if unit_var is not None and unit_var not in column_names: | |
logger.warning(f"LLM identified unit variable '{unit_var}' not found in columns. Setting to None.") | |
unit_var = None | |
logger.info(f"LLM identified time='{time_var}', unit='{unit_var}'") | |
return {"time_variable": time_var, "unit_variable": unit_var} | |
else: | |
logger.warning("LLM call failed or returned invalid/unparsable JSON for time/unit identification.") | |
return {"time_variable": None, "unit_variable": None} |