import streamlit as st import pandas as pd import plotly.express as px import plotly.graph_objects as go from datetime import datetime import json import os from typing import Dict, List, Any # Constants RISK_COLORS = { "Low": "#7FD8BE", # Soft mint green "Medium": "#FFC857", # Warm amber "High": "#E84855" # Bright red } def highlight_text_with_entities(text: str, entities: List[Dict[str, Any]]) -> str: """ Format text with HTML to highlight extracted entities. Args: text: Original input text entities: List of entity dictionaries with 'start', 'end', and 'text' keys Returns: HTML formatted string with highlighted entities """ if not entities: return text # Sort entities by start position (descending) to avoid index issues when replacing sorted_entities = sorted(entities, key=lambda x: x['start'], reverse=True) result = text for entity in sorted_entities: start = entity['start'] end = entity['end'] highlight = f"{text[start:end]}" result = result[:start] + highlight + result[end:] return result def format_duration(duration_entities: List[Dict[str, Any]]) -> str: """Format duration entities into a readable string.""" if not duration_entities: return "No specific duration mentioned" return ", ".join([entity["text"] for entity in duration_entities]) def create_risk_gauge(risk_level: str, confidence: float) -> go.Figure: """Create a gauge chart for risk level visualization.""" # Map risk levels to numerical values for the gauge risk_value_map = {"Low": 1, "Medium": 2, "High": 3} risk_value = risk_value_map.get(risk_level, 2) # Default to Medium if unknown fig = go.Figure(go.Indicator( mode="gauge+number+delta", value=risk_value, domain={'x': [0, 1], 'y': [0, 1]}, gauge={ 'axis': {'range': [0, 3], 'tickvals': [1, 2, 3], 'ticktext': ['Low', 'Medium', 'High']}, 'bar': {'color': RISK_COLORS[risk_level]}, 'steps': [ {'range': [0, 1.5], 'color': "rgba(127, 216, 190, 0.3)"}, {'range': [1.5, 2.5], 'color': "rgba(255, 200, 87, 0.3)"}, {'range': [2.5, 3], 'color': "rgba(232, 72, 85, 0.3)"} ], 'threshold': { 'line': {'color': "white", 'width': 2}, 'thickness': 0.85, 'value': risk_value } }, number={'valueformat': '.0f', 'font': {'size': 36}}, title={ 'text': f"Risk Level: {risk_level}", 'font': {'size': 24} }, )) fig.update_layout( height=250, margin=dict(l=10, r=10, t=50, b=10), paper_bgcolor='white', font={'color': "#2C363F", 'family': "Arial"} ) return fig def create_risk_probability_chart(probabilities: Dict[str, float]) -> go.Figure: """Create a horizontal bar chart for risk probabilities.""" labels = list(probabilities.keys()) values = list(probabilities.values()) colors = [RISK_COLORS[label] for label in labels] fig = go.Figure(go.Bar( x=values, y=labels, orientation='h', marker_color=colors, text=[f"{v:.1%}" for v in values], textposition='auto' )) fig.update_layout( title="Risk Probability Distribution", xaxis_title="Probability", yaxis_title="Risk Level", height=250, margin=dict(l=10, r=10, t=50, b=10), xaxis=dict(range=[0, 1], tickformat=".0%"), paper_bgcolor='white', plot_bgcolor='white', font={'color': "#2C363F", 'family': "Arial"} ) return fig def save_consultation(consultation_data: Dict[str, Any]): """Save consultation data to a JSON file.""" # Create history directory if it doesn't exist os.makedirs("consultation_history", exist_ok=True) # Generate a filename with timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"consultation_history/consultation_{timestamp}.json" # Add timestamp to data consultation_data["timestamp"] = datetime.now().isoformat() # Save to file with open(filename, "w") as f: json.dump(consultation_data, f, indent=2) return filename def load_consultation_history() -> List[Dict[str, Any]]: """Load all saved consultations from the history directory.""" history_dir = "consultation_history" if not os.path.exists(history_dir): return [] history = [] for filename in os.listdir(history_dir): if filename.endswith(".json"): try: with open(os.path.join(history_dir, filename), "r") as f: consultation = json.load(f) history.append(consultation) except Exception as e: st.error(f"Error loading {filename}: {str(e)}") # Sort by timestamp (newest first) history.sort(key=lambda x: x.get("timestamp", ""), reverse=True) return history def init_session_state(): """Initialize session state variables.""" if "consultation_history" not in st.session_state: st.session_state.consultation_history = [] if "current_result" not in st.session_state: st.session_state.current_result = None if "is_processing" not in st.session_state: st.session_state.is_processing = False if "loaded_models" not in st.session_state: st.session_state.loaded_models = False