crystina-z commited on
Commit
5bffde4
·
1 Parent(s): 5da05af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -44
app.py CHANGED
@@ -14,30 +14,6 @@ sys.path.append(str(path_root))
14
  st.set_page_config(page_title="PSC Runtime",
15
  page_icon='🌸', layout="centered")
16
 
17
- # cola, colb, colc = st.columns([5, 4, 5])
18
-
19
- # colaa, colbb, colcc = st.columns([1, 8, 1])
20
- # with colbb:
21
- # runtime = st.select_slider(
22
- # 'Select a runtime type',
23
- # options=['PyTorch', 'ONNX Runtime'])
24
- # st.write('Now using: ', runtime)
25
-
26
-
27
- # colaa, colbb, colcc = st.columns([1, 8, 1])
28
- # with colbb:
29
- # encoder = st.select_slider(
30
- # 'Select a query encoder',
31
- # options=['uniCOIL', 'SPLADE++ Ensemble Distil', 'SPLADE++ Self Distil'])
32
- # st.write('Now Running Encoder: ', encoder)
33
-
34
- # if runtime == 'PyTorch':
35
- # runtime = 'pytorch'
36
- # runtime_index = 1
37
- # else:
38
- # runtime = 'onnx'
39
- # runtime_index = 0
40
-
41
 
42
  col1, col2 = st.columns([9, 1])
43
  with col1:
@@ -51,9 +27,8 @@ with col2:
51
  import torch
52
  fn = "dl19-gpt-3.5.pt"
53
  object = torch.load(fn)
54
- # for x for x in object:
55
-
56
- # outputs = [x[2] for x in object]
57
  outputs = object[2]
58
  query2outputs = {}
59
  for output in outputs:
@@ -62,8 +37,11 @@ for output in outputs:
62
  query = list(all_queries)[0]
63
  query2outputs[query] = [x['hits'] for x in output]
64
 
65
- search_query = sorted(query2outputs)[0]
66
-
 
 
 
67
 
68
  def preferences_from_hits(list_of_hits):
69
  docid2id = {}
@@ -109,12 +87,6 @@ def aggregate(list_of_hits):
109
  y_optimal = KemenyOptimalAggregator().aggregate(preferences)
110
  # y_optimal = BordaRankAggregator().aggregate(preferences)
111
 
112
- # print("-------------------------------------")
113
- # print("preference:")
114
- # print(preferences)
115
- # print("preferences shape: ", preferences.shape)
116
- # print("y_optimal: ", y_optimal)
117
-
118
  return [id2doc[id] for id in y_optimal]
119
 
120
 
@@ -130,7 +102,7 @@ def write_ranking(search_results):
130
  result_id = result["docid"]
131
  contents = result["content"]
132
 
133
- label = qrels[qid].get(result_id, 0)
134
  if label == 3:
135
  style = "style=\"color:blue;\""
136
  elif label == 2:
@@ -155,13 +127,13 @@ def write_ranking(search_results):
155
 
156
  aggregated_ranking = aggregate(query2outputs[search_query])
157
  qrels = load_qrels("dl19")
158
-
159
  col1, col2 = st.columns([5, 5])
160
 
161
- with col1:
162
- if search_query or button_clicked:
163
- write_ranking(search_results=query2outputs[search_query][0])
164
-
165
- with col2:
166
- if search_query or button_clicked:
167
- write_ranking(search_results=aggregated_ranking)
 
 
14
  st.set_page_config(page_title="PSC Runtime",
15
  page_icon='🌸', layout="centered")
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  col1, col2 = st.columns([9, 1])
19
  with col1:
 
27
  import torch
28
  fn = "dl19-gpt-3.5.pt"
29
  object = torch.load(fn)
30
+
31
+
 
32
  outputs = object[2]
33
  query2outputs = {}
34
  for output in outputs:
 
37
  query = list(all_queries)[0]
38
  query2outputs[query] = [x['hits'] for x in output]
39
 
40
+ # search_query = sorted(query2outputs)[0]
41
+ search_query = st.selectbox(
42
+ "Choose a query from the list",
43
+ sorted(query2outputs)
44
+ )
45
 
46
  def preferences_from_hits(list_of_hits):
47
  docid2id = {}
 
87
  y_optimal = KemenyOptimalAggregator().aggregate(preferences)
88
  # y_optimal = BordaRankAggregator().aggregate(preferences)
89
 
 
 
 
 
 
 
90
  return [id2doc[id] for id in y_optimal]
91
 
92
 
 
102
  result_id = result["docid"]
103
  contents = result["content"]
104
 
105
+ label = qrels[str(qid)].get(str(result_id), 0)
106
  if label == 3:
107
  style = "style=\"color:blue;\""
108
  elif label == 2:
 
127
 
128
  aggregated_ranking = aggregate(query2outputs[search_query])
129
  qrels = load_qrels("dl19")
 
130
  col1, col2 = st.columns([5, 5])
131
 
132
+ if search_query:
133
+ with col1:
134
+ if search_query or button_clicked:
135
+ write_ranking(search_results=query2outputs[search_query][0])
136
+
137
+ with col2:
138
+ if search_query or button_clicked:
139
+ write_ranking(search_results=aggregated_ranking)