shenjingwen commited on
Commit
d3c7abc
Β·
verified Β·
1 Parent(s): 295cea1

Update vector_search.py

Browse files
Files changed (1) hide show
  1. vector_search.py +4 -2
vector_search.py CHANGED
@@ -9,6 +9,8 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration
9
 
10
 
11
  class HybridVectorSearch:
 
 
12
  cuda_device = torch.device("cpu")
13
  sparse_model = "naver/splade-v3"
14
  tokenizer = AutoTokenizer.from_pretrained(sparse_model)
@@ -19,7 +21,7 @@ class HybridVectorSearch:
19
 
20
  model_name_t5 = "Falconsai/text_summarization" # "t5-small"
21
  tokenizer_t5 = T5Tokenizer.from_pretrained(model_name_t5)
22
- model_t5 = T5ForConditionalGeneration.from_pretrained(model_name_t5).to("cuda")
23
 
24
  client = QdrantClient(url="http://localhost:6333")
25
  earnings_collection = "earnings_calls"
@@ -54,7 +56,7 @@ class HybridVectorSearch:
54
  def summary(text: str):
55
  inputs = HybridVectorSearch.tokenizer_t5.encode(
56
  f"summarize: {text}", return_tensors="pt", max_length=1024, truncation=True
57
- ).to("cuda")
58
  summary_ids = HybridVectorSearch.model_t5.generate(
59
  inputs,
60
  max_length=512,
 
9
 
10
 
11
  class HybridVectorSearch:
12
+ # dd="cuda"
13
+ dd="cpu"
14
  cuda_device = torch.device("cpu")
15
  sparse_model = "naver/splade-v3"
16
  tokenizer = AutoTokenizer.from_pretrained(sparse_model)
 
21
 
22
  model_name_t5 = "Falconsai/text_summarization" # "t5-small"
23
  tokenizer_t5 = T5Tokenizer.from_pretrained(model_name_t5)
24
+ model_t5 = T5ForConditionalGeneration.from_pretrained(model_name_t5).to(dd)
25
 
26
  client = QdrantClient(url="http://localhost:6333")
27
  earnings_collection = "earnings_calls"
 
56
  def summary(text: str):
57
  inputs = HybridVectorSearch.tokenizer_t5.encode(
58
  f"summarize: {text}", return_tensors="pt", max_length=1024, truncation=True
59
+ ).to(dd)
60
  summary_ids = HybridVectorSearch.model_t5.generate(
61
  inputs,
62
  max_length=512,