Spaces:
Runtime error
Runtime error
import time | |
import json | |
import numpy as np | |
import streamlit as st | |
from pathlib import Path | |
from collections import defaultdict | |
import sys | |
path_root = Path("./") | |
sys.path.append(str(path_root)) | |
st.set_page_config(page_title="PSC Runtime", | |
page_icon='πΈ', layout="centered") | |
# cola, colb, colc = st.columns([5, 4, 5]) | |
# colaa, colbb, colcc = st.columns([1, 8, 1]) | |
# with colbb: | |
# runtime = st.select_slider( | |
# 'Select a runtime type', | |
# options=['PyTorch', 'ONNX Runtime']) | |
# st.write('Now using: ', runtime) | |
# colaa, colbb, colcc = st.columns([1, 8, 1]) | |
# with colbb: | |
# encoder = st.select_slider( | |
# 'Select a query encoder', | |
# options=['uniCOIL', 'SPLADE++ Ensemble Distil', 'SPLADE++ Self Distil']) | |
# st.write('Now Running Encoder: ', encoder) | |
# if runtime == 'PyTorch': | |
# runtime = 'pytorch' | |
# runtime_index = 1 | |
# else: | |
# runtime = 'onnx' | |
# runtime_index = 0 | |
col1, col2 = st.columns([9, 1]) | |
with col1: | |
search_query = st.text_input(label="search query", placeholder="Search") | |
with col2: | |
st.write('#') | |
button_clicked = st.button("π") | |
import torch | |
fn = "dl19-gpt-3.5.pt" | |
object = torch.load(fn) | |
# for x for x in object: | |
# outputs = [x[2] for x in object] | |
outputs = object[2] | |
query2outputs = {} | |
for output in outputs: | |
all_queries = {x['query'] for x in output} | |
assert len(all_queries) == 1 | |
query = list(all_queries)[0] | |
query2outputs[query] = [x['hits'] for x in output] | |
search_query = sorted(query2outputs)[0] | |
def preferences_from_hits(list_of_hits): | |
docid2id = {} | |
id2doc = {} | |
preferences = [] | |
for result in list_of_hits: | |
for doc in result: | |
if doc["docid"] not in docid2id: | |
id = len(docid2id) | |
docid2id[doc["docid"]] = id | |
id2doc[id] = doc | |
print([doc["docid"] for doc in result]) | |
print([docid2id[doc["docid"]] for doc in result]) | |
preferences.append([docid2id[doc["docid"]] for doc in result]) | |
# = {v: k for k, v in docid2id.items()} | |
return np.array(preferences), id2doc | |
def load_qrels(name): | |
import ir_datasets | |
if name == "dl19": | |
ds_name = "msmarco-passage/trec-dl-2019/judged" | |
elif name == "dl20": | |
ds_name = "msmarco-passage/trec-dl-2020/judged" | |
else: | |
raise ValueError(name) | |
dataset = ir_datasets.load(ds_name) | |
qrels = defaultdict(dict) | |
for qrel in dataset.qrels_iter(): | |
qrels[qrel.query_id][qrel.doc_id] = qrel.relevance | |
return qrels | |
def aggregate(list_of_hits): | |
import numpy as np | |
from permsc import KemenyOptimalAggregator, sum_kendall_tau, ranks_from_preferences | |
from permsc import BordaRankAggregator | |
preferences, id2doc = preferences_from_hits(list_of_hits) | |
y_optimal = KemenyOptimalAggregator().aggregate(preferences) | |
# y_optimal = BordaRankAggregator().aggregate(preferences) | |
# print("-------------------------------------") | |
# print("preference:") | |
# print(preferences) | |
# print("preferences shape: ", preferences.shape) | |
# print("y_optimal: ", y_optimal) | |
return [id2doc[id] for id in y_optimal] | |
aggregated_ranking = aggregate(query2outputs[search_query]) | |
qrels = load_qrels("dl19") | |
col1, col2 = st.columns([5, 5]) | |
with col2: | |
if search_query or button_clicked: | |
num_results = None | |
t_0 = time.time() | |
# search_results = query2outputs[search_query][0] # first from the 20 | |
search_results = aggregated_ranking | |
st.write( | |
f'<p align=\"right\" style=\"color:grey;\"> Before aggregation for query [{search_query}] ms</p>', unsafe_allow_html=True) | |
qid = {result["qid"] for result in search_results} | |
assert len(qid) == 1 | |
qid = list(qid)[0] | |
for i, result in enumerate(search_results): | |
result_id = result["docid"] | |
contents = result["content"] | |
label = qrels[qid].get(result_id, 0) | |
if label == 3: | |
style = "style=\"color:blue;\"" | |
elif label == 2: | |
style = "style=\"color:green;\"" | |
elif label == 1: | |
style = "style=\"color:red;\"" | |
else: | |
style = "style=\"color:grey;\"" | |
# output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>' | |
output = f'<div class="row" {style}> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id}' | |
try: | |
st.write(output, unsafe_allow_html=True) | |
st.write( | |
f'<div class="row" {style}>{contents}</div>', unsafe_allow_html=True) | |
except: | |
pass | |
st.write('---') |