Spaces:
Runtime error
Runtime error
Commit
·
5bffde4
1
Parent(s):
5da05af
Update app.py
Browse files
app.py
CHANGED
@@ -14,30 +14,6 @@ sys.path.append(str(path_root))
|
|
14 |
st.set_page_config(page_title="PSC Runtime",
|
15 |
page_icon='🌸', layout="centered")
|
16 |
|
17 |
-
# cola, colb, colc = st.columns([5, 4, 5])
|
18 |
-
|
19 |
-
# colaa, colbb, colcc = st.columns([1, 8, 1])
|
20 |
-
# with colbb:
|
21 |
-
# runtime = st.select_slider(
|
22 |
-
# 'Select a runtime type',
|
23 |
-
# options=['PyTorch', 'ONNX Runtime'])
|
24 |
-
# st.write('Now using: ', runtime)
|
25 |
-
|
26 |
-
|
27 |
-
# colaa, colbb, colcc = st.columns([1, 8, 1])
|
28 |
-
# with colbb:
|
29 |
-
# encoder = st.select_slider(
|
30 |
-
# 'Select a query encoder',
|
31 |
-
# options=['uniCOIL', 'SPLADE++ Ensemble Distil', 'SPLADE++ Self Distil'])
|
32 |
-
# st.write('Now Running Encoder: ', encoder)
|
33 |
-
|
34 |
-
# if runtime == 'PyTorch':
|
35 |
-
# runtime = 'pytorch'
|
36 |
-
# runtime_index = 1
|
37 |
-
# else:
|
38 |
-
# runtime = 'onnx'
|
39 |
-
# runtime_index = 0
|
40 |
-
|
41 |
|
42 |
col1, col2 = st.columns([9, 1])
|
43 |
with col1:
|
@@ -51,9 +27,8 @@ with col2:
|
|
51 |
import torch
|
52 |
fn = "dl19-gpt-3.5.pt"
|
53 |
object = torch.load(fn)
|
54 |
-
|
55 |
-
|
56 |
-
# outputs = [x[2] for x in object]
|
57 |
outputs = object[2]
|
58 |
query2outputs = {}
|
59 |
for output in outputs:
|
@@ -62,8 +37,11 @@ for output in outputs:
|
|
62 |
query = list(all_queries)[0]
|
63 |
query2outputs[query] = [x['hits'] for x in output]
|
64 |
|
65 |
-
search_query = sorted(query2outputs)[0]
|
66 |
-
|
|
|
|
|
|
|
67 |
|
68 |
def preferences_from_hits(list_of_hits):
|
69 |
docid2id = {}
|
@@ -109,12 +87,6 @@ def aggregate(list_of_hits):
|
|
109 |
y_optimal = KemenyOptimalAggregator().aggregate(preferences)
|
110 |
# y_optimal = BordaRankAggregator().aggregate(preferences)
|
111 |
|
112 |
-
# print("-------------------------------------")
|
113 |
-
# print("preference:")
|
114 |
-
# print(preferences)
|
115 |
-
# print("preferences shape: ", preferences.shape)
|
116 |
-
# print("y_optimal: ", y_optimal)
|
117 |
-
|
118 |
return [id2doc[id] for id in y_optimal]
|
119 |
|
120 |
|
@@ -130,7 +102,7 @@ def write_ranking(search_results):
|
|
130 |
result_id = result["docid"]
|
131 |
contents = result["content"]
|
132 |
|
133 |
-
label = qrels[qid].get(result_id, 0)
|
134 |
if label == 3:
|
135 |
style = "style=\"color:blue;\""
|
136 |
elif label == 2:
|
@@ -155,13 +127,13 @@ def write_ranking(search_results):
|
|
155 |
|
156 |
aggregated_ranking = aggregate(query2outputs[search_query])
|
157 |
qrels = load_qrels("dl19")
|
158 |
-
|
159 |
col1, col2 = st.columns([5, 5])
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
14 |
st.set_page_config(page_title="PSC Runtime",
|
15 |
page_icon='🌸', layout="centered")
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
col1, col2 = st.columns([9, 1])
|
19 |
with col1:
|
|
|
27 |
import torch
|
28 |
fn = "dl19-gpt-3.5.pt"
|
29 |
object = torch.load(fn)
|
30 |
+
|
31 |
+
|
|
|
32 |
outputs = object[2]
|
33 |
query2outputs = {}
|
34 |
for output in outputs:
|
|
|
37 |
query = list(all_queries)[0]
|
38 |
query2outputs[query] = [x['hits'] for x in output]
|
39 |
|
40 |
+
# search_query = sorted(query2outputs)[0]
|
41 |
+
search_query = st.selectbox(
|
42 |
+
"Choose a query from the list",
|
43 |
+
sorted(query2outputs)
|
44 |
+
)
|
45 |
|
46 |
def preferences_from_hits(list_of_hits):
|
47 |
docid2id = {}
|
|
|
87 |
y_optimal = KemenyOptimalAggregator().aggregate(preferences)
|
88 |
# y_optimal = BordaRankAggregator().aggregate(preferences)
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
return [id2doc[id] for id in y_optimal]
|
91 |
|
92 |
|
|
|
102 |
result_id = result["docid"]
|
103 |
contents = result["content"]
|
104 |
|
105 |
+
label = qrels[str(qid)].get(str(result_id), 0)
|
106 |
if label == 3:
|
107 |
style = "style=\"color:blue;\""
|
108 |
elif label == 2:
|
|
|
127 |
|
128 |
aggregated_ranking = aggregate(query2outputs[search_query])
|
129 |
qrels = load_qrels("dl19")
|
|
|
130 |
col1, col2 = st.columns([5, 5])
|
131 |
|
132 |
+
if search_query:
|
133 |
+
with col1:
|
134 |
+
if search_query or button_clicked:
|
135 |
+
write_ranking(search_results=query2outputs[search_query][0])
|
136 |
+
|
137 |
+
with col2:
|
138 |
+
if search_query or button_clicked:
|
139 |
+
write_ranking(search_results=aggregated_ranking)
|