File size: 4,038 Bytes
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os, re, io, time, json, logging, contextlib, textwrap
from typing import Dict, Any
import pandas as pd
import argparse
from auto_causal.agent import run_causal_analysis

# Constants
RATE_LIMIT_SECONDS = 2

def run_caia(desc, question, df):
    return run_causal_analysis(query=question, dataset_path=df, dataset_description=desc)

def parse_args():
    parser = argparse.ArgumentParser(description="Run batch causal analysis.")
    parser.add_argument("--csv_path", type=str, required=True, help="CSV file with queries, descriptions, and file names.")
    parser.add_argument("--data_folder", type=str, required=True, help="Folder containing data CSVs.")
    parser.add_argument("--data_category", type=str, required=True, help="Dataset category (e.g., real, qrdata, synthetic).")
    parser.add_argument("--output_folder", type=str, required=True, help="Folder to save output.")
    parser.add_argument("--llm_name", type=str, required=True, help="Name of the LLM used.")
    return parser.parse_args()

def main():
    
    args = parse_args()
    csv_meta = args.csv_meta
    data_dir = args.data_dir
    output_json = args.output_json
    os.environ["LLM_MODEL"] = args.llm_name
    print("[main] Starting batch processing…")

    if not os.path.exists(csv_meta):
        logging.error(f"Meta file not found: {csv_meta}")
        return

    meta_df = pd.read_csv(csv_meta)
    print(f"[main] Loaded metadata CSV with {len(meta_df)} rows.")

    results: Dict[int, Dict[str, Any]] = {}

    for idx, row in meta_df.iterrows():
        data_path = os.path.join(data_dir, str(row["data_files"]))
        print(f"\n[main] Row {idx+1}/{len(meta_df)} → Dataset: {data_path}")

        try:
            res = run_caia(
                desc=row["data_description"],
                question=row["natural_language_query"],
                df=data_path,
            )
            
            # Format result according to specified structure
            formatted_result = {
                "query": row["natural_language_query"],
                "method": row["method"],
                "answer": row["answer"],
                "dataset_description": row["data_description"],
                "dataset_path": data_path,
                "keywords": row.get("keywords", "Causality, Average treatment effect"),
                "final_result": {
                    "method": res['results']['results'].get("method_used"),
                    "causal_effect": res['results']['results'].get("effect_estimate"),
                    "standard_deviation": res['results']['results'].get("standard_error"),
                    "treatment_variable": res['results']['variables'].get("treatment_variable", None),
                    "outcome_variable": res['results']['variables'].get("outcome_variable", None),
                    "covariates": res['results']['variables'].get("covariates", []),
                    "instrument_variable": res['results']['variables'].get("instrument_variable", None),
                    "running_variable": res['results']['variables'].get("running_variable", None),
                    "temporal_variable": res['results']['variables'].get("time_variable", None),
                    "statistical_test_results": res.get("summary", ""),
                    "explanation_for_model_choice": res.get("explanation", ""),
                    "regression_equation": res.get("regression_equation", "")
                }
            }
            results[idx] = formatted_result
            print(type(res))
            print(res)
            print(f"[main] Formatted result for row {idx+1}:", formatted_result)
        except Exception as e:
            logging.error(f"[{idx+1}] Error: {e}")
            results[idx] = {"answer": str(e)}

        time.sleep(RATE_LIMIT_SECONDS)

    os.makedirs(os.path.dirname(output_json), exist_ok=True)
    with open(output_json, "w") as f:
        json.dump(results, f, indent=2)
    print(f"[main] Done. Predictions saved to {output_json}")

if __name__ == "__main__":
    main()