Spaces:
Running
Running
import os | |
import sys | |
import json | |
from pathlib import Path | |
import gradio as gr | |
import time | |
# Make your repo importable (expecting a folder named causal-agent at repo root) | |
sys.path.append(str(Path(__file__).parent / "causal-agent")) | |
from auto_causal.agent import run_causal_analysis # uses env for provider/model | |
# -------- LLM config (OpenAI only; key via HF Secrets) -------- | |
os.environ.setdefault("LLM_PROVIDER", "openai") | |
os.environ.setdefault("LLM_MODEL", "gpt-4o") | |
# Lazy import to avoid import-time errors if key missing | |
def _get_openai_client(): | |
if os.getenv("LLM_PROVIDER", "openai") != "openai": | |
raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.") | |
if not os.getenv("OPENAI_API_KEY"): | |
raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).") | |
try: | |
# OpenAI SDK v1+ | |
from openai import OpenAI | |
return OpenAI() | |
except Exception as e: | |
raise RuntimeError(f"OpenAI SDK not available: {e}") | |
# -------- System prompt you asked for (verbatim) -------- | |
SYSTEM_PROMPT = """You are an expert in statistics and causal inference. | |
You will be given: | |
1) The original research question. | |
2) The analysis method used. | |
3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group. | |
4) A brief dataset description. | |
Your task is to produce a clear, concise, and non-technical summary that: | |
- Directly answers the research question. | |
- States whether the effect is statistically significant. | |
- Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change). | |
- Mentions the method used in one sentence. | |
- Optionally ranks the treatment effects from largest to smallest if multiple treatments exist. | |
Formatting rules: | |
- Use bullet points or short paragraphs. | |
- Report effect sizes to two decimal places. | |
- Clearly state the interpretation in plain English without technical jargon. | |
Example Output Structure: | |
- **Method:** [Name of method + 1-line rationale] | |
- **Key Finding:** [Main answer to the research question] | |
- **Details:** | |
- [Treatment name]: +X.XX percentage points (95% CI: [L, U]), p < 0.001 β [Significance comment] | |
- β¦ | |
- **Rank Order of Effects:** [Largest β Smallest] | |
""" | |
def _extract_minimal_payload(agent_result: dict) -> dict: | |
""" | |
Extract the minimal, LLM-friendly payload from run_causal_analysis output. | |
Falls back gracefully if any fields are missing. | |
""" | |
# Try both top-level and nested (your JSON showed both patterns) | |
res = agent_result or {} | |
results = res.get("results", {}) if isinstance(res.get("results"), dict) else {} | |
inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {} | |
vars_ = results.get("variables", {}) if isinstance(results.get("variables"), dict) else {} | |
dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {} | |
# Pull best-available fields | |
question = ( | |
results.get("original_query") | |
or dataset_analysis.get("original_query") | |
or res.get("query") | |
or "N/A" | |
) | |
method = ( | |
inner.get("method_used") | |
or res.get("method_used") | |
or results.get("method_used") | |
or "N/A" | |
) | |
effect_estimate = ( | |
inner.get("effect_estimate") | |
or res.get("effect_estimate") | |
or {} | |
) | |
confidence_interval = ( | |
inner.get("confidence_interval") | |
or res.get("confidence_interval") | |
or {} | |
) | |
standard_error = ( | |
inner.get("standard_error") | |
or res.get("standard_error") | |
or {} | |
) | |
p_value = ( | |
inner.get("p_value") | |
or res.get("p_value") | |
or {} | |
) | |
dataset_desc = ( | |
results.get("dataset_description") | |
or res.get("dataset_description") | |
or "N/A" | |
) | |
return { | |
"original_question": question, | |
"method_used": method, | |
"estimates": { | |
"effect_estimate": effect_estimate, | |
"confidence_interval": confidence_interval, | |
"standard_error": standard_error, | |
"p_value": p_value, | |
}, | |
"dataset_description": dataset_desc, | |
} | |
def _format_effects_md(effect_estimate: dict) -> str: | |
""" | |
Minimal human-readable view of effect estimates for display. | |
""" | |
if not effect_estimate or not isinstance(effect_estimate, dict): | |
return "_No effect estimates found._" | |
# Render as bullet list | |
lines = [] | |
for k, v in effect_estimate.items(): | |
try: | |
lines.append(f"- **{k}**: {float(v):+.4f}") | |
except Exception: | |
lines.append(f"- **{k}**: {v}") | |
return "\n".join(lines) | |
def _summarize_with_llm(payload: dict) -> str: | |
""" | |
Calls OpenAI with the provided SYSTEM_PROMPT and the JSON payload as the user message. | |
Returns the model's text, or raises on error. | |
""" | |
client = _get_openai_client() | |
model_name = os.getenv("LLM_MODEL", "gpt-4o-mini") | |
user_content = ( | |
"Summarize the following causal analysis results:\n\n" | |
+ json.dumps(payload, indent=2, ensure_ascii=False) | |
) | |
# Use Chat Completions for broad compatibility | |
resp = client.chat.completions.create( | |
model=model_name, | |
messages=[ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": user_content}, | |
], | |
temperature=0 | |
) | |
text = resp.choices[0].message.content.strip() | |
return text | |
def run_agent(query: str, csv_path: str, dataset_description: str): | |
""" | |
Modified to use yield for progressive updates and immediate feedback | |
""" | |
# Immediate feedback - show processing has started | |
processing_html = """ | |
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'> | |
<div style='font-size: 16px; margin-bottom: 5px;'>π Analysis in Progress...</div> | |
<div style='font-size: 14px; color: #666;'>This may take 1-2 minutes depending on dataset size</div> | |
</div> | |
""" | |
yield ( | |
processing_html, # method_out | |
processing_html, # effects_out | |
processing_html, # explanation_out | |
{"status": "Processing started..."} # raw_results | |
) | |
# Input validation | |
if not os.getenv("OPENAI_API_KEY"): | |
error_html = "<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>β οΈ Set a Space Secret named OPENAI_API_KEY</div>" | |
yield (error_html, "", "", {}) | |
return | |
if not csv_path: | |
error_html = "<div style='padding: 10px; border: 1px solid #ffc107; border-radius: 5px; color: #856404; background-color: #333333;'>Please upload a CSV dataset.</div>" | |
yield (error_html, "", "", {}) | |
return | |
try: | |
# Update status to show causal analysis is running | |
analysis_html = """ | |
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'> | |
<div style='font-size: 16px; margin-bottom: 5px;'>π Running Causal Analysis...</div> | |
<div style='font-size: 14px; color: #666;'>Analyzing dataset and selecting optimal method</div> | |
</div> | |
""" | |
yield ( | |
analysis_html, | |
analysis_html, | |
analysis_html, | |
{"status": "Running causal analysis..."} | |
) | |
result = run_causal_analysis( | |
query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(), | |
dataset_path=csv_path, | |
dataset_description=(dataset_description or "").strip(), | |
) | |
# Update to show LLM summarization step | |
llm_html = """ | |
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'> | |
<div style='font-size: 16px; margin-bottom: 5px;'>π€ Generating Summary...</div> | |
<div style='font-size: 14px; color: #666;'>Creating human-readable interpretation</div> | |
</div> | |
""" | |
yield ( | |
llm_html, | |
llm_html, | |
llm_html, | |
{"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}} | |
) | |
except Exception as e: | |
error_html = f"<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>β Error: {e}</div>" | |
yield (error_html, "", "", {}) | |
return | |
try: | |
payload = _extract_minimal_payload(result if isinstance(result, dict) else {}) | |
method = payload.get("method_used", "N/A") | |
# Format method output with simple styling | |
method_html = f""" | |
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'> | |
<h3 style='margin: 0 0 10px 0; font-size: 18px;'>Selected Method</h3> | |
<p style='margin: 0; font-size: 16px;'>{method}</p> | |
</div> | |
""" | |
# Format effects with simple styling | |
effect_estimate = payload.get("estimates", {}).get("effect_estimate", {}) | |
if effect_estimate: | |
effects_html = "<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>" | |
effects_html += "<h3 style='margin: 0 0 10px 0; font-size: 18px;'>Effect Estimates</h3>" | |
# for k, v in effect_estimate.items(): | |
# try: | |
# value = f"{float(v):+.4f}" | |
# effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #ffffff;'><strong>{k}:</strong> <span style='font-size: 16px;'>{value}</span></div>" | |
# except: | |
effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #333333;'>{effect_estimate}</div>" | |
effects_html += "</div>" | |
else: | |
effects_html = "<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; color: #666; font-style: italic; background-color: #333333;'>No effect estimates found</div>" | |
# Generate explanation and format it | |
try: | |
explanation = _summarize_with_llm(payload) | |
explanation_html = f""" | |
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'> | |
<h3 style='margin: 0 0 15px 0; font-size: 18px;'>Detailed Explanation</h3> | |
<div style='line-height: 1.6; white-space: pre-wrap;'>{explanation}</div> | |
</div> | |
""" | |
except Exception as e: | |
explanation_html = f"<div style='padding: 10px; border: 1px solid #ffc107; border-radius: 5px; color: #856404; background-color: #333333;'>β οΈ LLM summary failed: {e}</div>" | |
except Exception as e: | |
error_html = f"<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>β Failed to parse results: {e}</div>" | |
yield (error_html, "", "", {}) | |
return | |
# Final result | |
yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {}) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Causal Agent") | |
gr.Markdown("Upload your dataset and ask causal questions in natural language. The system will automatically select the appropriate causal inference method and provide clear explanations.") | |
with gr.Row(): | |
query = gr.Textbox( | |
label="Your causal question (natural language)", | |
placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?", | |
lines=2, | |
) | |
with gr.Row(): | |
csv_file = gr.File( | |
label="Dataset (CSV)", | |
file_types=[".csv"], | |
type="filepath" | |
) | |
dataset_description = gr.Textbox( | |
label="Dataset description (optional)", | |
placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.", | |
lines=4, | |
) | |
run_btn = gr.Button("Run analysis", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
method_out = gr.HTML(label="Selected Method") | |
with gr.Column(scale=1): | |
effects_out = gr.HTML(label="Effect Estimates") | |
with gr.Row(): | |
explanation_out = gr.HTML(label="Detailed Explanation") | |
# Add the collapsible raw results section | |
with gr.Accordion("Raw Results (Advanced)", open=False): | |
raw_results = gr.JSON(label="Complete Analysis Output", show_label=False) | |
run_btn.click( | |
fn=run_agent, | |
inputs=[query, csv_file, dataset_description], | |
outputs=[method_out, effects_out, explanation_out, raw_results], | |
show_progress=True | |
) | |
gr.Markdown( | |
""" | |
**Tips:** | |
- Be specific about your treatment, outcome, and control variables | |
- Include relevant context in the dataset description | |
- The analysis may take 1-2 minutes for complex datasets | |
""" | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() |