finalProject / utils.py
BillyZ1129's picture
Upload 7 files
79bcb1b verified
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from datetime import datetime
import json
import os
from typing import Dict, List, Any
# Constants
RISK_COLORS = {
"Low": "#7FD8BE", # Soft mint green
"Medium": "#FFC857", # Warm amber
"High": "#E84855" # Bright red
}
def highlight_text_with_entities(text: str, entities: List[Dict[str, Any]]) -> str:
"""
Format text with HTML to highlight extracted entities.
Args:
text: Original input text
entities: List of entity dictionaries with 'start', 'end', and 'text' keys
Returns:
HTML formatted string with highlighted entities
"""
if not entities:
return text
# Sort entities by start position (descending) to avoid index issues when replacing
sorted_entities = sorted(entities, key=lambda x: x['start'], reverse=True)
result = text
for entity in sorted_entities:
start = entity['start']
end = entity['end']
highlight = f"<span style='background-color: rgba(255, 200, 87, 0.3); border-radius: 3px; padding: 0px 3px;'>{text[start:end]}</span>"
result = result[:start] + highlight + result[end:]
return result
def format_duration(duration_entities: List[Dict[str, Any]]) -> str:
"""Format duration entities into a readable string."""
if not duration_entities:
return "No specific duration mentioned"
return ", ".join([entity["text"] for entity in duration_entities])
def create_risk_gauge(risk_level: str, confidence: float) -> go.Figure:
"""Create a gauge chart for risk level visualization."""
# Map risk levels to numerical values for the gauge
risk_value_map = {"Low": 1, "Medium": 2, "High": 3}
risk_value = risk_value_map.get(risk_level, 2) # Default to Medium if unknown
fig = go.Figure(go.Indicator(
mode="gauge+number+delta",
value=risk_value,
domain={'x': [0, 1], 'y': [0, 1]},
gauge={
'axis': {'range': [0, 3], 'tickvals': [1, 2, 3], 'ticktext': ['Low', 'Medium', 'High']},
'bar': {'color': RISK_COLORS[risk_level]},
'steps': [
{'range': [0, 1.5], 'color': "rgba(127, 216, 190, 0.3)"},
{'range': [1.5, 2.5], 'color': "rgba(255, 200, 87, 0.3)"},
{'range': [2.5, 3], 'color': "rgba(232, 72, 85, 0.3)"}
],
'threshold': {
'line': {'color': "white", 'width': 2},
'thickness': 0.85,
'value': risk_value
}
},
number={'valueformat': '.0f', 'font': {'size': 36}},
title={
'text': f"Risk Level: {risk_level}",
'font': {'size': 24}
},
))
fig.update_layout(
height=250,
margin=dict(l=10, r=10, t=50, b=10),
paper_bgcolor='white',
font={'color': "#2C363F", 'family': "Arial"}
)
return fig
def create_risk_probability_chart(probabilities: Dict[str, float]) -> go.Figure:
"""Create a horizontal bar chart for risk probabilities."""
labels = list(probabilities.keys())
values = list(probabilities.values())
colors = [RISK_COLORS[label] for label in labels]
fig = go.Figure(go.Bar(
x=values,
y=labels,
orientation='h',
marker_color=colors,
text=[f"{v:.1%}" for v in values],
textposition='auto'
))
fig.update_layout(
title="Risk Probability Distribution",
xaxis_title="Probability",
yaxis_title="Risk Level",
height=250,
margin=dict(l=10, r=10, t=50, b=10),
xaxis=dict(range=[0, 1], tickformat=".0%"),
paper_bgcolor='white',
plot_bgcolor='white',
font={'color': "#2C363F", 'family': "Arial"}
)
return fig
def save_consultation(consultation_data: Dict[str, Any]):
"""Save consultation data to a JSON file."""
# Create history directory if it doesn't exist
os.makedirs("consultation_history", exist_ok=True)
# Generate a filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"consultation_history/consultation_{timestamp}.json"
# Add timestamp to data
consultation_data["timestamp"] = datetime.now().isoformat()
# Save to file
with open(filename, "w") as f:
json.dump(consultation_data, f, indent=2)
return filename
def load_consultation_history() -> List[Dict[str, Any]]:
"""Load all saved consultations from the history directory."""
history_dir = "consultation_history"
if not os.path.exists(history_dir):
return []
history = []
for filename in os.listdir(history_dir):
if filename.endswith(".json"):
try:
with open(os.path.join(history_dir, filename), "r") as f:
consultation = json.load(f)
history.append(consultation)
except Exception as e:
st.error(f"Error loading {filename}: {str(e)}")
# Sort by timestamp (newest first)
history.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
return history
def init_session_state():
"""Initialize session state variables."""
if "consultation_history" not in st.session_state:
st.session_state.consultation_history = []
if "current_result" not in st.session_state:
st.session_state.current_result = None
if "is_processing" not in st.session_state:
st.session_state.is_processing = False
if "loaded_models" not in st.session_state:
st.session_state.loaded_models = False