import os import sys import json from pathlib import Path import gradio as gr import time import smtplib from email.message import EmailMessage # Make your repo importable (expecting a folder named causal-agent at repo root) sys.path.append(str(Path(__file__).parent / "causal-agent")) EXAMPLE_CSV_PATH = os.getenv( "EXAMPLE_CSV_PATH", str(Path(__file__).parent ) ) from auto_causal.agent import run_causal_analysis # uses env for provider/model # -------- LLM config (OpenAI only; key via HF Secrets) -------- os.environ.setdefault("LLM_PROVIDER", "openai") os.environ.setdefault("LLM_MODEL", "gpt-4o") def _get_openai_client(): if os.getenv("LLM_PROVIDER", "openai") != "openai": raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.") if not os.getenv("OPENAI_API_KEY"): raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).") from openai import OpenAI return OpenAI() SYSTEM_PROMPT = """You are an expert in statistics and causal inference. You will be given: 1) The original research question. 2) The analysis method used. 3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group. 4) A brief dataset description. Your task is to produce a clear, concise, and non-technical summary that: - Directly answers the research question. - States whether the effect is statistically significant. - Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change). - Mentions the method used in one sentence. - Optionally ranks the treatment effects from largest to smallest if multiple treatments exist. Formatting rules: - Use bullet points or short paragraphs. - Report effect sizes to two decimal places. - Clearly state the interpretation in plain English without technical jargon. Example Output Structure: - **Method:** [Name of method + 1-line rationale] - **Key Finding:** [Main answer to the research question] - **Details:** - [Treatment name]: +X.XX percentage points (95% CI: [L, U]), p < 0.001 — [Significance comment] - … - **Rank Order of Effects:** [Largest → Smallest] """ def _extract_minimal_payload(agent_result: dict) -> dict: res = agent_result or {} results = res.get("results", {}) if isinstance(res.get("results"), dict) else {} inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {} dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {} question = ( results.get("original_query") or dataset_analysis.get("original_query") or res.get("query") or "N/A" ) method = ( inner.get("method_used") or res.get("method_used") or results.get("method_used") or "N/A" ) effect_estimate = inner.get("effect_estimate") or res.get("effect_estimate") or {} confidence_interval = inner.get("confidence_interval") or res.get("confidence_interval") or {} standard_error = inner.get("standard_error") or res.get("standard_error") or {} p_value = inner.get("p_value") or res.get("p_value") or {} dataset_desc = results.get("dataset_description") or res.get("dataset_description") or "N/A" return { "original_question": question, "method_used": method, "estimates": { "effect_estimate": effect_estimate, "confidence_interval": confidence_interval, "standard_error": standard_error, "p_value": p_value, }, "dataset_description": dataset_desc, } def _summarize_with_llm(payload: dict) -> str: client = _get_openai_client() model_name = os.getenv("LLM_MODEL", "gpt-4o-mini") user_content = "Summarize the following causal analysis results:\n\n" + json.dumps(payload, indent=2, ensure_ascii=False) resp = client.chat.completions.create( model=model_name, messages=[{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_content}], temperature=0 ) return resp.choices[0].message.content.strip() def _html_panel(title, body_html): return f"""

{title}

{body_html}
""" def _warn_html(text): return f"
⚠️ {text}
" def _err_html(text): return f"
❌ {text}
" def _ok_html(text): return f"
✅ {text}
" # --- Email support --- import base64, json, requests from email.message import EmailMessage def _gmail_access_token() -> str: token_url = "https://oauth2.googleapis.com/token" data = { "client_id": os.getenv("GMAIL_CLIENT_ID"), "client_secret": os.getenv("GMAIL_CLIENT_SECRET"), "refresh_token": os.getenv("GMAIL_REFRESH_TOKEN"), "grant_type": "refresh_token", } r = requests.post(token_url, data=data, timeout=20) r.raise_for_status() return r.json()["access_token"] def send_email(recipient: str, subject: str, body_text: str, attachment_name: str = None, attachment_json: dict = None) -> str: """ Sends via Gmail API. Returns '' on success, or an error string. """ from_addr = os.getenv("EMAIL_FROM") if not all([os.getenv("GMAIL_CLIENT_ID"), os.getenv("GMAIL_CLIENT_SECRET"), os.getenv("GMAIL_REFRESH_TOKEN"), from_addr]): return "Gmail API not configured (set GMAIL_CLIENT_ID, GMAIL_CLIENT_SECRET, GMAIL_REFRESH_TOKEN, EMAIL_FROM)." try: # Build MIME message msg = EmailMessage() msg["From"] = from_addr msg["To"] = recipient msg["Subject"] = subject msg.set_content(body_text) if attachment_json is not None and attachment_name: payload = json.dumps(attachment_json, indent=2).encode("utf-8") msg.add_attachment(payload, maintype="application", subtype="json", filename=attachment_name) # Base64url encode the raw RFC822 message raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8") # Get access token and send access_token = _gmail_access_token() api_url = "https://gmail.googleapis.com/gmail/v1/users/me/messages/send" headers = {"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"} r = requests.post(api_url, headers=headers, json={"raw": raw}, timeout=20) if r.status_code >= 400: return f"Gmail API error {r.status_code}: {r.text[:300]}" return "" except Exception as e: return f"Email send failed: {e}" def run_agent(query: str, csv_path: str, dataset_description: str, email: str): start = time.time() processing_html = _html_panel("🔄 Analysis in Progress...", "
This may take 1–2 minutes depending on dataset size.
") yield (processing_html, processing_html, processing_html, {"status": "Processing started..."}) if not os.getenv("OPENAI_API_KEY"): yield (_err_html("Set a Space Secret named OPENAI_API_KEY"), "", "", {}) return if not csv_path: yield (_warn_html("Please upload a CSV dataset."), "", "", {}) return try: step_html = _html_panel("📊 Running Causal Analysis...", "
Analyzing dataset and selecting optimal method…
") yield (step_html, step_html, step_html, {"status": "Running causal analysis..."}) result = run_causal_analysis( query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(), dataset_path=csv_path, dataset_description=(dataset_description or "").strip(), ) llm_html = _html_panel("🤖 Generating Summary...", "
Creating human-readable interpretation…
") yield (llm_html, llm_html, llm_html, {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}}) except Exception as e: yield (_err_html(str(e)), "", "", {}) return try: payload = _extract_minimal_payload(result if isinstance(result, dict) else {}) method = payload.get("method_used", "N/A") method_html = _html_panel("Selected Method", f"

{method}

") effect_estimate = payload.get("estimates", {}).get("effect_estimate", {}) if effect_estimate: effects_html = _html_panel("Effect Estimates", f"
{json.dumps(effect_estimate, indent=2)}
") else: effects_html = _warn_html("No effect estimates found") try: explanation = _summarize_with_llm(payload) explanation_html = _html_panel("Detailed Explanation", f"
{explanation}
") except Exception as e: explanation_html = _warn_html(f"LLM summary failed: {e}") except Exception as e: yield (_err_html(f"Failed to parse results: {e}"), "", "", {}) return # Optional email send (best-effort) elapsed = time.time() - start if email and "@" in email: # Always send once results are ready; if you prefer thresholded behavior, check (elapsed > 600) subject = "Causal Agent Results" body = ( "Here are your Causal Agent results.\n\n" f"Question: {payload.get('original_question','N/A')}\n" f"Method: {method}\n\n" f"Summary:\n{explanation}\n\n" "Raw JSON is attached.\n" ) email_err = send_email( recipient=email.strip(), subject=subject, body_text=body, attachment_name="causal_results.json", attachment_json=(result if isinstance(result, dict) else {"results": result}) ) if email_err: explanation_html += _warn_html(email_err) else: explanation_html += _ok_html(f"Results emailed to {email.strip()}") yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {}) with gr.Blocks() as demo: gr.Markdown("# Causal AI Scientist") # gr.Markdown( # """ # **Tips** # - Be specific about your treatment, outcome, and control variables. # - Include relevant context in the dataset description. # - If you enter an email, we’ll send results when ready (only if SMTP is configured via env). # """ # ) gr.Markdown( "Upload your dataset and ask causal questions in natural language. " "The system will automatically select the appropriate causal inference method and provide clear explanations." ) with gr.Row(): query = gr.Textbox( label="Your causal question (natural language)", placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?", lines=2, ) with gr.Row(): csv_file = gr.File(label="Dataset (CSV)", file_types=[".csv"], type="filepath") dataset_description = gr.Textbox( label="Dataset description (optional)", placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.", lines=4, ) # NEW: optional email field email = gr.Textbox( label="Email (optional)", placeholder="you@example.com — we'll email the results when ready (if email is configured).", ) # Helpful examples (question + description) gr.Examples( examples=[ [ "Does the adoption of the industrial reform policy increase the production output in factories?", """This dataset has been compiled to study the effect of an industrial reform policy on production output in several manufacturing factories. The data was collected over two-time frames, before and after the reform, from a group of factories which adopted the reform and another group that did not. The data collected includes continuous variables such as labor hours spent on production and the amount of raw materials used. Binary variables include the use of automation, energy efficiency of the machines used (energy efficient or not), and worker satisfaction (satisfied or not). - factory_id: Unique identifier for each factory - post_reform: Indicator if the data was collected after the reform (1) or before the reform (0) - labor_hours: The number of labor hours spent on production - raw_materials: The quantity of raw materials used in kilograms - automation_use: Indicator if the factory uses automation in its production process (1) or not (0) - energy_efficiency: Indicator if the factory uses energy-efficient machines (1) or not (0) - worker_satisfaction: Indicator if workers reported being satisfied with their work environment (1) or not (0)""", EXAMPLE_CSV_PATH + '/did_canonical_data_1.csv' ], [ "Does taking the newly developed medication have an impact on improving the lung capacity of patients with chronic obstructive pulmonary disease?", """This dataset is collected from a clinical trial study conducted by a pharmaceutical company. The study aims to understand if their newly developed medication can enhance the lung capacity of patients suffering from chronic obstructive pulmonary disease (COPD). Participants for the study, who are all COPD patients, were recruited from various healthcare centers and were randomly assigned to either receive the new medication or a placebo. Data was collected on various factors including the age of the participants, the number of years they have been smoking, their gender, whether or not they have a history of smoking, and whether or not they have a regular exercise habit. 'age' is the age of the participants in years. 'smoking_years' is the number of years the participant has been smoking. 'gender' is a binary variable where 1 represents male and 0 represents female. 'smoking_history' is a binary variable where 1 indicates the participant has a history of smoking, while 0 indicates no such history. 'exercise_habit' is a binary variable where 1 indicates the participant exercises regularly, while 0 indicates the participant does not. 'new_medication' is a binary variable where 1 indicates the participant was assigned the new medication, while 0 indicates the participant was assigned a placebo. 'lung_capacity' is the measured lung capacity of the participant.""", EXAMPLE_CSV_PATH + '/rct_data_4.csv', ], ], inputs=[query, dataset_description, csv_file], # include the file component here label="Quick Examples (click to fill)", ) run_btn = gr.Button("Run analysis", variant="primary") with gr.Row(): with gr.Column(scale=1): method_out = gr.HTML(label="Selected Method") with gr.Column(scale=1): effects_out = gr.HTML(label="Effect Estimates") with gr.Row(): explanation_out = gr.HTML(label="Detailed Explanation") with gr.Accordion("Raw Results (Advanced)", open=False): raw_results = gr.JSON(label="Complete Analysis Output", show_label=False) run_btn.click( fn=run_agent, inputs=[query, csv_file, dataset_description, email], outputs=[method_out, effects_out, explanation_out, raw_results], show_progress=True ) if __name__ == "__main__": demo.queue().launch()