Spaces:
Running
Running
File size: 13,903 Bytes
1721aea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
"""
Utility functions for LLM interactions within the auto_causal module.
"""
from typing import Dict, Any, Optional, List
# Assume pandas is available for get_columns_info and sample_data
import pandas as pd
import logging
import json # Ensure json is imported
# Added import for type hint
from langchain.chat_models.base import BaseChatModel
from langchain_core.messages import AIMessage # For type hinting llm.invoke response
logger = logging.getLogger(__name__)
# Placeholder for actual LLM calling logic
def call_llm_with_json_output(llm: Optional[BaseChatModel], prompt: str) -> Optional[Dict[str, Any]]:
"""
Calls the provided LLM with a prompt, expecting a JSON object in the response.
It parses the JSON string (after attempting to remove markdown fences)
and returns it as a Python dictionary.
Args:
llm: An instance of BaseChatModel (e.g., from Langchain). If None,
the function will log a warning and return None.
prompt: The prompt string to send to the LLM.
Returns:
A dictionary parsed from the LLM's JSON response, or None if:
- llm is None.
- The LLM call fails.
- The LLM response content cannot be extracted as a string.
- The response content is empty after stripping markdown.
- The response is not valid JSON.
- The parsed JSON is not a dictionary.
"""
if not llm:
logger.warning("LLM client (BaseChatModel) not provided to call_llm_with_json_output. Cannot make LLM call.")
return None
logger.info(f"Attempting LLM call with {type(llm).__name__} for JSON output.")
# Full prompt logging can be verbose, using DEBUG level.
logger.debug(f"LLM Prompt for JSON output:\\n{prompt}")
raw_response_content = "" # For logging in case of errors before parsing
processed_content_for_json = "" # For logging in case of JSON parsing error
try:
llm_response_obj = llm.invoke(prompt)
# Extract string content from LLM response object
if hasattr(llm_response_obj, 'content') and isinstance(llm_response_obj.content, str):
raw_response_content = llm_response_obj.content
elif isinstance(llm_response_obj, str):
raw_response_content = llm_response_obj
else:
# Fallback for other potential response structures
logger.warning(
f"LLM response is not a string and has no '.content' attribute of type string. "
f"Type: {type(llm_response_obj)}. Trying '.text' attribute."
)
if hasattr(llm_response_obj, 'text') and isinstance(llm_response_obj.text, str):
raw_response_content = llm_response_obj.text
if not raw_response_content:
logger.warning(f"LLM invocation returned no extractable string content. Response object type: {type(llm_response_obj)}")
return None
# Prepare content for JSON parsing: strip whitespace and markdown fences.
# Using the same stripping logic as in llm_identify_temporal_and_unit_vars for consistency.
processed_content_for_json = raw_response_content.strip()
if processed_content_for_json.startswith("```json"):
# Removes "```json" prefix and "```" suffix, then strips whitespace.
# Assumes the format is "```json\\nCONTENT\\n```" or similar.
processed_content_for_json = processed_content_for_json[7:-3].strip()
elif processed_content_for_json.startswith("```"):
# Removes generic "```" prefix and "```" suffix, then strips.
processed_content_for_json = processed_content_for_json[3:-3].strip()
if not processed_content_for_json: # Check if empty after stripping
logger.warning(
"LLM response content became empty after attempting to strip markdown. "
f"Original raw content snippet: '{raw_response_content[:200]}...'"
)
return None
parsed_json = json.loads(processed_content_for_json)
if not isinstance(parsed_json, dict):
logger.warning(
"LLM response was successfully parsed as JSON, but it is not a dictionary. "
f"Type: {type(parsed_json)}. Parsed content snippet: '{str(parsed_json)[:200]}...'"
)
return None
logger.info(f"Successfully received and parsed JSON response from {type(llm).__name__}.")
return parsed_json
except json.JSONDecodeError as e:
logger.error(
f"Failed to decode JSON from LLM response. Error: {e}. "
f"Content processed for parsing (snippet): '{processed_content_for_json[:500]}...'"
)
return None
except Exception as e:
# This catches errors from llm.invoke() or other unexpected issues.
logger.error(f"An unexpected error occurred during LLM call or JSON processing: {e}", exc_info=True)
# Log raw content if available and different from processed, for better debugging
if raw_response_content and raw_response_content[:500] != processed_content_for_json[:500]:
logger.debug(f"Original raw LLM response content (snippet): '{raw_response_content[:500]}...'")
return None
# Placeholder for processing LLM response
def process_llm_response(response: Dict[str, Any], method: str) -> Dict[str, Any]:
# Validate and structure the LLM response based on the method
# For now, just return the response
return response
# Placeholder for getting column info
def get_columns_info(df: pd.DataFrame) -> Dict[str, str]:
return {col: str(dtype) for col, dtype in df.dtypes.items()}
def analyze_dataset_for_method(df: pd.DataFrame, query: str, method: str) -> Dict[str, Any]:
"""Use LLM to analyze dataset for appropriate method parameters.
Args:
df: Input DataFrame
query: User's causal query
method: The causal method being considered
Returns:
Dictionary with suggested parameters and validation checks from LLM.
"""
# Prepare prompt with dataset information
columns_info = get_columns_info(df)
try:
# Attempt to get sample data safely
sample_data = df.head(5).to_dict(orient='records')
except Exception:
sample_data = "Error retrieving sample data."
# --- Revised Prompt ---
prompt = f"""
Given the dataset with columns {columns_info} and the causal query "{query}",
suggest SENSIBLE INITIAL DEFAULT parameters for applying the {method} method.
Do NOT attempt complex optimization; provide common starting points.
The first 5 rows of data look like:
{sample_data}
Specifically for {method}:
- If PS.Matching:
- For 'caliper': Suggest a common heuristic value like 0.01, 0.02, or 0.05 (this is relative to std dev of logit score, but just suggest the number). If unsure, suggest 0.02.
- For 'n_neighbors': Suggest 1.
- For 'propensity_model_type': Suggest 'logistic' unless the context strongly implies a more complex model is needed.
- If PS.Weighting:
- For 'weight_type': Suggest 'ATE' unless the query specifically asks for ATT or ATC.
- For 'trim_threshold': Suggest a small value like 0.01 or 0.05 if the data seems noisy or has extreme propensity scores, otherwise suggest null (no trimming). Default to null if unsure.
- Add other parameters if relevant for the specific method.
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
{{
"parameters": {{
// method-specific parameters based on the guidelines above
}},
"validation": {{
// validation checks typically needed (e.g., check_balance: true for PSM)
}}
}}
"""
# --- End Revised Prompt ---
# Call LLM with prompt - Assuming analyze_dataset_for_method provides the llm object
# For now, this internal call still uses the placeholder without passing llm
# This needs to be updated if analyze_dataset_for_method is intended to use a passed llm
response = call_llm_with_json_output(None, prompt) # Passing None for llm temporarily
# Process and validate response
# This step might involve ensuring the structure is correct,
# parameters are valid types, etc.
processed_response = process_llm_response(response, method)
return processed_response
def llm_identify_temporal_and_unit_vars(
column_names: List[str],
column_dtypes: Dict[str, str],
dataset_description: str,
dataset_summary: str,
heuristic_time_candidates: Optional[List[str]] = None, # These are no longer used in the revised prompt
heuristic_id_candidates: Optional[List[str]] = None, # These are no longer used in the revised prompt
query: str = "No query provided.",
llm: Optional[BaseChatModel] = None
) -> Dict[str, Optional[str]]:
"""Uses LLM to identify the primary time:
Args:
column_names: List of all column names.
column_dtypes: Dictionary mapping column names to string representation of data types.
dataset_description: Textual description of the dataset.
dataset_summary: Summary of the dataset
heuristic_time_candidates: Optional list of columns identified as time vars by heuristics (currently unused by prompt).
heuristic_id_candidates: Optional list of columns identified as unit ID vars by heuristics (currently unused by prompt).
llm: The language model client instance.
Returns:
A dictionary with keys 'time_variable' and 'unit_variable',
whose values are the identified column names or None.
"""
if not llm:
logger.warning("LLM client not provided for temporal/unit identification. Returning None.")
return {"time_variable": None, "unit_variable": None}
logger.info("Attempting LLM identification of time and unit variables...")
# Construct the prompt (revised based on user feedback in conversation)
prompt = f"""
You are a data analysis expert tasked with determining whether a dataset supports a Difference-in-Differences (DiD) or Two-Way Fixed Effects (TWFE) design to answer the following query:
{query}
You are given the following information:
Dataset Description:
{dataset_description}
Columns and Data Types:
{column_dtypes}
First, based on the above information, check if any columns represent information about the time/periods associated directly with intervention application. It could be either:
1. A variable that represents **time periods associated with the intervention**. This must satisfy one of the following:
- A binary indicator showing pre/post-intervention status,
- A discrete or continuous variable that records **when units were observed**, which can be aligned with treatment application periods.
Do **not** select generic time-related variables that merely describe time as a feature, such as **'date of birth'**, **'year of graduation'**, 'week of sign-up', **'years of schooling'** unless they directly represent **observation times relevant to treatment**.
2. A variable that represents the **unit of observation** (e.g., individual, region, school) — the entity over which we compare treated vs. untreated groups across time.
Return ONLY a valid JSON object with this structure and no surrounding explanation:
{{
"time_variable": "<column_name_or_null>",
"unit_variable": "<column_name_or_null>"
}}
"""
parsed_response = None
try:
llm_response_obj = llm.invoke(prompt)
response_content = ""
if hasattr(llm_response_obj, 'content'):
response_content = llm_response_obj.content
elif isinstance(llm_response_obj, str): # Some LLMs might return str directly
response_content = llm_response_obj
else:
logger.warning(f"LLM response object type not recognized for content extraction: {type(llm_response_obj)}")
if response_content:
# Attempt to strip markdown ```json ... ``` if present
if response_content.strip().startswith("```json"):
response_content = response_content.strip()[7:-3].strip()
elif response_content.strip().startswith("```"):
response_content = response_content.strip()[3:-3].strip()
parsed_response = json.loads(response_content)
else:
logger.warning("LLM invocation returned no content.")
except json.JSONDecodeError as e:
logger.error(f"Failed to decode JSON from LLM response for time/unit vars: {e}. Response content: '{response_content[:500]}...'") # Log snippet
except Exception as e:
logger.error(f"Error during LLM invocation or processing for time/unit vars: {e}", exc_info=True)
# Process the response
if parsed_response and isinstance(parsed_response, dict):
time_var = parsed_response.get("time_variable")
unit_var = parsed_response.get("unit_variable")
# Basic validation: ensure returned names are actual columns or None
if time_var is not None and time_var not in column_names:
logger.warning(f"LLM identified time variable '{time_var}' not found in columns. Setting to None.")
time_var = None
if unit_var is not None and unit_var not in column_names:
logger.warning(f"LLM identified unit variable '{unit_var}' not found in columns. Setting to None.")
unit_var = None
logger.info(f"LLM identified time='{time_var}', unit='{unit_var}'")
return {"time_variable": time_var, "unit_variable": unit_var}
else:
logger.warning("LLM call failed or returned invalid/unparsable JSON for time/unit identification.")
return {"time_variable": None, "unit_variable": None} |