Spaces:
Running
Running
Commit
·
f5c8ef7
1
Parent(s):
d56e62f
added quick examples
Browse files
app.py
CHANGED
@@ -4,49 +4,45 @@ import json
|
|
4 |
from pathlib import Path
|
5 |
import gradio as gr
|
6 |
import time
|
|
|
|
|
7 |
|
8 |
# Make your repo importable (expecting a folder named causal-agent at repo root)
|
9 |
sys.path.append(str(Path(__file__).parent / "causal-agent"))
|
10 |
-
|
|
|
|
|
|
|
11 |
from auto_causal.agent import run_causal_analysis # uses env for provider/model
|
12 |
|
13 |
# -------- LLM config (OpenAI only; key via HF Secrets) --------
|
14 |
os.environ.setdefault("LLM_PROVIDER", "openai")
|
15 |
os.environ.setdefault("LLM_MODEL", "gpt-4o")
|
16 |
|
17 |
-
# Lazy import to avoid import-time errors if key missing
|
18 |
def _get_openai_client():
|
19 |
if os.getenv("LLM_PROVIDER", "openai") != "openai":
|
20 |
raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.")
|
21 |
if not os.getenv("OPENAI_API_KEY"):
|
22 |
raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).")
|
23 |
-
|
24 |
-
|
25 |
-
from openai import OpenAI
|
26 |
-
return OpenAI()
|
27 |
-
except Exception as e:
|
28 |
-
raise RuntimeError(f"OpenAI SDK not available: {e}")
|
29 |
|
30 |
-
# -------- System prompt you asked for (verbatim) --------
|
31 |
SYSTEM_PROMPT = """You are an expert in statistics and causal inference.
|
32 |
You will be given:
|
33 |
1) The original research question.
|
34 |
2) The analysis method used.
|
35 |
3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group.
|
36 |
4) A brief dataset description.
|
37 |
-
|
38 |
Your task is to produce a clear, concise, and non-technical summary that:
|
39 |
- Directly answers the research question.
|
40 |
- States whether the effect is statistically significant.
|
41 |
- Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change).
|
42 |
- Mentions the method used in one sentence.
|
43 |
- Optionally ranks the treatment effects from largest to smallest if multiple treatments exist.
|
44 |
-
|
45 |
Formatting rules:
|
46 |
- Use bullet points or short paragraphs.
|
47 |
- Report effect sizes to two decimal places.
|
48 |
- Clearly state the interpretation in plain English without technical jargon.
|
49 |
-
|
50 |
Example Output Structure:
|
51 |
- **Method:** [Name of method + 1-line rationale]
|
52 |
- **Key Finding:** [Main answer to the research question]
|
@@ -57,18 +53,11 @@ Example Output Structure:
|
|
57 |
"""
|
58 |
|
59 |
def _extract_minimal_payload(agent_result: dict) -> dict:
|
60 |
-
"""
|
61 |
-
Extract the minimal, LLM-friendly payload from run_causal_analysis output.
|
62 |
-
Falls back gracefully if any fields are missing.
|
63 |
-
"""
|
64 |
-
# Try both top-level and nested (your JSON showed both patterns)
|
65 |
res = agent_result or {}
|
66 |
results = res.get("results", {}) if isinstance(res.get("results"), dict) else {}
|
67 |
inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {}
|
68 |
-
vars_ = results.get("variables", {}) if isinstance(results.get("variables"), dict) else {}
|
69 |
dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {}
|
70 |
|
71 |
-
# Pull best-available fields
|
72 |
question = (
|
73 |
results.get("original_query")
|
74 |
or dataset_analysis.get("original_query")
|
@@ -82,32 +71,11 @@ def _extract_minimal_payload(agent_result: dict) -> dict:
|
|
82 |
or "N/A"
|
83 |
)
|
84 |
|
85 |
-
effect_estimate = (
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
)
|
90 |
-
confidence_interval = (
|
91 |
-
inner.get("confidence_interval")
|
92 |
-
or res.get("confidence_interval")
|
93 |
-
or {}
|
94 |
-
)
|
95 |
-
standard_error = (
|
96 |
-
inner.get("standard_error")
|
97 |
-
or res.get("standard_error")
|
98 |
-
or {}
|
99 |
-
)
|
100 |
-
p_value = (
|
101 |
-
inner.get("p_value")
|
102 |
-
or res.get("p_value")
|
103 |
-
or {}
|
104 |
-
)
|
105 |
-
|
106 |
-
dataset_desc = (
|
107 |
-
results.get("dataset_description")
|
108 |
-
or res.get("dataset_description")
|
109 |
-
or "N/A"
|
110 |
-
)
|
111 |
|
112 |
return {
|
113 |
"original_question": question,
|
@@ -121,168 +89,159 @@ def _extract_minimal_payload(agent_result: dict) -> dict:
|
|
121 |
"dataset_description": dataset_desc,
|
122 |
}
|
123 |
|
124 |
-
def _format_effects_md(effect_estimate: dict) -> str:
|
125 |
-
"""
|
126 |
-
Minimal human-readable view of effect estimates for display.
|
127 |
-
"""
|
128 |
-
if not effect_estimate or not isinstance(effect_estimate, dict):
|
129 |
-
return "_No effect estimates found._"
|
130 |
-
# Render as bullet list
|
131 |
-
lines = []
|
132 |
-
for k, v in effect_estimate.items():
|
133 |
-
try:
|
134 |
-
lines.append(f"- **{k}**: {float(v):+.4f}")
|
135 |
-
except Exception:
|
136 |
-
lines.append(f"- **{k}**: {v}")
|
137 |
-
return "\n".join(lines)
|
138 |
-
|
139 |
def _summarize_with_llm(payload: dict) -> str:
|
140 |
-
"""
|
141 |
-
Calls OpenAI with the provided SYSTEM_PROMPT and the JSON payload as the user message.
|
142 |
-
Returns the model's text, or raises on error.
|
143 |
-
"""
|
144 |
client = _get_openai_client()
|
145 |
model_name = os.getenv("LLM_MODEL", "gpt-4o-mini")
|
146 |
|
147 |
-
user_content = (
|
148 |
-
"Summarize the following causal analysis results:\n\n"
|
149 |
-
+ json.dumps(payload, indent=2, ensure_ascii=False)
|
150 |
-
)
|
151 |
-
|
152 |
-
# Use Chat Completions for broad compatibility
|
153 |
resp = client.chat.completions.create(
|
154 |
model=model_name,
|
155 |
-
messages=[
|
156 |
-
|
157 |
-
{"role": "user", "content": user_content},
|
158 |
-
],
|
159 |
temperature=0
|
160 |
)
|
161 |
-
|
162 |
-
return text
|
163 |
|
164 |
-
def
|
165 |
-
"""
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
processing_html = """
|
170 |
-
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
|
171 |
-
<div style='font-size: 16px; margin-bottom: 5px;'>🔄 Analysis in Progress...</div>
|
172 |
-
<div style='font-size: 14px; color: #666;'>This may take 1-2 minutes depending on dataset size</div>
|
173 |
</div>
|
174 |
"""
|
175 |
-
|
176 |
-
yield (
|
177 |
-
processing_html, # method_out
|
178 |
-
processing_html, # effects_out
|
179 |
-
processing_html, # explanation_out
|
180 |
-
{"status": "Processing started..."} # raw_results
|
181 |
-
)
|
182 |
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
if not os.getenv("OPENAI_API_KEY"):
|
185 |
-
|
186 |
-
yield (error_html, "", "", {})
|
187 |
return
|
188 |
-
|
189 |
if not csv_path:
|
190 |
-
|
191 |
-
yield (error_html, "", "", {})
|
192 |
return
|
193 |
|
194 |
try:
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
<div style='font-size: 16px; margin-bottom: 5px;'>📊 Running Causal Analysis...</div>
|
199 |
-
<div style='font-size: 14px; color: #666;'>Analyzing dataset and selecting optimal method</div>
|
200 |
-
</div>
|
201 |
-
"""
|
202 |
-
|
203 |
-
yield (
|
204 |
-
analysis_html,
|
205 |
-
analysis_html,
|
206 |
-
analysis_html,
|
207 |
-
{"status": "Running causal analysis..."}
|
208 |
-
)
|
209 |
-
|
210 |
result = run_causal_analysis(
|
211 |
query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
|
212 |
dataset_path=csv_path,
|
213 |
dataset_description=(dataset_description or "").strip(),
|
214 |
)
|
215 |
-
|
216 |
-
|
217 |
-
llm_html
|
218 |
-
|
219 |
-
<div style='font-size: 16px; margin-bottom: 5px;'>🤖 Generating Summary...</div>
|
220 |
-
<div style='font-size: 14px; color: #666;'>Creating human-readable interpretation</div>
|
221 |
-
</div>
|
222 |
-
"""
|
223 |
-
|
224 |
-
yield (
|
225 |
-
llm_html,
|
226 |
-
llm_html,
|
227 |
-
llm_html,
|
228 |
-
{"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}}
|
229 |
-
)
|
230 |
-
|
231 |
except Exception as e:
|
232 |
-
|
233 |
-
yield (error_html, "", "", {})
|
234 |
return
|
235 |
|
236 |
try:
|
237 |
payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
|
238 |
method = payload.get("method_used", "N/A")
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
|
243 |
-
<h3 style='margin: 0 0 10px 0; font-size: 18px;'>Selected Method</h3>
|
244 |
-
<p style='margin: 0; font-size: 16px;'>{method}</p>
|
245 |
-
</div>
|
246 |
-
"""
|
247 |
-
|
248 |
-
# Format effects with simple styling
|
249 |
effect_estimate = payload.get("estimates", {}).get("effect_estimate", {})
|
250 |
if effect_estimate:
|
251 |
-
effects_html = "<
|
252 |
-
effects_html += "<h3 style='margin: 0 0 10px 0; font-size: 18px;'>Effect Estimates</h3>"
|
253 |
-
# for k, v in effect_estimate.items():
|
254 |
-
# try:
|
255 |
-
# value = f"{float(v):+.4f}"
|
256 |
-
# effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #ffffff;'><strong>{k}:</strong> <span style='font-size: 16px;'>{value}</span></div>"
|
257 |
-
# except:
|
258 |
-
effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #333333;'>{effect_estimate}</div>"
|
259 |
-
effects_html += "</div>"
|
260 |
else:
|
261 |
-
effects_html = "
|
262 |
|
263 |
-
# Generate explanation and format it
|
264 |
try:
|
265 |
explanation = _summarize_with_llm(payload)
|
266 |
-
explanation_html = f""
|
267 |
-
<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
|
268 |
-
<h3 style='margin: 0 0 15px 0; font-size: 18px;'>Detailed Explanation</h3>
|
269 |
-
<div style='line-height: 1.6; white-space: pre-wrap;'>{explanation}</div>
|
270 |
-
</div>
|
271 |
-
"""
|
272 |
except Exception as e:
|
273 |
-
explanation_html = f"
|
274 |
|
275 |
except Exception as e:
|
276 |
-
|
277 |
-
yield (error_html, "", "", {})
|
278 |
return
|
279 |
|
280 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {})
|
282 |
|
283 |
with gr.Blocks() as demo:
|
284 |
-
gr.Markdown("# Causal
|
285 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
with gr.Row():
|
288 |
query = gr.Textbox(
|
@@ -290,20 +249,56 @@ with gr.Blocks() as demo:
|
|
290 |
placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?",
|
291 |
lines=2,
|
292 |
)
|
293 |
-
|
294 |
with gr.Row():
|
295 |
-
csv_file = gr.File(
|
296 |
-
|
297 |
-
file_types=[".csv"],
|
298 |
-
type="filepath"
|
299 |
-
)
|
300 |
-
|
301 |
dataset_description = gr.Textbox(
|
302 |
label="Dataset description (optional)",
|
303 |
placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.",
|
304 |
lines=4,
|
305 |
)
|
306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
run_btn = gr.Button("Run analysis", variant="primary")
|
308 |
|
309 |
with gr.Row():
|
@@ -315,25 +310,17 @@ with gr.Blocks() as demo:
|
|
315 |
with gr.Row():
|
316 |
explanation_out = gr.HTML(label="Detailed Explanation")
|
317 |
|
318 |
-
# Add the collapsible raw results section
|
319 |
with gr.Accordion("Raw Results (Advanced)", open=False):
|
320 |
raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
|
321 |
|
322 |
run_btn.click(
|
323 |
fn=run_agent,
|
324 |
-
inputs=[query, csv_file, dataset_description],
|
325 |
outputs=[method_out, effects_out, explanation_out, raw_results],
|
326 |
show_progress=True
|
327 |
)
|
328 |
|
329 |
-
|
330 |
-
"""
|
331 |
-
**Tips:**
|
332 |
-
- Be specific about your treatment, outcome, and control variables
|
333 |
-
- Include relevant context in the dataset description
|
334 |
-
- The analysis may take 1-2 minutes for complex datasets
|
335 |
-
"""
|
336 |
-
)
|
337 |
|
338 |
if __name__ == "__main__":
|
339 |
-
demo.queue().launch()
|
|
|
4 |
from pathlib import Path
|
5 |
import gradio as gr
|
6 |
import time
|
7 |
+
import smtplib
|
8 |
+
from email.message import EmailMessage
|
9 |
|
10 |
# Make your repo importable (expecting a folder named causal-agent at repo root)
|
11 |
sys.path.append(str(Path(__file__).parent / "causal-agent"))
|
12 |
+
EXAMPLE_CSV_PATH = os.getenv(
|
13 |
+
"EXAMPLE_CSV_PATH",
|
14 |
+
str(Path(__file__).parent / "data" / "synthetic_data")
|
15 |
+
)
|
16 |
from auto_causal.agent import run_causal_analysis # uses env for provider/model
|
17 |
|
18 |
# -------- LLM config (OpenAI only; key via HF Secrets) --------
|
19 |
os.environ.setdefault("LLM_PROVIDER", "openai")
|
20 |
os.environ.setdefault("LLM_MODEL", "gpt-4o")
|
21 |
|
|
|
22 |
def _get_openai_client():
|
23 |
if os.getenv("LLM_PROVIDER", "openai") != "openai":
|
24 |
raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.")
|
25 |
if not os.getenv("OPENAI_API_KEY"):
|
26 |
raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).")
|
27 |
+
from openai import OpenAI
|
28 |
+
return OpenAI()
|
|
|
|
|
|
|
|
|
29 |
|
|
|
30 |
SYSTEM_PROMPT = """You are an expert in statistics and causal inference.
|
31 |
You will be given:
|
32 |
1) The original research question.
|
33 |
2) The analysis method used.
|
34 |
3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group.
|
35 |
4) A brief dataset description.
|
|
|
36 |
Your task is to produce a clear, concise, and non-technical summary that:
|
37 |
- Directly answers the research question.
|
38 |
- States whether the effect is statistically significant.
|
39 |
- Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change).
|
40 |
- Mentions the method used in one sentence.
|
41 |
- Optionally ranks the treatment effects from largest to smallest if multiple treatments exist.
|
|
|
42 |
Formatting rules:
|
43 |
- Use bullet points or short paragraphs.
|
44 |
- Report effect sizes to two decimal places.
|
45 |
- Clearly state the interpretation in plain English without technical jargon.
|
|
|
46 |
Example Output Structure:
|
47 |
- **Method:** [Name of method + 1-line rationale]
|
48 |
- **Key Finding:** [Main answer to the research question]
|
|
|
53 |
"""
|
54 |
|
55 |
def _extract_minimal_payload(agent_result: dict) -> dict:
|
|
|
|
|
|
|
|
|
|
|
56 |
res = agent_result or {}
|
57 |
results = res.get("results", {}) if isinstance(res.get("results"), dict) else {}
|
58 |
inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {}
|
|
|
59 |
dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {}
|
60 |
|
|
|
61 |
question = (
|
62 |
results.get("original_query")
|
63 |
or dataset_analysis.get("original_query")
|
|
|
71 |
or "N/A"
|
72 |
)
|
73 |
|
74 |
+
effect_estimate = inner.get("effect_estimate") or res.get("effect_estimate") or {}
|
75 |
+
confidence_interval = inner.get("confidence_interval") or res.get("confidence_interval") or {}
|
76 |
+
standard_error = inner.get("standard_error") or res.get("standard_error") or {}
|
77 |
+
p_value = inner.get("p_value") or res.get("p_value") or {}
|
78 |
+
dataset_desc = results.get("dataset_description") or res.get("dataset_description") or "N/A"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
return {
|
81 |
"original_question": question,
|
|
|
89 |
"dataset_description": dataset_desc,
|
90 |
}
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
def _summarize_with_llm(payload: dict) -> str:
|
|
|
|
|
|
|
|
|
93 |
client = _get_openai_client()
|
94 |
model_name = os.getenv("LLM_MODEL", "gpt-4o-mini")
|
95 |
|
96 |
+
user_content = "Summarize the following causal analysis results:\n\n" + json.dumps(payload, indent=2, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
97 |
resp = client.chat.completions.create(
|
98 |
model=model_name,
|
99 |
+
messages=[{"role": "system", "content": SYSTEM_PROMPT},
|
100 |
+
{"role": "user", "content": user_content}],
|
|
|
|
|
101 |
temperature=0
|
102 |
)
|
103 |
+
return resp.choices[0].message.content.strip()
|
|
|
104 |
|
105 |
+
def _html_panel(title, body_html):
|
106 |
+
return f"""
|
107 |
+
<div style='padding:15px;border:1px solid #ddd;border-radius:8px;margin:5px 0;background-color:#333333;'>
|
108 |
+
<h3 style='margin:0 0 10px 0;font-size:18px;'>{title}</h3>
|
109 |
+
<div style='line-height:1.6;'>{body_html}</div>
|
|
|
|
|
|
|
|
|
110 |
</div>
|
111 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
+
def _warn_html(text):
|
114 |
+
return f"<div style='padding:10px;border:1px solid #ffc107;border-radius:5px;color:#ffc107;background-color:#333333;'>⚠️ {text}</div>"
|
115 |
+
|
116 |
+
def _err_html(text):
|
117 |
+
return f"<div style='padding:10px;border:1px solid #dc3545;border-radius:5px;color:#dc3545;background-color:#333333;'>❌ {text}</div>"
|
118 |
+
|
119 |
+
def _ok_html(text):
|
120 |
+
return f"<div style='padding:10px;border:1px solid #2ea043;border-radius:5px;color:#2ea043;background-color:#333333;'>✅ {text}</div>"
|
121 |
+
|
122 |
+
# --- Email support ---
|
123 |
+
def send_email(recipient: str, subject: str, body_text: str, attachment_name: str = None, attachment_json: dict = None) -> str:
|
124 |
+
"""Returns '' on success, or error message string."""
|
125 |
+
host = os.getenv("SMTP_HOST")
|
126 |
+
port = int(os.getenv("SMTP_PORT", "587"))
|
127 |
+
user = os.getenv("SMTP_USER")
|
128 |
+
pwd = os.getenv("SMTP_PASS")
|
129 |
+
from_addr = os.getenv("EMAIL_FROM")
|
130 |
+
|
131 |
+
if not all([host, port, user, pwd, from_addr]):
|
132 |
+
return "Email is not configured (set SMTP_HOST, SMTP_PORT, SMTP_USER, SMTP_PASS, EMAIL_FROM)."
|
133 |
+
|
134 |
+
try:
|
135 |
+
msg = EmailMessage()
|
136 |
+
msg["From"] = from_addr
|
137 |
+
msg["To"] = recipient
|
138 |
+
msg["Subject"] = subject
|
139 |
+
msg.set_content(body_text)
|
140 |
+
|
141 |
+
if attachment_json is not None and attachment_name:
|
142 |
+
payload = json.dumps(attachment_json, indent=2).encode("utf-8")
|
143 |
+
msg.add_attachment(payload, maintype="application", subtype="json", filename=attachment_name)
|
144 |
+
|
145 |
+
with smtplib.SMTP(host, port, timeout=30) as s:
|
146 |
+
s.starttls()
|
147 |
+
s.login(user, pwd)
|
148 |
+
s.send_message(msg)
|
149 |
+
return ""
|
150 |
+
except Exception as e:
|
151 |
+
return f"Email send failed: {e}"
|
152 |
+
|
153 |
+
def run_agent(query: str, csv_path: str, dataset_description: str, email: str):
|
154 |
+
start = time.time()
|
155 |
+
|
156 |
+
processing_html = _html_panel("🔄 Analysis in Progress...", "<div style='font-size:14px;color:#bbb;'>This may take 1–2 minutes depending on dataset size.</div>")
|
157 |
+
yield (processing_html, processing_html, processing_html, {"status": "Processing started..."})
|
158 |
+
|
159 |
if not os.getenv("OPENAI_API_KEY"):
|
160 |
+
yield (_err_html("Set a Space Secret named OPENAI_API_KEY"), "", "", {})
|
|
|
161 |
return
|
|
|
162 |
if not csv_path:
|
163 |
+
yield (_warn_html("Please upload a CSV dataset."), "", "", {})
|
|
|
164 |
return
|
165 |
|
166 |
try:
|
167 |
+
step_html = _html_panel("📊 Running Causal Analysis...", "<div style='font-size:14px;color:#bbb;'>Analyzing dataset and selecting optimal method…</div>")
|
168 |
+
yield (step_html, step_html, step_html, {"status": "Running causal analysis..."})
|
169 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
result = run_causal_analysis(
|
171 |
query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
|
172 |
dataset_path=csv_path,
|
173 |
dataset_description=(dataset_description or "").strip(),
|
174 |
)
|
175 |
+
|
176 |
+
llm_html = _html_panel("🤖 Generating Summary...", "<div style='font-size:14px;color:#bbb;'>Creating human-readable interpretation…</div>")
|
177 |
+
yield (llm_html, llm_html, llm_html, {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}})
|
178 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
except Exception as e:
|
180 |
+
yield (_err_html(str(e)), "", "", {})
|
|
|
181 |
return
|
182 |
|
183 |
try:
|
184 |
payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
|
185 |
method = payload.get("method_used", "N/A")
|
186 |
+
|
187 |
+
method_html = _html_panel("Selected Method", f"<p style='margin:0;font-size:16px;'>{method}</p>")
|
188 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
effect_estimate = payload.get("estimates", {}).get("effect_estimate", {})
|
190 |
if effect_estimate:
|
191 |
+
effects_html = _html_panel("Effect Estimates", f"<pre style='white-space:pre-wrap;margin:0;'>{json.dumps(effect_estimate, indent=2)}</pre>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
else:
|
193 |
+
effects_html = _warn_html("No effect estimates found")
|
194 |
|
|
|
195 |
try:
|
196 |
explanation = _summarize_with_llm(payload)
|
197 |
+
explanation_html = _html_panel("Detailed Explanation", f"<div style='white-space:pre-wrap;'>{explanation}</div>")
|
|
|
|
|
|
|
|
|
|
|
198 |
except Exception as e:
|
199 |
+
explanation_html = _warn_html(f"LLM summary failed: {e}")
|
200 |
|
201 |
except Exception as e:
|
202 |
+
yield (_err_html(f"Failed to parse results: {e}"), "", "", {})
|
|
|
203 |
return
|
204 |
|
205 |
+
# Optional email send (best-effort)
|
206 |
+
elapsed = time.time() - start
|
207 |
+
if email and "@" in email:
|
208 |
+
# Always send once results are ready; if you prefer thresholded behavior, check (elapsed > 600)
|
209 |
+
subject = "Causal Agent Results"
|
210 |
+
body = (
|
211 |
+
"Here are your Causal Agent results.\n\n"
|
212 |
+
f"Question: {payload.get('original_question','N/A')}\n"
|
213 |
+
f"Method: {method}\n\n"
|
214 |
+
f"Summary:\n{explanation}\n\n"
|
215 |
+
"Raw JSON is attached.\n"
|
216 |
+
)
|
217 |
+
email_err = send_email(
|
218 |
+
recipient=email.strip(),
|
219 |
+
subject=subject,
|
220 |
+
body_text=body,
|
221 |
+
attachment_name="causal_results.json",
|
222 |
+
attachment_json=(result if isinstance(result, dict) else {"results": result})
|
223 |
+
)
|
224 |
+
if email_err:
|
225 |
+
explanation_html += _warn_html(email_err)
|
226 |
+
else:
|
227 |
+
explanation_html += _ok_html(f"Results emailed to {email.strip()}")
|
228 |
+
|
229 |
yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {})
|
230 |
|
231 |
with gr.Blocks() as demo:
|
232 |
+
gr.Markdown("# Causal AI Scientist")
|
233 |
+
# gr.Markdown(
|
234 |
+
# """
|
235 |
+
# **Tips**
|
236 |
+
# - Be specific about your treatment, outcome, and control variables.
|
237 |
+
# - Include relevant context in the dataset description.
|
238 |
+
# - If you enter an email, we’ll send results when ready (only if SMTP is configured via env).
|
239 |
+
# """
|
240 |
+
# )
|
241 |
+
gr.Markdown(
|
242 |
+
"Upload your dataset and ask causal questions in natural language. "
|
243 |
+
"The system will automatically select the appropriate causal inference method and provide clear explanations."
|
244 |
+
)
|
245 |
|
246 |
with gr.Row():
|
247 |
query = gr.Textbox(
|
|
|
249 |
placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?",
|
250 |
lines=2,
|
251 |
)
|
252 |
+
|
253 |
with gr.Row():
|
254 |
+
csv_file = gr.File(label="Dataset (CSV)", file_types=[".csv"], type="filepath")
|
255 |
+
|
|
|
|
|
|
|
|
|
256 |
dataset_description = gr.Textbox(
|
257 |
label="Dataset description (optional)",
|
258 |
placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.",
|
259 |
lines=4,
|
260 |
)
|
261 |
|
262 |
+
# NEW: optional email field
|
263 |
+
email = gr.Textbox(
|
264 |
+
label="Email (optional)",
|
265 |
+
placeholder="you@example.com — we'll email the results when ready (if email is configured).",
|
266 |
+
)
|
267 |
+
|
268 |
+
# Helpful examples (question + description)
|
269 |
+
gr.Examples(
|
270 |
+
examples=[
|
271 |
+
[
|
272 |
+
"Does the adoption of the industrial reform policy increase the production output in factories?",
|
273 |
+
"""This dataset has been compiled to study the effect of an industrial reform policy on production output in several manufacturing factories. The data was collected over two-time frames, before and after the reform, from a group of factories which adopted the reform and another group that did not. The data collected includes continuous variables such as labor hours spent on production and the amount of raw materials used. Binary variables include the use of automation, energy efficiency of the machines used (energy efficient or not), and worker satisfaction (satisfied or not).
|
274 |
+
|
275 |
+
- factory_id: Unique identifier for each factory
|
276 |
+
- post_reform: Indicator if the data was collected after the reform (1) or before the reform (0)
|
277 |
+
- labor_hours: The number of labor hours spent on production
|
278 |
+
- raw_materials: The quantity of raw materials used in kilograms
|
279 |
+
- automation_use: Indicator if the factory uses automation in its production process (1) or not (0)
|
280 |
+
- energy_efficiency: Indicator if the factory uses energy-efficient machines (1) or not (0)
|
281 |
+
- worker_satisfaction: Indicator if workers reported being satisfied with their work environment (1) or not (0)""",
|
282 |
+
EXAMPLE_CSV_PATH + '/did_canonical_data_1.csv'
|
283 |
+
],
|
284 |
+
[
|
285 |
+
"Does taking the newly developed medication have an impact on improving the lung capacity of patients with chronic obstructive pulmonary disease?",
|
286 |
+
"""This dataset is collected from a clinical trial study conducted by a pharmaceutical company. The study aims to understand if their newly developed medication can enhance the lung capacity of patients suffering from chronic obstructive pulmonary disease (COPD). Participants for the study, who are all COPD patients, were recruited from various healthcare centers and were randomly assigned to either receive the new medication or a placebo. Data was collected on various factors including the age of the participants, the number of years they have been smoking, their gender, whether or not they have a history of smoking, and whether or not they have a regular exercise habit.
|
287 |
+
|
288 |
+
'age' is the age of the participants in years.
|
289 |
+
'smoking_years' is the number of years the participant has been smoking.
|
290 |
+
'gender' is a binary variable where 1 represents male and 0 represents female.
|
291 |
+
'smoking_history' is a binary variable where 1 indicates the participant has a history of smoking, while 0 indicates no such history.
|
292 |
+
'exercise_habit' is a binary variable where 1 indicates the participant exercises regularly, while 0 indicates the participant does not.
|
293 |
+
'new_medication' is a binary variable where 1 indicates the participant was assigned the new medication, while 0 indicates the participant was assigned a placebo.
|
294 |
+
'lung_capacity' is the measured lung capacity of the participant.""",
|
295 |
+
EXAMPLE_CSV_PATH + '/rct_data_4.csv',
|
296 |
+
],
|
297 |
+
],
|
298 |
+
inputs=[query, dataset_description, csv_file], # include the file component here
|
299 |
+
label="Quick Examples (click to fill)",
|
300 |
+
)
|
301 |
+
|
302 |
run_btn = gr.Button("Run analysis", variant="primary")
|
303 |
|
304 |
with gr.Row():
|
|
|
310 |
with gr.Row():
|
311 |
explanation_out = gr.HTML(label="Detailed Explanation")
|
312 |
|
|
|
313 |
with gr.Accordion("Raw Results (Advanced)", open=False):
|
314 |
raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
|
315 |
|
316 |
run_btn.click(
|
317 |
fn=run_agent,
|
318 |
+
inputs=[query, csv_file, dataset_description, email],
|
319 |
outputs=[method_out, effects_out, explanation_out, raw_results],
|
320 |
show_progress=True
|
321 |
)
|
322 |
|
323 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
|
325 |
if __name__ == "__main__":
|
326 |
+
demo.queue().launch()
|