FireShadow commited on
Commit
f5c8ef7
·
1 Parent(s): d56e62f

added quick examples

Browse files
Files changed (1) hide show
  1. app.py +169 -182
app.py CHANGED
@@ -4,49 +4,45 @@ import json
4
  from pathlib import Path
5
  import gradio as gr
6
  import time
 
 
7
 
8
  # Make your repo importable (expecting a folder named causal-agent at repo root)
9
  sys.path.append(str(Path(__file__).parent / "causal-agent"))
10
-
 
 
 
11
  from auto_causal.agent import run_causal_analysis # uses env for provider/model
12
 
13
  # -------- LLM config (OpenAI only; key via HF Secrets) --------
14
  os.environ.setdefault("LLM_PROVIDER", "openai")
15
  os.environ.setdefault("LLM_MODEL", "gpt-4o")
16
 
17
- # Lazy import to avoid import-time errors if key missing
18
  def _get_openai_client():
19
  if os.getenv("LLM_PROVIDER", "openai") != "openai":
20
  raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.")
21
  if not os.getenv("OPENAI_API_KEY"):
22
  raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).")
23
- try:
24
- # OpenAI SDK v1+
25
- from openai import OpenAI
26
- return OpenAI()
27
- except Exception as e:
28
- raise RuntimeError(f"OpenAI SDK not available: {e}")
29
 
30
- # -------- System prompt you asked for (verbatim) --------
31
  SYSTEM_PROMPT = """You are an expert in statistics and causal inference.
32
  You will be given:
33
  1) The original research question.
34
  2) The analysis method used.
35
  3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group.
36
  4) A brief dataset description.
37
-
38
  Your task is to produce a clear, concise, and non-technical summary that:
39
  - Directly answers the research question.
40
  - States whether the effect is statistically significant.
41
  - Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change).
42
  - Mentions the method used in one sentence.
43
  - Optionally ranks the treatment effects from largest to smallest if multiple treatments exist.
44
-
45
  Formatting rules:
46
  - Use bullet points or short paragraphs.
47
  - Report effect sizes to two decimal places.
48
  - Clearly state the interpretation in plain English without technical jargon.
49
-
50
  Example Output Structure:
51
  - **Method:** [Name of method + 1-line rationale]
52
  - **Key Finding:** [Main answer to the research question]
@@ -57,18 +53,11 @@ Example Output Structure:
57
  """
58
 
59
  def _extract_minimal_payload(agent_result: dict) -> dict:
60
- """
61
- Extract the minimal, LLM-friendly payload from run_causal_analysis output.
62
- Falls back gracefully if any fields are missing.
63
- """
64
- # Try both top-level and nested (your JSON showed both patterns)
65
  res = agent_result or {}
66
  results = res.get("results", {}) if isinstance(res.get("results"), dict) else {}
67
  inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {}
68
- vars_ = results.get("variables", {}) if isinstance(results.get("variables"), dict) else {}
69
  dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {}
70
 
71
- # Pull best-available fields
72
  question = (
73
  results.get("original_query")
74
  or dataset_analysis.get("original_query")
@@ -82,32 +71,11 @@ def _extract_minimal_payload(agent_result: dict) -> dict:
82
  or "N/A"
83
  )
84
 
85
- effect_estimate = (
86
- inner.get("effect_estimate")
87
- or res.get("effect_estimate")
88
- or {}
89
- )
90
- confidence_interval = (
91
- inner.get("confidence_interval")
92
- or res.get("confidence_interval")
93
- or {}
94
- )
95
- standard_error = (
96
- inner.get("standard_error")
97
- or res.get("standard_error")
98
- or {}
99
- )
100
- p_value = (
101
- inner.get("p_value")
102
- or res.get("p_value")
103
- or {}
104
- )
105
-
106
- dataset_desc = (
107
- results.get("dataset_description")
108
- or res.get("dataset_description")
109
- or "N/A"
110
- )
111
 
112
  return {
113
  "original_question": question,
@@ -121,168 +89,159 @@ def _extract_minimal_payload(agent_result: dict) -> dict:
121
  "dataset_description": dataset_desc,
122
  }
123
 
124
- def _format_effects_md(effect_estimate: dict) -> str:
125
- """
126
- Minimal human-readable view of effect estimates for display.
127
- """
128
- if not effect_estimate or not isinstance(effect_estimate, dict):
129
- return "_No effect estimates found._"
130
- # Render as bullet list
131
- lines = []
132
- for k, v in effect_estimate.items():
133
- try:
134
- lines.append(f"- **{k}**: {float(v):+.4f}")
135
- except Exception:
136
- lines.append(f"- **{k}**: {v}")
137
- return "\n".join(lines)
138
-
139
  def _summarize_with_llm(payload: dict) -> str:
140
- """
141
- Calls OpenAI with the provided SYSTEM_PROMPT and the JSON payload as the user message.
142
- Returns the model's text, or raises on error.
143
- """
144
  client = _get_openai_client()
145
  model_name = os.getenv("LLM_MODEL", "gpt-4o-mini")
146
 
147
- user_content = (
148
- "Summarize the following causal analysis results:\n\n"
149
- + json.dumps(payload, indent=2, ensure_ascii=False)
150
- )
151
-
152
- # Use Chat Completions for broad compatibility
153
  resp = client.chat.completions.create(
154
  model=model_name,
155
- messages=[
156
- {"role": "system", "content": SYSTEM_PROMPT},
157
- {"role": "user", "content": user_content},
158
- ],
159
  temperature=0
160
  )
161
- text = resp.choices[0].message.content.strip()
162
- return text
163
 
164
- def run_agent(query: str, csv_path: str, dataset_description: str):
165
- """
166
- Modified to use yield for progressive updates and immediate feedback
167
- """
168
- # Immediate feedback - show processing has started
169
- processing_html = """
170
- <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
171
- <div style='font-size: 16px; margin-bottom: 5px;'>🔄 Analysis in Progress...</div>
172
- <div style='font-size: 14px; color: #666;'>This may take 1-2 minutes depending on dataset size</div>
173
  </div>
174
  """
175
-
176
- yield (
177
- processing_html, # method_out
178
- processing_html, # effects_out
179
- processing_html, # explanation_out
180
- {"status": "Processing started..."} # raw_results
181
- )
182
 
183
- # Input validation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  if not os.getenv("OPENAI_API_KEY"):
185
- error_html = "<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>⚠️ Set a Space Secret named OPENAI_API_KEY</div>"
186
- yield (error_html, "", "", {})
187
  return
188
-
189
  if not csv_path:
190
- error_html = "<div style='padding: 10px; border: 1px solid #ffc107; border-radius: 5px; color: #856404; background-color: #333333;'>Please upload a CSV dataset.</div>"
191
- yield (error_html, "", "", {})
192
  return
193
 
194
  try:
195
- # Update status to show causal analysis is running
196
- analysis_html = """
197
- <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
198
- <div style='font-size: 16px; margin-bottom: 5px;'>📊 Running Causal Analysis...</div>
199
- <div style='font-size: 14px; color: #666;'>Analyzing dataset and selecting optimal method</div>
200
- </div>
201
- """
202
-
203
- yield (
204
- analysis_html,
205
- analysis_html,
206
- analysis_html,
207
- {"status": "Running causal analysis..."}
208
- )
209
-
210
  result = run_causal_analysis(
211
  query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
212
  dataset_path=csv_path,
213
  dataset_description=(dataset_description or "").strip(),
214
  )
215
-
216
- # Update to show LLM summarization step
217
- llm_html = """
218
- <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
219
- <div style='font-size: 16px; margin-bottom: 5px;'>🤖 Generating Summary...</div>
220
- <div style='font-size: 14px; color: #666;'>Creating human-readable interpretation</div>
221
- </div>
222
- """
223
-
224
- yield (
225
- llm_html,
226
- llm_html,
227
- llm_html,
228
- {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}}
229
- )
230
-
231
  except Exception as e:
232
- error_html = f"<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>❌ Error: {e}</div>"
233
- yield (error_html, "", "", {})
234
  return
235
 
236
  try:
237
  payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
238
  method = payload.get("method_used", "N/A")
239
-
240
- # Format method output with simple styling
241
- method_html = f"""
242
- <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
243
- <h3 style='margin: 0 0 10px 0; font-size: 18px;'>Selected Method</h3>
244
- <p style='margin: 0; font-size: 16px;'>{method}</p>
245
- </div>
246
- """
247
-
248
- # Format effects with simple styling
249
  effect_estimate = payload.get("estimates", {}).get("effect_estimate", {})
250
  if effect_estimate:
251
- effects_html = "<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>"
252
- effects_html += "<h3 style='margin: 0 0 10px 0; font-size: 18px;'>Effect Estimates</h3>"
253
- # for k, v in effect_estimate.items():
254
- # try:
255
- # value = f"{float(v):+.4f}"
256
- # effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #ffffff;'><strong>{k}:</strong> <span style='font-size: 16px;'>{value}</span></div>"
257
- # except:
258
- effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #333333;'>{effect_estimate}</div>"
259
- effects_html += "</div>"
260
  else:
261
- effects_html = "<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; color: #666; font-style: italic; background-color: #333333;'>No effect estimates found</div>"
262
 
263
- # Generate explanation and format it
264
  try:
265
  explanation = _summarize_with_llm(payload)
266
- explanation_html = f"""
267
- <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>
268
- <h3 style='margin: 0 0 15px 0; font-size: 18px;'>Detailed Explanation</h3>
269
- <div style='line-height: 1.6; white-space: pre-wrap;'>{explanation}</div>
270
- </div>
271
- """
272
  except Exception as e:
273
- explanation_html = f"<div style='padding: 10px; border: 1px solid #ffc107; border-radius: 5px; color: #856404; background-color: #333333;'>⚠️ LLM summary failed: {e}</div>"
274
 
275
  except Exception as e:
276
- error_html = f"<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>❌ Failed to parse results: {e}</div>"
277
- yield (error_html, "", "", {})
278
  return
279
 
280
- # Final result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {})
282
 
283
  with gr.Blocks() as demo:
284
- gr.Markdown("# Causal Agent")
285
- gr.Markdown("Upload your dataset and ask causal questions in natural language. The system will automatically select the appropriate causal inference method and provide clear explanations.")
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  with gr.Row():
288
  query = gr.Textbox(
@@ -290,20 +249,56 @@ with gr.Blocks() as demo:
290
  placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?",
291
  lines=2,
292
  )
293
-
294
  with gr.Row():
295
- csv_file = gr.File(
296
- label="Dataset (CSV)",
297
- file_types=[".csv"],
298
- type="filepath"
299
- )
300
-
301
  dataset_description = gr.Textbox(
302
  label="Dataset description (optional)",
303
  placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.",
304
  lines=4,
305
  )
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  run_btn = gr.Button("Run analysis", variant="primary")
308
 
309
  with gr.Row():
@@ -315,25 +310,17 @@ with gr.Blocks() as demo:
315
  with gr.Row():
316
  explanation_out = gr.HTML(label="Detailed Explanation")
317
 
318
- # Add the collapsible raw results section
319
  with gr.Accordion("Raw Results (Advanced)", open=False):
320
  raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
321
 
322
  run_btn.click(
323
  fn=run_agent,
324
- inputs=[query, csv_file, dataset_description],
325
  outputs=[method_out, effects_out, explanation_out, raw_results],
326
  show_progress=True
327
  )
328
 
329
- gr.Markdown(
330
- """
331
- **Tips:**
332
- - Be specific about your treatment, outcome, and control variables
333
- - Include relevant context in the dataset description
334
- - The analysis may take 1-2 minutes for complex datasets
335
- """
336
- )
337
 
338
  if __name__ == "__main__":
339
- demo.queue().launch()
 
4
  from pathlib import Path
5
  import gradio as gr
6
  import time
7
+ import smtplib
8
+ from email.message import EmailMessage
9
 
10
  # Make your repo importable (expecting a folder named causal-agent at repo root)
11
  sys.path.append(str(Path(__file__).parent / "causal-agent"))
12
+ EXAMPLE_CSV_PATH = os.getenv(
13
+ "EXAMPLE_CSV_PATH",
14
+ str(Path(__file__).parent / "data" / "synthetic_data")
15
+ )
16
  from auto_causal.agent import run_causal_analysis # uses env for provider/model
17
 
18
  # -------- LLM config (OpenAI only; key via HF Secrets) --------
19
  os.environ.setdefault("LLM_PROVIDER", "openai")
20
  os.environ.setdefault("LLM_MODEL", "gpt-4o")
21
 
 
22
  def _get_openai_client():
23
  if os.getenv("LLM_PROVIDER", "openai") != "openai":
24
  raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.")
25
  if not os.getenv("OPENAI_API_KEY"):
26
  raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).")
27
+ from openai import OpenAI
28
+ return OpenAI()
 
 
 
 
29
 
 
30
  SYSTEM_PROMPT = """You are an expert in statistics and causal inference.
31
  You will be given:
32
  1) The original research question.
33
  2) The analysis method used.
34
  3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group.
35
  4) A brief dataset description.
 
36
  Your task is to produce a clear, concise, and non-technical summary that:
37
  - Directly answers the research question.
38
  - States whether the effect is statistically significant.
39
  - Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change).
40
  - Mentions the method used in one sentence.
41
  - Optionally ranks the treatment effects from largest to smallest if multiple treatments exist.
 
42
  Formatting rules:
43
  - Use bullet points or short paragraphs.
44
  - Report effect sizes to two decimal places.
45
  - Clearly state the interpretation in plain English without technical jargon.
 
46
  Example Output Structure:
47
  - **Method:** [Name of method + 1-line rationale]
48
  - **Key Finding:** [Main answer to the research question]
 
53
  """
54
 
55
  def _extract_minimal_payload(agent_result: dict) -> dict:
 
 
 
 
 
56
  res = agent_result or {}
57
  results = res.get("results", {}) if isinstance(res.get("results"), dict) else {}
58
  inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {}
 
59
  dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {}
60
 
 
61
  question = (
62
  results.get("original_query")
63
  or dataset_analysis.get("original_query")
 
71
  or "N/A"
72
  )
73
 
74
+ effect_estimate = inner.get("effect_estimate") or res.get("effect_estimate") or {}
75
+ confidence_interval = inner.get("confidence_interval") or res.get("confidence_interval") or {}
76
+ standard_error = inner.get("standard_error") or res.get("standard_error") or {}
77
+ p_value = inner.get("p_value") or res.get("p_value") or {}
78
+ dataset_desc = results.get("dataset_description") or res.get("dataset_description") or "N/A"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  return {
81
  "original_question": question,
 
89
  "dataset_description": dataset_desc,
90
  }
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def _summarize_with_llm(payload: dict) -> str:
 
 
 
 
93
  client = _get_openai_client()
94
  model_name = os.getenv("LLM_MODEL", "gpt-4o-mini")
95
 
96
+ user_content = "Summarize the following causal analysis results:\n\n" + json.dumps(payload, indent=2, ensure_ascii=False)
 
 
 
 
 
97
  resp = client.chat.completions.create(
98
  model=model_name,
99
+ messages=[{"role": "system", "content": SYSTEM_PROMPT},
100
+ {"role": "user", "content": user_content}],
 
 
101
  temperature=0
102
  )
103
+ return resp.choices[0].message.content.strip()
 
104
 
105
+ def _html_panel(title, body_html):
106
+ return f"""
107
+ <div style='padding:15px;border:1px solid #ddd;border-radius:8px;margin:5px 0;background-color:#333333;'>
108
+ <h3 style='margin:0 0 10px 0;font-size:18px;'>{title}</h3>
109
+ <div style='line-height:1.6;'>{body_html}</div>
 
 
 
 
110
  </div>
111
  """
 
 
 
 
 
 
 
112
 
113
+ def _warn_html(text):
114
+ return f"<div style='padding:10px;border:1px solid #ffc107;border-radius:5px;color:#ffc107;background-color:#333333;'>⚠️ {text}</div>"
115
+
116
+ def _err_html(text):
117
+ return f"<div style='padding:10px;border:1px solid #dc3545;border-radius:5px;color:#dc3545;background-color:#333333;'>❌ {text}</div>"
118
+
119
+ def _ok_html(text):
120
+ return f"<div style='padding:10px;border:1px solid #2ea043;border-radius:5px;color:#2ea043;background-color:#333333;'>✅ {text}</div>"
121
+
122
+ # --- Email support ---
123
+ def send_email(recipient: str, subject: str, body_text: str, attachment_name: str = None, attachment_json: dict = None) -> str:
124
+ """Returns '' on success, or error message string."""
125
+ host = os.getenv("SMTP_HOST")
126
+ port = int(os.getenv("SMTP_PORT", "587"))
127
+ user = os.getenv("SMTP_USER")
128
+ pwd = os.getenv("SMTP_PASS")
129
+ from_addr = os.getenv("EMAIL_FROM")
130
+
131
+ if not all([host, port, user, pwd, from_addr]):
132
+ return "Email is not configured (set SMTP_HOST, SMTP_PORT, SMTP_USER, SMTP_PASS, EMAIL_FROM)."
133
+
134
+ try:
135
+ msg = EmailMessage()
136
+ msg["From"] = from_addr
137
+ msg["To"] = recipient
138
+ msg["Subject"] = subject
139
+ msg.set_content(body_text)
140
+
141
+ if attachment_json is not None and attachment_name:
142
+ payload = json.dumps(attachment_json, indent=2).encode("utf-8")
143
+ msg.add_attachment(payload, maintype="application", subtype="json", filename=attachment_name)
144
+
145
+ with smtplib.SMTP(host, port, timeout=30) as s:
146
+ s.starttls()
147
+ s.login(user, pwd)
148
+ s.send_message(msg)
149
+ return ""
150
+ except Exception as e:
151
+ return f"Email send failed: {e}"
152
+
153
+ def run_agent(query: str, csv_path: str, dataset_description: str, email: str):
154
+ start = time.time()
155
+
156
+ processing_html = _html_panel("🔄 Analysis in Progress...", "<div style='font-size:14px;color:#bbb;'>This may take 1–2 minutes depending on dataset size.</div>")
157
+ yield (processing_html, processing_html, processing_html, {"status": "Processing started..."})
158
+
159
  if not os.getenv("OPENAI_API_KEY"):
160
+ yield (_err_html("Set a Space Secret named OPENAI_API_KEY"), "", "", {})
 
161
  return
 
162
  if not csv_path:
163
+ yield (_warn_html("Please upload a CSV dataset."), "", "", {})
 
164
  return
165
 
166
  try:
167
+ step_html = _html_panel("📊 Running Causal Analysis...", "<div style='font-size:14px;color:#bbb;'>Analyzing dataset and selecting optimal method…</div>")
168
+ yield (step_html, step_html, step_html, {"status": "Running causal analysis..."})
169
+
 
 
 
 
 
 
 
 
 
 
 
 
170
  result = run_causal_analysis(
171
  query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
172
  dataset_path=csv_path,
173
  dataset_description=(dataset_description or "").strip(),
174
  )
175
+
176
+ llm_html = _html_panel("🤖 Generating Summary...", "<div style='font-size:14px;color:#bbb;'>Creating human-readable interpretation…</div>")
177
+ yield (llm_html, llm_html, llm_html, {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}})
178
+
 
 
 
 
 
 
 
 
 
 
 
 
179
  except Exception as e:
180
+ yield (_err_html(str(e)), "", "", {})
 
181
  return
182
 
183
  try:
184
  payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
185
  method = payload.get("method_used", "N/A")
186
+
187
+ method_html = _html_panel("Selected Method", f"<p style='margin:0;font-size:16px;'>{method}</p>")
188
+
 
 
 
 
 
 
 
189
  effect_estimate = payload.get("estimates", {}).get("effect_estimate", {})
190
  if effect_estimate:
191
+ effects_html = _html_panel("Effect Estimates", f"<pre style='white-space:pre-wrap;margin:0;'>{json.dumps(effect_estimate, indent=2)}</pre>")
 
 
 
 
 
 
 
 
192
  else:
193
+ effects_html = _warn_html("No effect estimates found")
194
 
 
195
  try:
196
  explanation = _summarize_with_llm(payload)
197
+ explanation_html = _html_panel("Detailed Explanation", f"<div style='white-space:pre-wrap;'>{explanation}</div>")
 
 
 
 
 
198
  except Exception as e:
199
+ explanation_html = _warn_html(f"LLM summary failed: {e}")
200
 
201
  except Exception as e:
202
+ yield (_err_html(f"Failed to parse results: {e}"), "", "", {})
 
203
  return
204
 
205
+ # Optional email send (best-effort)
206
+ elapsed = time.time() - start
207
+ if email and "@" in email:
208
+ # Always send once results are ready; if you prefer thresholded behavior, check (elapsed > 600)
209
+ subject = "Causal Agent Results"
210
+ body = (
211
+ "Here are your Causal Agent results.\n\n"
212
+ f"Question: {payload.get('original_question','N/A')}\n"
213
+ f"Method: {method}\n\n"
214
+ f"Summary:\n{explanation}\n\n"
215
+ "Raw JSON is attached.\n"
216
+ )
217
+ email_err = send_email(
218
+ recipient=email.strip(),
219
+ subject=subject,
220
+ body_text=body,
221
+ attachment_name="causal_results.json",
222
+ attachment_json=(result if isinstance(result, dict) else {"results": result})
223
+ )
224
+ if email_err:
225
+ explanation_html += _warn_html(email_err)
226
+ else:
227
+ explanation_html += _ok_html(f"Results emailed to {email.strip()}")
228
+
229
  yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {})
230
 
231
  with gr.Blocks() as demo:
232
+ gr.Markdown("# Causal AI Scientist")
233
+ # gr.Markdown(
234
+ # """
235
+ # **Tips**
236
+ # - Be specific about your treatment, outcome, and control variables.
237
+ # - Include relevant context in the dataset description.
238
+ # - If you enter an email, we’ll send results when ready (only if SMTP is configured via env).
239
+ # """
240
+ # )
241
+ gr.Markdown(
242
+ "Upload your dataset and ask causal questions in natural language. "
243
+ "The system will automatically select the appropriate causal inference method and provide clear explanations."
244
+ )
245
 
246
  with gr.Row():
247
  query = gr.Textbox(
 
249
  placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?",
250
  lines=2,
251
  )
252
+
253
  with gr.Row():
254
+ csv_file = gr.File(label="Dataset (CSV)", file_types=[".csv"], type="filepath")
255
+
 
 
 
 
256
  dataset_description = gr.Textbox(
257
  label="Dataset description (optional)",
258
  placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.",
259
  lines=4,
260
  )
261
 
262
+ # NEW: optional email field
263
+ email = gr.Textbox(
264
+ label="Email (optional)",
265
+ placeholder="you@example.com — we'll email the results when ready (if email is configured).",
266
+ )
267
+
268
+ # Helpful examples (question + description)
269
+ gr.Examples(
270
+ examples=[
271
+ [
272
+ "Does the adoption of the industrial reform policy increase the production output in factories?",
273
+ """This dataset has been compiled to study the effect of an industrial reform policy on production output in several manufacturing factories. The data was collected over two-time frames, before and after the reform, from a group of factories which adopted the reform and another group that did not. The data collected includes continuous variables such as labor hours spent on production and the amount of raw materials used. Binary variables include the use of automation, energy efficiency of the machines used (energy efficient or not), and worker satisfaction (satisfied or not).
274
+
275
+ - factory_id: Unique identifier for each factory
276
+ - post_reform: Indicator if the data was collected after the reform (1) or before the reform (0)
277
+ - labor_hours: The number of labor hours spent on production
278
+ - raw_materials: The quantity of raw materials used in kilograms
279
+ - automation_use: Indicator if the factory uses automation in its production process (1) or not (0)
280
+ - energy_efficiency: Indicator if the factory uses energy-efficient machines (1) or not (0)
281
+ - worker_satisfaction: Indicator if workers reported being satisfied with their work environment (1) or not (0)""",
282
+ EXAMPLE_CSV_PATH + '/did_canonical_data_1.csv'
283
+ ],
284
+ [
285
+ "Does taking the newly developed medication have an impact on improving the lung capacity of patients with chronic obstructive pulmonary disease?",
286
+ """This dataset is collected from a clinical trial study conducted by a pharmaceutical company. The study aims to understand if their newly developed medication can enhance the lung capacity of patients suffering from chronic obstructive pulmonary disease (COPD). Participants for the study, who are all COPD patients, were recruited from various healthcare centers and were randomly assigned to either receive the new medication or a placebo. Data was collected on various factors including the age of the participants, the number of years they have been smoking, their gender, whether or not they have a history of smoking, and whether or not they have a regular exercise habit.
287
+
288
+ 'age' is the age of the participants in years.
289
+ 'smoking_years' is the number of years the participant has been smoking.
290
+ 'gender' is a binary variable where 1 represents male and 0 represents female.
291
+ 'smoking_history' is a binary variable where 1 indicates the participant has a history of smoking, while 0 indicates no such history.
292
+ 'exercise_habit' is a binary variable where 1 indicates the participant exercises regularly, while 0 indicates the participant does not.
293
+ 'new_medication' is a binary variable where 1 indicates the participant was assigned the new medication, while 0 indicates the participant was assigned a placebo.
294
+ 'lung_capacity' is the measured lung capacity of the participant.""",
295
+ EXAMPLE_CSV_PATH + '/rct_data_4.csv',
296
+ ],
297
+ ],
298
+ inputs=[query, dataset_description, csv_file], # include the file component here
299
+ label="Quick Examples (click to fill)",
300
+ )
301
+
302
  run_btn = gr.Button("Run analysis", variant="primary")
303
 
304
  with gr.Row():
 
310
  with gr.Row():
311
  explanation_out = gr.HTML(label="Detailed Explanation")
312
 
 
313
  with gr.Accordion("Raw Results (Advanced)", open=False):
314
  raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
315
 
316
  run_btn.click(
317
  fn=run_agent,
318
+ inputs=[query, csv_file, dataset_description, email],
319
  outputs=[method_out, effects_out, explanation_out, raw_results],
320
  show_progress=True
321
  )
322
 
323
+
 
 
 
 
 
 
 
324
 
325
  if __name__ == "__main__":
326
+ demo.queue().launch()