Spaces:
Running
Running
File size: 4,038 Bytes
1721aea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os, re, io, time, json, logging, contextlib, textwrap
from typing import Dict, Any
import pandas as pd
import argparse
from auto_causal.agent import run_causal_analysis
# Constants
RATE_LIMIT_SECONDS = 2
def run_caia(desc, question, df):
return run_causal_analysis(query=question, dataset_path=df, dataset_description=desc)
def parse_args():
parser = argparse.ArgumentParser(description="Run batch causal analysis.")
parser.add_argument("--csv_path", type=str, required=True, help="CSV file with queries, descriptions, and file names.")
parser.add_argument("--data_folder", type=str, required=True, help="Folder containing data CSVs.")
parser.add_argument("--data_category", type=str, required=True, help="Dataset category (e.g., real, qrdata, synthetic).")
parser.add_argument("--output_folder", type=str, required=True, help="Folder to save output.")
parser.add_argument("--llm_name", type=str, required=True, help="Name of the LLM used.")
return parser.parse_args()
def main():
args = parse_args()
csv_meta = args.csv_meta
data_dir = args.data_dir
output_json = args.output_json
os.environ["LLM_MODEL"] = args.llm_name
print("[main] Starting batch processing…")
if not os.path.exists(csv_meta):
logging.error(f"Meta file not found: {csv_meta}")
return
meta_df = pd.read_csv(csv_meta)
print(f"[main] Loaded metadata CSV with {len(meta_df)} rows.")
results: Dict[int, Dict[str, Any]] = {}
for idx, row in meta_df.iterrows():
data_path = os.path.join(data_dir, str(row["data_files"]))
print(f"\n[main] Row {idx+1}/{len(meta_df)} → Dataset: {data_path}")
try:
res = run_caia(
desc=row["data_description"],
question=row["natural_language_query"],
df=data_path,
)
# Format result according to specified structure
formatted_result = {
"query": row["natural_language_query"],
"method": row["method"],
"answer": row["answer"],
"dataset_description": row["data_description"],
"dataset_path": data_path,
"keywords": row.get("keywords", "Causality, Average treatment effect"),
"final_result": {
"method": res['results']['results'].get("method_used"),
"causal_effect": res['results']['results'].get("effect_estimate"),
"standard_deviation": res['results']['results'].get("standard_error"),
"treatment_variable": res['results']['variables'].get("treatment_variable", None),
"outcome_variable": res['results']['variables'].get("outcome_variable", None),
"covariates": res['results']['variables'].get("covariates", []),
"instrument_variable": res['results']['variables'].get("instrument_variable", None),
"running_variable": res['results']['variables'].get("running_variable", None),
"temporal_variable": res['results']['variables'].get("time_variable", None),
"statistical_test_results": res.get("summary", ""),
"explanation_for_model_choice": res.get("explanation", ""),
"regression_equation": res.get("regression_equation", "")
}
}
results[idx] = formatted_result
print(type(res))
print(res)
print(f"[main] Formatted result for row {idx+1}:", formatted_result)
except Exception as e:
logging.error(f"[{idx+1}] Error: {e}")
results[idx] = {"answer": str(e)}
time.sleep(RATE_LIMIT_SECONDS)
os.makedirs(os.path.dirname(output_json), exist_ok=True)
with open(output_json, "w") as f:
json.dump(results, f, indent=2)
print(f"[main] Done. Predictions saved to {output_json}")
if __name__ == "__main__":
main()
|