Spaces:
Running
Running
File size: 12,825 Bytes
1721aea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
from typing import List, Optional, Union, Dict, Any, Tuple
from pydantic import BaseModel, Field, validator
import json
# --- Pydantic models for LLM structured output ---
# These models are used by query_interpreter and potentially other components
# to structure the output received from Language Models.
class LLMSelectedVariable(BaseModel):
"""Pydantic model for selecting a single variable."""
variable_name: Optional[str] = Field(None, description="The single best column name selected.")
class LLMSelectedCovariates(BaseModel):
"""Pydantic model for selecting a list of covariates."""
covariates: List[str] = Field(default_factory=list, description="The list of selected covariate column names.")
class LLMIVars(BaseModel):
"""Pydantic model for identifying IVs."""
instrument_variable: Optional[str] = Field(None, description="The identified instrumental variable column name.")
class LLMEstimand(BaseModel):
"""Pydantic model for identifying estimand"""
estimand: Optional[str] = Field(None, description="The identified estimand")
class LLMRDDVars(BaseModel):
"""Pydantic model for identifying RDD variables."""
running_variable: Optional[str] = Field(None, description="The identified running variable column name.")
cutoff_value: Optional[Union[float, int]] = Field(None, description="The identified cutoff value.")
class LLMRCTCheck(BaseModel):
"""Pydantic model for checking if data is RCT."""
is_rct: Optional[bool] = Field(None, description="True if the data is from a randomized controlled trial, False otherwise, None if unsure.")
reasoning: Optional[str] = Field(None, description="Brief reasoning for the RCT conclusion.")
class LLMTreatmentReferenceLevel(BaseModel):
reference_level: Optional[str] = Field(None, description="The identified reference/control level for the treatment variable, if specified in the query. Should be one of the actual values in the treatment column.")
reasoning: Optional[str] = Field(None, description="Brief reasoning for identifying this reference level.")
class LLMInteractionSuggestion(BaseModel):
"""Pydantic model for LLM suggestion on interaction terms."""
interaction_needed: Optional[bool] = Field(None, description="True if an interaction term is strongly suggested by the query or context. LLM should provide true, false, or omit for None.")
interaction_variable: Optional[str] = Field(None, description="The name of the covariate that should interact with the treatment. Null if not applicable or if the interaction is complex/multiple.")
reasoning: Optional[str] = Field(None, description="Brief reasoning for the suggestion for or against an interaction term.")
# --- Pydantic models for Tool Inputs/Outputs and Data Structures ---
class TemporalStructure(BaseModel):
"""Represents detected temporal structure in the data."""
has_temporal_structure: bool
temporal_columns: List[str]
is_panel_data: bool
id_column: Optional[str] = None
time_column: Optional[str] = None
time_periods: Optional[int] = None
units: Optional[int] = None
class DatasetInfo(BaseModel):
"""Basic information about the dataset file."""
num_rows: int
num_columns: int
file_path: str
file_name: str
class DatasetAnalysis(BaseModel):
"""Results from the dataset analysis component."""
dataset_info: DatasetInfo
columns: List[str]
potential_treatments: List[str]
potential_outcomes: List[str]
temporal_structure_detected: bool
panel_data_detected: bool
potential_instruments_detected: bool
discontinuities_detected: bool
temporal_structure: TemporalStructure
column_categories: Optional[Dict[str, str]] = None
column_nunique_counts: Optional[Dict[str, int]] = None
sample_size: int
num_covariates_estimate: int
per_group_summary_stats: Optional[Dict[str, Dict[str, Any]]] = None
potential_instruments: Optional[List[str]] = None
overlap_assessment: Optional[Dict[str, Any]] = None
# --- Model for Dataset Analyzer Tool Output ---
class DatasetAnalyzerOutput(BaseModel):
"""Structured output for the dataset analyzer tool."""
analysis_results: DatasetAnalysis
dataset_description: Optional[str] = None
workflow_state: Dict[str, Any]
#TODO make query info consistent with the Data analysis out put
class QueryInfo(BaseModel):
"""Information extracted from the user's initial query."""
query_text: str
potential_treatments: Optional[List[str]] = None
potential_outcomes: Optional[List[str]] = None
covariates_hints: Optional[List[str]] = None
instrument_hints: Optional[List[str]] = None
running_variable_hints: Optional[List[str]] = None
cutoff_value_hint: Optional[Union[float, int]] = None
class QueryInterpreterInput(BaseModel):
"""Input structure for the query interpreter tool."""
query_info: QueryInfo
dataset_analysis: DatasetAnalysis
dataset_description: str
# Add original_query if it should be part of the standard input
original_query: Optional[str] = None
class Variables(BaseModel):
"""Structured variables identified by the query interpreter component."""
treatment_variable: Optional[str] = None
treatment_variable_type: Optional[str] = Field(None, description="Type of the treatment variable (e.g., 'binary', 'continuous', 'categorical_multi_value')")
outcome_variable: Optional[str] = None
instrument_variable: Optional[str] = None
covariates: Optional[List[str]] = Field(default_factory=list)
time_variable: Optional[str] = None
group_variable: Optional[str] = None # Often the unit ID
running_variable: Optional[str] = None
cutoff_value: Optional[Union[float, int]] = None
is_rct: Optional[bool] = Field(False, description="Flag indicating if the dataset is from an RCT.")
treatment_reference_level: Optional[str] = Field(None, description="The specified reference/control level for a multi-valued treatment variable.")
interaction_term_suggested: Optional[bool] = Field(False, description="Whether the query or context suggests an interaction term with the treatment might be relevant.")
interaction_variable_candidate: Optional[str] = Field(None, description="The covariate identified as a candidate for interaction with the treatment.")
class QueryInterpreterOutput(BaseModel):
"""Structured output for the query interpreter tool."""
variables: Variables
dataset_analysis: DatasetAnalysis
dataset_description: Optional[str]
workflow_state: Dict[str, Any]
original_query: Optional[str] = None
# Input model for Method Selector Tool
class MethodSelectorInput(BaseModel):
"""Input structure for the method selector tool."""
variables: Variables# Uses the Variables model identified by QueryInterpreter
dataset_analysis: DatasetAnalysis # Uses the DatasetAnalysis model
dataset_description: Optional[str] = None
original_query: Optional[str] = None
# Note: is_rct is expected inside inputs.variables
# --- Models for Method Validator Tool ---
class MethodInfo(BaseModel):
"""Information about the selected causal inference method."""
selected_method: Optional[str] = None
method_name: Optional[str] = None # Often a title-cased version for display
method_justification: Optional[str] = None
method_assumptions: Optional[List[str]] = Field(default_factory=list)
# Add alternative methods if it should be part of the standard info passed around
alternative_methods: Optional[List[str]] = Field(default_factory=list)
class MethodValidatorInput(BaseModel):
"""Input structure for the method validator tool."""
method_info: MethodInfo
variables: Variables
dataset_analysis: DatasetAnalysis
dataset_description: Optional[str] = None
original_query: Optional[str] = None
# --- Model for Method Executor Tool ---
class MethodExecutorInput(BaseModel):
"""Input structure for the method executor tool."""
method: str = Field(..., description="The causal method name (use recommended method if validation failed).")
variables: Variables # Contains T, O, C, etc.
dataset_path: str
dataset_analysis: DatasetAnalysis
dataset_description: Optional[str] = None
# Include validation_info from validator output if needed by estimator or LLM assist later?
validation_info: Optional[Any] = None
original_query: Optional[str] = None
# --- Model for Explanation Generator Tool ---
class ExplainerInput(BaseModel):
"""Input structure for the explanation generator tool."""
# Based on expected output from method_executor_tool and validator
method_info: MethodInfo
validation_info: Optional[Dict[str, Any]] = None # From validator tool
variables: Variables
results: Dict[str, Any] # Numerical results from executor
dataset_analysis: DatasetAnalysis
dataset_description: Optional[str] = None
# Add original query if needed for explanation context
original_query: Optional[str] = None
# Add other shared models/schemas below as needed.
class FormattedOutput(BaseModel):
"""
Structured output containing the final formatted results and explanations
from a causal analysis run.
"""
query: str = Field(description="The original user query.")
method_used: str = Field(description="The user-friendly name of the causal inference method used.")
causal_effect: Optional[float] = Field(None, description="The point estimate of the causal effect.")
standard_error: Optional[float] = Field(None, description="The standard error of the causal effect estimate.")
confidence_interval: Optional[Tuple[Optional[float], Optional[float]]] = Field(None, description="The confidence interval for the causal effect (e.g., 95% CI).")
p_value: Optional[float] = Field(None, description="The p-value associated with the causal effect estimate.")
summary: str = Field(description="A concise summary paragraph interpreting the main findings.")
method_explanation: Optional[str] = Field("", description="Explanation of the causal inference method used.")
interpretation_guide: Optional[str] = Field("", description="Guidance on how to interpret the results.")
limitations: Optional[List[str]] = Field(default_factory=list, description="List of limitations or potential issues with the analysis.")
assumptions: Optional[str] = Field("", description="Discussion of the key assumptions underlying the method and their validity.")
practical_implications: Optional[str] = Field("", description="Discussion of the practical implications or significance of the findings.")
# Optionally add dataset_analysis and dataset_description if they should be part of the final structure
# dataset_analysis: Optional[DatasetAnalysis] = None # Example if using DatasetAnalysis model
# dataset_description: Optional[str] = None
# This model itself doesn't include workflow_state, as it represents the *content*
# The tool using this component will add the workflow_state separately.
class LLMParameterDetails(BaseModel):
parameter_name: str = Field(description="The full parameter name as found in the model results.")
estimate: float
p_value: float
conf_int_low: float
conf_int_high: float
std_err: float
reasoning: Optional[str] = Field(None, description="Brief reasoning for selecting this parameter and its values.")
class LLMTreatmentEffectResults(BaseModel):
effects: Optional[Dict[str, LLMParameterDetails]] = Field(description="Dictionary where keys are treatment level names (e.g., 'LevelA', 'LevelB' if multi-level) or a generic key like 'treatment_effect' for binary/continuous treatments. Values are the statistical details for that effect.")
all_parameters_successfully_identified: Optional[bool] = Field(description="True if all expected treatment effect parameters were identified and their values extracted, False otherwise.")
overall_reasoning: Optional[str] = Field(None, description="Overall reasoning for the extraction process or if issues were encountered.")
class RelevantParamInfo(BaseModel):
param_name: str = Field(description="The exact parameter name as it appears in the statsmodels results.")
param_index: int = Field(description="The index of this parameter in the original list of parameter names.")
class LLMIdentifiedRelevantParams(BaseModel):
identified_params: List[RelevantParamInfo] = Field(description="A list of parameters identified as relevant to the query or representing all treatment effects for a general query.")
all_parameters_successfully_identified: bool = Field(description="True if LLM is confident it identified all necessary params based on query type (e.g., all levels for a general query).")
|