File size: 5,919 Bytes
79bcb1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from datetime import datetime
import json
import os
from typing import Dict, List, Any

# Constants
RISK_COLORS = {
    "Low": "#7FD8BE",     # Soft mint green
    "Medium": "#FFC857",  # Warm amber
    "High": "#E84855"     # Bright red
}

def highlight_text_with_entities(text: str, entities: List[Dict[str, Any]]) -> str:
    """

    Format text with HTML to highlight extracted entities.

    

    Args:

        text: Original input text

        entities: List of entity dictionaries with 'start', 'end', and 'text' keys

    

    Returns:

        HTML formatted string with highlighted entities

    """
    if not entities:
        return text
    
    # Sort entities by start position (descending) to avoid index issues when replacing
    sorted_entities = sorted(entities, key=lambda x: x['start'], reverse=True)
    
    result = text
    for entity in sorted_entities:
        start = entity['start']
        end = entity['end']
        highlight = f"<span style='background-color: rgba(255, 200, 87, 0.3); border-radius: 3px; padding: 0px 3px;'>{text[start:end]}</span>"
        result = result[:start] + highlight + result[end:]
    
    return result

def format_duration(duration_entities: List[Dict[str, Any]]) -> str:
    """Format duration entities into a readable string."""
    if not duration_entities:
        return "No specific duration mentioned"
    
    return ", ".join([entity["text"] for entity in duration_entities])

def create_risk_gauge(risk_level: str, confidence: float) -> go.Figure:
    """Create a gauge chart for risk level visualization."""
    
    # Map risk levels to numerical values for the gauge
    risk_value_map = {"Low": 1, "Medium": 2, "High": 3}
    risk_value = risk_value_map.get(risk_level, 2)  # Default to Medium if unknown
    
    fig = go.Figure(go.Indicator(
        mode="gauge+number+delta",
        value=risk_value,
        domain={'x': [0, 1], 'y': [0, 1]},
        gauge={
            'axis': {'range': [0, 3], 'tickvals': [1, 2, 3], 'ticktext': ['Low', 'Medium', 'High']},
            'bar': {'color': RISK_COLORS[risk_level]},
            'steps': [
                {'range': [0, 1.5], 'color': "rgba(127, 216, 190, 0.3)"},
                {'range': [1.5, 2.5], 'color': "rgba(255, 200, 87, 0.3)"},
                {'range': [2.5, 3], 'color': "rgba(232, 72, 85, 0.3)"}
            ],
            'threshold': {
                'line': {'color': "white", 'width': 2},
                'thickness': 0.85,
                'value': risk_value
            }
        },
        number={'valueformat': '.0f', 'font': {'size': 36}},
        title={
            'text': f"Risk Level: {risk_level}",
            'font': {'size': 24}
        },
    ))
    
    fig.update_layout(
        height=250,
        margin=dict(l=10, r=10, t=50, b=10),
        paper_bgcolor='white',
        font={'color': "#2C363F", 'family': "Arial"}
    )
    
    return fig

def create_risk_probability_chart(probabilities: Dict[str, float]) -> go.Figure:
    """Create a horizontal bar chart for risk probabilities."""
    labels = list(probabilities.keys())
    values = list(probabilities.values())
    colors = [RISK_COLORS[label] for label in labels]
    
    fig = go.Figure(go.Bar(
        x=values,
        y=labels,
        orientation='h',
        marker_color=colors,
        text=[f"{v:.1%}" for v in values],
        textposition='auto'
    ))
    
    fig.update_layout(
        title="Risk Probability Distribution",
        xaxis_title="Probability",
        yaxis_title="Risk Level",
        height=250,
        margin=dict(l=10, r=10, t=50, b=10),
        xaxis=dict(range=[0, 1], tickformat=".0%"),
        paper_bgcolor='white',
        plot_bgcolor='white',
        font={'color': "#2C363F", 'family': "Arial"}
    )
    
    return fig

def save_consultation(consultation_data: Dict[str, Any]):
    """Save consultation data to a JSON file."""
    # Create history directory if it doesn't exist
    os.makedirs("consultation_history", exist_ok=True)
    
    # Generate a filename with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"consultation_history/consultation_{timestamp}.json"
    
    # Add timestamp to data
    consultation_data["timestamp"] = datetime.now().isoformat()
    
    # Save to file
    with open(filename, "w") as f:
        json.dump(consultation_data, f, indent=2)
    
    return filename

def load_consultation_history() -> List[Dict[str, Any]]:
    """Load all saved consultations from the history directory."""
    history_dir = "consultation_history"
    if not os.path.exists(history_dir):
        return []
    
    history = []
    for filename in os.listdir(history_dir):
        if filename.endswith(".json"):
            try:
                with open(os.path.join(history_dir, filename), "r") as f:
                    consultation = json.load(f)
                    history.append(consultation)
            except Exception as e:
                st.error(f"Error loading {filename}: {str(e)}")
    
    # Sort by timestamp (newest first)
    history.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
    return history

def init_session_state():
    """Initialize session state variables."""
    if "consultation_history" not in st.session_state:
        st.session_state.consultation_history = []
    
    if "current_result" not in st.session_state:
        st.session_state.current_result = None
        
    if "is_processing" not in st.session_state:
        st.session_state.is_processing = False
    
    if "loaded_models" not in st.session_state:
        st.session_state.loaded_models = False