Spaces:
Running
Running
final
Browse files- app.py +1 -1
- notebooks/distilroberta_base_main.ipynb +49 -17
app.py
CHANGED
@@ -59,7 +59,7 @@ def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]
|
|
59 |
current_index += 1
|
60 |
index_to_category = {value: key for key, value in category_to_index.items()}
|
61 |
|
62 |
-
text = f'{title} $ {summary}'
|
63 |
category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits
|
64 |
sigmoid = torch.nn.Sigmoid()
|
65 |
category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()
|
|
|
59 |
current_index += 1
|
60 |
index_to_category = {value: key for key, value in category_to_index.items()}
|
61 |
|
62 |
+
text = f'{title} $ {summary or ""}'
|
63 |
category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits
|
64 |
sigmoid = torch.nn.Sigmoid()
|
65 |
category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()
|
notebooks/distilroberta_base_main.ipynb
CHANGED
@@ -59,7 +59,7 @@
|
|
59 |
},
|
60 |
{
|
61 |
"cell_type": "code",
|
62 |
-
"execution_count":
|
63 |
"metadata": {},
|
64 |
"outputs": [],
|
65 |
"source": [
|
@@ -68,7 +68,7 @@
|
|
68 |
},
|
69 |
{
|
70 |
"cell_type": "code",
|
71 |
-
"execution_count":
|
72 |
"metadata": {},
|
73 |
"outputs": [],
|
74 |
"source": [
|
@@ -91,7 +91,7 @@
|
|
91 |
},
|
92 |
{
|
93 |
"cell_type": "code",
|
94 |
-
"execution_count":
|
95 |
"metadata": {},
|
96 |
"outputs": [],
|
97 |
"source": [
|
@@ -100,7 +100,7 @@
|
|
100 |
},
|
101 |
{
|
102 |
"cell_type": "code",
|
103 |
-
"execution_count":
|
104 |
"metadata": {},
|
105 |
"outputs": [
|
106 |
{
|
@@ -117,7 +117,7 @@
|
|
117 |
" 'year': 2018}"
|
118 |
]
|
119 |
},
|
120 |
-
"execution_count":
|
121 |
"metadata": {},
|
122 |
"output_type": "execute_result"
|
123 |
}
|
@@ -135,7 +135,7 @@
|
|
135 |
},
|
136 |
{
|
137 |
"cell_type": "code",
|
138 |
-
"execution_count":
|
139 |
"metadata": {},
|
140 |
"outputs": [
|
141 |
{
|
@@ -215,7 +215,7 @@
|
|
215 |
"4 cs.CG Computational Geometry Computer Science"
|
216 |
]
|
217 |
},
|
218 |
-
"execution_count":
|
219 |
"metadata": {},
|
220 |
"output_type": "execute_result"
|
221 |
}
|
@@ -228,7 +228,7 @@
|
|
228 |
},
|
229 |
{
|
230 |
"cell_type": "code",
|
231 |
-
"execution_count":
|
232 |
"metadata": {},
|
233 |
"outputs": [],
|
234 |
"source": [
|
@@ -523,7 +523,7 @@
|
|
523 |
},
|
524 |
{
|
525 |
"cell_type": "code",
|
526 |
-
"execution_count":
|
527 |
"metadata": {},
|
528 |
"outputs": [],
|
529 |
"source": [
|
@@ -988,13 +988,13 @@
|
|
988 |
},
|
989 |
{
|
990 |
"cell_type": "code",
|
991 |
-
"execution_count":
|
992 |
"metadata": {},
|
993 |
"outputs": [],
|
994 |
"source": [
|
995 |
"@torch.no_grad\n",
|
996 |
"def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:\n",
|
997 |
-
" text = f'{title} $ {summary}'\n",
|
998 |
" category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits\n",
|
999 |
" sigmoid = torch.nn.Sigmoid()\n",
|
1000 |
" category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()\n",
|
@@ -1007,7 +1007,7 @@
|
|
1007 |
},
|
1008 |
{
|
1009 |
"cell_type": "code",
|
1010 |
-
"execution_count":
|
1011 |
"metadata": {},
|
1012 |
"outputs": [],
|
1013 |
"source": [
|
@@ -1071,7 +1071,7 @@
|
|
1071 |
},
|
1072 |
{
|
1073 |
"cell_type": "code",
|
1074 |
-
"execution_count":
|
1075 |
"metadata": {},
|
1076 |
"outputs": [],
|
1077 |
"source": [
|
@@ -1086,7 +1086,7 @@
|
|
1086 |
},
|
1087 |
{
|
1088 |
"cell_type": "code",
|
1089 |
-
"execution_count":
|
1090 |
"metadata": {},
|
1091 |
"outputs": [
|
1092 |
{
|
@@ -1098,7 +1098,7 @@
|
|
1098 |
" 'Physics (0.07676041126251221)']"
|
1099 |
]
|
1100 |
},
|
1101 |
-
"execution_count":
|
1102 |
"metadata": {},
|
1103 |
"output_type": "execute_result"
|
1104 |
}
|
@@ -1118,7 +1118,7 @@
|
|
1118 |
},
|
1119 |
{
|
1120 |
"cell_type": "code",
|
1121 |
-
"execution_count":
|
1122 |
"metadata": {},
|
1123 |
"outputs": [
|
1124 |
{
|
@@ -1130,7 +1130,7 @@
|
|
1130 |
" 'Statistics (0.02984526939690113)']"
|
1131 |
]
|
1132 |
},
|
1133 |
-
"execution_count":
|
1134 |
"metadata": {},
|
1135 |
"output_type": "execute_result"
|
1136 |
}
|
@@ -1148,6 +1148,38 @@
|
|
1148 |
")"
|
1149 |
]
|
1150 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1151 |
{
|
1152 |
"cell_type": "code",
|
1153 |
"execution_count": null,
|
|
|
59 |
},
|
60 |
{
|
61 |
"cell_type": "code",
|
62 |
+
"execution_count": 3,
|
63 |
"metadata": {},
|
64 |
"outputs": [],
|
65 |
"source": [
|
|
|
68 |
},
|
69 |
{
|
70 |
"cell_type": "code",
|
71 |
+
"execution_count": 12,
|
72 |
"metadata": {},
|
73 |
"outputs": [],
|
74 |
"source": [
|
|
|
91 |
},
|
92 |
{
|
93 |
"cell_type": "code",
|
94 |
+
"execution_count": 14,
|
95 |
"metadata": {},
|
96 |
"outputs": [],
|
97 |
"source": [
|
|
|
100 |
},
|
101 |
{
|
102 |
"cell_type": "code",
|
103 |
+
"execution_count": 15,
|
104 |
"metadata": {},
|
105 |
"outputs": [
|
106 |
{
|
|
|
117 |
" 'year': 2018}"
|
118 |
]
|
119 |
},
|
120 |
+
"execution_count": 15,
|
121 |
"metadata": {},
|
122 |
"output_type": "execute_result"
|
123 |
}
|
|
|
135 |
},
|
136 |
{
|
137 |
"cell_type": "code",
|
138 |
+
"execution_count": 16,
|
139 |
"metadata": {},
|
140 |
"outputs": [
|
141 |
{
|
|
|
215 |
"4 cs.CG Computational Geometry Computer Science"
|
216 |
]
|
217 |
},
|
218 |
+
"execution_count": 16,
|
219 |
"metadata": {},
|
220 |
"output_type": "execute_result"
|
221 |
}
|
|
|
228 |
},
|
229 |
{
|
230 |
"cell_type": "code",
|
231 |
+
"execution_count": 17,
|
232 |
"metadata": {},
|
233 |
"outputs": [],
|
234 |
"source": [
|
|
|
523 |
},
|
524 |
{
|
525 |
"cell_type": "code",
|
526 |
+
"execution_count": 18,
|
527 |
"metadata": {},
|
528 |
"outputs": [],
|
529 |
"source": [
|
|
|
988 |
},
|
989 |
{
|
990 |
"cell_type": "code",
|
991 |
+
"execution_count": 22,
|
992 |
"metadata": {},
|
993 |
"outputs": [],
|
994 |
"source": [
|
995 |
"@torch.no_grad\n",
|
996 |
"def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:\n",
|
997 |
+
" text = f'{title} $ {summary or \"\"}'\n",
|
998 |
" category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits\n",
|
999 |
" sigmoid = torch.nn.Sigmoid()\n",
|
1000 |
" category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()\n",
|
|
|
1007 |
},
|
1008 |
{
|
1009 |
"cell_type": "code",
|
1010 |
+
"execution_count": 21,
|
1011 |
"metadata": {},
|
1012 |
"outputs": [],
|
1013 |
"source": [
|
|
|
1071 |
},
|
1072 |
{
|
1073 |
"cell_type": "code",
|
1074 |
+
"execution_count": 19,
|
1075 |
"metadata": {},
|
1076 |
"outputs": [],
|
1077 |
"source": [
|
|
|
1086 |
},
|
1087 |
{
|
1088 |
"cell_type": "code",
|
1089 |
+
"execution_count": 23,
|
1090 |
"metadata": {},
|
1091 |
"outputs": [
|
1092 |
{
|
|
|
1098 |
" 'Physics (0.07676041126251221)']"
|
1099 |
]
|
1100 |
},
|
1101 |
+
"execution_count": 23,
|
1102 |
"metadata": {},
|
1103 |
"output_type": "execute_result"
|
1104 |
}
|
|
|
1118 |
},
|
1119 |
{
|
1120 |
"cell_type": "code",
|
1121 |
+
"execution_count": 24,
|
1122 |
"metadata": {},
|
1123 |
"outputs": [
|
1124 |
{
|
|
|
1130 |
" 'Statistics (0.02984526939690113)']"
|
1131 |
]
|
1132 |
},
|
1133 |
+
"execution_count": 24,
|
1134 |
"metadata": {},
|
1135 |
"output_type": "execute_result"
|
1136 |
}
|
|
|
1148 |
")"
|
1149 |
]
|
1150 |
},
|
1151 |
+
{
|
1152 |
+
"cell_type": "code",
|
1153 |
+
"execution_count": null,
|
1154 |
+
"metadata": {},
|
1155 |
+
"outputs": [
|
1156 |
+
{
|
1157 |
+
"data": {
|
1158 |
+
"text/plain": [
|
1159 |
+
"['Quantitative Biology (0.45450547337532043)',\n",
|
1160 |
+
" 'Computer Science (0.3519783318042755)',\n",
|
1161 |
+
" 'Physics (0.07536326348781586)',\n",
|
1162 |
+
" 'Statistics (0.06953499466180801)']"
|
1163 |
+
]
|
1164 |
+
},
|
1165 |
+
"execution_count": 26,
|
1166 |
+
"metadata": {},
|
1167 |
+
"output_type": "execute_result"
|
1168 |
+
}
|
1169 |
+
],
|
1170 |
+
"source": [
|
1171 |
+
"# правильный ответ Quantitative Biology\n",
|
1172 |
+
"get_most_probable_keys(\n",
|
1173 |
+
" probs_dict=get_category_probs_dict(\n",
|
1174 |
+
" model=model,\n",
|
1175 |
+
" title='Simulating cell populations with explicit cell cycle length -- implications to cell cycle dependent tumour therapy',\n",
|
1176 |
+
" summary=''\n",
|
1177 |
+
" ),\n",
|
1178 |
+
" target_probability=0.95,\n",
|
1179 |
+
" print_probabilities=True\n",
|
1180 |
+
")"
|
1181 |
+
]
|
1182 |
+
},
|
1183 |
{
|
1184 |
"cell_type": "code",
|
1185 |
"execution_count": null,
|