bumchik2 commited on
Commit
b1eb861
·
1 Parent(s): 97c6f33
Files changed (2) hide show
  1. app.py +1 -1
  2. notebooks/distilroberta_base_main.ipynb +49 -17
app.py CHANGED
@@ -59,7 +59,7 @@ def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]
59
  current_index += 1
60
  index_to_category = {value: key for key, value in category_to_index.items()}
61
 
62
- text = f'{title} $ {summary}'
63
  category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits
64
  sigmoid = torch.nn.Sigmoid()
65
  category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()
 
59
  current_index += 1
60
  index_to_category = {value: key for key, value in category_to_index.items()}
61
 
62
+ text = f'{title} $ {summary or ""}'
63
  category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits
64
  sigmoid = torch.nn.Sigmoid()
65
  category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()
notebooks/distilroberta_base_main.ipynb CHANGED
@@ -59,7 +59,7 @@
59
  },
60
  {
61
  "cell_type": "code",
62
- "execution_count": 2,
63
  "metadata": {},
64
  "outputs": [],
65
  "source": [
@@ -68,7 +68,7 @@
68
  },
69
  {
70
  "cell_type": "code",
71
- "execution_count": 3,
72
  "metadata": {},
73
  "outputs": [],
74
  "source": [
@@ -91,7 +91,7 @@
91
  },
92
  {
93
  "cell_type": "code",
94
- "execution_count": 4,
95
  "metadata": {},
96
  "outputs": [],
97
  "source": [
@@ -100,7 +100,7 @@
100
  },
101
  {
102
  "cell_type": "code",
103
- "execution_count": 5,
104
  "metadata": {},
105
  "outputs": [
106
  {
@@ -117,7 +117,7 @@
117
  " 'year': 2018}"
118
  ]
119
  },
120
- "execution_count": 5,
121
  "metadata": {},
122
  "output_type": "execute_result"
123
  }
@@ -135,7 +135,7 @@
135
  },
136
  {
137
  "cell_type": "code",
138
- "execution_count": 6,
139
  "metadata": {},
140
  "outputs": [
141
  {
@@ -215,7 +215,7 @@
215
  "4 cs.CG Computational Geometry Computer Science"
216
  ]
217
  },
218
- "execution_count": 6,
219
  "metadata": {},
220
  "output_type": "execute_result"
221
  }
@@ -228,7 +228,7 @@
228
  },
229
  {
230
  "cell_type": "code",
231
- "execution_count": 7,
232
  "metadata": {},
233
  "outputs": [],
234
  "source": [
@@ -523,7 +523,7 @@
523
  },
524
  {
525
  "cell_type": "code",
526
- "execution_count": 8,
527
  "metadata": {},
528
  "outputs": [],
529
  "source": [
@@ -988,13 +988,13 @@
988
  },
989
  {
990
  "cell_type": "code",
991
- "execution_count": 11,
992
  "metadata": {},
993
  "outputs": [],
994
  "source": [
995
  "@torch.no_grad\n",
996
  "def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:\n",
997
- " text = f'{title} $ {summary}'\n",
998
  " category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits\n",
999
  " sigmoid = torch.nn.Sigmoid()\n",
1000
  " category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()\n",
@@ -1007,7 +1007,7 @@
1007
  },
1008
  {
1009
  "cell_type": "code",
1010
- "execution_count": 12,
1011
  "metadata": {},
1012
  "outputs": [],
1013
  "source": [
@@ -1071,7 +1071,7 @@
1071
  },
1072
  {
1073
  "cell_type": "code",
1074
- "execution_count": 14,
1075
  "metadata": {},
1076
  "outputs": [],
1077
  "source": [
@@ -1086,7 +1086,7 @@
1086
  },
1087
  {
1088
  "cell_type": "code",
1089
- "execution_count": 15,
1090
  "metadata": {},
1091
  "outputs": [
1092
  {
@@ -1098,7 +1098,7 @@
1098
  " 'Physics (0.07676041126251221)']"
1099
  ]
1100
  },
1101
- "execution_count": 15,
1102
  "metadata": {},
1103
  "output_type": "execute_result"
1104
  }
@@ -1118,7 +1118,7 @@
1118
  },
1119
  {
1120
  "cell_type": "code",
1121
- "execution_count": 16,
1122
  "metadata": {},
1123
  "outputs": [
1124
  {
@@ -1130,7 +1130,7 @@
1130
  " 'Statistics (0.02984526939690113)']"
1131
  ]
1132
  },
1133
- "execution_count": 16,
1134
  "metadata": {},
1135
  "output_type": "execute_result"
1136
  }
@@ -1148,6 +1148,38 @@
1148
  ")"
1149
  ]
1150
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1151
  {
1152
  "cell_type": "code",
1153
  "execution_count": null,
 
59
  },
60
  {
61
  "cell_type": "code",
62
+ "execution_count": 3,
63
  "metadata": {},
64
  "outputs": [],
65
  "source": [
 
68
  },
69
  {
70
  "cell_type": "code",
71
+ "execution_count": 12,
72
  "metadata": {},
73
  "outputs": [],
74
  "source": [
 
91
  },
92
  {
93
  "cell_type": "code",
94
+ "execution_count": 14,
95
  "metadata": {},
96
  "outputs": [],
97
  "source": [
 
100
  },
101
  {
102
  "cell_type": "code",
103
+ "execution_count": 15,
104
  "metadata": {},
105
  "outputs": [
106
  {
 
117
  " 'year': 2018}"
118
  ]
119
  },
120
+ "execution_count": 15,
121
  "metadata": {},
122
  "output_type": "execute_result"
123
  }
 
135
  },
136
  {
137
  "cell_type": "code",
138
+ "execution_count": 16,
139
  "metadata": {},
140
  "outputs": [
141
  {
 
215
  "4 cs.CG Computational Geometry Computer Science"
216
  ]
217
  },
218
+ "execution_count": 16,
219
  "metadata": {},
220
  "output_type": "execute_result"
221
  }
 
228
  },
229
  {
230
  "cell_type": "code",
231
+ "execution_count": 17,
232
  "metadata": {},
233
  "outputs": [],
234
  "source": [
 
523
  },
524
  {
525
  "cell_type": "code",
526
+ "execution_count": 18,
527
  "metadata": {},
528
  "outputs": [],
529
  "source": [
 
988
  },
989
  {
990
  "cell_type": "code",
991
+ "execution_count": 22,
992
  "metadata": {},
993
  "outputs": [],
994
  "source": [
995
  "@torch.no_grad\n",
996
  "def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:\n",
997
+ " text = f'{title} $ {summary or \"\"}'\n",
998
  " category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits\n",
999
  " sigmoid = torch.nn.Sigmoid()\n",
1000
  " category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()\n",
 
1007
  },
1008
  {
1009
  "cell_type": "code",
1010
+ "execution_count": 21,
1011
  "metadata": {},
1012
  "outputs": [],
1013
  "source": [
 
1071
  },
1072
  {
1073
  "cell_type": "code",
1074
+ "execution_count": 19,
1075
  "metadata": {},
1076
  "outputs": [],
1077
  "source": [
 
1086
  },
1087
  {
1088
  "cell_type": "code",
1089
+ "execution_count": 23,
1090
  "metadata": {},
1091
  "outputs": [
1092
  {
 
1098
  " 'Physics (0.07676041126251221)']"
1099
  ]
1100
  },
1101
+ "execution_count": 23,
1102
  "metadata": {},
1103
  "output_type": "execute_result"
1104
  }
 
1118
  },
1119
  {
1120
  "cell_type": "code",
1121
+ "execution_count": 24,
1122
  "metadata": {},
1123
  "outputs": [
1124
  {
 
1130
  " 'Statistics (0.02984526939690113)']"
1131
  ]
1132
  },
1133
+ "execution_count": 24,
1134
  "metadata": {},
1135
  "output_type": "execute_result"
1136
  }
 
1148
  ")"
1149
  ]
1150
  },
1151
+ {
1152
+ "cell_type": "code",
1153
+ "execution_count": null,
1154
+ "metadata": {},
1155
+ "outputs": [
1156
+ {
1157
+ "data": {
1158
+ "text/plain": [
1159
+ "['Quantitative Biology (0.45450547337532043)',\n",
1160
+ " 'Computer Science (0.3519783318042755)',\n",
1161
+ " 'Physics (0.07536326348781586)',\n",
1162
+ " 'Statistics (0.06953499466180801)']"
1163
+ ]
1164
+ },
1165
+ "execution_count": 26,
1166
+ "metadata": {},
1167
+ "output_type": "execute_result"
1168
+ }
1169
+ ],
1170
+ "source": [
1171
+ "# правильный ответ Quantitative Biology\n",
1172
+ "get_most_probable_keys(\n",
1173
+ " probs_dict=get_category_probs_dict(\n",
1174
+ " model=model,\n",
1175
+ " title='Simulating cell populations with explicit cell cycle length -- implications to cell cycle dependent tumour therapy',\n",
1176
+ " summary=''\n",
1177
+ " ),\n",
1178
+ " target_probability=0.95,\n",
1179
+ " print_probabilities=True\n",
1180
+ ")"
1181
+ ]
1182
+ },
1183
  {
1184
  "cell_type": "code",
1185
  "execution_count": null,