Spaces:
Sleeping
Sleeping
File size: 5,919 Bytes
79bcb1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from datetime import datetime
import json
import os
from typing import Dict, List, Any
# Constants
RISK_COLORS = {
"Low": "#7FD8BE", # Soft mint green
"Medium": "#FFC857", # Warm amber
"High": "#E84855" # Bright red
}
def highlight_text_with_entities(text: str, entities: List[Dict[str, Any]]) -> str:
"""
Format text with HTML to highlight extracted entities.
Args:
text: Original input text
entities: List of entity dictionaries with 'start', 'end', and 'text' keys
Returns:
HTML formatted string with highlighted entities
"""
if not entities:
return text
# Sort entities by start position (descending) to avoid index issues when replacing
sorted_entities = sorted(entities, key=lambda x: x['start'], reverse=True)
result = text
for entity in sorted_entities:
start = entity['start']
end = entity['end']
highlight = f"<span style='background-color: rgba(255, 200, 87, 0.3); border-radius: 3px; padding: 0px 3px;'>{text[start:end]}</span>"
result = result[:start] + highlight + result[end:]
return result
def format_duration(duration_entities: List[Dict[str, Any]]) -> str:
"""Format duration entities into a readable string."""
if not duration_entities:
return "No specific duration mentioned"
return ", ".join([entity["text"] for entity in duration_entities])
def create_risk_gauge(risk_level: str, confidence: float) -> go.Figure:
"""Create a gauge chart for risk level visualization."""
# Map risk levels to numerical values for the gauge
risk_value_map = {"Low": 1, "Medium": 2, "High": 3}
risk_value = risk_value_map.get(risk_level, 2) # Default to Medium if unknown
fig = go.Figure(go.Indicator(
mode="gauge+number+delta",
value=risk_value,
domain={'x': [0, 1], 'y': [0, 1]},
gauge={
'axis': {'range': [0, 3], 'tickvals': [1, 2, 3], 'ticktext': ['Low', 'Medium', 'High']},
'bar': {'color': RISK_COLORS[risk_level]},
'steps': [
{'range': [0, 1.5], 'color': "rgba(127, 216, 190, 0.3)"},
{'range': [1.5, 2.5], 'color': "rgba(255, 200, 87, 0.3)"},
{'range': [2.5, 3], 'color': "rgba(232, 72, 85, 0.3)"}
],
'threshold': {
'line': {'color': "white", 'width': 2},
'thickness': 0.85,
'value': risk_value
}
},
number={'valueformat': '.0f', 'font': {'size': 36}},
title={
'text': f"Risk Level: {risk_level}",
'font': {'size': 24}
},
))
fig.update_layout(
height=250,
margin=dict(l=10, r=10, t=50, b=10),
paper_bgcolor='white',
font={'color': "#2C363F", 'family': "Arial"}
)
return fig
def create_risk_probability_chart(probabilities: Dict[str, float]) -> go.Figure:
"""Create a horizontal bar chart for risk probabilities."""
labels = list(probabilities.keys())
values = list(probabilities.values())
colors = [RISK_COLORS[label] for label in labels]
fig = go.Figure(go.Bar(
x=values,
y=labels,
orientation='h',
marker_color=colors,
text=[f"{v:.1%}" for v in values],
textposition='auto'
))
fig.update_layout(
title="Risk Probability Distribution",
xaxis_title="Probability",
yaxis_title="Risk Level",
height=250,
margin=dict(l=10, r=10, t=50, b=10),
xaxis=dict(range=[0, 1], tickformat=".0%"),
paper_bgcolor='white',
plot_bgcolor='white',
font={'color': "#2C363F", 'family': "Arial"}
)
return fig
def save_consultation(consultation_data: Dict[str, Any]):
"""Save consultation data to a JSON file."""
# Create history directory if it doesn't exist
os.makedirs("consultation_history", exist_ok=True)
# Generate a filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"consultation_history/consultation_{timestamp}.json"
# Add timestamp to data
consultation_data["timestamp"] = datetime.now().isoformat()
# Save to file
with open(filename, "w") as f:
json.dump(consultation_data, f, indent=2)
return filename
def load_consultation_history() -> List[Dict[str, Any]]:
"""Load all saved consultations from the history directory."""
history_dir = "consultation_history"
if not os.path.exists(history_dir):
return []
history = []
for filename in os.listdir(history_dir):
if filename.endswith(".json"):
try:
with open(os.path.join(history_dir, filename), "r") as f:
consultation = json.load(f)
history.append(consultation)
except Exception as e:
st.error(f"Error loading {filename}: {str(e)}")
# Sort by timestamp (newest first)
history.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
return history
def init_session_state():
"""Initialize session state variables."""
if "consultation_history" not in st.session_state:
st.session_state.consultation_history = []
if "current_result" not in st.session_state:
st.session_state.current_result = None
if "is_processing" not in st.session_state:
st.session_state.is_processing = False
if "loaded_models" not in st.session_state:
st.session_state.loaded_models = False |