Spaces:
Running
Running
update mmlu chart
Browse files- data_handler.py +36 -9
data_handler.py
CHANGED
@@ -82,13 +82,25 @@ def unified_exam_chart(unified_exam_df, plot_column):
|
|
82 |
|
83 |
def mmlu_chart(mmlu_df, plot_column):
|
84 |
df = mmlu_df.copy()
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
86 |
df['Average'] = df[subject_cols].mean(axis=1)
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
x=x_col,
|
93 |
y='Model',
|
94 |
color=x_col,
|
@@ -96,15 +108,30 @@ def mmlu_chart(mmlu_df, plot_column):
|
|
96 |
labels={x_col: 'Accuracy', 'Model': 'Model'},
|
97 |
title=title,
|
98 |
orientation='h',
|
99 |
-
range_color=[0,1]
|
100 |
)
|
101 |
|
102 |
fig.update_layout(
|
|
|
|
|
|
|
|
|
103 |
xaxis=dict(range=[0, x_range_max]),
|
104 |
title=dict(text=title, font=dict(size=16)),
|
105 |
xaxis_title=dict(font=dict(size=12)),
|
106 |
yaxis_title=dict(font=dict(size=12)),
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
)
|
|
|
|
|
|
|
110 |
return fig
|
|
|
|
82 |
|
83 |
def mmlu_chart(mmlu_df, plot_column):
|
84 |
df = mmlu_df.copy()
|
85 |
+
|
86 |
+
subject_cols = [
|
87 |
+
'Biology', 'Business', 'Chemistry', 'Computer Science', 'Economics',
|
88 |
+
'Engineering', 'Health', 'History', 'Law', 'Math', 'Other',
|
89 |
+
'Philosophy', 'Physics', 'Psychology'
|
90 |
+
]
|
91 |
df['Average'] = df[subject_cols].mean(axis=1)
|
92 |
+
|
93 |
+
df = df.sort_values(by=[plot_column, 'Model'],
|
94 |
+
ascending=[False, True]
|
95 |
+
).reset_index(drop=True)
|
96 |
+
|
97 |
+
x_col = plot_column
|
98 |
+
title = f'{plot_column}'
|
99 |
+
x_range_max = 1.0
|
100 |
+
bar_height_px = 28
|
101 |
+
|
102 |
+
fig = px.bar(
|
103 |
+
df,
|
104 |
x=x_col,
|
105 |
y='Model',
|
106 |
color=x_col,
|
|
|
108 |
labels={x_col: 'Accuracy', 'Model': 'Model'},
|
109 |
title=title,
|
110 |
orientation='h',
|
111 |
+
range_color=[0, 1]
|
112 |
)
|
113 |
|
114 |
fig.update_layout(
|
115 |
+
height=bar_height_px * len(df) + 120,
|
116 |
+
margin=dict(l=220, r=40, t=60, b=40),
|
117 |
+
width=1000,
|
118 |
+
|
119 |
xaxis=dict(range=[0, x_range_max]),
|
120 |
title=dict(text=title, font=dict(size=16)),
|
121 |
xaxis_title=dict(font=dict(size=12)),
|
122 |
yaxis_title=dict(font=dict(size=12)),
|
123 |
+
|
124 |
+
yaxis=dict(
|
125 |
+
automargin=True,
|
126 |
+
tickmode='array',
|
127 |
+
tickvals=df['Model'],
|
128 |
+
ticktext=df['Model'],
|
129 |
+
dtick=1,
|
130 |
+
autorange='reversed'
|
131 |
+
)
|
132 |
)
|
133 |
+
|
134 |
+
fig.update_yaxes(tickfont=dict(size=10))
|
135 |
+
|
136 |
return fig
|
137 |
+
|