File size: 5,017 Bytes
83870cc
51dabd6
 
be1f224
 
 
51a31d4
 
be1f224
51a31d4
be1f224
 
 
51a31d4
be1f224
 
51dabd6
b06298d
51a31d4
b7158e7
b06298d
8bbe3aa
51a31d4
 
ab5dfc2
51a31d4
83870cc
51a31d4
83870cc
51a31d4
be1f224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab5dfc2
8bbe3aa
 
 
 
be1f224
 
8bbe3aa
 
be1f224
 
8bbe3aa
be1f224
 
8bbe3aa
 
be1f224
 
8bbe3aa
e9df5ab
be1f224
1fb8ae3
 
8bbe3aa
be1f224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fb8ae3
ab5dfc2
 
1fb8ae3
e9df5ab
b7158e7
 
1fb8ae3
 
 
83870cc
 
 
be1f224
83870cc
1fb8ae3
8bbe3aa
83870cc
ab5dfc2
1fb8ae3
 
8bbe3aa
1fb8ae3
8bbe3aa
b06298d
 
be1f224
1fb8ae3
83870cc
 
2827202
8bbe3aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import os.path
import torch

from datasets import DatasetDict
from dataclasses import dataclass
from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizerFast,
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizerFast,
    LongformerModel,
    LongformerTokenizerFast
)
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

from src.retrievers.base_retriever import RetrieveType, Retriever
from src.utils.log import get_logger
from src.utils.preprocessing import remove_formulas
from src.utils.timing import timeit

# Hacky fix for FAISS error on macOS
# See https://stackoverflow.com/a/63374568/4545692
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"


logger = get_logger()


@dataclass
class FaissRetrieverOptions:
    ctx_encoder: PreTrainedModel
    ctx_tokenizer: PreTrainedTokenizerFast
    q_encoder: PreTrainedModel
    q_tokenizer: PreTrainedTokenizerFast
    embedding_path: str
    lm: str

    @staticmethod
    def dpr(embedding_path: str):
        return FaissRetrieverOptions(
            ctx_encoder=DPRContextEncoder.from_pretrained(
                "facebook/dpr-ctx_encoder-single-nq-base"
            ),
            ctx_tokenizer=DPRContextEncoderTokenizerFast.from_pretrained(
                "facebook/dpr-ctx_encoder-single-nq-base"
            ),
            q_encoder=DPRQuestionEncoder.from_pretrained(
                "facebook/dpr-question_encoder-single-nq-base"
            ),
            q_tokenizer=DPRQuestionEncoderTokenizerFast.from_pretrained(
                "facebook/dpr-question_encoder-single-nq-base"
            ),
            embedding_path=embedding_path,
            lm="dpr"
        )

    @staticmethod
    def longformer(embedding_path: str):
        encoder = LongformerModel.from_pretrained(
            "allenai/longformer-base-4096"
        )
        tokenizer = LongformerTokenizerFast.from_pretrained(
            "allenai/longformer-base-4096"
        )
        return FaissRetrieverOptions(
            ctx_encoder=encoder,
            ctx_tokenizer=tokenizer,
            q_encoder=encoder,
            q_tokenizer=tokenizer,
            embedding_path=embedding_path,
            lm="longformer"
        )


class FaissRetriever(Retriever):
    """A class used to retrieve relevant documents based on some query.
    based on https://huggingface.co/docs/datasets/faiss_es#faiss.
    """

    def __init__(self, paragraphs: DatasetDict,
                 options: FaissRetrieverOptions) -> None:
        torch.set_grad_enabled(False)

        self.lm = options.lm

        # Context encoding and tokenization
        self.ctx_encoder = options.ctx_encoder
        self.ctx_tokenizer = options.ctx_tokenizer

        # Question encoding and tokenization
        self.q_encoder = options.q_encoder
        self.q_tokenizer = options.q_tokenizer

        self.paragraphs = paragraphs
        self.embedding_path = options.embedding_path

        self.index = self._init_index()

    def _embed_question(self, q):
        match self.lm:
            case "dpr":
                tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
                return self.q_encoder(**tok)[0][0].numpy()
            case "longformer":
                tok = self.q_tokenizer(q, return_tensors="pt")
                return self.q_encoder(**tok).last_hidden_state[0][0].numpy()

    def _embed_context(self, row):
        p = row["text"]

        match self.lm:
            case "dpr":
                tok = self.ctx_tokenizer(
                    p, return_tensors="pt", truncation=True)
                enc = self.ctx_encoder(**tok)[0][0].numpy()
                return {"embeddings": enc}
            case "longformer":
                tok = self.ctx_tokenizer(p, return_tensors="pt")
                enc = self.ctx_encoder(**tok).last_hidden_state[0][0].numpy()
                return {"embeddings": enc}

    def _init_index(
            self,
            force_new_embedding: bool = False):

        ds = self.paragraphs["train"]
        ds = ds.map(remove_formulas)

        if not force_new_embedding and os.path.exists(self.embedding_path):
            ds.load_faiss_index(
                'embeddings', self.embedding_path)  # type: ignore
            return ds
        else:
            # Add FAISS embeddings
            index = ds.map(self._embed_context)  # type: ignore

            index.add_faiss_index(column="embeddings")

            # save dataset w/ embeddings
            os.makedirs("./src/models/", exist_ok=True)
            index.save_faiss_index(
                "embeddings", self.embedding_path)

            return index

    @timeit("faissretriever.retrieve")
    def retrieve(self, query: str, k: int = 5) -> RetrieveType:
        question_embedding = self._embed_question(query)
        scores, results = self.index.get_nearest_examples(
            "embeddings", question_embedding, k=k
        )

        return scores, results