File size: 30,469 Bytes
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
"""
Query interpreter component for causal inference.

This module provides functionality to match query concepts to actual dataset variables,
identifying treatment, outcome, and covariate variables for causal inference analysis.
"""

import re
from typing import Dict, List, Any, Optional, Union, Tuple
import pandas as pd
import logging
import numpy as np
from auto_causal.config import get_llm_client
# Import LLM and message types
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_core.exceptions import OutputParserException
# Import base Pydantic models needed directly
from pydantic import BaseModel, ValidationError
from dowhy import CausalModel
import json

# Import shared Pydantic models from the central location
from auto_causal.models import (
    LLMSelectedVariable,
    LLMSelectedCovariates,
    LLMIVars,
    LLMRDDVars,
    LLMRCTCheck,
    LLMTreatmentReferenceLevel,
    LLMInteractionSuggestion, 
    LLMEstimand, 
    # LLMDIDCheck,
    # LLMDiDTemporalVars,
    # LLMDiDGroupVars,
    # LLMRDDCheck,
    # LLMRDDVarsExtended
)

# Import the new prompt templates
from auto_causal.prompts.method_identification_prompts import (
    IV_IDENTIFICATION_PROMPT_TEMPLATE,
    RDD_IDENTIFICATION_PROMPT_TEMPLATE,
    RCT_IDENTIFICATION_PROMPT_TEMPLATE,
    TREATMENT_REFERENCE_IDENTIFICATION_PROMPT_TEMPLATE,
    INTERACTION_TERM_IDENTIFICATION_PROMPT_TEMPLATE,
    TREATMENT_VAR_IDENTIFICATION_PROMPT_TEMPLATE, 
    OUTCOME_VAR_IDENTIFICATION_PROMPT_TEMPLATE,
    COVARIATES_IDENTIFICATION_PROMPT_TEMPLATE, 
    ESTIMAND_PROMPT_TEMPLATE,
    CONFOUNDER_IDENTIFICATION_PROMPT_TEMPLATE,
    DID_TERM_IDENTIFICATION_PROMPT_TEMPLATE)


# Assume central models are defined elsewhere or keep local definitions for now
# from ..models import ... 

# --- Pydantic models for LLM structured output --- 
# REMOVED - Now defined in causalscientist/auto_causal/models.py
# class LLMSelectedVariable(BaseModel): ...
# class LLMSelectedCovariates(BaseModel): ...
# class LLMIVars(BaseModel): ...
# class LLMRDDVars(BaseModel): ...
# class LLMRCTCheck(BaseModel): ...


logger = logging.getLogger(__name__)

def infer_treatment_variable_type(treatment_variable: str, column_categories: Dict[str, str],
                                  dataset_analysis: Dict[str, Any]) -> str:
    """
    Determine treatment variable type from column category and unique value count
    Args:
        treatment_variable: name of the treatment variable
        column_categories: mapping of column names to their categories
        dataset_analysis: exploratory analysis results

    Returns:
        str: type of the treatment variable (e.g., "binary", "continuous", etc
    """

    treatment_variable_type = "unknown"
    if treatment_variable and treatment_variable in column_categories:
        category = column_categories[treatment_variable]
        logger.info(f"Category for treatment '{treatment_variable}' is '{category}'.")

        if category == "continuous_numeric":
            treatment_variable_type = "continuous"

        elif category == "discrete_numeric":
            num_unique = dataset_analysis.get("column_nunique_counts", {}).get(treatment_variable, -1)
            if num_unique > 10:
                logger.info(f"'{treatment_variable}' has {num_unique} unique values, treating as continuous.")
                treatment_variable_type = "continuous"
            elif num_unique == 2:
                logger.info(f"'{treatment_variable}' has 2 unique values, treating as binary.")
                treatment_variable_type = "binary"
            elif num_unique > 0:
                logger.info(f"'{treatment_variable}' has {num_unique} unique values, treating as discrete_multi_value.")
                treatment_variable_type = "discrete_multi_value"
            else:
                logger.info(f"'{treatment_variable}' unique value count unknown or too few.")
                treatment_variable_type = "discrete_numeric_unknown_cardinality"

        elif category in ["binary", "binary_categorical"]:
            treatment_variable_type = "binary"

        elif category in ["categorical", "categorical_numeric"]:
            num_unique = dataset_analysis.get("column_nunique_counts", {}).get(treatment_variable, -1)
            if num_unique == 2:
                treatment_variable_type = "binary"
            elif num_unique > 0:
                treatment_variable_type = "categorical_multi_value"
            else:
                treatment_variable_type = "categorical_unknown_cardinality"

        else:
            logger.warning(f"Unmapped category '{category}' for '{treatment_variable}', setting as 'other'.")
            treatment_variable_type = "other"

    elif treatment_variable:
        logger.warning(f"'{treatment_variable}' not found in column_categories.")
    else:
        logger.info("No treatment variable identified.")

    logger.info(f"Final Determined Treatment Variable Type: {treatment_variable_type}")
    return treatment_variable_type

def determine_treatment_reference_level(is_rct: Optional[bool], llm: Optional[BaseChatModel], treatment_variable: Optional[str], 
                                      query_text: str, dataset_description: Optional[str], file_path: Optional[str], 
                                      columns: List[str]) -> Optional[str]:
    """
    Determines the treatment reference level
    """

    # If LLM didn't explicitly say RCT, default to False or keep None?
    # Let's default to False if LLM didn't provide a boolean value.
    if is_rct is None: is_rct = False
    treatment_reference_level = None

    if llm and treatment_variable and treatment_variable in columns:
        treatment_values_sample = []
        if file_path:
            try:
                df = pd.read_csv(file_path)
                if treatment_variable in df.columns:
                    unique_vals = df[treatment_variable].unique()
                    treatment_values_sample = [item.item() if hasattr(item, 'item') else item for item in unique_vals][:10]
                    if treatment_values_sample:
                        logger.info(f"Successfully read treatment values sample from dataset at '{file_path}' for variable '{treatment_variable}'.")
                    else:
                        logger.info(f"'{treatment_variable}' in '{file_path}' has no unique values or is empty.")
                else:
                    logger.warning(f"'{treatment_variable}' not found in dataset columns at '{file_path}'.")
            except FileNotFoundError:
                logger.warning(f"File not found at: {file_path}")
            except pd.errors.EmptyDataError:
                logger.warning(f"Empty file at: {file_path}")
            except Exception as e:
                logger.warning(f"Error reading dataset at '{file_path}' for '{treatment_variable}': {e}")

        if not treatment_values_sample:
            logger.warning(f"No unique values found for treatment '{treatment_variable}'. LLM prompt will receive empty list.")
        else:
            logger.info(f"Final treatment values sample: {treatment_values_sample}")

        try:
            prompt = TREATMENT_REFERENCE_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description or 'N/A', treatment_variable=treatment_variable, treatment_variable_values=treatment_values_sample)
            ref_result = _call_llm_for_var(llm, prompt, LLMTreatmentReferenceLevel)
            if ref_result and ref_result.reference_level:
                if treatment_values_sample and ref_result.reference_level not in treatment_values_sample:
                    logger.warning(f"LLM reference level '{ref_result.reference_level}' not in sampled values for '{treatment_variable}'.")
                treatment_reference_level = ref_result.reference_level
                logger.info(f"LLM identified reference level: {treatment_reference_level} (Reason: {ref_result.reasoning})")
            elif ref_result:
                logger.info(f"LLM returned no reference level (Reason: {ref_result.reasoning})")
        except Exception as e:
            logger.error(f"LLM error for treatment reference level: {e}")

    return treatment_reference_level

def identify_interaction_term(llm: Optional[BaseChatModel], treatment_variable: Optional[str], covariates: List[str],
                              column_categories: Dict[str, str], query_text: str, 
                              dataset_description: Optional[str]) -> Tuple[bool, Optional[str]]:
    """
    Identifies the interaction term based on the query and the dataset information
    """

    interaction_term_suggested, interaction_variable_candidate = False, None
    
    if llm and treatment_variable and covariates:
        try:
            covariates_list_str = "\n".join([f"- {cov}: {column_categories.get(cov, 'Unknown')}" for cov in covariates]) or "No covariates identified or available."
            prompt = INTERACTION_TERM_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description or 'N/A', treatment_variable=treatment_variable, covariates_list_with_types=covariates_list_str)
            result = _call_llm_for_var(llm, prompt, LLMInteractionSuggestion)
            if result:
                interaction_term_suggested = result.interaction_needed if result.interaction_needed is not None else False
                if interaction_term_suggested and result.interaction_variable:
                    if result.interaction_variable in covariates:
                        interaction_variable_candidate = result.interaction_variable
                        logger.info(f"LLM suggested interaction: needed={interaction_term_suggested}, variable='{interaction_variable_candidate}' (Reason: {result.reasoning})")
                    else:
                        logger.warning(f"LLM suggested variable '{result.interaction_variable}' not in covariates {covariates}. Ignoring.")
                        interaction_term_suggested = False
                elif interaction_term_suggested:
                    logger.info(f"LLM suggested interaction is needed but no variable provided (Reason: {result.reasoning})")
                else:
                    logger.info(f"LLM suggested no interaction is needed (Reason: {result.reasoning})")
            else:
                logger.warning("LLM returned no result for interaction term suggestion.")
        except Exception as e:
            logger.error(f"LLM error during interaction term check: {e}")

    return interaction_term_suggested, interaction_variable_candidate


def interpret_query(query_info: Dict[str, Any], dataset_analysis: Dict[str, Any],
                    dataset_description: Optional[str] = None) -> Dict[str, Any]:
    """
    Interpret query using hybrid heuristic/LLM approach to identify variables.
    
    Args:
        query_info: Information extracted from the user's query (text, hints).
        dataset_analysis: Information about the dataset structure (columns, types, etc.).
        dataset_description: Optional textual description of the dataset.
        llm: Optional language model instance.
        
    Returns:
        Dict containing identified variables (treatment, outcome, covariates, etc., and is_rct).
    """

    logger.info("Interpreting query with hybrid approach...")
    llm = get_llm_client()
    
    query_text = query_info.get("query_text", "")
    columns = dataset_analysis.get("columns", [])
    column_categories = dataset_analysis.get("column_categories", {})
    file_path = dataset_analysis["dataset_info"]["file_path"]

    
    # --- Identify Treatment --- 
    treatment_hints = query_info.get("potential_treatments", [])
    dataset_treatments = dataset_analysis.get("potential_treatments", [])
    treatment_variable = _identify_variable_hybrid(role="treatment", query_hints=treatment_hints, 
                                                   dataset_suggestions=dataset_treatments, columns=columns,
                                                   column_categories=column_categories,
                                                   prioritize_types=["binary", "binary_categorical", "discrete_numeric","continuous_numeric"], # Prioritize binary/discrete
                                                   query_text=query_text, dataset_description=dataset_description,llm=llm)
    logger.info(f"Identified Treatment: {treatment_variable}")
    treatment_variable_type = infer_treatment_variable_type(treatment_variable, column_categories, dataset_analysis)

    
    # --- Identify Outcome --- 
    outcome_hints = query_info.get("outcome_hints", [])
    dataset_outcomes = dataset_analysis.get("potential_outcomes", [])
    outcome_variable = _identify_variable_hybrid(role="outcome", query_hints=outcome_hints, dataset_suggestions=dataset_outcomes,
                                                 columns=columns, column_categories=column_categories,
                                                 prioritize_types=["continuous_numeric", "discrete_numeric"], # Prioritize numeric
                                                 exclude_vars=[treatment_variable], # Exclude treatment
                                                 query_text=query_text, dataset_description=dataset_description, llm=llm)
    logger.info(f"Identified Outcome: {outcome_variable}")

    # --- Identify Covariates --- 
    covariate_hints = query_info.get("covariates_hints", [])
    covariates = _identify_covariates_hybrid("covars", treatment_variable=treatment_variable, outcome_variable=outcome_variable,
                                             columns=columns, column_categories=column_categories, query_hints=covariate_hints,
                                             query_text=query_text, dataset_description=dataset_description, llm=llm)
    logger.info(f"Identified Covariates: {covariates}")

    # --- Identify Confounders ---
    confounder_hints = query_info.get("covariates_hints", [])
    confounders = _identify_covariates_hybrid("confounders", treatment_variable=treatment_variable, outcome_variable=outcome_variable,
                                              columns=columns, column_categories=column_categories, query_hints=confounder_hints,
                                              query_text=query_text, dataset_description=dataset_description, llm=llm)
    logger.info(f"Identified Confounders: {confounders}")

    # --- Identify Time/Group (from dataset analysis) --- 
    time_variable = None
    group_variable = None
    has_temporal = dataset_analysis.get("temporal_structure", {}).get("has_temporal_structure", False)
    temporal_structure = dataset_analysis.get("temporal_structure", {})
    if temporal_structure.get("has_temporal_structure", False):
        time_variable = temporal_structure.get("time_column") or temporal_structure.get("temporal_columns", [None])[0]
        if temporal_structure.get("is_panel_data", False):
            group_variable = temporal_structure.get("id_column")
    logger.info(f"Identified Time Var: {time_variable}, Group Var: {group_variable}, temporal structure: {temporal_structure}")

    # --- Identify IV/RDD/RCT using LLM --- 
    instrument_variable = None
    running_variable = None
    cutoff_value = None
    is_rct = None
    smd_score = None

    if llm:
        try:
            # Check for RCT
            prompt_rct = _create_identify_prompt("whether data is from RCT", query_text, dataset_description, columns, column_categories, treatment_variable, outcome_variable)
            rct_result = _call_llm_for_var(llm, prompt_rct, LLMRCTCheck)
            is_rct = rct_result.is_rct if rct_result else None
            logger.info(f"LLM identified RCT: {is_rct}")

            # Check for IV
            prompt_iv = _create_identify_prompt("instrumental variable", query_text, dataset_description, columns, column_categories, treatment_variable, outcome_variable)
            iv_result = _call_llm_for_var(llm, prompt_iv, LLMIVars)
            instrument_variable = iv_result.instrument_variable if iv_result else None
            if instrument_variable not in columns:
                instrument_variable = None  
            logger.info(f"LLM identified IV: {instrument_variable}")

            # Check for RDD
            prompt_rdd = _create_identify_prompt("regression discontinuity (running variable and cutoff)", query_text, dataset_description, columns, column_categories, treatment_variable, outcome_variable)
            rdd_result = _call_llm_for_var(llm, prompt_rdd, LLMRDDVars)
            if rdd_result:
                running_variable = rdd_result.running_variable
                cutoff_value = rdd_result.cutoff_value
            if running_variable not in columns or cutoff_value is None:
                running_variable = None
                cutoff_value = None
            logger.info(f"LLM identified RDD: Running={running_variable}, Cutoff={cutoff_value}")

            ## For graph based methods 
            exclude_cols = [treatment_variable, outcome_variable]
            potential_covariates = [col for col in columns if col not in exclude_cols and col is not None]
            usable_covariates = [col for col in potential_covariates if column_categories.get(col) not in ["text_or_other"]]
            logger.info(f"Usable covariates for graph: {usable_covariates}")
  
            estimand_prompt = ESTIMAND_PROMPT_TEMPLATE.format(query=query_text,dataset_description=dataset_description,
                                                               dataset_columns=usable_covariates,
                                                               treatment=treatment_variable, outcome=outcome_variable)

            estimand_result = _call_llm_for_var(llm, estimand_prompt, LLMEstimand)
            estimand = "ate" if "ate" in estimand_result.estimand.strip().lower() else "att"
            logger.info(f"LLM identified estimand: {estimand}")

            ## Did Term  
            did_term_prompt = DID_TERM_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description,
                                                                             column_info=columns, time_variable=time_variable,
                                                                             group_variable=group_variable, column_types=column_categories)
            did_term_result = _call_llm_for_var(llm, did_term_prompt, LLMRDDVars)
            did_term_result = did_term_result.did_term if did_term_result in columns else None
            logger.info(f"LLM identified DiD term: {did_term_result}")



            #smd_score_all = compute_smd(dataset_analysis.get("data", pd.DataFrame()), treatment_variable, usable_covariates)
            #smd_score = smd_score_all.get("ate", 0.0) if smd_score_all else 0.0
            #logger.info(f"Computed SMD score: {smd_score}")

            #logger.debug(f"Computed SMD score for {estimand}: {smd_score}")


        except Exception as e:
            logger.error(f"Error during LLM checks for IV/RDD/RCT: {e}")
            


    # --- Identify Treatment Reference Level --- 
    treatment_reference_level = determine_treatment_reference_level(is_rct=is_rct, llm=llm, treatment_variable=treatment_variable,
                                                                    query_text=query_text, dataset_description=dataset_description, 
                                                                    file_path=file_path, columns=columns)

    # --- Identify Interaction Term Suggestion --- 
    interaction_term_suggested, interaction_variable_candidate = identify_interaction_term(llm=llm, treatment_variable=treatment_variable, 
                                                                                           covariates=covariates,
                                                                                           column_categories=column_categories, query_text=query_text, 
                                                                                           dataset_description=dataset_description)
    

    # --- Consolidate --- 
    return {
        "treatment_variable": treatment_variable,
        "treatment_variable_type": treatment_variable_type,
        "outcome_variable": outcome_variable,
        "covariates": covariates,
        "time_variable": time_variable,
        "group_variable": group_variable,
        "instrument_variable": instrument_variable,
        "running_variable": running_variable,
        "cutoff_value": cutoff_value,
        "is_rct": is_rct,
        "treatment_reference_level": treatment_reference_level,
        "interaction_term_suggested": interaction_term_suggested,
        "interaction_variable_candidate": interaction_variable_candidate, 
        "confounders": confounders,
        "did_term": did_term_result
    }

def compute_smd(df: pd.DataFrame, treat, covars_list) -> Dict[str, float]:
    """
    Computed the standardized mean differences (SMD) for the treatment variable
    Args:
        df (pd.DataFrame): The dataset.
        treat (str): Name of the binary treatment column (0/1).
        covars_list (List[str]): List of covariate names to consider for SMD calculation

    Returns:
        Dict{str ->float}: the standardized mean difference (SMD)
    """
    logger.info(f"Computing SMD for treatment variable '{treat}' with covariates: {covars_list}")
    df_t = df[df[treat] == 1]
    df_c = df[df[treat] == 0]

    covariates = covars_list if covars_list else df.columns.tolist()
    smd_ate = np.zeros(len(covariates))
    smd_att = np.zeros(len(covariates))

    for i, col in enumerate(covariates):
        try:
            m_t, m_c = df_t[col].mean(), df_c[col].mean()
            s_t, s_c = df_t[col].std(ddof=0), df_c[col].std(ddof=0)
            pooled = np.sqrt((s_t**2 + s_c**2) / 2)

            ate_val = 0.0 if pooled == 0 else (m_t - m_c) / pooled
            att_val = 0.0 if s_t == 0 else (m_t - m_c) / s_t

            smd_ate.append(ate_val)
            smd_att.append(att_val)
        except Exception as e:
            logger.warning(f"SMD computation failed for column '{col}': {e}")
            continue

    avg_ate = np.nanmean(np.abs(smd_ate))
    avg_att = np.nanmean(np.abs(smd_att))

    return {"ate":avg_ate, "att":avg_att}



# --- Helper Functions for Hybrid Identification --- 
def _identify_variable_hybrid(role: str, query_hints: List[str], dataset_suggestions: List[str],
                               columns: List[str], column_categories: Dict[str, str],
                               prioritize_types: List[str], query_text: str,
                               dataset_description: Optional[str],llm: Optional[BaseChatModel],
                               exclude_vars: Optional[List[str]] = None) -> Optional[str]:
    """
    Used to identify a variable from the avaiable information by prompting the LLM. In case of failure, 
    it will fallback to a programmatic selection (heuristics)

    Args:
        role: variable type (treatment or outcome)
        query_hints: hints from the query for this variable
        dataset_suggestions: dataset-specific suggestions for this variable
        columns: list of available columns in the dataset
        column_categories: mapping of column names to their categories
        prioritize_types: types to prioritize for this variable
        query_text: the original query text
        dataset_description: description of the dataset   
        llm: language model 
        exclude_vars: list of variables to exclude from selection (e.g., treatment for outcome)
    Returns:
        str: name of the identified variable, or None if not found
    """

    candidates = set()
    available_columns = [c for c in columns if c not in (exclude_vars or [])]
    if not available_columns: return None

    # 1. Exact matches from hints
    for hint in query_hints:
        if hint in available_columns:
            candidates.add(hint)
    # 2. Add dataset suggestions
    for sugg in dataset_suggestions:
        if sugg in available_columns:
            candidates.add(sugg)

    # 3. Programmatic Filtering based on type
    plausible_candidates = [c for c in candidates if column_categories.get(c) in prioritize_types]

    if llm:
        if role == "treatment":
            prompt_template = TREATMENT_VAR_IDENTIFICATION_PROMPT_TEMPLATE
        elif role == "outcome":
            prompt_template = OUTCOME_VAR_IDENTIFICATION_PROMPT_TEMPLATE
        else:
            raise ValueError(f"Unsupported role for LLM variable identification: {role}")

        prompt = prompt_template.format(query=query_text, description=dataset_description,
                                        column_info=available_columns)
        llm_choice = _call_llm_for_var(llm, prompt, LLMSelectedVariable)

        if llm_choice and llm_choice.variable_name in available_columns:
            logger.info(f"LLM selected {role}: {llm_choice.variable_name}")
            return llm_choice.variable_name
        else:
            fallback = plausible_candidates[0] if plausible_candidates else None
            logger.warning(f"LLM failed to select valid {role}. Falling back to: {fallback}")
            return fallback

    if plausible_candidates:
        logger.info(f"No LLM provided. Using first plausible {role}: {plausible_candidates[0]}")
        return plausible_candidates[0]

    logger.warning(f"No plausible candidates for {role}. Cannot identify variable.")
    return None


def _identify_covariates_hybrid(role, treatment_variable: Optional[str], outcome_variable: Optional[str],
                                columns: List[str], column_categories: Dict[str, str], query_hints: List[str], 
                                query_text: str, dataset_description: Optional[str], llm: Optional[BaseChatModel]) -> List[str]:
    """
    Prompts an LLM to identify the covariates
    """
    
    # 1. Initial Programmatic Filtering
    exclude_cols = [treatment_variable, outcome_variable]
    potential_covariates = [col for col in columns if col not in exclude_cols and col is not None]
    
    # Filter out unusable types
    usable_covariates = [col for col in potential_covariates if column_categories.get(col) not in ["text_or_other"]]
    logger.debug(f"Initial usable covariates: {usable_covariates}")

    # 2. LLM Refinement (if LLM available)
    if llm:
        logger.info("Using LLM to refine covariate list...")
        prompt = ""
        if role == "covars":
            prompt = COVARIATES_IDENTIFICATION_PROMPT_TEMPLATE.format("covars", query=query_text, description=dataset_description, 
                                                                 column_info=", ".join(usable_covariates), 
                                                                 treatment=treatment_variable, outcome=outcome_variable)
        elif role == "confounders":
            prompt = CONFOUNDER_IDENTIFICATION_PROMPT_TEMPLATE.format(query=query_text, description=dataset_description, 
                                                       column_info=", ".join(usable_covariates), 
                                                       treatment=treatment_variable, outcome=outcome_variable)
        llm_selection = _call_llm_for_var(llm, prompt, LLMSelectedCovariates)
        
        if llm_selection and llm_selection.covariates:
            # Validate LLM output against available columns
            valid_llm_covs = [c for c in llm_selection.covariates if c in usable_covariates]
            if len(valid_llm_covs) < len(llm_selection.covariates):
                 logger.warning("LLM suggested covariates not found in initial usable list.")
            if valid_llm_covs: # Use LLM selection if it's valid and non-empty
                 logger.info(f"LLM refined covariates to: {valid_llm_covs}")
                 return valid_llm_covs[:10] # Cap at 10
            else:
                 logger.warning("LLM refinement failed or returned empty/invalid list. Falling back.")
        else:
             logger.warning("LLM refinement call failed or returned no covariates. Falling back.")

    # 3. Fallback to Programmatic List (Capped)
    logger.info(f"Using programmatically determined covariates (capped at 10): {usable_covariates[:10]}")
    return usable_covariates[:10]

def _create_identify_prompt(target: str, query: str, description: Optional[str], columns: List[str], 
                            categories: Dict[str,str], treatment: Optional[str], outcome: Optional[str]) -> str:
    """
    Creates a prompt to ask LLM to identify specific roles like IV, RDD, or RCT by selecting and formatting a specific template
    """
    column_info = "\n".join([f"- '{c}' (Type: {categories.get(c, 'Unknown')})" for c in columns])
    
    # Select the appropriate detailed prompt template based on the target
    if "instrumental variable" in target.lower():
        template = IV_IDENTIFICATION_PROMPT_TEMPLATE
    elif "regression discontinuity" in target.lower():
        template = RDD_IDENTIFICATION_PROMPT_TEMPLATE
    elif "rct" in target.lower():
        template = RCT_IDENTIFICATION_PROMPT_TEMPLATE
    else:
        # Fallback or error? For now, let's raise an error if target is unexpected.
        logger.error(f"Unsupported target for _create_identify_prompt: {target}")
        raise ValueError(f"Unsupported target for specific identification prompt: {target}")

    # Format the selected template with the provided context
    prompt = template.format(query=query, description=description or 'N/A', column_info=column_info,
                             treatment=treatment or 'N/A', outcome=outcome or 'N/A')
    return prompt

def _call_llm_for_var(llm: BaseChatModel, prompt: str, pydantic_model: BaseModel) -> Optional[BaseModel]:
    """Helper to call LLM with structured output and handle errors."""
    try:
        messages = [HumanMessage(content=prompt)]
        structured_llm = llm.with_structured_output(pydantic_model)
        parsed_result = structured_llm.invoke(messages)
        return parsed_result
    except (OutputParserException, ValidationError) as e:
        logger.error(f"LLM call failed parsing/validation for {pydantic_model.__name__}: {e}")
    except Exception as e:
         logger.error(f"LLM call failed unexpectedly for {pydantic_model.__name__}: {e}", exc_info=True)
    return None