Spaces:
Runtime error
Runtime error
import time | |
import plotly.graph_objects as go | |
from datetime import datetime, timedelta | |
SAMPLING_RATE = 16_000 | |
COLOR_MAP = { | |
"Neutralità": "rgb(178, 178, 178)", | |
"Rabbia": "rgb(160, 61, 62)", | |
"Paura": "rgb(91, 57, 136)", | |
"Gioia": "rgb(255, 255, 0)", | |
"Sorpresa": "rgb(60, 175, 175)", | |
"Tristezza": "rgb(64, 106, 173)", | |
"Disgusto": "rgb(100, 153, 65)", | |
} | |
def create_behaviour_gantt_plot(behaviour_chunks, confidence_threshold=60): | |
print("Creating behaviour Gantt plot...") | |
emotion_order = [ | |
"Gioia", | |
"Sorpresa", | |
"Disgusto", | |
"Tristezza", | |
"Paura", | |
"Rabbia", | |
"Neutralità" | |
] | |
fig = go.Figure() | |
chunk_starts = [start/SAMPLING_RATE for start, _, _, _, _ in behaviour_chunks] | |
chunk_ends = [end/SAMPLING_RATE for _, end, _, _, _ in behaviour_chunks] | |
# Create reference time for plotting (starting at 0) | |
# We'll use a base datetime and add seconds | |
base_time = datetime(2_000, 1, 1, 0, 0, 0) # TODO: change magic numbers | |
start_times = [base_time + timedelta(seconds=t) for t in chunk_starts] | |
end_times = [base_time + timedelta(seconds=t) for t in chunk_ends] | |
# Calculate midpoints for each chunk (for trend line) | |
mid_times = [base_time + timedelta(seconds=(s+e)/2) for s, e in zip(chunk_starts, chunk_ends)] | |
heights = [height * 100 for _, _, _, height, _ in behaviour_chunks] | |
emotions = [emotion for _, _, _, _, emotion in behaviour_chunks] | |
hover_texts = [] | |
for i, (start, end, label, height, emotion) in enumerate(behaviour_chunks): | |
start_fmt = time.strftime('%H:%M:%S', time.gmtime(start / SAMPLING_RATE)) | |
end_fmt = time.strftime('%H:%M:%S', time.gmtime(end / SAMPLING_RATE)) | |
duration_seconds = (end - start) / SAMPLING_RATE | |
duration_str = time.strftime('%H:%M:%S', time.gmtime(duration_seconds)) | |
hover_text = f"Inizio: {start_fmt}<br>Fine: {end_fmt}<br>Durata: {duration_str}<br>Testo: {label}<br>Attendibilità: {height*100:.2f}%<br>Emozione: {emotion}" | |
hover_texts.append(hover_text) | |
fig.add_shape( | |
type="rect", | |
x0=start_times[0], | |
x1=end_times[-1], | |
y0=confidence_threshold, | |
y1=100, | |
fillcolor="rgba(188,223,241,0.8)", | |
opacity=0.8, | |
layer="below", | |
line_width=0, | |
) | |
fig.add_hline(y=confidence_threshold, line_dash="dash", line_color="black", line_width=1) | |
fig.add_trace( | |
go.Scatter( | |
x=mid_times, | |
y=heights, | |
mode='lines', | |
name='Disregolazione', | |
line=dict( | |
color='orange', | |
width=2, | |
shape='spline', # This enables smoothing | |
smoothing=1.0, # Adjust smoothing factor | |
), | |
text=hover_texts, | |
hoverinfo='text', | |
showlegend=False, | |
) | |
) | |
emotion_data = {} | |
for i, height in enumerate(heights): | |
if height >= confidence_threshold: | |
emotion = emotions[i] | |
if emotion not in emotion_data: | |
emotion_data[emotion] = { | |
'times': [], | |
'heights': [], | |
'hover_texts': [] | |
} | |
emotion_data[emotion]['times'].append(mid_times[i]) | |
emotion_data[emotion]['heights'].append(height) | |
emotion_data[emotion]['hover_texts'].append(hover_texts[i]) | |
for emotion in emotion_order: | |
color = COLOR_MAP.get(emotion, '#000000') | |
if emotion in emotion_data: | |
data = emotion_data[emotion] | |
fig.add_trace( | |
go.Scatter( | |
x=data['times'], | |
y=data['heights'], | |
mode='markers', | |
name=emotion.capitalize(), | |
marker=dict( | |
size=15, | |
color=color, | |
symbol='circle' | |
), | |
text=data['hover_texts'], | |
hoverinfo='text', | |
showlegend=True, | |
) | |
) | |
else: | |
fig.add_trace( | |
go.Scatter( | |
x=[None], | |
y=[None], | |
mode='markers', | |
name=emotion.capitalize(), | |
marker=dict( | |
size=15, | |
color=color, | |
symbol='circle' | |
), | |
showlegend=True, | |
) | |
) | |
fig.update_layout( | |
title='Distribuzione della disregolazione', | |
xaxis_title='Tempo', | |
yaxis_title='Attendibilità', | |
xaxis=dict( | |
type='date', | |
tickformat='%H:%M:%S', | |
showline=True, | |
zeroline=False, | |
side='bottom', | |
showgrid=False, | |
), | |
yaxis=dict( | |
range=[0, 100], | |
tickvals=[0, 20, 40, 60, 80, 100], | |
ticktext=['0%', '20%', '40%', '60%', '80%', '100%'], | |
tickmode='array', | |
showgrid=False, | |
), | |
legend_title=None, | |
legend=dict( | |
yanchor="top" | |
), | |
hoverlabel=dict( | |
font_size=12, | |
font_family="Arial" | |
), | |
paper_bgcolor='white', | |
plot_bgcolor='white', | |
) | |
fig.update_traces(hovertemplate=None) | |
return fig |