Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import os.path as osp | |
| import torch | |
| from typing import Any | |
| from src.models.model import ModelForSemiStructQA | |
| from src.models.vss import VSS | |
| from src.tools.api import get_openai_embeddings | |
| from src.tools.process_text import chunk_text | |
| class MultiVSS(ModelForSemiStructQA): | |
| def __init__(self, | |
| kb, | |
| query_emb_dir, | |
| candidates_emb_dir, | |
| chunk_emb_dir, | |
| emb_model='text-embedding-ada-002', | |
| aggregate='top3_avg', | |
| max_k=50, | |
| chunk_size=256): | |
| ''' | |
| Multivector Vector Similarity Search | |
| Args: | |
| kb (src.benchmarks.semistruct.SemiStruct): kb | |
| query_emb_dir (str): directory to query embeddings | |
| candidates_emb_dir (str): directory to candidate embeddings | |
| chunk_emb_dir (str): directory to chunk embeddings | |
| ''' | |
| super().__init__(kb) | |
| self.kb = kb | |
| self.aggregate = aggregate # 'max', 'avg', 'top{k}_avg' | |
| self.max_k = max_k | |
| self.chunk_size = chunk_size | |
| self.emb_model = emb_model | |
| self.query_emb_dir = query_emb_dir | |
| self.chunk_emb_dir = chunk_emb_dir | |
| self.candidates_emb_dir = candidates_emb_dir | |
| self.parent_vss = VSS(kb, query_emb_dir, candidates_emb_dir, emb_model=emb_model) | |
| def forward(self, | |
| query, | |
| query_id, | |
| **kwargs: Any): | |
| query_emb = self._get_query_emb(query, query_id) | |
| initial_score_dict = self.parent_vss(query, query_id) | |
| node_ids = list(initial_score_dict.keys()) | |
| node_scores = list(initial_score_dict.values()) | |
| # get the ids with top k highest scores | |
| top_k_idx = torch.topk(torch.FloatTensor(node_scores), | |
| min(self.max_k, len(node_scores)), | |
| dim=-1 | |
| ).indices.view(-1).tolist() | |
| top_k_node_ids = [node_ids[i] for i in top_k_idx] | |
| pred_dict = {} | |
| for node_id in top_k_node_ids: | |
| doc = self.kb.get_doc_info(node_id, add_rel=True, compact=True) | |
| chunks = chunk_text(doc, chunk_size=self.chunk_size) | |
| chunk_path = osp.join(self.chunk_emb_dir, f'{node_id}_size={self.chunk_size}.pt') | |
| if osp.exists(chunk_path): | |
| chunk_embs = torch.load(chunk_path) | |
| else: | |
| chunk_embs = get_openai_embeddings(chunks, | |
| model=self.emb_model) | |
| torch.save(chunk_embs, chunk_path) | |
| print(f'chunk_embs.shape: {chunk_embs.shape}') | |
| similarity = torch.matmul(query_emb.cuda(), chunk_embs.cuda().T).cpu().view(-1) | |
| if self.aggregate == 'max': | |
| pred_dict[node_id] = torch.max(similarity).item() | |
| elif self.aggregate == 'avg': | |
| pred_dict[node_id] = torch.mean(similarity).item() | |
| elif 'top' in self.aggregate: | |
| k = int(self.aggregate.split('_')[0][len('top'):]) | |
| pred_dict[node_id] = torch.mean(torch.topk(similarity, k=min(k, len(chunks)), dim=-1).values).item() | |
| return pred_dict |