File size: 16,136 Bytes
1721aea
 
 
 
 
 
f5c8ef7
 
1721aea
 
b7dc123
 
 
 
 
1721aea
 
 
 
 
 
 
 
 
 
 
f5c8ef7
 
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5c8ef7
 
 
 
 
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5c8ef7
1721aea
 
f5c8ef7
 
1721aea
 
f5c8ef7
1721aea
f5c8ef7
 
 
 
 
1721aea
 
 
f5c8ef7
 
 
 
 
 
 
 
 
 
ca3ce07
 
f5c8ef7
ca3ce07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5c8ef7
 
ca3ce07
f5c8ef7
 
 
 
 
 
 
 
 
 
ca3ce07
 
 
 
 
 
 
 
 
 
 
f5c8ef7
 
 
 
 
 
 
 
 
 
1721aea
f5c8ef7
1721aea
 
f5c8ef7
1721aea
 
 
f5c8ef7
 
 
1721aea
 
 
 
 
f5c8ef7
 
 
 
1721aea
f5c8ef7
1721aea
 
 
 
 
f5c8ef7
 
 
1721aea
 
f5c8ef7
1721aea
f5c8ef7
1721aea
 
 
f5c8ef7
1721aea
f5c8ef7
1721aea
 
f5c8ef7
1721aea
 
f5c8ef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1721aea
 
 
f5c8ef7
 
 
 
 
 
 
 
 
 
 
 
 
1721aea
 
 
 
 
 
 
f5c8ef7
1721aea
f5c8ef7
 
1721aea
 
 
 
 
 
f5c8ef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5c8ef7
1721aea
 
 
 
f5c8ef7
1721aea
 
f5c8ef7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import os
import sys
import json
from pathlib import Path
import gradio as gr
import time
import smtplib
from email.message import EmailMessage

# Make your repo importable (expecting a folder named causal-agent at repo root)
sys.path.append(str(Path(__file__).parent / "causal-agent"))
EXAMPLE_CSV_PATH = os.getenv(
    "EXAMPLE_CSV_PATH",
    str(Path(__file__).parent )
)
from auto_causal.agent import run_causal_analysis  # uses env for provider/model

# -------- LLM config (OpenAI only; key via HF Secrets) --------
os.environ.setdefault("LLM_PROVIDER", "openai")
os.environ.setdefault("LLM_MODEL", "gpt-4o")

def _get_openai_client():
    if os.getenv("LLM_PROVIDER", "openai") != "openai":
        raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.")
    if not os.getenv("OPENAI_API_KEY"):
        raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).")
    from openai import OpenAI
    return OpenAI()

SYSTEM_PROMPT = """You are an expert in statistics and causal inference.
You will be given:
1) The original research question.
2) The analysis method used.
3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group.
4) A brief dataset description.
Your task is to produce a clear, concise, and non-technical summary that:
- Directly answers the research question.
- States whether the effect is statistically significant.
- Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change).
- Mentions the method used in one sentence.
- Optionally ranks the treatment effects from largest to smallest if multiple treatments exist.
Formatting rules:
- Use bullet points or short paragraphs.
- Report effect sizes to two decimal places.
- Clearly state the interpretation in plain English without technical jargon.
Example Output Structure:
- **Method:** [Name of method + 1-line rationale]
- **Key Finding:** [Main answer to the research question]
- **Details:**
  - [Treatment name]: +X.XX percentage points (95% CI: [L, U]), p < 0.001 — [Significance comment]
  - …
- **Rank Order of Effects:** [Largest → Smallest]
"""

def _extract_minimal_payload(agent_result: dict) -> dict:
    res = agent_result or {}
    results = res.get("results", {}) if isinstance(res.get("results"), dict) else {}
    inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {}
    dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {}

    question = (
        results.get("original_query")
        or dataset_analysis.get("original_query")
        or res.get("query")
        or "N/A"
    )
    method = (
        inner.get("method_used")
        or res.get("method_used")
        or results.get("method_used")
        or "N/A"
    )

    effect_estimate = inner.get("effect_estimate") or res.get("effect_estimate") or {}
    confidence_interval = inner.get("confidence_interval") or res.get("confidence_interval") or {}
    standard_error = inner.get("standard_error") or res.get("standard_error") or {}
    p_value = inner.get("p_value") or res.get("p_value") or {}
    dataset_desc = results.get("dataset_description") or res.get("dataset_description") or "N/A"

    return {
        "original_question": question,
        "method_used": method,
        "estimates": {
            "effect_estimate": effect_estimate,
            "confidence_interval": confidence_interval,
            "standard_error": standard_error,
            "p_value": p_value,
        },
        "dataset_description": dataset_desc,
    }

def _summarize_with_llm(payload: dict) -> str:
    client = _get_openai_client()
    model_name = os.getenv("LLM_MODEL", "gpt-4o-mini")

    user_content = "Summarize the following causal analysis results:\n\n" + json.dumps(payload, indent=2, ensure_ascii=False)
    resp = client.chat.completions.create(
        model=model_name,
        messages=[{"role": "system", "content": SYSTEM_PROMPT},
                  {"role": "user", "content": user_content}],
        temperature=0
    )
    return resp.choices[0].message.content.strip()

def _html_panel(title, body_html):
    return f"""
    <div style='padding:15px;border:1px solid #ddd;border-radius:8px;margin:5px 0;background-color:#333333;'>
      <h3 style='margin:0 0 10px 0;font-size:18px;'>{title}</h3>
      <div style='line-height:1.6;'>{body_html}</div>
    </div>
    """

def _warn_html(text):
    return f"<div style='padding:10px;border:1px solid #ffc107;border-radius:5px;color:#ffc107;background-color:#333333;'>⚠️ {text}</div>"

def _err_html(text):
    return f"<div style='padding:10px;border:1px solid #dc3545;border-radius:5px;color:#dc3545;background-color:#333333;'>❌ {text}</div>"

def _ok_html(text):
    return f"<div style='padding:10px;border:1px solid #2ea043;border-radius:5px;color:#2ea043;background-color:#333333;'>✅ {text}</div>"

# --- Email support ---
import base64, json, requests
from email.message import EmailMessage

def _gmail_access_token() -> str:
    token_url = "https://oauth2.googleapis.com/token"
    data = {
        "client_id": os.getenv("GMAIL_CLIENT_ID"),
        "client_secret": os.getenv("GMAIL_CLIENT_SECRET"),
        "refresh_token": os.getenv("GMAIL_REFRESH_TOKEN"),
        "grant_type": "refresh_token",
    }
    r = requests.post(token_url, data=data, timeout=20)
    r.raise_for_status()
    return r.json()["access_token"]

def send_email(recipient: str, subject: str, body_text: str,
               attachment_name: str = None, attachment_json: dict = None) -> str:
    """
    Sends via Gmail API. Returns '' on success, or an error string.
    """
    from_addr = os.getenv("EMAIL_FROM")
    if not all([os.getenv("GMAIL_CLIENT_ID"), os.getenv("GMAIL_CLIENT_SECRET"),
                os.getenv("GMAIL_REFRESH_TOKEN"), from_addr]):
        return "Gmail API not configured (set GMAIL_CLIENT_ID, GMAIL_CLIENT_SECRET, GMAIL_REFRESH_TOKEN, EMAIL_FROM)."

    try:
        # Build MIME message
        msg = EmailMessage()
        msg["From"] = from_addr
        msg["To"] = recipient
        msg["Subject"] = subject
        msg.set_content(body_text)

        if attachment_json is not None and attachment_name:
            payload = json.dumps(attachment_json, indent=2).encode("utf-8")
            msg.add_attachment(payload, maintype="application", subtype="json", filename=attachment_name)

        # Base64url encode the raw RFC822 message
        raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")

        # Get access token and send
        access_token = _gmail_access_token()
        api_url = "https://gmail.googleapis.com/gmail/v1/users/me/messages/send"
        headers = {"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"}
        r = requests.post(api_url, headers=headers, json={"raw": raw}, timeout=20)

        if r.status_code >= 400:
            return f"Gmail API error {r.status_code}: {r.text[:300]}"
        return ""
    except Exception as e:
        return f"Email send failed: {e}"

def run_agent(query: str, csv_path: str, dataset_description: str, email: str):
    start = time.time()

    processing_html = _html_panel("🔄 Analysis in Progress...", "<div style='font-size:14px;color:#bbb;'>This may take 1–2 minutes depending on dataset size.</div>")
    yield (processing_html, processing_html, processing_html, {"status": "Processing started..."})

    if not os.getenv("OPENAI_API_KEY"):
        yield (_err_html("Set a Space Secret named OPENAI_API_KEY"), "", "", {})
        return
    if not csv_path:
        yield (_warn_html("Please upload a CSV dataset."), "", "", {})
        return

    try:
        step_html = _html_panel("📊 Running Causal Analysis...", "<div style='font-size:14px;color:#bbb;'>Analyzing dataset and selecting optimal method…</div>")
        yield (step_html, step_html, step_html, {"status": "Running causal analysis..."})

        result = run_causal_analysis(
            query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
            dataset_path=csv_path,
            dataset_description=(dataset_description or "").strip(),
        )

        llm_html = _html_panel("🤖 Generating Summary...", "<div style='font-size:14px;color:#bbb;'>Creating human-readable interpretation…</div>")
        yield (llm_html, llm_html, llm_html, {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}})

    except Exception as e:
        yield (_err_html(str(e)), "", "", {})
        return

    try:
        payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
        method = payload.get("method_used", "N/A")

        method_html = _html_panel("Selected Method", f"<p style='margin:0;font-size:16px;'>{method}</p>")

        effect_estimate = payload.get("estimates", {}).get("effect_estimate", {})
        if effect_estimate:
            effects_html = _html_panel("Effect Estimates", f"<pre style='white-space:pre-wrap;margin:0;'>{json.dumps(effect_estimate, indent=2)}</pre>")
        else:
            effects_html = _warn_html("No effect estimates found")

        try:
            explanation = _summarize_with_llm(payload)
            explanation_html = _html_panel("Detailed Explanation", f"<div style='white-space:pre-wrap;'>{explanation}</div>")
        except Exception as e:
            explanation_html = _warn_html(f"LLM summary failed: {e}")

    except Exception as e:
        yield (_err_html(f"Failed to parse results: {e}"), "", "", {})
        return

    # Optional email send (best-effort)
    elapsed = time.time() - start
    if email and "@" in email:
        # Always send once results are ready; if you prefer thresholded behavior, check (elapsed > 600)
        subject = "Causal Agent Results"
        body = (
            "Here are your Causal Agent results.\n\n"
            f"Question: {payload.get('original_question','N/A')}\n"
            f"Method: {method}\n\n"
            f"Summary:\n{explanation}\n\n"
            "Raw JSON is attached.\n"
        )
        email_err = send_email(
            recipient=email.strip(),
            subject=subject,
            body_text=body,
            attachment_name="causal_results.json",
            attachment_json=(result if isinstance(result, dict) else {"results": result})
        )
        if email_err:
            explanation_html += _warn_html(email_err)
        else:
            explanation_html += _ok_html(f"Results emailed to {email.strip()}")

    yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {})

with gr.Blocks() as demo:
    gr.Markdown("# Causal AI Scientist")
    # gr.Markdown(
    #         """
    # **Tips**
    # - Be specific about your treatment, outcome, and control variables.
    # - Include relevant context in the dataset description.
    # - If you enter an email, we’ll send results when ready (only if SMTP is configured via env).
    #         """
    #     )
    gr.Markdown(
        "Upload your dataset and ask causal questions in natural language. "
        "The system will automatically select the appropriate causal inference method and provide clear explanations."
    )

    with gr.Row():
        query = gr.Textbox(
            label="Your causal question (natural language)",
            placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?",
            lines=2,
        )

    with gr.Row():
        csv_file = gr.File(label="Dataset (CSV)", file_types=[".csv"], type="filepath")

    dataset_description = gr.Textbox(
        label="Dataset description (optional)",
        placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.",
        lines=4,
    )

    # NEW: optional email field
    email = gr.Textbox(
        label="Email (optional)",
        placeholder="you@example.com — we'll email the results when ready (if email is configured).",
    )

    # Helpful examples (question + description)
    gr.Examples(
    examples=[
        [
            "Does the adoption of the industrial reform policy increase the production output in factories?",
            """This dataset has been compiled to study the effect of an industrial reform policy on production output in several manufacturing factories. The data was collected over two-time frames, before and after the reform, from a group of factories which adopted the reform and another group that did not. The data collected includes continuous variables such as labor hours spent on production and the amount of raw materials used. Binary variables include the use of automation, energy efficiency of the machines used (energy efficient or not), and worker satisfaction (satisfied or not). 

- factory_id: Unique identifier for each factory 
- post_reform: Indicator if the data was collected after the reform (1) or before the reform (0) 
- labor_hours: The number of labor hours spent on production 
- raw_materials: The quantity of raw materials used in kilograms 
- automation_use: Indicator if the factory uses automation in its production process (1) or not (0) 
- energy_efficiency: Indicator if the factory uses energy-efficient machines (1) or not (0) 
- worker_satisfaction: Indicator if workers reported being satisfied with their work environment (1) or not (0)""",
            EXAMPLE_CSV_PATH + '/did_canonical_data_1.csv'
        ],
        [
            "Does taking the newly developed medication have an impact on improving the lung capacity of patients with chronic obstructive pulmonary disease?",
            """This dataset is collected from a clinical trial study conducted by a pharmaceutical company. The study aims to understand if their newly developed medication can enhance the lung capacity of patients suffering from chronic obstructive pulmonary disease (COPD). Participants for the study, who are all COPD patients, were recruited from various healthcare centers and were randomly assigned to either receive the new medication or a placebo. Data was collected on various factors including the age of the participants, the number of years they have been smoking, their gender, whether or not they have a history of smoking, and whether or not they have a regular exercise habit.

'age' is the age of the participants in years.
'smoking_years' is the number of years the participant has been smoking.
'gender' is a binary variable where 1 represents male and 0 represents female.
'smoking_history' is a binary variable where 1 indicates the participant has a history of smoking, while 0 indicates no such history.
'exercise_habit' is a binary variable where 1 indicates the participant exercises regularly, while 0 indicates the participant does not.
'new_medication' is a binary variable where 1 indicates the participant was assigned the new medication, while 0 indicates the participant was assigned a placebo.
'lung_capacity' is the measured lung capacity of the participant.""",
            EXAMPLE_CSV_PATH + '/rct_data_4.csv',
        ],
    ],
    inputs=[query, dataset_description, csv_file],   # include the file component here
    label="Quick Examples (click to fill)",
)

    run_btn = gr.Button("Run analysis", variant="primary")

    with gr.Row():
        with gr.Column(scale=1):
            method_out = gr.HTML(label="Selected Method")
        with gr.Column(scale=1):
            effects_out = gr.HTML(label="Effect Estimates")

    with gr.Row():
        explanation_out = gr.HTML(label="Detailed Explanation")

    with gr.Accordion("Raw Results (Advanced)", open=False):
        raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)

    run_btn.click(
        fn=run_agent,
        inputs=[query, csv_file, dataset_description, email],
        outputs=[method_out, effects_out, explanation_out, raw_results],
        show_progress=True
    )

    

if __name__ == "__main__":
    demo.queue().launch()