Spaces:
Sleeping
Sleeping
dolphinium
add history to chatbot and update solr query generation prompt errors. TODO: fix code generation for visualizations.
840c57d
import gradio as gr | |
import json | |
import re | |
import datetime | |
import pandas as pd | |
import pysolr | |
import google.generativeai as genai | |
from sshtunnel import SSHTunnelForwarder | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import io | |
import os | |
import logging | |
from IPython.display import display, Markdown | |
# --- Suppress Matplotlib Debug Logs --- | |
logging.getLogger('matplotlib').setLevel(logging.WARNING) | |
# --- SSH Tunnel Configuration --- | |
# It's recommended to load secrets securely, e.g., from environment variables | |
SSH_HOST = os.environ.get('SSH_HOST') | |
SSH_PORT = 5322 | |
SSH_USER = os.environ.get('SSH_USER') | |
SSH_PASS = os.environ.get('SSH_PASS') | |
# --- Solr Configuration --- | |
REMOTE_SOLR_HOST = '69.167.186.48' | |
REMOTE_SOLR_PORT = 8983 | |
LOCAL_BIND_PORT = 8983 | |
SOLR_CORE_NAME = 'news' | |
SOLR_USER = os.environ.get('SOLR_USER') | |
SOLR_PASS = os.environ.get('SOLR_PASS') | |
# --- Google Gemini Configuration --- | |
try: | |
genai.configure(api_key=os.environ.get('GEMINI_API_KEY')) | |
except Exception as e: | |
print(f"β Gemini API Key Error: {e}. Please ensure 'GEMINI_API_KEY' is set in your environment.") | |
# --- Global Variables --- | |
ssh_tunnel_server = None | |
solr_client = None | |
llm_model = None | |
is_initialized = False | |
try: | |
# 1. Start the SSH Tunnel | |
ssh_tunnel_server = SSHTunnelForwarder( | |
(SSH_HOST, SSH_PORT), | |
ssh_username=SSH_USER, | |
ssh_password=SSH_PASS, | |
remote_bind_address=(REMOTE_SOLR_HOST, REMOTE_SOLR_PORT), | |
local_bind_address=('127.0.0.1', LOCAL_BIND_PORT) | |
) | |
ssh_tunnel_server.start() | |
print(f"π SSH tunnel established: Local Port {ssh_tunnel_server.local_bind_port} -> Remote Solr.") | |
# 2. Initialize the pysolr client | |
solr_url = f'http://127.0.0.1:{ssh_tunnel_server.local_bind_port}/solr/{SOLR_CORE_NAME}' | |
solr_client = pysolr.Solr(solr_url, auth=(SOLR_USER, SOLR_PASS), always_commit=True) | |
solr_client.ping() | |
print(f"β Solr connection successful on core '{SOLR_CORE_NAME}'.") | |
# 3. Initialize the LLM | |
llm_model = genai.GenerativeModel('gemini-1.5-flash', generation_config=genai.types.GenerationConfig(temperature=0)) | |
print(f"β LLM Model '{llm_model.model_name}' initialized.") | |
print("β System Initialized Successfully.") | |
is_initialized = True | |
except Exception as e: | |
print(f"\nβ An error occurred during setup: {e}") | |
if ssh_tunnel_server and ssh_tunnel_server.is_active: | |
ssh_tunnel_server.stop() | |
field_metadata = [ | |
{ | |
"field_name": "business_model", | |
"type": "string (categorical)", | |
"example_values": ["pharma/bio", "drug delivery", "pharma services"], | |
"definition": "The primary business category of the company involved in the news. Use for filtering by high-level industry segments." | |
}, | |
{ | |
"field_name": "news_type", | |
"type": "string (categorical)", | |
"example_values": ["product news", "financial news", "regulatory news"], | |
"definition": "The category of the news article itself (e.g., financial, regulatory, acquisition). Use for filtering by the type of event being reported." | |
}, | |
{ | |
"field_name": "event_type", | |
"type": "string (categorical)", | |
"example_values": ["phase 2", "phase 1", "pre clinical", "marketed"], | |
"definition": "The clinical or developmental stage of a product or event discussed in the article. Essential for queries about clinical trial phases." | |
}, | |
{ | |
"field_name": "source", | |
"type": "string (categorical)", | |
"example_values": ["Press Release", "PR Newswire", "Business Wire"], | |
"definition": "The original source of the news article, such as a newswire or official report." | |
}, | |
{ | |
"field_name": "company_name", | |
"type": "string (exact match, for faceting)", | |
"example_values": ["pfizer inc.", "astrazeneca plc", "roche"], | |
"definition": "The canonical, standardized name of a company. **Crucially, you MUST use this field for `terms` faceting** to group results by a unique company. Do NOT use this for searching." | |
}, | |
{ | |
"field_name": "company_name_s", | |
"type": "string (multi-valued, for searching)", | |
"example_values": ["pfizer inc.", "roche", "f. hoffmann-la roche ag", "nih"], | |
"definition": "A field containing all known names and synonyms for a company. **You MUST use this field for all `query` parameter searches involving a company name** to ensure comprehensive results. Do NOT use for `terms` faceting." | |
}, | |
{ | |
"field_name": "territory_hq_s", | |
"type": "string (multi-valued, hierarchical)", | |
"example_values": ["united states of america", "europe", "europe western"], | |
"definition": "The geographic location (country and continent) of a company's headquarters. It is hierarchical. Use for filtering by location." | |
}, | |
{ | |
"field_name": "therapeutic_category", | |
"type": "string (specific)", | |
"example_values": ["cancer, other", "cancer, nsclc metastatic", "alzheimer's"], | |
"definition": "The specific disease or therapeutic area being targeted. Use for very specific disease queries." | |
}, | |
{ | |
"field_name": "therapeutic_category_s", | |
"type": "string (multi-valued, for searching)", | |
"example_values": ["cancer", "oncology", "infections", "cns"], | |
"definition": "Broader, multi-valued therapeutic categories and their synonyms. **Use this field for broad category searches** in the `query` parameter." | |
}, | |
{ | |
"field_name": "compound_name", | |
"type": "string (exact match, for faceting)", | |
"example_values": ["opdivo injection solution", "keytruda injection solution"], | |
"definition": "The specific, full trade name of a drug. **Use this field for `terms` faceting** on compounds." | |
}, | |
{ | |
"field_name": "compound_name_s", | |
"type": "string (multi-valued, for searching)", | |
"example_values": ["nivolumab injection solution", "opdivo injection solution", "ono-4538 injection solution"], | |
"definition": "A field with all known trade names and synonyms for a drug. **Use this field for all `query` parameter searches** involving a compound name." | |
}, | |
{ | |
"field_name": "molecule_name", | |
"type": "string (exact match, for faceting)", | |
"example_values": ["cannabidiol", "paclitaxel", "pembrolizumab"], | |
"definition": "The generic, non-proprietary name of the active molecule. **Use this field for `terms` faceting** on molecules." | |
}, | |
{ | |
"field_name": "molecule_name_s", | |
"type": "string (multi-valued, for searching)", | |
"example_values": ["cbd", "s1-220", "a1002n5s"], | |
"definition": "A field with all known generic names and synonyms for a molecule. **Use this field for all `query` parameter searches** involving a molecule name." | |
}, | |
{ | |
"field_name": "highest_phase", | |
"type": "string (categorical)", | |
"example_values": ["marketed", "phase 2", "phase 1"], | |
"definition": "The highest stage of development a drug has ever reached." | |
}, | |
{ | |
"field_name": "drug_delivery_branch_s", | |
"type": "string (multi-valued, for searching)", | |
"example_values": ["injection", "parenteral", "oral", "injection, other", "oral, other"], | |
"definition": "The method of drug administration. **Use this for `query` parameter searches about route of administration** as it contains broader, search-friendly terms." | |
}, | |
{ | |
"field_name": "drug_delivery_branch", | |
"type": "string (categorical, specific, for faceting)", | |
"example_values": ["injection, other", "prefilled syringes", "np liposome", "oral enteric/delayed release"], | |
"definition": "The most specific category of drug delivery technology. **Use this field for `terms` faceting** on specific delivery technologies." | |
}, | |
{ | |
"field_name": "route_branch", | |
"type": "string (categorical)", | |
"example_values": ["injection", "oral", "topical", "inhalation"], | |
"definition": "The primary route of drug administration. Good for faceting on exact routes." | |
}, | |
{ | |
"field_name": "molecule_api_group", | |
"type": "string (categorical)", | |
"example_values": ["small molecules", "biologics", "nucleic acids"], | |
"definition": "High-level classification of the drug's molecular type." | |
}, | |
{ | |
"field_name": "content", | |
"type": "text (full-text search)", | |
"example_values": ["The largest study to date...", "balstilimab..."], | |
"definition": "The full text content of the news article. Use for keyword searches on topics not covered by other specific fields." | |
}, | |
{ | |
"field_name": "date", | |
"type": "date", | |
"example_values": ["2020-10-22T00:00:00Z"], | |
"definition": "The full publication date and time in ISO 8601 format. Use for precise date range queries." | |
}, | |
{ | |
"field_name": "date_year", | |
"type": "number (year)", | |
"example_values": [2020, 2021, 2022], | |
"definition": "The 4-digit year of publication. **Use this for queries involving whole years** (e.g., 'in 2023', 'last year', 'since 2020')." | |
}, | |
{ | |
"field_name": "total_deal_value_in_million", | |
"type": "number (metric)", | |
"example_values": [50, 120.5, 176.157, 1000], | |
"definition": "The total value of a financial deal, in millions of USD. This is the primary numeric field for financial aggregations (sum, avg, etc.). To use this, you must also filter for news that has a deal value, e.g., 'total_deal_value_in_million:[0 TO *]'." | |
} | |
] | |
# Helper function to format the metadata for the prompt | |
def format_metadata_for_prompt(metadata): | |
formatted_string = "" | |
for field in metadata: | |
formatted_string += f"- **{field['field_name']}**\n" | |
formatted_string += f" - **Type**: {field['type']}\n" | |
formatted_string += f" - **Definition**: {field['definition']}\n" | |
formatted_string += f" - **Examples**: {', '.join(map(str, field['example_values']))}\n\n" | |
return formatted_string | |
formatted_field_info = format_metadata_for_prompt(field_metadata) | |
def parse_suggestions_from_report(report_text): | |
"""Extracts numbered suggestions from the report's markdown text.""" | |
suggestions_match = re.search(r"### (?:Deeper Dive: Suggested Follow-up Analyses|Suggestions for Further Exploration)\s*\n(.*?)$", report_text, re.DOTALL | re.IGNORECASE) | |
if not suggestions_match: return [] | |
suggestions_text = suggestions_match.group(1) | |
suggestions = re.findall(r"^\s*\d+\.\s*(.*)", suggestions_text, re.MULTILINE) | |
return [s.strip() for s in suggestions] | |
def llm_generate_solr_query_with_history(natural_language_query, field_metadata, chat_history): | |
"""Generates a Solr query and facet JSON from a natural language query, considering the conversation history.""" | |
# Format the chat history for the prompt | |
formatted_history = "" | |
for user_msg, bot_msg in chat_history: | |
# We only need the user's queries for context, not the bot's detailed responses. | |
if user_msg: | |
# CORRECTED: Properly formatted f-string with a newline character | |
formatted_history += f"- User: \"{user_msg}\"\n" | |
prompt = f""" | |
You are an expert Solr query engineer who converts natural language questions into precise Solr JSON Facet API query objects. Your primary goal is to create a valid JSON object with `query` and `json.facet` keys. | |
--- | |
### CONVERSATIONAL CONTEXT & RULES | |
1. **Today's Date for Calculations**: 2025-07-16 | |
2. **Allowed Facet Types**: The `type` key for any facet MUST be one of the following: `terms`, `query`, or `range`. **Do not use `date_histogram`**. For time-series analysis, use a `range` facet on a date field. | |
3. **Field Usage**: You MUST use the fields described in the 'Field Definitions' section. Pay close attention to the definitions to select the correct field. | |
4. **Facet vs. Query Field Distinction**: This is critical. | |
* For searching in the main `query` parameter, ALWAYS use the multi-valued search fields (ending in `_s`, like `company_name_s`) to get comprehensive results. | |
* For grouping in a `terms` facet, ALWAYS use the canonical, single-value field (e.g., `company_name`, `molecule_name`) to ensure unique and accurate grouping. | |
5. **No `count(*)`**: Do NOT use functions like `count(*)`. The default facet bucket count is sufficient for counting documents. | |
6. **Allowed Aggregations**: For statistical facets, only use these functions: `sum`, `avg`, `min`, `max`, `unique`. The primary metric field is `total_deal_value_in_million`. The aggregation MUST be a simple string like `"sum(total_deal_value_in_million)"` and not a nested JSON object. | |
7. **Term Facet Limits**: Every `terms` facet MUST include a `limit` key. Default to `limit: 10` unless the user specifies a different number of top results. | |
8. **Output Format**: Your final output must be a single, raw JSON object and nothing else. Do not add comments, explanations, or markdown formatting like ```json. | |
--- | |
### FIELD DEFINITIONS (Your Source of Truth) | |
`{formatted_field_info}` | |
--- | |
### CHAT HISTORY | |
`{formatted_history}` | |
--- | |
### EXAMPLE OF A FOLLOW-UP QUERY | |
**Initial User Query:** "What are the infections news in this year?" | |
```json | |
{{ | |
"query": "date_year:2025 AND therapeutic_category_s:infections", | |
"json.facet": {{ | |
"infections_news_by_type": {{ | |
"type": "terms", | |
"field": "news_type", | |
"limit": 10 | |
}} | |
}} | |
}} | |
``` | |
**Follow-up User Query:** "Compare deal values for injection vs oral." | |
**Correct JSON Output for the Follow-up:** | |
```json | |
{{ | |
"query": "therapeutic_category_s:infections AND date_year:2025 AND total_deal_value_in_million:[0 TO *]", | |
"json.facet": {{ | |
"injection_deals": {{ | |
"type": "query", | |
"q": "route_branch:injection", | |
"facet": {{ | |
"total_deal_value": "sum(total_deal_value_in_million)" | |
}} | |
}}, | |
"oral_deals": {{ | |
"type": "query", | |
"q": "route_branch:oral", | |
"facet": {{ | |
"total_deal_value": "sum(total_deal_value_in_million)" | |
}} | |
}} | |
}} | |
}} | |
``` | |
--- | |
### YOUR TASK | |
Now, convert the following user query into a single, raw JSON object with 'query' and 'json.facet' keys, strictly following all rules and field definitions provided above and considering the chat history. | |
**Current User Query:** `{natural_language_query}` | |
""" | |
try: | |
response = llm_model.generate_content(prompt) | |
# Using a more robust regex to clean the response | |
cleaned_text = re.sub(r'```json\s*|\s*```', '', response.text, flags=re.MULTILINE | re.DOTALL).strip() | |
return json.loads(cleaned_text) | |
except Exception as e: | |
raw_response_text = response.text if 'response' in locals() else 'N/A' | |
print(f"Error in llm_generate_solr_query_with_history: {e}\nRaw Response:\n{raw_response_text}") | |
return None | |
def llm_generate_visualization_code(query_context, facet_data): | |
"""Generates Python code for visualization based on query and data.""" | |
prompt = f""" | |
You are a Python Data Visualization expert specializing in Matplotlib and Seaborn. | |
Your task is to generate Python code to create a single, insightful visualization. | |
**Context:** | |
1. **User's Analytical Goal:** "{query_context}" | |
2. **Aggregated Data (from Solr Facets):** | |
```json | |
{json.dumps(facet_data, indent=2)} | |
``` | |
**Instructions:** | |
1. **Goal:** Write Python code to generate a chart that best visualizes the answer to the user's goal using the provided data. | |
2. **Data Access:** The data is available in a Python dictionary named `facet_data`. Your code must parse this dictionary. | |
3. **Code Requirements:** | |
* Start with `import matplotlib.pyplot as plt` and `import seaborn as sns`. | |
* Use `plt.style.use('seaborn-v0_8-whitegrid')` and `fig, ax = plt.subplots(figsize=(12, 7))`. Plot using the `ax` object. | |
* Always include a clear `ax.set_title(...)`, `ax.set_xlabel(...)`, and `ax.set_ylabel(...)`. | |
* Dynamically find the primary facet key and extract the 'buckets'. | |
* For each bucket, extract the 'val' (label) and the relevant metric ('count' or a nested metric). | |
* Use `plt.tight_layout()` and rotate x-axis labels if needed. | |
4. **Output Format:** ONLY output raw Python code. Do not wrap it in ```python ... ```. Do not include `plt.show()` or any explanation. | |
""" | |
try: | |
response = llm_model.generate_content(prompt) | |
code = re.sub(r'^```python\s*|\s*```$', '', response.text, flags=re.MULTILINE) | |
return code | |
except Exception as e: | |
print(f"Error in llm_generate_visualization_code: {e}") | |
return None | |
def execute_viz_code_and_get_path(viz_code, facet_data): | |
"""Executes visualization code and returns the path to the saved plot image.""" | |
if not viz_code: return None | |
try: | |
if not os.path.exists('/tmp/plots'): os.makedirs('/tmp/plots') | |
plot_path = f"/tmp/plots/plot_{datetime.datetime.now().timestamp()}.png" | |
# The exec environment needs access to the required libraries and the data | |
exec_globals = {'facet_data': facet_data, 'plt': plt, 'sns': sns, 'pd': pd} | |
exec(viz_code, exec_globals) | |
fig = exec_globals.get('fig') | |
if fig: | |
fig.savefig(plot_path, bbox_inches='tight') | |
plt.close(fig) # Important to free up memory | |
return plot_path | |
return None | |
except Exception as e: | |
print(f"ERROR executing visualization code: {e}\n---Code---\n{viz_code}") | |
return None | |
def llm_generate_summary_and_suggestions_stream(query_context, facet_data): | |
""" | |
Yields a streaming analytical report and strategic, context-aware suggestions for further exploration. | |
""" | |
prompt = f""" | |
You are a leading business intelligence analyst and strategist. Your audience is an executive or decision-maker who relies on you to not just present data, but to uncover its meaning and suggest smart next steps. | |
Your task is to analyze the provided data, deliver a concise, insightful report, and then propose logical follow-up analyses that could uncover deeper trends or causes. | |
**Today's Date for Context:** {datetime.datetime.now().strftime('%Y-%m-%d')} | |
**Analysis Context:** | |
* **User's Core Question:** "{query_context}" | |
* **Structured Data (Your Evidence):** | |
```json | |
{json.dumps(facet_data, indent=2)} | |
``` | |
**--- INSTRUCTIONS ---** | |
**PART 1: THE ANALYTICAL REPORT** | |
Structure your report using Markdown. Your tone should be insightful, data-driven, and forward-looking. | |
* `## Executive Summary`: A 1-2 sentence, top-line answer to the user's core question. Get straight to the point. | |
* `### Key Findings & Insights`: Use bullet points. Don't just state the data; interpret it. | |
* Highlight the most significant figures, patterns, or anomalies. | |
* Where relevant, calculate key differences or growth rates (e.g., "X is 25% higher than Y"). | |
* Pinpoint what the visualization or data reveals about the core business question. | |
* **Data Note:** Briefly mention any important caveats if apparent from the data (e.g., a short time frame, a small sample size). | |
* `### Context & Implications`: Briefly explain the "so what?" of these findings. What might this mean for our strategy, the market, or operations? | |
**PART 2: DEEPER DIVE: SUGGESTED FOLLOW-UP ANALYSES** | |
After the report, create a final section titled `### Deeper Dive: Suggested Follow-up Analyses`. | |
* **Think like a strategist.** Based on the findings, what would you ask next to validate a trend, understand a change, or uncover a root cause? | |
* **Propose 2-3 logical next questions.** These should be concise and framed as natural language questions that inspire further exploration. | |
* **Focus on comparative and trend analysis.** For example: | |
* If the user asked for "this year," suggest a comparison: *"How does this year's performance in [X] compare to last year?"* | |
* If a category is a clear leader, suggest breaking it down: *"What are the top sub-categories driving the growth in [Leading Category]?"* | |
* If there's a time-based trend, suggest exploring correlations: *"Is the decline in [Metric Z] correlated with changes in any other category during the same period?"* | |
* Format them as a numbered list. | |
* Ensure your suggestions are answerable using the available field definitions below. | |
### FIELD DEFINITIONS (Your Source of Truth) | |
{formatted_field_info} | |
**--- YOUR TASK ---** | |
Generate the full report and the strategic suggestions based on the user's question and the data provided. | |
""" | |
try: | |
response_stream = llm_model.generate_content(prompt, stream=True) | |
for chunk in response_stream: | |
yield chunk.text | |
except Exception as e: | |
print(f"Error in llm_generate_summary_and_suggestions_stream: {e}") | |
yield "Sorry, I was unable to generate a summary for this data." | |
# CORRECTED: Only one, correctly implemented version of this function remains. | |
def process_analysis_flow(user_input, history, state): | |
""" | |
A generator that manages the conversation and yields tuples of UI updates for Gradio. | |
This version treats any user input as a new query and considers conversation history. | |
""" | |
# Initialize state on the first run | |
if state is None: | |
state = {'query_count': 0, 'last_suggestions': []} | |
# If history is None (from a reset), initialize it as an empty list | |
if history is None: | |
history = [] | |
# Reset UI components for the new analysis, but keep chat history | |
yield (history, state, gr.update(value=None, visible=False), gr.update(value=None, visible=False), gr.update(value=None, visible=False), gr.update(value=None, visible=False)) | |
query_context = user_input.strip() | |
if not query_context: | |
history.append((user_input, "Please enter a question to analyze.")) | |
yield (history, state, None, None, None, None) | |
return | |
# 1. Acknowledge and start the process | |
history.append((user_input, f"Analyzing: '{query_context}'\n\n*Generating Solr query...*")) | |
yield (history, state, None, None, None, None) | |
# 2. Generate Solr Query with history | |
llm_solr_obj = llm_generate_solr_query_with_history(query_context, field_metadata, history) | |
if not llm_solr_obj or 'query' not in llm_solr_obj or 'json.facet' not in llm_solr_obj: | |
history.append((None, "I'm sorry, I couldn't generate a valid Solr query for that request. Please try rephrasing your question.")) | |
yield (history, state, None, None, None, None) | |
return | |
solr_q, solr_facet = llm_solr_obj.get('query'), llm_solr_obj.get('json.facet') | |
history.append((None, "β Solr query generated!")) | |
formatted_query = f"**Query:**\n```\n{solr_q}\n```\n\n**Facet JSON:**\n```json\n{json.dumps(solr_facet, indent=2)}\n```" | |
yield (history, state, None, None, gr.update(value=formatted_query, visible=True), None) | |
# 3. Execute Query | |
try: | |
history.append((None, "*Executing query against the database...*")) | |
yield (history, state, None, None, gr.update(value=formatted_query, visible=True), None) | |
search_params = {"rows": 0, "json.facet": json.dumps(solr_facet)} | |
results = solr_client.search(q=solr_q, **search_params) | |
facet_data = results.raw_response.get("facets", {}) | |
formatted_data = f"**Facet Data:**\n```json\n{json.dumps(facet_data, indent=2)}\n```" | |
if not facet_data or facet_data.get('count', 0) == 0: | |
history.append((None, "No data was found for your query. Please try a different question.")) | |
yield (history, state, None, None, gr.update(value=formatted_query, visible=True), gr.update(value=formatted_data, visible=True)) | |
return | |
# 4. Generate Visualization | |
history.append((None, "β Data retrieved. Generating visualization...")) | |
yield (history, state, None, None, gr.update(value=formatted_query, visible=True), gr.update(value=formatted_data, visible=True)) | |
viz_code = llm_generate_visualization_code(query_context, facet_data) | |
plot_path = execute_viz_code_and_get_path(viz_code, facet_data) | |
output_plot = gr.update(value=plot_path, visible=True) if plot_path else gr.update(visible=False) | |
if not plot_path: | |
history.append((None, "*I was unable to generate a plot for this data.*\n")) | |
yield (history, state, output_plot, None, gr.update(value=formatted_query, visible=True), gr.update(value=formatted_data, visible=True)) | |
# 5. Generate and Stream Report | |
history.append((None, "β Plot created. Streaming final report...")) | |
output_report = gr.update(value="", visible=True) # Make it visible before streaming | |
yield (history, state, output_plot, output_report, gr.update(value=formatted_query, visible=True), gr.update(value=formatted_data, visible=True)) | |
report_text = "" | |
# The history object is not modified during streaming, so we pass it once | |
# The yield statement for streaming only updates the report text | |
stream_history = history[:] # Make a copy | |
for chunk in llm_generate_summary_and_suggestions_stream(query_context, facet_data): | |
report_text += chunk | |
yield (stream_history, state, output_plot, report_text, gr.update(value=formatted_query, visible=True), gr.update(value=formatted_data, visible=True)) | |
# Update the main history with the final report text | |
history.append((None, report_text)) | |
# 6. Finalize and prompt for next action | |
state['query_count'] += 1 | |
state['last_suggestions'] = parse_suggestions_from_report(report_text) | |
next_prompt = "Analysis complete. What would you like to explore next? You can ask a follow-up question, or ask something new." | |
history.append((None, next_prompt)) | |
yield (history, state, output_plot, report_text, gr.update(value=formatted_query, visible=True), gr.update(value=formatted_data, visible=True)) | |
except Exception as e: | |
error_message = f"An unexpected error occurred during analysis: {e}" | |
history.append((None, error_message)) | |
print(f"Error during analysis execution: {e}") | |
yield (history, state, None, None, gr.update(value=formatted_query, visible=True), None) | |
# --- Gradio UI --- | |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: | |
state = gr.State() | |
gr.Markdown("# π PharmaCircle AI Data Analyst") | |
gr.Markdown("Ask a question to begin your analysis. I will generate a Solr query, retrieve the data, create a visualization, and write a report. You can then ask follow-up questions freely.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
chatbot = gr.Chatbot(label="Analysis Chat Log", height=700, show_copy_button=True, avatar_images=(None, "https://pharma-circle.com/images/favicon.png")) | |
msg_textbox = gr.Textbox(placeholder="Ask a question, e.g., 'Show me the top 5 companies by total deal value in 2023'", label="Your Question", interactive=True) | |
with gr.Row(): | |
clear_button = gr.Button("π Start New Analysis", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Accordion("Generated Solr Query", open=False): | |
solr_query_display = gr.Markdown("Query will appear here...", visible=True) | |
with gr.Accordion("Retrieved Solr Data", open=False): | |
solr_data_display = gr.Markdown("Data will appear here...", visible=False) | |
plot_display = gr.Image(label="Visualization", type="filepath", visible=False) | |
report_display = gr.Markdown("Report will be streamed here...", visible=False) | |
# --- Event Wiring --- | |
def reset_all(): | |
"""Resets the entire UI for a new analysis session.""" | |
return ( | |
[], # chatbot (cleared) | |
None, # state (reset) | |
"", # msg_textbox (cleared) | |
gr.update(value=None, visible=False), # plot_display | |
gr.update(value=None, visible=False), # report_display | |
gr.update(value=None, visible=False), # solr_query_display | |
gr.update(value=None, visible=False) # solr_data_display | |
) | |
# Main event handler for all user queries | |
msg_textbox.submit( | |
fn=process_analysis_flow, | |
inputs=[msg_textbox, chatbot, state], | |
outputs=[chatbot, state, plot_display, report_display, solr_query_display, solr_data_display], | |
).then( | |
lambda: gr.update(value=""), | |
None, | |
[msg_textbox], | |
queue=False, | |
) | |
clear_button.click( | |
fn=reset_all, | |
inputs=None, | |
outputs=[chatbot, state, msg_textbox, plot_display, report_display, solr_query_display, solr_data_display], | |
queue=False | |
) | |
if is_initialized: | |
demo.queue().launch(debug=True, share=True) | |
else: | |
print("\nSkipping Gradio launch due to initialization errors.") | |