crystina-z commited on
Commit
aad2fb9
·
1 Parent(s): 45c093d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -107
app.py CHANGED
@@ -15,118 +15,139 @@ st.set_page_config(page_title="PSC Runtime",
15
  page_icon='🌸', layout="centered")
16
 
17
 
18
- import torch
19
- fn = "dl19-gpt-3.5.pt"
20
- object = torch.load(fn)
21
 
22
-
23
- outputs = object[2]
24
- query2outputs = {}
25
- for output in outputs:
26
- all_queries = {x['query'] for x in output}
27
- assert len(all_queries) == 1
28
- query = list(all_queries)[0]
29
- query2outputs[query] = [x['hits'] for x in output]
30
-
31
-
32
- search_query = st.selectbox(
33
  "",
34
- sorted(query2outputs),
35
  index=None,
36
- placeholder="Choose a query from the list..."
37
  )
38
 
39
- def preferences_from_hits(list_of_hits):
40
- docid2id = {}
41
- id2doc = {}
42
- preferences = []
43
-
44
- for result in list_of_hits:
45
- for doc in result:
46
- if doc["docid"] not in docid2id:
47
- id = len(docid2id)
48
- docid2id[doc["docid"]] = id
49
- id2doc[id] = doc
50
- print([doc["docid"] for doc in result])
51
- print([docid2id[doc["docid"]] for doc in result])
52
- preferences.append([docid2id[doc["docid"]] for doc in result])
53
-
54
- # = {v: k for k, v in docid2id.items()}
55
- return np.array(preferences), id2doc
56
-
57
-
58
- def load_qrels(name):
59
- import ir_datasets
60
- if name == "dl19":
61
- ds_name = "msmarco-passage/trec-dl-2019/judged"
62
- elif name == "dl20":
63
- ds_name = "msmarco-passage/trec-dl-2020/judged"
64
- else:
65
- raise ValueError(name)
66
-
67
- dataset = ir_datasets.load(ds_name)
68
- qrels = defaultdict(dict)
69
- for qrel in dataset.qrels_iter():
70
- qrels[qrel.query_id][qrel.doc_id] = qrel.relevance
71
- return qrels
72
-
73
-
74
- def aggregate(list_of_hits):
75
- import numpy as np
76
- from permsc import KemenyOptimalAggregator, sum_kendall_tau, ranks_from_preferences
77
- from permsc import BordaRankAggregator
78
-
79
- preferences, id2doc = preferences_from_hits(list_of_hits)
80
- y_optimal = KemenyOptimalAggregator().aggregate(preferences)
81
- # y_optimal = BordaRankAggregator().aggregate(preferences)
82
-
83
- return [id2doc[id] for id in y_optimal]
84
-
85
-
86
- def write_ranking(search_results):
87
- # st.write(
88
- # f'<p align=\"right\" style=\"color:grey;\"> Before aggregation for query [{search_query}] ms</p>', unsafe_allow_html=True)
89
-
90
- qid = {result["qid"] for result in search_results}
91
- assert len(qid) == 1
92
- qid = list(qid)[0]
93
 
94
- for i, result in enumerate(search_results):
95
- result_id = result["docid"]
96
- contents = result["content"]
97
 
98
- label = qrels[str(qid)].get(str(result_id), 0)
99
- if label == 3:
100
- style = "style=\"color:blue;\""
101
- elif label == 2:
102
- style = "style=\"color:green;\""
103
- elif label == 1:
104
- style = "style=\"color:red;\""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  else:
106
- style = "style=\"color:grey;\""
107
-
108
- print(qid, result_id, label, style)
109
- # output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>'
110
- output = f'<div class="row" {style}> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id}'
111
-
112
- try:
113
- st.write(output, unsafe_allow_html=True)
114
- st.write(
115
- f'<div class="row" {style}>{contents}</div>', unsafe_allow_html=True)
116
-
117
- except:
118
- pass
119
- st.write('---')
120
-
121
- aggregated_ranking = aggregate(query2outputs[search_query])
122
- qrels = load_qrels("dl19")
123
- col1, col2 = st.columns([5, 5])
124
-
125
- if search_query:
126
- with col1:
127
- if search_query or button_clicked:
128
- write_ranking(search_results=query2outputs[search_query][0])
129
-
130
- with col2:
131
- if search_query or button_clicked:
132
- write_ranking(search_results=aggregated_ranking)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  page_icon='🌸', layout="centered")
16
 
17
 
 
 
 
18
 
19
+ name = st.selectbox(
 
 
 
 
 
 
 
 
 
 
20
  "",
21
+ ["dl19", "dl20"]
22
  index=None,
23
+ placeholder="Choose a dataset..."
24
  )
25
 
26
+ model_name = st.selectbox(
27
+ "",
28
+ ["gpt-3.5", "gpt-4"]
29
+ index=None,
30
+ placeholder="Choose a model..."
31
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # "dl19"
 
 
34
 
35
+ if name and model_name:
36
+
37
+ import torch
38
+ # fn = f"dl19-gpt-3.5.pt"
39
+ fn = f"{name}-{model_name}.pt"
40
+ object = torch.load(fn)
41
+
42
+
43
+ outputs = object[2]
44
+ query2outputs = {}
45
+ for output in outputs:
46
+ all_queries = {x['query'] for x in output}
47
+ assert len(all_queries) == 1
48
+ query = list(all_queries)[0]
49
+ query2outputs[query] = [x['hits'] for x in output]
50
+
51
+
52
+ search_query = st.selectbox(
53
+ "",
54
+ sorted(query2outputs),
55
+ index=None,
56
+ placeholder="Choose a query from the list..."
57
+ )
58
+
59
+ def preferences_from_hits(list_of_hits):
60
+ docid2id = {}
61
+ id2doc = {}
62
+ preferences = []
63
+
64
+ for result in list_of_hits:
65
+ for doc in result:
66
+ if doc["docid"] not in docid2id:
67
+ id = len(docid2id)
68
+ docid2id[doc["docid"]] = id
69
+ id2doc[id] = doc
70
+ print([doc["docid"] for doc in result])
71
+ print([docid2id[doc["docid"]] for doc in result])
72
+ preferences.append([docid2id[doc["docid"]] for doc in result])
73
+
74
+ # = {v: k for k, v in docid2id.items()}
75
+ return np.array(preferences), id2doc
76
+
77
+
78
+ def load_qrels(name):
79
+ import ir_datasets
80
+ if name == "dl19":
81
+ ds_name = "msmarco-passage/trec-dl-2019/judged"
82
+ elif name == "dl20":
83
+ ds_name = "msmarco-passage/trec-dl-2020/judged"
84
  else:
85
+ raise ValueError(name)
86
+
87
+ dataset = ir_datasets.load(ds_name)
88
+ qrels = defaultdict(dict)
89
+ for qrel in dataset.qrels_iter():
90
+ qrels[qrel.query_id][qrel.doc_id] = qrel.relevance
91
+ return qrels
92
+
93
+
94
+ def aggregate(list_of_hits):
95
+ import numpy as np
96
+ from permsc import KemenyOptimalAggregator, sum_kendall_tau, ranks_from_preferences
97
+ from permsc import BordaRankAggregator
98
+
99
+ preferences, id2doc = preferences_from_hits(list_of_hits)
100
+ y_optimal = KemenyOptimalAggregator().aggregate(preferences)
101
+ # y_optimal = BordaRankAggregator().aggregate(preferences)
102
+
103
+ return [id2doc[id] for id in y_optimal]
104
+
105
+
106
+ def write_ranking(search_results):
107
+ # st.write(
108
+ # f'<p align=\"right\" style=\"color:grey;\"> Before aggregation for query [{search_query}] ms</p>', unsafe_allow_html=True)
109
+
110
+ qid = {result["qid"] for result in search_results}
111
+ assert len(qid) == 1
112
+ qid = list(qid)[0]
113
+
114
+ for i, result in enumerate(search_results):
115
+ result_id = result["docid"]
116
+ contents = result["content"]
117
+
118
+ label = qrels[str(qid)].get(str(result_id), 0)
119
+ if label == 3:
120
+ style = "style=\"color:blue;\""
121
+ elif label == 2:
122
+ style = "style=\"color:green;\""
123
+ elif label == 1:
124
+ style = "style=\"color:red;\""
125
+ else:
126
+ style = "style=\"color:grey;\""
127
+
128
+ print(qid, result_id, label, style)
129
+ # output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>'
130
+ output = f'<div class="row" {style}> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id}'
131
+
132
+ try:
133
+ st.write(output, unsafe_allow_html=True)
134
+ st.write(
135
+ f'<div class="row" {style}>{contents}</div>', unsafe_allow_html=True)
136
+
137
+ except:
138
+ pass
139
+ st.write('---')
140
+
141
+
142
+ aggregated_ranking = aggregate(query2outputs[search_query])
143
+ qrels = load_qrels(name)
144
+ col1, col2 = st.columns([5, 5])
145
+
146
+ if search_query:
147
+ with col1:
148
+ if search_query or button_clicked:
149
+ write_ranking(search_results=query2outputs[search_query][0])
150
+
151
+ with col2:
152
+ if search_query or button_clicked:
153
+ write_ranking(search_results=aggregated_ranking)