causal-agent / app.py
FireShadow's picture
Initial clean commit
1721aea
raw
history blame
13.5 kB
import os
import sys
import json
from pathlib import Path
import gradio as gr
import time
# Make your repo importable (expecting a folder named causal-agent at repo root)
sys.path.append(str(Path(__file__).parent / "causal-agent"))
from auto_causal.agent import run_causal_analysis # uses env for provider/model
# -------- LLM config (OpenAI only; key via HF Secrets) --------
os.environ.setdefault("LLM_PROVIDER", "openai")
os.environ.setdefault("LLM_MODEL", "gpt-4o")
# Lazy import to avoid import-time errors if key missing
def _get_openai_client():
if os.getenv("LLM_PROVIDER", "openai") != "openai":
raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.")
if not os.getenv("OPENAI_API_KEY"):
raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).")
try:
# OpenAI SDK v1+
from openai import OpenAI
return OpenAI()
except Exception as e:
raise RuntimeError(f"OpenAI SDK not available: {e}")
# -------- System prompt you asked for (verbatim) --------
SYSTEM_PROMPT = """You are an expert in statistics and causal inference.
You will be given:
1) The original research question.
2) The analysis method used.
3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group.
4) A brief dataset description.
Your task is to produce a clear, concise, and non-technical summary that:
- Directly answers the research question.
- States whether the effect is statistically significant.
- Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change).
- Mentions the method used in one sentence.
- Optionally ranks the treatment effects from largest to smallest if multiple treatments exist.
Formatting rules:
- Use bullet points or short paragraphs.
- Report effect sizes to two decimal places.
- Clearly state the interpretation in plain English without technical jargon.
Example Output Structure:
- **Method:** [Name of method + 1-line rationale]
- **Key Finding:** [Main answer to the research question]
- **Details:**
- [Treatment name]: +X.XX percentage points (95% CI: [L, U]), p < 0.001 β€” [Significance comment]
- …
- **Rank Order of Effects:** [Largest β†’ Smallest]
"""
def _extract_minimal_payload(agent_result: dict) -> dict:
"""
Extract the minimal, LLM-friendly payload from run_causal_analysis output.
Falls back gracefully if any fields are missing.
"""
# Try both top-level and nested (your JSON showed both patterns)
res = agent_result or {}
results = res.get("results", {}) if isinstance(res.get("results"), dict) else {}
inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {}
vars_ = results.get("variables", {}) if isinstance(results.get("variables"), dict) else {}
dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {}
# Pull best-available fields
question = (
results.get("original_query")
or dataset_analysis.get("original_query")
or res.get("query")
or "N/A"
)
method = (
inner.get("method_used")
or res.get("method_used")
or results.get("method_used")
or "N/A"
)
effect_estimate = (
inner.get("effect_estimate")
or res.get("effect_estimate")
or {}
)
confidence_interval = (
inner.get("confidence_interval")
or res.get("confidence_interval")
or {}
)
standard_error = (
inner.get("standard_error")
or res.get("standard_error")
or {}
)
p_value = (
inner.get("p_value")
or res.get("p_value")
or {}
)
dataset_desc = (
results.get("dataset_description")
or res.get("dataset_description")
or "N/A"
)
return {
"original_question": question,
"method_used": method,
"estimates": {
"effect_estimate": effect_estimate,
"confidence_interval": confidence_interval,
"standard_error": standard_error,
"p_value": p_value,
},
"dataset_description": dataset_desc,
}
def _format_effects_md(effect_estimate: dict) -> str:
"""
Minimal human-readable view of effect estimates for display.
"""
if not effect_estimate or not isinstance(effect_estimate, dict):
return "_No effect estimates found._"
# Render as bullet list
lines = []
for k, v in effect_estimate.items():
try:
lines.append(f"- **{k}**: {float(v):+.4f}")
except Exception:
lines.append(f"- **{k}**: {v}")
return "\n".join(lines)
def _summarize_with_llm(payload: dict) -> str:
"""
Calls OpenAI with the provided SYSTEM_PROMPT and the JSON payload as the user message.
Returns the model's text, or raises on error.
"""
client = _get_openai_client()
model_name = os.getenv("LLM_MODEL", "gpt-4o-mini")
user_content = (
"Summarize the following causal analysis results:\n\n"
+ json.dumps(payload, indent=2, ensure_ascii=False)
)
# Use Chat Completions for broad compatibility
resp = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
],
temperature=0
)
text = resp.choices[0].message.content.strip()
return text
def run_agent(query: str, csv_path: str, dataset_description: str):
"""
Modified to use yield for progressive updates and immediate feedback
"""
# Immediate feedback - show processing has started
processing_html = """
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
<div style='font-size: 16px; margin-bottom: 5px;'>πŸ”„ Analysis in Progress...</div>
<div style='font-size: 14px; color: #666;'>This may take 1-2 minutes depending on dataset size</div>
</div>
"""
yield (
processing_html, # method_out
processing_html, # effects_out
processing_html, # explanation_out
{"status": "Processing started..."} # raw_results
)
# Input validation
if not os.getenv("OPENAI_API_KEY"):
error_html = "<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>⚠️ Set a Space Secret named OPENAI_API_KEY</div>"
yield (error_html, "", "", {})
return
if not csv_path:
error_html = "<div style='padding: 10px; border: 1px solid #ffc107; border-radius: 5px; color: #856404; background-color: #333333;'>Please upload a CSV dataset.</div>"
yield (error_html, "", "", {})
return
try:
# Update status to show causal analysis is running
analysis_html = """
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
<div style='font-size: 16px; margin-bottom: 5px;'>πŸ“Š Running Causal Analysis...</div>
<div style='font-size: 14px; color: #666;'>Analyzing dataset and selecting optimal method</div>
</div>
"""
yield (
analysis_html,
analysis_html,
analysis_html,
{"status": "Running causal analysis..."}
)
result = run_causal_analysis(
query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
dataset_path=csv_path,
dataset_description=(dataset_description or "").strip(),
)
# Update to show LLM summarization step
llm_html = """
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
<div style='font-size: 16px; margin-bottom: 5px;'>πŸ€– Generating Summary...</div>
<div style='font-size: 14px; color: #666;'>Creating human-readable interpretation</div>
</div>
"""
yield (
llm_html,
llm_html,
llm_html,
{"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}}
)
except Exception as e:
error_html = f"<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>❌ Error: {e}</div>"
yield (error_html, "", "", {})
return
try:
payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
method = payload.get("method_used", "N/A")
# Format method output with simple styling
method_html = f"""
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
<h3 style='margin: 0 0 10px 0; font-size: 18px;'>Selected Method</h3>
<p style='margin: 0; font-size: 16px;'>{method}</p>
</div>
"""
# Format effects with simple styling
effect_estimate = payload.get("estimates", {}).get("effect_estimate", {})
if effect_estimate:
effects_html = "<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>"
effects_html += "<h3 style='margin: 0 0 10px 0; font-size: 18px;'>Effect Estimates</h3>"
# for k, v in effect_estimate.items():
# try:
# value = f"{float(v):+.4f}"
# effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #ffffff;'><strong>{k}:</strong> <span style='font-size: 16px;'>{value}</span></div>"
# except:
effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #333333;'>{effect_estimate}</div>"
effects_html += "</div>"
else:
effects_html = "<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; color: #666; font-style: italic; background-color: #333333;'>No effect estimates found</div>"
# Generate explanation and format it
try:
explanation = _summarize_with_llm(payload)
explanation_html = f"""
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
<h3 style='margin: 0 0 15px 0; font-size: 18px;'>Detailed Explanation</h3>
<div style='line-height: 1.6; white-space: pre-wrap;'>{explanation}</div>
</div>
"""
except Exception as e:
explanation_html = f"<div style='padding: 10px; border: 1px solid #ffc107; border-radius: 5px; color: #856404; background-color: #333333;'>⚠️ LLM summary failed: {e}</div>"
except Exception as e:
error_html = f"<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>❌ Failed to parse results: {e}</div>"
yield (error_html, "", "", {})
return
# Final result
yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {})
with gr.Blocks() as demo:
gr.Markdown("# Causal Agent")
gr.Markdown("Upload your dataset and ask causal questions in natural language. The system will automatically select the appropriate causal inference method and provide clear explanations.")
with gr.Row():
query = gr.Textbox(
label="Your causal question (natural language)",
placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?",
lines=2,
)
with gr.Row():
csv_file = gr.File(
label="Dataset (CSV)",
file_types=[".csv"],
type="filepath"
)
dataset_description = gr.Textbox(
label="Dataset description (optional)",
placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.",
lines=4,
)
run_btn = gr.Button("Run analysis", variant="primary")
with gr.Row():
with gr.Column(scale=1):
method_out = gr.HTML(label="Selected Method")
with gr.Column(scale=1):
effects_out = gr.HTML(label="Effect Estimates")
with gr.Row():
explanation_out = gr.HTML(label="Detailed Explanation")
# Add the collapsible raw results section
with gr.Accordion("Raw Results (Advanced)", open=False):
raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
run_btn.click(
fn=run_agent,
inputs=[query, csv_file, dataset_description],
outputs=[method_out, effects_out, explanation_out, raw_results],
show_progress=True
)
gr.Markdown(
"""
**Tips:**
- Be specific about your treatment, outcome, and control variables
- Include relevant context in the dataset description
- The analysis may take 1-2 minutes for complex datasets
"""
)
if __name__ == "__main__":
demo.queue().launch()