""" Utility functions for LLM interactions within the auto_causal module. """ from typing import Dict, Any, Optional, List # Assume pandas is available for get_columns_info and sample_data import pandas as pd import logging import json # Ensure json is imported # Added import for type hint from langchain.chat_models.base import BaseChatModel from langchain_core.messages import AIMessage # For type hinting llm.invoke response logger = logging.getLogger(__name__) # Placeholder for actual LLM calling logic def call_llm_with_json_output(llm: Optional[BaseChatModel], prompt: str) -> Optional[Dict[str, Any]]: """ Calls the provided LLM with a prompt, expecting a JSON object in the response. It parses the JSON string (after attempting to remove markdown fences) and returns it as a Python dictionary. Args: llm: An instance of BaseChatModel (e.g., from Langchain). If None, the function will log a warning and return None. prompt: The prompt string to send to the LLM. Returns: A dictionary parsed from the LLM's JSON response, or None if: - llm is None. - The LLM call fails. - The LLM response content cannot be extracted as a string. - The response content is empty after stripping markdown. - The response is not valid JSON. - The parsed JSON is not a dictionary. """ if not llm: logger.warning("LLM client (BaseChatModel) not provided to call_llm_with_json_output. Cannot make LLM call.") return None logger.info(f"Attempting LLM call with {type(llm).__name__} for JSON output.") # Full prompt logging can be verbose, using DEBUG level. logger.debug(f"LLM Prompt for JSON output:\\n{prompt}") raw_response_content = "" # For logging in case of errors before parsing processed_content_for_json = "" # For logging in case of JSON parsing error try: llm_response_obj = llm.invoke(prompt) # Extract string content from LLM response object if hasattr(llm_response_obj, 'content') and isinstance(llm_response_obj.content, str): raw_response_content = llm_response_obj.content elif isinstance(llm_response_obj, str): raw_response_content = llm_response_obj else: # Fallback for other potential response structures logger.warning( f"LLM response is not a string and has no '.content' attribute of type string. " f"Type: {type(llm_response_obj)}. Trying '.text' attribute." ) if hasattr(llm_response_obj, 'text') and isinstance(llm_response_obj.text, str): raw_response_content = llm_response_obj.text if not raw_response_content: logger.warning(f"LLM invocation returned no extractable string content. Response object type: {type(llm_response_obj)}") return None # Prepare content for JSON parsing: strip whitespace and markdown fences. # Using the same stripping logic as in llm_identify_temporal_and_unit_vars for consistency. processed_content_for_json = raw_response_content.strip() if processed_content_for_json.startswith("```json"): # Removes "```json" prefix and "```" suffix, then strips whitespace. # Assumes the format is "```json\\nCONTENT\\n```" or similar. processed_content_for_json = processed_content_for_json[7:-3].strip() elif processed_content_for_json.startswith("```"): # Removes generic "```" prefix and "```" suffix, then strips. processed_content_for_json = processed_content_for_json[3:-3].strip() if not processed_content_for_json: # Check if empty after stripping logger.warning( "LLM response content became empty after attempting to strip markdown. " f"Original raw content snippet: '{raw_response_content[:200]}...'" ) return None parsed_json = json.loads(processed_content_for_json) if not isinstance(parsed_json, dict): logger.warning( "LLM response was successfully parsed as JSON, but it is not a dictionary. " f"Type: {type(parsed_json)}. Parsed content snippet: '{str(parsed_json)[:200]}...'" ) return None logger.info(f"Successfully received and parsed JSON response from {type(llm).__name__}.") return parsed_json except json.JSONDecodeError as e: logger.error( f"Failed to decode JSON from LLM response. Error: {e}. " f"Content processed for parsing (snippet): '{processed_content_for_json[:500]}...'" ) return None except Exception as e: # This catches errors from llm.invoke() or other unexpected issues. logger.error(f"An unexpected error occurred during LLM call or JSON processing: {e}", exc_info=True) # Log raw content if available and different from processed, for better debugging if raw_response_content and raw_response_content[:500] != processed_content_for_json[:500]: logger.debug(f"Original raw LLM response content (snippet): '{raw_response_content[:500]}...'") return None # Placeholder for processing LLM response def process_llm_response(response: Dict[str, Any], method: str) -> Dict[str, Any]: # Validate and structure the LLM response based on the method # For now, just return the response return response # Placeholder for getting column info def get_columns_info(df: pd.DataFrame) -> Dict[str, str]: return {col: str(dtype) for col, dtype in df.dtypes.items()} def analyze_dataset_for_method(df: pd.DataFrame, query: str, method: str) -> Dict[str, Any]: """Use LLM to analyze dataset for appropriate method parameters. Args: df: Input DataFrame query: User's causal query method: The causal method being considered Returns: Dictionary with suggested parameters and validation checks from LLM. """ # Prepare prompt with dataset information columns_info = get_columns_info(df) try: # Attempt to get sample data safely sample_data = df.head(5).to_dict(orient='records') except Exception: sample_data = "Error retrieving sample data." # --- Revised Prompt --- prompt = f""" Given the dataset with columns {columns_info} and the causal query "{query}", suggest SENSIBLE INITIAL DEFAULT parameters for applying the {method} method. Do NOT attempt complex optimization; provide common starting points. The first 5 rows of data look like: {sample_data} Specifically for {method}: - If PS.Matching: - For 'caliper': Suggest a common heuristic value like 0.01, 0.02, or 0.05 (this is relative to std dev of logit score, but just suggest the number). If unsure, suggest 0.02. - For 'n_neighbors': Suggest 1. - For 'propensity_model_type': Suggest 'logistic' unless the context strongly implies a more complex model is needed. - If PS.Weighting: - For 'weight_type': Suggest 'ATE' unless the query specifically asks for ATT or ATC. - For 'trim_threshold': Suggest a small value like 0.01 or 0.05 if the data seems noisy or has extreme propensity scores, otherwise suggest null (no trimming). Default to null if unsure. - Add other parameters if relevant for the specific method. Return ONLY a valid JSON object with the following structure (no explanations or surrounding text): {{ "parameters": {{ // method-specific parameters based on the guidelines above }}, "validation": {{ // validation checks typically needed (e.g., check_balance: true for PSM) }} }} """ # --- End Revised Prompt --- # Call LLM with prompt - Assuming analyze_dataset_for_method provides the llm object # For now, this internal call still uses the placeholder without passing llm # This needs to be updated if analyze_dataset_for_method is intended to use a passed llm response = call_llm_with_json_output(None, prompt) # Passing None for llm temporarily # Process and validate response # This step might involve ensuring the structure is correct, # parameters are valid types, etc. processed_response = process_llm_response(response, method) return processed_response def llm_identify_temporal_and_unit_vars( column_names: List[str], column_dtypes: Dict[str, str], dataset_description: str, dataset_summary: str, heuristic_time_candidates: Optional[List[str]] = None, # These are no longer used in the revised prompt heuristic_id_candidates: Optional[List[str]] = None, # These are no longer used in the revised prompt query: str = "No query provided.", llm: Optional[BaseChatModel] = None ) -> Dict[str, Optional[str]]: """Uses LLM to identify the primary time: Args: column_names: List of all column names. column_dtypes: Dictionary mapping column names to string representation of data types. dataset_description: Textual description of the dataset. dataset_summary: Summary of the dataset heuristic_time_candidates: Optional list of columns identified as time vars by heuristics (currently unused by prompt). heuristic_id_candidates: Optional list of columns identified as unit ID vars by heuristics (currently unused by prompt). llm: The language model client instance. Returns: A dictionary with keys 'time_variable' and 'unit_variable', whose values are the identified column names or None. """ if not llm: logger.warning("LLM client not provided for temporal/unit identification. Returning None.") return {"time_variable": None, "unit_variable": None} logger.info("Attempting LLM identification of time and unit variables...") # Construct the prompt (revised based on user feedback in conversation) prompt = f""" You are a data analysis expert tasked with determining whether a dataset supports a Difference-in-Differences (DiD) or Two-Way Fixed Effects (TWFE) design to answer the following query: {query} You are given the following information: Dataset Description: {dataset_description} Columns and Data Types: {column_dtypes} First, based on the above information, check if any columns represent information about the time/periods associated directly with intervention application. It could be either: 1. A variable that represents **time periods associated with the intervention**. This must satisfy one of the following: - A binary indicator showing pre/post-intervention status, - A discrete or continuous variable that records **when units were observed**, which can be aligned with treatment application periods. Do **not** select generic time-related variables that merely describe time as a feature, such as **'date of birth'**, **'year of graduation'**, 'week of sign-up', **'years of schooling'** unless they directly represent **observation times relevant to treatment**. 2. A variable that represents the **unit of observation** (e.g., individual, region, school) — the entity over which we compare treated vs. untreated groups across time. Return ONLY a valid JSON object with this structure and no surrounding explanation: {{ "time_variable": "", "unit_variable": "" }} """ parsed_response = None try: llm_response_obj = llm.invoke(prompt) response_content = "" if hasattr(llm_response_obj, 'content'): response_content = llm_response_obj.content elif isinstance(llm_response_obj, str): # Some LLMs might return str directly response_content = llm_response_obj else: logger.warning(f"LLM response object type not recognized for content extraction: {type(llm_response_obj)}") if response_content: # Attempt to strip markdown ```json ... ``` if present if response_content.strip().startswith("```json"): response_content = response_content.strip()[7:-3].strip() elif response_content.strip().startswith("```"): response_content = response_content.strip()[3:-3].strip() parsed_response = json.loads(response_content) else: logger.warning("LLM invocation returned no content.") except json.JSONDecodeError as e: logger.error(f"Failed to decode JSON from LLM response for time/unit vars: {e}. Response content: '{response_content[:500]}...'") # Log snippet except Exception as e: logger.error(f"Error during LLM invocation or processing for time/unit vars: {e}", exc_info=True) # Process the response if parsed_response and isinstance(parsed_response, dict): time_var = parsed_response.get("time_variable") unit_var = parsed_response.get("unit_variable") # Basic validation: ensure returned names are actual columns or None if time_var is not None and time_var not in column_names: logger.warning(f"LLM identified time variable '{time_var}' not found in columns. Setting to None.") time_var = None if unit_var is not None and unit_var not in column_names: logger.warning(f"LLM identified unit variable '{unit_var}' not found in columns. Setting to None.") unit_var = None logger.info(f"LLM identified time='{time_var}', unit='{unit_var}'") return {"time_variable": time_var, "unit_variable": unit_var} else: logger.warning("LLM call failed or returned invalid/unparsable JSON for time/unit identification.") return {"time_variable": None, "unit_variable": None}