File size: 3,215 Bytes
934f74d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a48ed2
934f74d
7b626c7
 
 
 
934f74d
 
 
 
 
 
 
 
 
b8c5764
 
 
 
608503a
b8c5764
 
 
24e4b2c
b8c5764
 
608503a
 
 
b8c5764
 
24e4b2c
608503a
 
 
 
b8c5764
6c34254
 
b8c5764
934f74d
6c34254
934f74d
 
6c34254
 
934f74d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import time
import json

import streamlit as st
from pathlib import Path

import sys
path_root = Path("./")
sys.path.append(str(path_root))


st.set_page_config(page_title="PSC Runtime",
                   page_icon='🌸', layout="centered")

# cola, colb, colc = st.columns([5, 4, 5])

# colaa, colbb, colcc = st.columns([1, 8, 1])
# with colbb:
#     runtime = st.select_slider(
#         'Select a runtime type',
#         options=['PyTorch', 'ONNX Runtime'])
#     st.write('Now using: ', runtime)


# colaa, colbb, colcc = st.columns([1, 8, 1])
# with colbb:
#     encoder = st.select_slider(
#         'Select a query encoder',
#         options=['uniCOIL', 'SPLADE++ Ensemble Distil', 'SPLADE++ Self Distil'])
#     st.write('Now Running Encoder: ', encoder)

# if runtime == 'PyTorch':
#     runtime = 'pytorch'
#     runtime_index = 1
# else:
#     runtime = 'onnx'
#     runtime_index = 0


col1, col2 = st.columns([9, 1])
with col1:
    search_query = st.text_input(label="search query", placeholder="Search")

with col2:
    st.write('#')
    button_clicked = st.button("🔎")


import torch
fn = "dl19-gpt-3.5.pt"
object = torch.load(fn)
# for x for x in object:
    
# outputs = [x[2] for x in object]
outputs = object[2]
query2outputs = {}
for output in outputs:
    all_queries = {x['query'] for x in output}
    assert len(all_queries) == 1
    query = list(all_queries)[0]
    query2outputs[query] = [x['hits'] for x in output]

search_query = sorted(query2outputs)[0]


def aggregate(list_of_hits):
    import numpy as np
    from permsc import KemenyOptimalAggregator, sum_kendall_tau, ranks_from_preferences
    from permsc import BordaRankAggregator

    preferences = []
    for result in list_of_hits:
        preferences.append([doc["rank"] - 1 for doc in result])

    preferences = np.array(preferences)
    # y_optimal = KemenyOptimalAggregator().aggregate(preferences)
    y_optimal = BordaRankAggregator().aggregate(preferences)

    rank2doc = {}
    for doc in list_of_hits[0]:
        rank2doc[doc["rank"] - 1] = doc

    print("preferences: ", preferences.shape, preferences[0])
    print("rank2doc:", rank2doc.keys())
    print("y_optimal: ", y_optimal)
    return [rank2doc[rank] for rank in y_optimal]

aggregated_ranking = aggregate(query2outputs[search_query])

if search_query or button_clicked:

    num_results = None
    t_0 = time.time()
    # search_results = query2outputs[search_query][0] # first from the 20
    search_results = aggregated_ranking

    st.write(
        f'<p align=\"right\" style=\"color:grey;\"> Before aggregation for query [{search_query}] ms</p>', unsafe_allow_html=True)

    for i, result in enumerate(search_results):
        result_id = result["docid"]
        contents = result["content"]

        # output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>'
        output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id}'

        try:
            st.write(output, unsafe_allow_html=True)
            st.write(
                f'<div class="row">{contents}</div>', unsafe_allow_html=True)

        except:
            pass
        st.write('---')