Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from datetime import datetime | |
import json | |
import os | |
from typing import Dict, List, Any | |
# Constants | |
RISK_COLORS = { | |
"Low": "#7FD8BE", # Soft mint green | |
"Medium": "#FFC857", # Warm amber | |
"High": "#E84855" # Bright red | |
} | |
def highlight_text_with_entities(text: str, entities: List[Dict[str, Any]]) -> str: | |
""" | |
Format text with HTML to highlight extracted entities. | |
Args: | |
text: Original input text | |
entities: List of entity dictionaries with 'start', 'end', and 'text' keys | |
Returns: | |
HTML formatted string with highlighted entities | |
""" | |
if not entities: | |
return text | |
# Sort entities by start position (descending) to avoid index issues when replacing | |
sorted_entities = sorted(entities, key=lambda x: x['start'], reverse=True) | |
result = text | |
for entity in sorted_entities: | |
start = entity['start'] | |
end = entity['end'] | |
highlight = f"<span style='background-color: rgba(255, 200, 87, 0.3); border-radius: 3px; padding: 0px 3px;'>{text[start:end]}</span>" | |
result = result[:start] + highlight + result[end:] | |
return result | |
def format_duration(duration_entities: List[Dict[str, Any]]) -> str: | |
"""Format duration entities into a readable string.""" | |
if not duration_entities: | |
return "No specific duration mentioned" | |
return ", ".join([entity["text"] for entity in duration_entities]) | |
def create_risk_gauge(risk_level: str, confidence: float) -> go.Figure: | |
"""Create a gauge chart for risk level visualization.""" | |
# Map risk levels to numerical values for the gauge | |
risk_value_map = {"Low": 1, "Medium": 2, "High": 3} | |
risk_value = risk_value_map.get(risk_level, 2) # Default to Medium if unknown | |
fig = go.Figure(go.Indicator( | |
mode="gauge+number+delta", | |
value=risk_value, | |
domain={'x': [0, 1], 'y': [0, 1]}, | |
gauge={ | |
'axis': {'range': [0, 3], 'tickvals': [1, 2, 3], 'ticktext': ['Low', 'Medium', 'High']}, | |
'bar': {'color': RISK_COLORS[risk_level]}, | |
'steps': [ | |
{'range': [0, 1.5], 'color': "rgba(127, 216, 190, 0.3)"}, | |
{'range': [1.5, 2.5], 'color': "rgba(255, 200, 87, 0.3)"}, | |
{'range': [2.5, 3], 'color': "rgba(232, 72, 85, 0.3)"} | |
], | |
'threshold': { | |
'line': {'color': "white", 'width': 2}, | |
'thickness': 0.85, | |
'value': risk_value | |
} | |
}, | |
number={'valueformat': '.0f', 'font': {'size': 36}}, | |
title={ | |
'text': f"Risk Level: {risk_level}", | |
'font': {'size': 24} | |
}, | |
)) | |
fig.update_layout( | |
height=250, | |
margin=dict(l=10, r=10, t=50, b=10), | |
paper_bgcolor='white', | |
font={'color': "#2C363F", 'family': "Arial"} | |
) | |
return fig | |
def create_risk_probability_chart(probabilities: Dict[str, float]) -> go.Figure: | |
"""Create a horizontal bar chart for risk probabilities.""" | |
labels = list(probabilities.keys()) | |
values = list(probabilities.values()) | |
colors = [RISK_COLORS[label] for label in labels] | |
fig = go.Figure(go.Bar( | |
x=values, | |
y=labels, | |
orientation='h', | |
marker_color=colors, | |
text=[f"{v:.1%}" for v in values], | |
textposition='auto' | |
)) | |
fig.update_layout( | |
title="Risk Probability Distribution", | |
xaxis_title="Probability", | |
yaxis_title="Risk Level", | |
height=250, | |
margin=dict(l=10, r=10, t=50, b=10), | |
xaxis=dict(range=[0, 1], tickformat=".0%"), | |
paper_bgcolor='white', | |
plot_bgcolor='white', | |
font={'color': "#2C363F", 'family': "Arial"} | |
) | |
return fig | |
def save_consultation(consultation_data: Dict[str, Any]): | |
"""Save consultation data to a JSON file.""" | |
# Create history directory if it doesn't exist | |
os.makedirs("consultation_history", exist_ok=True) | |
# Generate a filename with timestamp | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"consultation_history/consultation_{timestamp}.json" | |
# Add timestamp to data | |
consultation_data["timestamp"] = datetime.now().isoformat() | |
# Save to file | |
with open(filename, "w") as f: | |
json.dump(consultation_data, f, indent=2) | |
return filename | |
def load_consultation_history() -> List[Dict[str, Any]]: | |
"""Load all saved consultations from the history directory.""" | |
history_dir = "consultation_history" | |
if not os.path.exists(history_dir): | |
return [] | |
history = [] | |
for filename in os.listdir(history_dir): | |
if filename.endswith(".json"): | |
try: | |
with open(os.path.join(history_dir, filename), "r") as f: | |
consultation = json.load(f) | |
history.append(consultation) | |
except Exception as e: | |
st.error(f"Error loading {filename}: {str(e)}") | |
# Sort by timestamp (newest first) | |
history.sort(key=lambda x: x.get("timestamp", ""), reverse=True) | |
return history | |
def init_session_state(): | |
"""Initialize session state variables.""" | |
if "consultation_history" not in st.session_state: | |
st.session_state.consultation_history = [] | |
if "current_result" not in st.session_state: | |
st.session_state.current_result = None | |
if "is_processing" not in st.session_state: | |
st.session_state.is_processing = False | |
if "loaded_models" not in st.session_state: | |
st.session_state.loaded_models = False |