Bagratuni commited on
Commit
33a913f
·
1 Parent(s): 001dc3f

update mmlu chart

Browse files
Files changed (1) hide show
  1. data_handler.py +36 -9
data_handler.py CHANGED
@@ -82,13 +82,25 @@ def unified_exam_chart(unified_exam_df, plot_column):
82
 
83
  def mmlu_chart(mmlu_df, plot_column):
84
  df = mmlu_df.copy()
85
- subject_cols = ['Biology', 'Business', 'Chemistry', 'Computer Science', 'Economics', 'Engineering', 'Health', 'History', 'Law', 'Math', 'Other', 'Philosophy', 'Physics', 'Psychology']
 
 
 
 
 
86
  df['Average'] = df[subject_cols].mean(axis=1)
87
- df = df.sort_values(by=plot_column, ascending=False).reset_index(drop=True)
88
- x_col = plot_column
89
- title = f'{plot_column}'
90
- x_range_max = 1.0
91
- fig = px.bar(df,
 
 
 
 
 
 
 
92
  x=x_col,
93
  y='Model',
94
  color=x_col,
@@ -96,15 +108,30 @@ def mmlu_chart(mmlu_df, plot_column):
96
  labels={x_col: 'Accuracy', 'Model': 'Model'},
97
  title=title,
98
  orientation='h',
99
- range_color=[0,1]
100
  )
101
 
102
  fig.update_layout(
 
 
 
 
103
  xaxis=dict(range=[0, x_range_max]),
104
  title=dict(text=title, font=dict(size=16)),
105
  xaxis_title=dict(font=dict(size=12)),
106
  yaxis_title=dict(font=dict(size=12)),
107
- yaxis=dict(autorange="reversed"),
108
- width=1000
 
 
 
 
 
 
 
109
  )
 
 
 
110
  return fig
 
 
82
 
83
  def mmlu_chart(mmlu_df, plot_column):
84
  df = mmlu_df.copy()
85
+
86
+ subject_cols = [
87
+ 'Biology', 'Business', 'Chemistry', 'Computer Science', 'Economics',
88
+ 'Engineering', 'Health', 'History', 'Law', 'Math', 'Other',
89
+ 'Philosophy', 'Physics', 'Psychology'
90
+ ]
91
  df['Average'] = df[subject_cols].mean(axis=1)
92
+
93
+ df = df.sort_values(by=[plot_column, 'Model'],
94
+ ascending=[False, True]
95
+ ).reset_index(drop=True)
96
+
97
+ x_col = plot_column
98
+ title = f'{plot_column}'
99
+ x_range_max = 1.0
100
+ bar_height_px = 28
101
+
102
+ fig = px.bar(
103
+ df,
104
  x=x_col,
105
  y='Model',
106
  color=x_col,
 
108
  labels={x_col: 'Accuracy', 'Model': 'Model'},
109
  title=title,
110
  orientation='h',
111
+ range_color=[0, 1]
112
  )
113
 
114
  fig.update_layout(
115
+ height=bar_height_px * len(df) + 120,
116
+ margin=dict(l=220, r=40, t=60, b=40),
117
+ width=1000,
118
+
119
  xaxis=dict(range=[0, x_range_max]),
120
  title=dict(text=title, font=dict(size=16)),
121
  xaxis_title=dict(font=dict(size=12)),
122
  yaxis_title=dict(font=dict(size=12)),
123
+
124
+ yaxis=dict(
125
+ automargin=True,
126
+ tickmode='array',
127
+ tickvals=df['Model'],
128
+ ticktext=df['Model'],
129
+ dtick=1,
130
+ autorange='reversed'
131
+ )
132
  )
133
+
134
+ fig.update_yaxes(tickfont=dict(size=10))
135
+
136
  return fig
137
+