crystina-z commited on
Commit
7a18ba5
·
1 Parent(s): 9caffdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -7
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
 
5
  import streamlit as st
6
  from pathlib import Path
 
7
 
8
  import sys
9
  path_root = Path("./")
@@ -83,6 +84,22 @@ def preferences_from_hits(list_of_hits):
83
  return np.array(preferences), id2doc
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def aggregate(list_of_hits):
87
  import numpy as np
88
  from permsc import KemenyOptimalAggregator, sum_kendall_tau, ranks_from_preferences
@@ -92,15 +109,16 @@ def aggregate(list_of_hits):
92
  y_optimal = KemenyOptimalAggregator().aggregate(preferences)
93
  # y_optimal = BordaRankAggregator().aggregate(preferences)
94
 
95
- print("-------------------------------------")
96
- print("preference:")
97
- print(preferences)
98
- print("preferences shape: ", preferences.shape)
99
- print("y_optimal: ", y_optimal)
100
 
101
  return [id2doc[id] for id in y_optimal]
102
 
103
  aggregated_ranking = aggregate(query2outputs[search_query])
 
104
 
105
  if search_query or button_clicked:
106
 
@@ -112,17 +130,22 @@ if search_query or button_clicked:
112
  st.write(
113
  f'<p align=\"right\" style=\"color:grey;\"> Before aggregation for query [{search_query}] ms</p>', unsafe_allow_html=True)
114
 
 
 
 
 
115
  for i, result in enumerate(search_results):
116
  result_id = result["docid"]
117
  contents = result["content"]
118
 
 
119
  # output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>'
120
- output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id}'
121
 
122
  try:
123
  st.write(output, unsafe_allow_html=True)
124
  st.write(
125
- f'<div class="row">{contents}</div>', unsafe_allow_html=True)
126
 
127
  except:
128
  pass
 
4
 
5
  import streamlit as st
6
  from pathlib import Path
7
+ from collections import defaultdict
8
 
9
  import sys
10
  path_root = Path("./")
 
84
  return np.array(preferences), id2doc
85
 
86
 
87
+ def load_qrels(name):
88
+ import ir_datasets
89
+ if name == "dl19":
90
+ ds_name = "msmarco-passage/trec-dl-2019/judged"
91
+ elif name == "dl20":
92
+ ds_name = "msmarco-passage/trec-dl-2020/judged"
93
+ else:
94
+ raise ValueError(name)
95
+
96
+ dataset = ir_datasets.load(ds_name)
97
+ qrels = defaultdict(dict)
98
+ for qrel in dataset.qrels_iter():
99
+ qrels[qrel.query_id][qrel.doc_id] = qrel.relevance
100
+ return qrels
101
+
102
+
103
  def aggregate(list_of_hits):
104
  import numpy as np
105
  from permsc import KemenyOptimalAggregator, sum_kendall_tau, ranks_from_preferences
 
109
  y_optimal = KemenyOptimalAggregator().aggregate(preferences)
110
  # y_optimal = BordaRankAggregator().aggregate(preferences)
111
 
112
+ # print("-------------------------------------")
113
+ # print("preference:")
114
+ # print(preferences)
115
+ # print("preferences shape: ", preferences.shape)
116
+ # print("y_optimal: ", y_optimal)
117
 
118
  return [id2doc[id] for id in y_optimal]
119
 
120
  aggregated_ranking = aggregate(query2outputs[search_query])
121
+ qrels = load_qrels("dl19")
122
 
123
  if search_query or button_clicked:
124
 
 
130
  st.write(
131
  f'<p align=\"right\" style=\"color:grey;\"> Before aggregation for query [{search_query}] ms</p>', unsafe_allow_html=True)
132
 
133
+ qid = {result["qid"] for result in search_results}
134
+ assert len(qid) == 1
135
+ qid = list(qid)[0]
136
+
137
  for i, result in enumerate(search_results):
138
  result_id = result["docid"]
139
  contents = result["content"]
140
 
141
+ style = "style=\"color:grey;\"" if qrels[qid].get(result_id, 0) else "style=\"color:red;\""
142
  # output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>'
143
+ output = f'<div class="row" {style}> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id}'
144
 
145
  try:
146
  st.write(output, unsafe_allow_html=True)
147
  st.write(
148
+ f'<div class="row" {style}>{contents}</div>', unsafe_allow_html=True)
149
 
150
  except:
151
  pass