File size: 37,322 Bytes
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
"""
Dataset analyzer component for causal inference.

This module provides functionality to analyze datasets to detect characteristics
relevant for causal inference methods, including temporal structure, potential
instrumental variables, discontinuities, and variable relationships.
"""

import os
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from scipy import stats
import logging
import json
from langchain_core.language_models import BaseChatModel
from auto_causal.utils.llm_helpers import llm_identify_temporal_and_unit_vars

logger = logging.getLogger(__name__)

def _calculate_per_group_stats(df: pd.DataFrame, potential_treatments: List[str]) -> Dict[str, Dict]:
    """Calculates summary stats for numeric covariates grouped by potential binary treatments."""
    stats_dict = {}
    numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
    
    for treat_var in potential_treatments:
        if treat_var not in df.columns:
            logger.warning(f"Potential treatment '{treat_var}' not found in DataFrame columns.")
            continue
            
        # Ensure treatment is binary (0/1 or similar)
        unique_vals = df[treat_var].dropna().unique()
        if len(unique_vals) != 2:
            logger.info(f"Skipping stats for potential treatment '{treat_var}' as it is not binary ({len(unique_vals)} unique values).")
            continue
            
        # Attempt to map values to 0 and 1 if possible
        try:
             # Ensure boolean is converted to int
            if df[treat_var].dtype == 'bool':
                 df[treat_var] = df[treat_var].astype(int)
                 unique_vals = df[treat_var].dropna().unique()
                 
            # Basic check if values are interpretable as 0/1
            if not set(unique_vals).issubset({0, 1}): 
                 # Attempt conversion if possible (e.g., True/False strings?)
                 logger.warning(f"Potential treatment '{treat_var}' has values {unique_vals}, not {0, 1}. Cannot calculate group stats reliably.")
                 continue 
        except Exception as e:
            logger.warning(f"Could not process potential treatment '{treat_var}' values ({unique_vals}): {e}")
            continue
            
        logger.info(f"Calculating group stats for treatment: '{treat_var}'")
        treat_stats = {'group_sizes': {}, 'covariate_stats': {}}
        
        try:
            grouped = df.groupby(treat_var)
            sizes = grouped.size()
            treat_stats['group_sizes']['treated'] = int(sizes.get(1, 0))
            treat_stats['group_sizes']['control'] = int(sizes.get(0, 0))
            
            if treat_stats['group_sizes']['treated'] == 0 or treat_stats['group_sizes']['control'] == 0:
                logger.warning(f"Treatment '{treat_var}' has zero samples in one group. Skipping covariate stats.")
                stats_dict[treat_var] = treat_stats
                continue

            # Calculate mean and std for numeric covariates
            cov_stats = grouped[numeric_cols].agg(['mean', 'std']).unstack()
            
            for cov in numeric_cols:
                if cov == treat_var: continue # Skip treatment variable itself
                
                mean_control = cov_stats.get(('mean', 0, cov), np.nan)
                std_control = cov_stats.get(('std', 0, cov), np.nan)
                mean_treated = cov_stats.get(('mean', 1, cov), np.nan)
                std_treated = cov_stats.get(('std', 1, cov), np.nan)
                
                treat_stats['covariate_stats'][cov] = {
                    'mean_control': float(mean_control) if pd.notna(mean_control) else None,
                    'std_control': float(std_control) if pd.notna(std_control) else None,
                    'mean_treat': float(mean_treated) if pd.notna(mean_treated) else None,
                    'std_treat': float(std_treated) if pd.notna(std_treated) else None,
                }
            stats_dict[treat_var] = treat_stats
        except Exception as e:
            logger.error(f"Error calculating stats for treatment '{treat_var}': {e}", exc_info=True)
            # Store partial info if possible
            if treat_var not in stats_dict: 
                 stats_dict[treat_var] = {'error': str(e)}
            elif 'error' not in stats_dict[treat_var]:
                 stats_dict[treat_var]['error'] = str(e)
                 
    return stats_dict

def analyze_dataset(
    dataset_path: str, 
    llm_client: Optional[BaseChatModel] = None,
    dataset_description: Optional[str] = None,
    original_query: Optional[str] = None
) -> Dict[str, Any]:
    """
    Analyze a dataset to identify important characteristics for causal inference.
    
    Args:
        dataset_path: Path to the dataset file
        llm_client: Optional LLM client for enhanced analysis
        dataset_description: Optional description of the dataset for context
        
    Returns:
        Dict containing dataset analysis results:
            - dataset_info: Basic information about the dataset
            - columns: List of column names
            - potential_treatments: List of potential treatment variables (possibly LLM augmented)
            - potential_outcomes: List of potential outcome variables (possibly LLM augmented)
            - temporal_structure_detected: Whether temporal structure was detected
            - panel_data_detected: Whether panel data structure was detected
            - potential_instruments_detected: Whether potential instruments were detected
            - discontinuities_detected: Whether discontinuities were detected
            - llm_augmentation: Status of LLM augmentation if used
    """
    llm_augmentation = "Not used" if not llm_client else "Initialized"
    
    # Check if file exists
    if not os.path.exists(dataset_path):
        logger.error(f"Dataset file not found at {dataset_path}")
        return {"error": f"Dataset file not found at {dataset_path}"}
    
    try:
        # Load the dataset
        df = pd.read_csv(dataset_path)
        
        # Basic dataset information
        sample_size = len(df)
        columns_list = df.columns.tolist()
        num_covariates = len(columns_list) - 2 # Rough estimate (total - T - Y)
        dataset_info = {
            "num_rows": sample_size,
            "num_columns": len(columns_list),
            "file_path": dataset_path,
            "file_name": os.path.basename(dataset_path)
        }
        
        # --- Detailed Analysis (Keep internal) ---
        column_types_detailed = {col: str(df[col].dtype) for col in df.columns}
        missing_values_detailed = df.isnull().sum().to_dict()
        column_categories_detailed = _categorize_columns(df)
        column_nunique_counts_detailed = {col: df[col].nunique() for col in df.columns} # Calculate nunique
        numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
        correlations_detailed = df[numeric_cols].corr() if numeric_cols else pd.DataFrame()
        temporal_structure_detailed = detect_temporal_structure(df, llm_client, dataset_description, original_query)
        
        # First, identify potential treatment and outcome variables
        potential_variables = _identify_potential_variables(
            df, 
            column_categories_detailed,
            llm_client=llm_client,
            dataset_description=dataset_description
        )
        
        if llm_client:
            llm_augmentation = "Used for variable identification"
        
        # Then use that info to help find potential instrumental variables
        potential_instruments_detailed = find_potential_instruments(
            df,
            llm_client=llm_client,
            potential_treatments=potential_variables.get("potential_treatments", []),
            potential_outcomes=potential_variables.get("potential_outcomes", []),
            dataset_description=dataset_description
        )
        
        # Other analyses
        discontinuities_detailed = detect_discontinuities(df)
        variable_relationships_detailed = assess_variable_relationships(df, correlations_detailed)
        
        # Calculate per-group stats for potential binary treatments
        potential_binary_treatments = [ 
            t for t in potential_variables["potential_treatments"] 
            if column_categories_detailed.get(t) == 'binary' 
            or column_categories_detailed.get(t) == 'binary_categorical'
        ]
        per_group_stats = _calculate_per_group_stats(df.copy(), potential_binary_treatments)

        # --- Summarized Analysis (For Output) ---
        
        # Get boolean flags and essential lists
        has_temporal = temporal_structure_detailed.get("has_temporal_structure", False)
        is_panel = temporal_structure_detailed.get("is_panel_data", False)
        logger.info(f"iv is {potential_instruments_detailed}")
        has_instruments = len(potential_instruments_detailed) > 0
        has_discontinuities = discontinuities_detailed.get("has_discontinuities", False)
        
        # --- Extract only instrument names for the final output ---
        potential_instrument_names = [
            inst_dict.get('variable') 
            for inst_dict in potential_instruments_detailed 
            if isinstance(inst_dict, dict) and 'variable' in inst_dict
        ]
        logger.info(f"iv is {potential_instrument_names}")
        # --- Final Output Dictionary (Highly Summarized) ---
        return {
            "dataset_info": dataset_info, # Keep basic info
            "columns": columns_list,
            "potential_treatments": potential_variables["potential_treatments"],
            "potential_outcomes": potential_variables["potential_outcomes"],
            # Return concise flags instead of detailed dicts/lists
            "temporal_structure_detected": has_temporal, 
            "panel_data_detected": is_panel,
            "potential_instruments_detected": has_instruments,
            "discontinuities_detected": has_discontinuities,
            # Use the extracted list of names here
            "potential_instruments": potential_instrument_names,
            "discontinuities": discontinuities_detailed,
            "temporal_structure": temporal_structure_detailed,
            "column_categories": column_categories_detailed,
            "column_nunique_counts": column_nunique_counts_detailed, # Add nunique counts to output
            "sample_size": sample_size,
            "num_covariates_estimate": num_covariates,
            "llm_augmentation": llm_augmentation
        }
    
    except Exception as e:
        logger.error(f"Error analyzing dataset '{dataset_path}': {e}", exc_info=True)
        return {
            "error": f"Error analyzing dataset: {str(e)}",
            "llm_augmentation": llm_augmentation
        }


def _categorize_columns(df: pd.DataFrame) -> Dict[str, str]:
    """
    Categorize columns into types relevant for causal inference.
    
    Args:
        df: DataFrame to analyze
        
    Returns:
        Dict mapping column names to their types
    """
    result = {}
    
    for col in df.columns:
        # Check if column is numeric
        if pd.api.types.is_numeric_dtype(df[col]):
            # Count number of unique values
            n_unique = df[col].nunique()
            
            # Binary numeric variable
            if n_unique == 2:
                result[col] = "binary"
            # Likely categorical represented as numeric
            elif n_unique < 10:
                result[col] = "categorical_numeric"
            # Discrete numeric (integers)
            elif pd.api.types.is_integer_dtype(df[col]):
                result[col] = "discrete_numeric"
            # Continuous numeric
            else:
                result[col] = "continuous_numeric"
        
        # Check for datetime
        elif pd.api.types.is_datetime64_any_dtype(df[col]) or _is_date_string(df, col):
            result[col] = "datetime"
        
        # Check for categorical
        elif pd.api.types.is_categorical_dtype(df[col]) or df[col].nunique() < 20:
            if df[col].nunique() == 2:
                result[col] = "binary_categorical"
            else:
                result[col] = "categorical"
        
        # Must be text or other
        else:
            result[col] = "text_or_other"
    
    return result


def _is_date_string(df: pd.DataFrame, col: str) -> bool:
    """
    Check if a column contains date strings.
    
    Args:
        df: DataFrame to check
        col: Column name to check
        
    Returns:
        True if the column appears to contain date strings
    """
    # Try to convert to datetime
    if not pd.api.types.is_string_dtype(df[col]):
        return False
    
    # Check sample of values
    sample = df[col].dropna().sample(min(10, len(df[col].dropna()))).tolist()
    
    try:
        for val in sample:
            pd.to_datetime(val)
        return True
    except:
        return False


def _identify_potential_variables(
    df: pd.DataFrame, 
    column_categories: Dict[str, str],
    llm_client: Optional[BaseChatModel] = None,
    dataset_description: Optional[str] = None
) -> Dict[str, List[str]]:
    """
    Identify potential treatment and outcome variables in the dataset, using LLM if available.
    Falls back to heuristic method if LLM fails or is not available.
    
    Args:
        df: DataFrame to analyze
        column_categories: Dictionary mapping column names to their types
        llm_client: Optional LLM client for enhanced identification
        dataset_description: Optional description of the dataset for context
        
    Returns:
        Dict with potential treatment and outcome variables
    """
    # Try LLM approach if client is provided
    if llm_client:
        try:
            logger.info("Using LLM to identify potential treatment and outcome variables")
            
            # Create a concise prompt with just column information
            columns_list = df.columns.tolist()
            column_types = {col: str(df[col].dtype) for col in columns_list}
            
            # Get binary columns for extra context
            binary_cols = [col for col in columns_list 
                          if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() == 2]
            
            # Add dataset description if available
            description_text = f"\nDataset Description: {dataset_description}" if dataset_description else ""
            
            prompt = f"""
You are an expert causal inference data scientist. Identify potential treatment and outcome variables from this dataset.{description_text}

Dataset columns:
{columns_list}

Column types:
{column_types}

Binary columns (good treatment candidates):
{binary_cols}

Instructions:
1. Identify TREATMENT variables: interventions, treatments, programs, policies, or binary state changes.
   Look for binary variables or names with 'treatment', 'intervention', 'program', 'policy', etc.

2. Identify OUTCOME variables: results, effects, or responses to treatments.
   Look for numeric variables (especially non-binary) or names with 'outcome', 'result', 'effect', 'score', etc.

Return ONLY a valid JSON object with two lists: "potential_treatments" and "potential_outcomes".
Example: {{"potential_treatments": ["treatment_a", "program_b"], "potential_outcomes": ["result_score", "outcome_measure"]}}
"""
            
            # Call the LLM and parse the response
            response = llm_client.invoke(prompt)
            response_text = response.content if hasattr(response, 'content') else str(response)
            
            # Extract JSON from the response text
            import re
            json_match = re.search(r'{.*}', response_text, re.DOTALL)
            
            if json_match:
                result = json.loads(json_match.group(0))
                
                # Validate the response
                if (isinstance(result, dict) and 
                    "potential_treatments" in result and 
                    "potential_outcomes" in result and
                    isinstance(result["potential_treatments"], list) and
                    isinstance(result["potential_outcomes"], list)):
                    
                    # Ensure all suggestions are valid columns
                    valid_treatments = [col for col in result["potential_treatments"] if col in df.columns]
                    valid_outcomes = [col for col in result["potential_outcomes"] if col in df.columns]
                    
                    if valid_treatments and valid_outcomes:
                        logger.info(f"LLM identified {len(valid_treatments)} treatments and {len(valid_outcomes)} outcomes")
                        return {
                            "potential_treatments": valid_treatments,
                            "potential_outcomes": valid_outcomes
                        }
                    else:
                        logger.warning("LLM suggested invalid columns, falling back to heuristic method")
                else:
                    logger.warning("Invalid LLM response format, falling back to heuristic method")
            else:
                logger.warning("Could not extract JSON from LLM response, falling back to heuristic method")
                
        except Exception as e:
            logger.error(f"Error in LLM identification: {e}", exc_info=True)
            logger.info("Falling back to heuristic method")
    
    # Fallback to heuristic method
    logger.info("Using heuristic method to identify potential treatment and outcome variables")
    
    # Identify potential treatment variables
    potential_treatments = []
    
    # Look for binary variables (good treatment candidates)
    binary_cols = [col for col in df.columns 
                   if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() == 2]
    
    # Look for variables with names suggesting treatment
    treatment_keywords = ['treatment', 'treat', 'intervention', 'program', 'policy', 
                         'exposed', 'assigned', 'received', 'participated']
    
    for col in df.columns:
        col_lower = col.lower()
        if any(keyword in col_lower for keyword in treatment_keywords):
            potential_treatments.append(col)
    
    # Add binary variables if we don't have enough candidates
    if len(potential_treatments) < 3:
        for col in binary_cols:
            if col not in potential_treatments:
                potential_treatments.append(col)
                if len(potential_treatments) >= 3:
                    break
    
    # Identify potential outcome variables
    potential_outcomes = []
    
    # Look for numeric variables that aren't binary
    numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
    non_binary_numeric = [col for col in numeric_cols if col not in binary_cols]
    
    # Look for variables with names suggesting outcomes
    outcome_keywords = ['outcome', 'result', 'effect', 'response', 'score', 'performance',
                       'achievement', 'success', 'failure', 'improvement']
    
    for col in df.columns:
        col_lower = col.lower()
        if any(keyword in col_lower for keyword in outcome_keywords):
            potential_outcomes.append(col)
    
    # Add numeric non-binary variables if we don't have enough candidates
    if len(potential_outcomes) < 3:
        for col in non_binary_numeric:
            if col not in potential_outcomes and col not in potential_treatments:
                potential_outcomes.append(col)
                if len(potential_outcomes) >= 3:
                    break
    
    return {
        "potential_treatments": potential_treatments,
        "potential_outcomes": potential_outcomes
    }


def detect_temporal_structure(
    df: pd.DataFrame, 
    llm_client: Optional[BaseChatModel] = None, 
    dataset_description: Optional[str] = None,
    original_query: Optional[str] = None
) -> Dict[str, Any]:
    """
    Detect temporal structure in the dataset, using LLM for enhanced identification.
    
    Args:
        df: DataFrame to analyze
        llm_client: Optional LLM client for enhanced identification
        dataset_description: Optional description of the dataset for context
        
    Returns:
        Dict with information about temporal structure:
            - has_temporal_structure: Whether temporal structure exists
            - temporal_columns: Primary time column identified (or list if multiple from heuristic)
            - is_panel_data: Whether data is in panel format
            - time_column: Primary time column identified for panel data
            - id_column: Primary unit ID column identified for panel data
            - time_periods: Number of time periods (if panel data)
            - units: Number of unique units (if panel data)
            - identification_method: How time/unit vars were identified ('LLM', 'Heuristic', 'None')
    """
    result = {
        "has_temporal_structure": False,
        "temporal_columns": [], # Will store primary time column or heuristic list
        "is_panel_data": False,
        "time_column": None,
        "id_column": None,
        "time_periods": None,
        "units": None,
        "identification_method": "None"
    }
    
    # --- Step 1: Heuristic identification (as before) ---
    #heuristic_datetime_cols = []
    #for col in df.columns:
    #    if pd.api.types.is_datetime64_any_dtype(df[col]):
    #        heuristic_datetime_cols.append(col)
    #    elif pd.api.types.is_string_dtype(df[col]):
    #        try:
    #            if pd.to_datetime(df[col], errors='coerce').notna().any():
    #                heuristic_datetime_cols.append(col)
    #        except:
    #            pass # Ignore conversion errors
    
    #time_keywords = ['year', 'month', 'day', 'date', 'time', 'period', 'quarter', 'week']
    #for col in df.columns:
    #    col_lower = col.lower()
    #    if any(keyword in col_lower for keyword in time_keywords) and col not in heuristic_datetime_cols:
    #        heuristic_datetime_cols.append(col)

    #id_keywords = ['id', 'individual', 'person', 'unit', 'entity', 'firm', 'company', 'state', 'country']
    #heuristic_potential_id_cols = []
    #for col in df.columns:
    #    col_lower = col.lower()
    #    # Exclude columns already identified as time-related by heuristics
    #    if any(keyword in col_lower for keyword in id_keywords) and col not in heuristic_datetime_cols:
    #        heuristic_potential_id_cols.append(col)

    # --- Step 2: LLM-assisted identification --- 
    llm_identified_time_var = None
    llm_identified_unit_var = None
    heuristic_datetime_cols = []
    heuristic_potential_id_cols = []
    dataset_summary = df.describe(include='all')

    if llm_client:
        logger.info("Attempting LLM-assisted identification of temporal/unit variables.")
        column_names = df.columns.tolist()
        column_dtypes_dict = {col: str(df[col].dtype) for col in column_names}
        
        try:
            llm_suggestions = llm_identify_temporal_and_unit_vars(
                column_names=column_names,
                column_dtypes=column_dtypes_dict,
                dataset_description=dataset_description if dataset_description else "No dataset description provided.",
                dataset_summary=dataset_summary,
                heuristic_time_candidates=heuristic_datetime_cols,
                heuristic_id_candidates=heuristic_potential_id_cols,
                query=original_query if original_query else "No query provided.",
                llm=llm_client
            )
            llm_identified_time_var = llm_suggestions.get("time_variable")
            llm_identified_unit_var = llm_suggestions.get("unit_variable")
            result["identification_method"] = "LLM"
            
            if not llm_identified_time_var and not llm_identified_unit_var:
                result["identification_method"] = "LLM_NoIdentification"
        except Exception as e:
            logger.warning(f"LLM call for temporal/unit vars failed: {e}. Falling back to heuristics.")
            result["identification_method"] = "Heuristic_LLM_Error"
    else:
        result["identification_method"] = "Heuristic_NoLLM"

    # --- Step 3: Combine LLM and Heuristic Results --- 
    final_time_var = None
    final_unit_var = None

    if llm_identified_time_var:
        final_time_var = llm_identified_time_var
        logger.info(f"Prioritizing LLM identified time variable: {final_time_var}")
    elif heuristic_datetime_cols:
        final_time_var = heuristic_datetime_cols[0] # Fallback to first heuristic time col
        logger.info(f"Using heuristic time variable: {final_time_var}")
    
    if llm_identified_unit_var:
        final_unit_var = llm_identified_unit_var
        logger.info(f"Prioritizing LLM identified unit variable: {final_unit_var}")
    elif heuristic_potential_id_cols:
        final_unit_var = heuristic_potential_id_cols[0] # Fallback to first heuristic ID col
        logger.info(f"Using heuristic unit variable: {final_unit_var}")

    # Update results based on final selections
    if final_time_var:
        result["has_temporal_structure"] = True
        result["temporal_columns"] = [final_time_var] # Store as a list with the primary time var
        result["time_column"] = final_time_var
    else: # If no time var found by LLM or heuristic, use original heuristic list for temporal_columns
        if heuristic_datetime_cols:
            result["has_temporal_structure"] = True
            result["temporal_columns"] = heuristic_datetime_cols
        # time_column remains None

    if final_unit_var:
        result["id_column"] = final_unit_var

    # --- Step 4: Update Panel Data Logic (based on final_time_var and final_unit_var) ---
    if final_time_var and final_unit_var:
        # Check if there are multiple time periods per unit using the identified variables
        try:
            # Ensure columns exist before groupby
            if final_time_var in df.columns and final_unit_var in df.columns:
                if df.groupby(final_unit_var)[final_time_var].nunique().mean() > 1.0:
                    result["is_panel_data"] = True
                    result["time_periods"] = df[final_time_var].nunique()
                    result["units"] = df[final_unit_var].nunique()
                    logger.info(f"Panel data detected: Time='{final_time_var}', Unit='{final_unit_var}', Periods={result['time_periods']}, Units={result['units']}")
                else:
                    logger.info("Not panel data: Each unit does not have multiple time periods.")
            else:
                logger.warning(f"Final time ('{final_time_var}') or unit ('{final_unit_var}') var not in DataFrame. Cannot confirm panel structure.")
        except Exception as e:
            logger.error(f"Error checking panel data structure with time='{final_time_var}', unit='{final_unit_var}': {e}")
            result["is_panel_data"] = False # Default to false on error
    else:
        logger.info("Not panel data: Missing either time or unit variable for panel structure.")

    logger.debug(f"Final temporal structure detection result: {result}")
    return result


def find_potential_instruments(
    df: pd.DataFrame, 
    llm_client: Optional[BaseChatModel] = None,
    potential_treatments: List[str] = None,
    potential_outcomes: List[str] = None,
    dataset_description: Optional[str] = None
) -> List[Dict[str, Any]]:
    """
    Find potential instrumental variables in the dataset, using LLM if available.
    Falls back to heuristic method if LLM fails or is not available.
    
    Args:
        df: DataFrame to analyze
        llm_client: Optional LLM client for enhanced identification
        potential_treatments: Optional list of potential treatment variables
        potential_outcomes: Optional list of potential outcome variables
        dataset_description: Optional description of the dataset for context
        
    Returns:
        List of potential instrumental variables with their properties
    """
    # Try LLM approach if client is provided
    if llm_client:
        try:
            logger.info("Using LLM to identify potential instrumental variables")
            
            # Create a concise prompt with just column information
            columns_list = df.columns.tolist()
            
            # Exclude known treatment and outcome variables from consideration
            excluded_columns = []
            if potential_treatments:
                excluded_columns.extend(potential_treatments)
            if potential_outcomes:
                excluded_columns.extend(potential_outcomes)
                
            # Filter columns to exclude treatments and outcomes
            candidate_columns = [col for col in columns_list if col not in excluded_columns]
            
            if not candidate_columns:
                logger.warning("No eligible columns for instrumental variables after filtering treatments and outcomes")
                return []
                
            # Get column types for context
            column_types = {col: str(df[col].dtype) for col in candidate_columns}
            
            # Add dataset description if available
            description_text = f"\nDataset Description: {dataset_description}" if dataset_description else ""
            
            prompt = f"""
You are an expert causal inference data scientist. Identify potential instrumental variables from this dataset.{description_text}

DEFINITION: Instrumental variables must:
1. Be correlated with the treatment variable (relevance)
2. Only affect the outcome through the treatment (exclusion restriction)
3. Not be correlated with unmeasured confounders (exogeneity)

Treatment variables: {potential_treatments if potential_treatments else "Unknown"}
Outcome variables: {potential_outcomes if potential_outcomes else "Unknown"}

Available columns (excluding treatments and outcomes):
{candidate_columns}

Column types:
{column_types}

Look for variables likely to be:
- Random assignments
- Policy changes
- Geographic or temporal variations
- Variables with names containing: 'instrument', 'iv', 'assigned', 'random', 'lottery', 'exogenous'

Return ONLY a JSON array of objects, each with "variable", "reason", and "data_type" fields.
Example: 
[
  {{"variable": "random_assignment", "reason": "Random assignment variable", "data_type": "int64"}},
  {{"variable": "distance_to_facility", "reason": "Geographic variation", "data_type": "float64"}}
]
"""
            
            # Call the LLM and parse the response
            response = llm_client.invoke(prompt)
            response_text = response.content if hasattr(response, 'content') else str(response)
            
            # Extract JSON from the response text
            import re
            json_match = re.search(r'\[\s*{.*}\s*\]', response_text, re.DOTALL)
            
            if json_match:
                result = json.loads(json_match.group(0))
                
                # Validate the response
                if isinstance(result, list) and len(result) > 0:
                    # Filter for valid entries
                    valid_instruments = []
                    for item in result:
                        if not isinstance(item, dict) or "variable" not in item:
                            continue
                            
                        if item["variable"] not in df.columns:
                            continue
                            
                        # Ensure all required fields are present
                        if "reason" not in item:
                            item["reason"] = "Identified by LLM"
                        if "data_type" not in item:
                            item["data_type"] = str(df[item["variable"]].dtype)
                            
                        valid_instruments.append(item)
                    
                    if valid_instruments:
                        logger.info(f"LLM identified {len(valid_instruments)} potential instrumental variables {valid_instruments}")
                        return valid_instruments
                    else:
                        logger.warning("No valid instruments found by LLM, falling back to heuristic method")
                else:
                    logger.warning("Invalid LLM response format, falling back to heuristic method")
            else:
                logger.warning("Could not extract JSON from LLM response, falling back to heuristic method")
                
        except Exception as e:
            logger.error(f"Error in LLM identification of instruments: {e}", exc_info=True)
            logger.info("Falling back to heuristic method")
    
    # Fallback to heuristic method
    logger.info("Using heuristic method to identify potential instrumental variables")
    potential_instruments = []
    
    # Look for variables with instrumental-related names
    instrument_keywords = ['instrument', 'iv', 'assigned', 'random', 'lottery', 'exogenous']
    
    for col in df.columns:
        # Skip treatment and outcome variables
        if potential_treatments and col in potential_treatments:
            continue
        if potential_outcomes and col in potential_outcomes:
            continue
            
        col_lower = col.lower()
        if any(keyword in col_lower for keyword in instrument_keywords):
            instrument_info = {
                "variable": col,
                "reason": f"Name contains instrument-related keyword",
                "data_type": str(df[col].dtype)
            }
            potential_instruments.append(instrument_info)
    
    return potential_instruments


def detect_discontinuities(df: pd.DataFrame) -> Dict[str, Any]:
    """
    Identify discontinuities in continuous variables (for RDD).
    
    Args:
        df: DataFrame to analyze
        
    Returns:
        Dict with information about detected discontinuities
    """
    discontinuities = []
    
    # For each numeric column, check for potential discontinuities
    numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
    
    for col in numeric_cols:
        # Skip columns with too many unique values
        if df[col].nunique() > 100:
            continue
        
        values = df[col].dropna().sort_values().values
        
        # Calculate gaps between consecutive values
        if len(values) > 10:
            gaps = np.diff(values)
            mean_gap = np.mean(gaps)
            std_gap = np.std(gaps)
            
            # Look for unusually large gaps (potential discontinuities)
            large_gaps = np.where(gaps > mean_gap + 2*std_gap)[0]
            
            if len(large_gaps) > 0:
                for idx in large_gaps:
                    cutpoint = (values[idx] + values[idx+1]) / 2
                    discontinuities.append({
                        "variable": col,
                        "cutpoint": float(cutpoint),
                        "gap_size": float(gaps[idx]),
                        "mean_gap": float(mean_gap)
                    })
    
    return {
        "has_discontinuities": len(discontinuities) > 0,
        "discontinuities": discontinuities
    }


def assess_variable_relationships(df: pd.DataFrame, corr_matrix: pd.DataFrame) -> Dict[str, Any]:
    """
    Assess relationships between variables in the dataset.
    
    Args:
        df: DataFrame to analyze
        corr_matrix: Precomputed correlation matrix for numeric columns
        
    Returns:
        Dict with information about variable relationships:
            - strongly_correlated_pairs: Pairs of strongly correlated variables
            - potential_confounders: Variables that might be confounders
    """
    result = {"strongly_correlated_pairs": [], "potential_confounders": []}
    
    numeric_cols = corr_matrix.columns.tolist()
    if len(numeric_cols) < 2:
        return result
    
    # Use the precomputed correlation matrix
    corr_matrix_abs = corr_matrix.abs()
    
    # Find strongly correlated variable pairs
    for i in range(len(numeric_cols)):
        for j in range(i+1, len(numeric_cols)):
            if abs(corr_matrix_abs.iloc[i, j]) > 0.7:  # Correlation threshold
                result["strongly_correlated_pairs"].append({
                    "variables": [numeric_cols[i], numeric_cols[j]],
                    "correlation": float(corr_matrix.iloc[i, j])
                })
    
    # Identify potential confounders (variables correlated with multiple others)
    confounder_counts = {col: 0 for col in numeric_cols}
    
    for pair in result["strongly_correlated_pairs"]:
        confounder_counts[pair["variables"][0]] += 1
        confounder_counts[pair["variables"][1]] += 1
    
    # Variables correlated with multiple others are potential confounders
    for col, count in confounder_counts.items():
        if count >= 2:
            result["potential_confounders"].append({"variable": col, "num_correlations": count})
    
    return result