FireShadow's picture
Initial clean commit
1721aea
"""
Input parser component for extracting information from causal queries.
This module provides functionality to parse user queries and extract key
elements such as the causal question, relevant variables, and constraints.
"""
import re
import os
import json
import logging # Added for better logging
from typing import Dict, List, Any, Optional, Union
import pandas as pd
from pydantic import BaseModel, Field, ValidationError
from functools import partial # Import partial
# Add dotenv import
from dotenv import load_dotenv
# LangChain Imports
from langchain_openai import ChatOpenAI # Example, replace if using another provider
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.exceptions import OutputParserException # Correct path
from langchain_core.language_models import BaseChatModel # Import BaseChatModel
# --- Load .env file ---
load_dotenv() # Load environment variables from .env file
# --- Configure Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Instantiate LLM Client ---
# Ensure OPENAI_API_KEY environment variable is set
# Consider making model name configurable
try:
# Using with_structured_output later, so instantiate base model here
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
# Add a check or allow configuration for different providers if needed
except ImportError:
logger.error("langchain_openai not installed. Please install it to use OpenAI models.")
llm = None
except Exception as e:
logger.error(f"Error initializing LLM: {e}. Input parsing will rely on fallbacks.")
llm = None
# --- Pydantic Models for Structured Output ---
class ParsedVariables(BaseModel):
treatment: List[str] = Field(default_factory=list, description="Variable(s) representing the treatment/intervention.")
outcome: List[str] = Field(default_factory=list, description="Variable(s) representing the outcome/result.")
covariates_mentioned: Optional[List[str]] = Field(default_factory=list, description="Covariate/control variable(s) explicitly mentioned in the query.")
grouping_vars: Optional[List[str]] = Field(default_factory=list, description="Variable(s) identifying groups or units for analysis.")
instruments_mentioned: Optional[List[str]] = Field(default_factory=list, description="Potential instrumental variable(s) mentioned.")
class ParsedQueryInfo(BaseModel):
query_type: str = Field(..., description="Type of query (e.g., EFFECT_ESTIMATION, COUNTERFACTUAL, CORRELATION, DESCRIPTIVE, OTHER). Required.")
variables: ParsedVariables = Field(..., description="Variables identified in the query.")
constraints: Optional[List[str]] = Field(default_factory=list, description="Constraints or conditions mentioned (e.g., 'X > 10', 'country = USA').")
dataset_path_mentioned: Optional[str] = Field(None, description="Dataset path explicitly mentioned in the query, if any.")
# Add Pydantic model for path extraction
class ExtractedPath(BaseModel):
dataset_path: Optional[str] = Field(None, description="File path or URL for the dataset mentioned in the query.")
# --- End Pydantic Models ---
def _build_llm_prompt(query: str, dataset_info: Optional[Dict] = None) -> str:
"""Builds the prompt for the LLM to extract query information."""
dataset_context = "No dataset context provided."
if dataset_info:
columns = dataset_info.get('columns', [])
column_details = "\n".join([f"- {col} (Type: {dataset_info.get('column_types', {}).get(col, 'Unknown')})" for col in columns])
sample_rows = dataset_info.get('sample_rows', 'Not available')
# Ensure sample rows are formatted reasonably
if isinstance(sample_rows, list):
sample_rows_str = json.dumps(sample_rows[:3], indent=2) # Show first 3 sample rows
elif isinstance(sample_rows, str):
sample_rows_str = sample_rows
else:
sample_rows_str = 'Not available'
dataset_context = f"""
Dataset Context:
Columns:
{column_details}
Sample Rows (first few):
{sample_rows_str}
"""
prompt = f"""
Analyze the following causal query **strictly in the context of the provided dataset information (if available)**. Identify the query type, key variables (mapping query terms to actual column names when possible), constraints, and any explicitly mentioned dataset path.
User Query: "{query}"
{dataset_context}
# Add specific guidance for query types
Guidance for Identifying Query Type:
- EFFECT_ESTIMATION: Look for keywords like 'effect', 'impact', 'influence', 'cause', 'affect', 'consequence'. Also consider questions asking "how does X affect Y?" or comparing outcomes between groups based on an intervention.
- COUNTERFACTUAL: Look for hypothetical scenarios, often using phrases like 'what if', 'if X had been', 'would Y have changed', 'imagine if', 'counterfactual'.
- CORRELATION: Look for keywords like 'correlation', 'association', 'relationship', 'linked to', 'related to'. These queries ask about statistical relationships without necessarily implying causality.
- DESCRIPTIVE: These queries ask for summaries, descriptions, trends, or statistics about the data without investigating causal links or relationships (e.g., "Show sales over time", "What is the average age?").
- OTHER: Use this if the query does not fit any of the above categories.
Choose the most appropriate type from: EFFECT_ESTIMATION, COUNTERFACTUAL, CORRELATION, DESCRIPTIVE, OTHER.
Variable Roles to Identify:
- treatment: The intervention or variable whose effect is being studied.
- outcome: The result or variable being measured.
- covariates_mentioned: Variables explicitly mentioned to control for or adjust for.
- grouping_vars: Variables identifying specific subgroups for analysis (e.g., 'for men', 'in the sales department').
- instruments_mentioned: Variables explicitly mentioned as potential instruments.
Constraints: Conditions applied to the analysis (e.g., filters on columns, specific time periods).
Dataset Path Mentioned: Extract the file path or URL if explicitly stated in the query.
**Output ONLY a valid JSON object** matching this exact schema (no explanations, notes, or surrounding text):
```json
{{
"query_type": "<Identified Query Type>",
"variables": {{
"treatment": ["<Treatment Variable(s) Mentioned>"],
"outcome": ["<Outcome Variable(s) Mentioned>"],
"covariates_mentioned": ["<Covariate(s) Mentioned>"],
"grouping_vars": ["<Grouping Variable(s) Mentioned>"],
"instruments_mentioned": ["<Instrument(s) Mentioned>"]
}},
"constraints": ["<Constraint 1>", "<Constraint 2>"],
"dataset_path_mentioned": "<Path Mentioned or null>"
}}
```
If Dataset Context is provided, ensure variable names in the output JSON correspond to actual column names where possible. If no context is provided, or if a mentioned variable doesn't map directly, use the phrasing from the query.
Respond with only the JSON object.
"""
return prompt
def _validate_llm_output(parsed_info: ParsedQueryInfo, dataset_info: Optional[Dict] = None) -> bool:
"""Perform basic assertions on the parsed LLM output."""
# 1. Check required fields exist (Pydantic handles this on parsing)
# 2. Check query type is one of the allowed types (can add enum to Pydantic later)
allowed_types = {"EFFECT_ESTIMATION", "COUNTERFACTUAL", "CORRELATION", "DESCRIPTIVE", "OTHER"}
print(parsed_info)
assert parsed_info.query_type in allowed_types, f"Invalid query_type: {parsed_info.query_type}"
# 3. Check that if it's an effect query, treatment and outcome are likely present
if parsed_info.query_type == "EFFECT_ESTIMATION":
# Check that the lists are not empty
assert parsed_info.variables.treatment, "Treatment variable list is empty for effect query."
assert parsed_info.variables.outcome, "Outcome variable list is empty for effect query."
# 4. If dataset_info provided, check if extracted variables exist in columns
if dataset_info and (columns := dataset_info.get('columns')):
all_extracted_vars = set()
for var_list in parsed_info.variables.model_dump().values(): # Iterate through variable lists
if var_list: # Ensure var_list is not None or empty
all_extracted_vars.update(var_list)
unknown_vars = all_extracted_vars - set(columns)
# Allow for non-column variables if context is missing? Maybe relax this.
# For now, strict check if columns are provided.
if unknown_vars:
logger.warning(f"LLM mentioned variables potentially not in dataset columns: {unknown_vars}")
# Decide if this should be a hard failure (AssertionError) or just a warning.
# Let's make it a hard failure for now to enforce mapping.
raise AssertionError(f"LLM hallucinated variables not in dataset columns: {unknown_vars}")
logger.info("LLM output validation passed.")
return True
def _extract_query_information_with_llm(query: str, dataset_info: Optional[Dict] = None, llm: Optional[BaseChatModel] = None, max_retries: int = 3) -> Optional[ParsedQueryInfo]:
"""Extracts query type, variables, and constraints using LLM with retries and validation."""
if not llm:
logger.error("LLM client not provided. Cannot perform LLM extraction.")
return None
last_error = None
# Bind the Pydantic model to the LLM for structured output
structured_llm = llm.with_structured_output(ParsedQueryInfo)
# Initial prompt construction
system_prompt_content = _build_llm_prompt(query, dataset_info)
messages = [HumanMessage(content=system_prompt_content)] # Start with just the detailed prompt as Human message
for attempt in range(max_retries):
logger.info(f"LLM Extraction Attempt {attempt + 1}/{max_retries}...")
try:
# --- Invoke LangChain LLM with structured output (using passed llm) ---
parsed_info = structured_llm.invoke(messages)
# ---------------------------------------------------
print(messages)
print('---------------------------------------------------')
print(parsed_info)
# Perform custom assertions/validation
if _validate_llm_output(parsed_info, dataset_info):
return parsed_info # Success!
# Catch errors specific to structured output parsing or Pydantic validation
except (OutputParserException, ValidationError, AssertionError) as e:
logger.warning(f"Validation/Parsing Error (Attempt {attempt + 1}): {e}")
last_error = e
# Add feedback message for retry
messages.append(SystemMessage(content=f"Your previous response failed validation: {str(e)}. Please revise your response to be valid JSON conforming strictly to the schema and ensure variable names exist in the dataset context."))
continue # Go to next retry
except Exception as e: # Catch other potential LLM API errors
logger.error(f"Unexpected LLM Error (Attempt {attempt + 1}): {e}", exc_info=True)
last_error = e
break # Stop retrying on unexpected API errors
logger.error(f"LLM extraction failed after {max_retries} attempts.")
if last_error:
logger.error(f"Last error: {last_error}")
return None # Indicate failure
# Add helper function to call LLM for path - needs llm argument
def _call_llm_for_path(query: str, llm: Optional[BaseChatModel] = None, max_retries: int = 2) -> Optional[str]:
"""Uses LLM as a fallback to extract just the dataset path."""
if not llm:
logger.warning("LLM client not provided. Cannot perform LLM path fallback.")
return None
logger.info("Attempting LLM fallback for dataset path extraction...")
path_extractor_llm = llm.with_structured_output(ExtractedPath)
prompt = f"Extract the dataset file path (e.g., /path/to/file.csv or https://...) mentioned in the following query. Respond ONLY with the JSON object.\nQuery: \"{query}\""
messages = [HumanMessage(content=prompt)]
last_error = None
for attempt in range(max_retries):
try:
parsed_info = path_extractor_llm.invoke(messages)
if parsed_info.dataset_path:
logger.info(f"LLM fallback extracted path: {parsed_info.dataset_path}")
return parsed_info.dataset_path
else:
logger.info("LLM fallback did not find a path.")
return None # LLM explicitly found no path
except (OutputParserException, ValidationError) as e:
logger.warning(f"LLM path extraction parsing/validation error (Attempt {attempt+1}): {e}")
last_error = e
messages.append(SystemMessage(content=f"Parsing Error: {e}. Please ensure you provide valid JSON with only the 'dataset_path' key."))
continue
except Exception as e:
logger.error(f"Unexpected LLM Error during path fallback (Attempt {attempt+1}): {e}", exc_info=True)
last_error = e
break # Don't retry on unexpected errors
logger.error(f"LLM path fallback failed after {max_retries} attempts. Last error: {last_error}")
return None
# Renamed and modified function for regex path extraction + LLM fallback - needs llm argument
def extract_dataset_path(query: str, llm: Optional[BaseChatModel] = None) -> Optional[str]:
"""
Extract dataset path from the query using regex patterns, with LLM fallback.
Args:
query: The user's causal question text
llm: The shared LLM client instance for fallback.
Returns:
String with dataset path or None if not found
"""
# --- Regex Part (existing logic) ---
# Check for common patterns indicating dataset paths
path_patterns = [
# More specific patterns first
r"(?:dataset|data|file) (?:at|in|from|located at) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?", # Handles subdirs in path
r"(?:use|using|analyze|analyse) (?:the |)(?:dataset|data|file) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?",
# Simpler patterns
r"[\"']([^\"']+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"']", # Path in quotes
r"([a-zA-Z0-9_/.:-]+[\\/][a-zA-Z0-9_.:-]+\.csv)", # More generic path-like structure ending in .csv
r"([^\"\'.,\s]+\.csv)" # Just a .csv file name (least specific)
]
for pattern in path_patterns:
matches = re.search(pattern, query, re.IGNORECASE)
if matches:
path = matches.group(1).strip()
# Basic check if it looks like a path
if '/' in path or '\\' in path or os.path.exists(path):
# Check if this is a valid file path immediately
if os.path.exists(path):
logger.info(f"Regex found existing path: {path}")
return path
# Check if it's in common data directories
data_dir_paths = ["data/", "datasets/", "causalscientist/data/"]
for data_dir in data_dir_paths:
potential_path = os.path.join(data_dir, os.path.basename(path))
if os.path.exists(potential_path):
logger.info(f"Regex found path in {data_dir}: {potential_path}")
return potential_path
# If not found but looks like a path, return it anyway - let downstream handle non-existence
logger.info(f"Regex found potential path (existence not verified): {path}")
return path
# Else: it might just be a word ending in .csv, ignore unless it exists
elif os.path.exists(path):
logger.info(f"Regex found existing path (simple pattern): {path}")
return path
# --- LLM Fallback ---
logger.info("Regex did not find dataset path. Trying LLM fallback...")
llm_fallback_path = _call_llm_for_path(query, llm=llm)
if llm_fallback_path:
# Optional: Add existence check here too? Or let downstream handle it.
# For now, return what LLM found.
return llm_fallback_path
logger.info("No dataset path found via regex or LLM fallback.")
return None
def parse_input(query: str, dataset_path_arg: Optional[str] = None, dataset_info: Optional[Dict] = None, llm: Optional[BaseChatModel] = None) -> Dict[str, Any]:
"""
Parse the user's causal query using LLM and regex.
Args:
query: The user's causal question text.
dataset_path_arg: Path to dataset if provided directly as an argument.
dataset_info: Dictionary with dataset context (columns, types, etc.).
llm: The shared LLM client instance.
Returns:
Dict containing parsed query information.
"""
result = {
"original_query": query,
"dataset_path": dataset_path_arg, # Start with argument path
"query_type": "OTHER", # Default values
"extracted_variables": {},
"constraints": []
}
# --- 1. Use LLM for core NLP tasks ---
parsed_llm_info = _extract_query_information_with_llm(query, dataset_info, llm=llm)
if parsed_llm_info:
result["query_type"] = parsed_llm_info.query_type
result["extracted_variables"] = {k: v if v is not None else [] for k, v in parsed_llm_info.variables.model_dump().items()}
result["constraints"] = parsed_llm_info.constraints if parsed_llm_info.constraints is not None else []
llm_mentioned_path = parsed_llm_info.dataset_path_mentioned
else:
logger.warning("LLM-based query information extraction failed.")
llm_mentioned_path = None
# Consider falling back to old regex methods here if critical
# logger.info("Falling back to regex-based parsing (if implemented).")
# --- 2. Determine Dataset Path (Hybrid Approach) ---
final_dataset_path = dataset_path_arg # Priority 1: Explicit argument
# Pass llm instance to the path extractor for its fallback mechanism
path_extractor = partial(extract_dataset_path, llm=llm)
if not final_dataset_path:
# Priority 2: Path mentioned in query (extracted by main LLM call)
if llm_mentioned_path and os.path.exists(llm_mentioned_path):
logger.info(f"Using dataset path mentioned by LLM: {llm_mentioned_path}")
final_dataset_path = llm_mentioned_path
elif llm_mentioned_path: # Check data dirs if path not absolute
data_dir_paths = ["data/", "datasets/", "causalscientist/data/"]
base_name = os.path.basename(llm_mentioned_path)
for data_dir in data_dir_paths:
potential_path = os.path.join(data_dir, base_name)
if os.path.exists(potential_path):
logger.info(f"Using dataset path mentioned by LLM (found in {data_dir}): {potential_path}")
final_dataset_path = potential_path
break
if not final_dataset_path:
logger.warning(f"LLM mentioned path '{llm_mentioned_path}' but it was not found.")
if not final_dataset_path:
# Priority 3: Path extracted by dedicated Regex + LLM fallback function
logger.info("Attempting dedicated dataset path extraction (Regex + LLM Fallback)...")
extracted_path = path_extractor(query) # Call the partial function with llm bound
if extracted_path:
final_dataset_path = extracted_path
result["dataset_path"] = final_dataset_path
# Check if a path was found ultimately
if not result["dataset_path"]:
logger.warning("Could not determine dataset path from query or arguments.")
else:
logger.info(f"Final dataset path determined: {result['dataset_path']}")
return result
# --- Old Regex-based functions (Commented out or removed) ---
# def determine_query_type(query: str) -> str:
# ... (implementation removed)
# def extract_variables(query: str) -> Dict[str, Any]:
# ... (implementation removed)
# def detect_constraints(query: str) -> List[str]:
# ... (implementation removed)
# --- End Old Functions ---
# Renamed function for regex path extraction
def extract_dataset_path_regex(query: str) -> Optional[str]:
"""
Extract dataset path from the query using regex patterns.
Args:
query: The user's causal question text
Returns:
String with dataset path or None if not found
"""
# Check for common patterns indicating dataset paths
path_patterns = [
# More specific patterns first
r"(?:dataset|data|file) (?:at|in|from|located at) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?", # Handles subdirs in path
r"(?:use|using|analyze|analyse) (?:the |)(?:dataset|data|file) [\"\']?([^\"\'.,\s]+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"\']?",
# Simpler patterns
r"[\"']([^\"']+\.csv(?:[\\/][^\"\'.,\s]+)*)[\"']", # Path in quotes
r"([a-zA-Z0-9_/.:-]+[\\/][a-zA-Z0-9_.:-]+\.csv)", # More generic path-like structure ending in .csv
r"([^\"\'.,\s]+\.csv)" # Just a .csv file name (least specific)
]
for pattern in path_patterns:
matches = re.search(pattern, query, re.IGNORECASE)
if matches:
path = matches.group(1).strip()
# Basic check if it looks like a path
if '/' in path or '\\' in path or os.path.exists(path):
# Check if this is a valid file path immediately
if os.path.exists(path):
logger.info(f"Regex found existing path: {path}")
return path
# Check if it's in common data directories
data_dir_paths = ["data/", "datasets/", "causalscientist/data/"]
# Also check relative to current dir (often useful)
# base_name = os.path.basename(path)
for data_dir in data_dir_paths:
potential_path = os.path.join(data_dir, os.path.basename(path))
if os.path.exists(potential_path):
logger.info(f"Regex found path in {data_dir}: {potential_path}")
return potential_path
# If not found but looks like a path, return it anyway - let downstream handle non-existence
logger.info(f"Regex found potential path (existence not verified): {path}")
return path
# Else: it might just be a word ending in .csv, ignore unless it exists
elif os.path.exists(path):
logger.info(f"Regex found existing path (simple pattern): {path}")
return path
# TODO: Optional: Add LLM fallback call here if regex fails
# if no path found:
# llm_fallback_path = call_llm_for_path(query)
# return llm_fallback_path
return None