File size: 2,182 Bytes
51a31d4
51dabd6
ab5dfc2
 
51a31d4
ab5dfc2
51a31d4
51dabd6
ab5dfc2
 
 
 
 
51dabd6
ab5dfc2
 
 
 
51dabd6
 
51a31d4
 
 
 
 
 
ab5dfc2
51a31d4
51dabd6
ab5dfc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51dabd6
ab5dfc2
 
 
 
 
 
 
 
51a31d4
ab5dfc2
51dabd6
ab5dfc2
 
 
51dabd6
ab5dfc2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from datasets import DatasetDict, load_dataset

from src.readers.dpr_reader import DprReader
from src.retrievers.faiss_retriever import FaissRetriever
from src.utils.log import get_logger
# from src.evaluation import evaluate
from typing import cast

from src.utils.preprocessing import result_to_reader_input

import torch
import transformers
import os

os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'

logger = get_logger()
transformers.logging.set_verbosity_error()

if __name__ == '__main__':
    dataset_name = "GroNLP/ik-nlp-22_slp"
    paragraphs = load_dataset(dataset_name, "paragraphs")
    questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))

    questions_test = questions["test"]

    # logger.info(questions)

    # Initialize retriever
    retriever = FaissRetriever()

    # Retrieve example
    example_q = questions_test.shuffle()["question"][0]
    scores, result = retriever.retrieve(example_q)

    reader_input = result_to_reader_input(result)

    # Initialize reader
    reader = DprReader()
    answers = reader.read(example_q, reader_input)

    # Calculate softmaxed scores for readable output
    sm = torch.nn.Softmax(dim=0)
    document_scores = sm(torch.Tensor(
        [pred.relevance_score for pred in answers]))
    span_scores = sm(torch.Tensor(
        [pred.span_score for pred in answers]))

    print(example_q)
    for answer_i, answer in enumerate(answers):
        print(f"[{answer_i + 1}]: {answer.text}")
        print(f"\tDocument {answer.doc_id}", end='')
        print(f"\t(score {document_scores[answer_i] * 100:.02f})")
        print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
        print(f"\t(score {span_scores[answer_i] * 100:.02f})")
        print()  # Newline

    # print(f"Example q: {example_q} answer: {result['text'][0]}")

    # for i, score in enumerate(scores):
    #     print(f"Result {i+1} (score: {score:.02f}):")
    #     print(result['text'][i])

    # # Compute overall performance
    # exact_match, f1_score = evaluate(
    #     r, questions_test["question"], questions_test["answer"])
    # print(f"Exact match: {exact_match:.02f}\n", f"F1-score: {f1_score:.02f}")