shenjingwen commited on
Commit
ab176c3
ยท
verified ยท
1 Parent(s): 19e123a

Upload 2 files

Browse files
Files changed (2) hide show
  1. vector_search.py +140 -0
  2. whisper_asr.py +102 -0
vector_search.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
3
+ from langchain_text_splitters import SpacyTextSplitter
4
+ from sentence_transformers import SentenceTransformer
5
+ from typing import Dict, List
6
+ import torch
7
+ from qdrant_client import http, models, QdrantClient
8
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
9
+
10
+
11
+ class HybridVectorSearch:
12
+ cuda_device = torch.device("cpu")
13
+ sparse_model = "naver/splade-v3"
14
+ tokenizer = AutoTokenizer.from_pretrained(sparse_model)
15
+ model = AutoModelForMaskedLM.from_pretrained(sparse_model).to(cuda_device)
16
+
17
+ text_splitter = SpacyTextSplitter(chunk_size=1000)
18
+ dense_encoder = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
19
+
20
+ model_name_t5 = "Falconsai/text_summarization" # "t5-small"
21
+ tokenizer_t5 = T5Tokenizer.from_pretrained(model_name_t5)
22
+ model_t5 = T5ForConditionalGeneration.from_pretrained(model_name_t5).to("cuda")
23
+
24
+ client = QdrantClient(url="http://localhost:6333")
25
+ earnings_collection = "earnings_calls"
26
+
27
+ @staticmethod
28
+ def reciprocal_rank_fusion(
29
+ responses: List[List[http.models.ScoredPoint]], limit: int = 10
30
+ ) -> List[http.models.ScoredPoint]:
31
+ def compute_score(pos: int) -> float:
32
+ ranking_constant = 2 # the constant mitigates the impact of high rankings by outlier systems
33
+ return 1 / (ranking_constant + pos)
34
+
35
+ scores: Dict[http.models.ExtendedPointId, float] = {}
36
+ point_pile = {}
37
+ for response in responses:
38
+ for i, scored_point in enumerate(response):
39
+ if scored_point.id in scores:
40
+ scores[scored_point.id] += compute_score(i)
41
+ else:
42
+ point_pile[scored_point.id] = scored_point
43
+ scores[scored_point.id] = compute_score(i)
44
+
45
+ sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
46
+ sorted_points = []
47
+ for point_id, score in sorted_scores[:limit]:
48
+ point = point_pile[point_id]
49
+ point.score = score
50
+ sorted_points.append(point)
51
+ return sorted_points
52
+
53
+ @staticmethod
54
+ def summary(text: str):
55
+ inputs = HybridVectorSearch.tokenizer_t5.encode(
56
+ f"summarize: {text}", return_tensors="pt", max_length=1024, truncation=True
57
+ ).to("cuda")
58
+ summary_ids = HybridVectorSearch.model_t5.generate(
59
+ inputs,
60
+ max_length=512,
61
+ min_length=100,
62
+ length_penalty=2.0,
63
+ num_beams=4,
64
+ early_stopping=True,
65
+ )
66
+ summary = HybridVectorSearch.tokenizer_t5.decode(
67
+ summary_ids[0], skip_special_tokens=True
68
+ )
69
+ return summary
70
+
71
+ @staticmethod
72
+ def compute_vector(text):
73
+ tokens = HybridVectorSearch.tokenizer(text, return_tensors="pt").to(
74
+ HybridVectorSearch.cuda_device
75
+ )
76
+ split_texts = []
77
+ if len(tokens["input_ids"][0]) >= 512:
78
+ summary = HybridVectorSearch.summary(text)
79
+ split_texts = HybridVectorSearch.text_splitter.split_text(text)
80
+ tokens = HybridVectorSearch.tokenizer(summary, return_tensors="pt").to(
81
+ HybridVectorSearch.cuda_device
82
+ )
83
+
84
+ output = HybridVectorSearch.model(**tokens)
85
+ logits, attention_mask = output.logits, tokens.attention_mask
86
+ relu_log = torch.log(1 + torch.relu(logits))
87
+ weighted_log = relu_log * attention_mask.unsqueeze(-1)
88
+ max_val, _ = torch.max(weighted_log, dim=1)
89
+ vec = max_val.squeeze()
90
+
91
+ return vec, tokens, split_texts
92
+
93
+ @staticmethod
94
+ def search(query_text: str, symbol="AMD"):
95
+ vectors, tokens, split_texts = HybridVectorSearch.compute_vector(query_text)
96
+ indices = vectors.cpu().nonzero().numpy().flatten()
97
+ values = vectors.cpu().detach().numpy()[indices]
98
+
99
+ sparse_query_vector = models.SparseVector(indices=indices, values=values)
100
+
101
+ query_vector = HybridVectorSearch.dense_encoder.encode(query_text).tolist()
102
+ limit = 3
103
+
104
+ dense_request = models.SearchRequest(
105
+ vector=models.NamedVector(name="dense_vector", vector=query_vector),
106
+ limit=limit,
107
+ with_payload=True,
108
+ )
109
+ sparse_request = models.SearchRequest(
110
+ vector=models.NamedSparseVector(
111
+ name="sparse_vector", vector=sparse_query_vector
112
+ ),
113
+ limit=limit,
114
+ with_payload=True,
115
+ )
116
+
117
+ (dense_request_response, sparse_request_response) = (
118
+ HybridVectorSearch.client.search_batch(
119
+ collection_name=HybridVectorSearch.earnings_collection,
120
+ requests=[dense_request, sparse_request],
121
+ )
122
+ )
123
+ ranked_search_response = HybridVectorSearch.reciprocal_rank_fusion(
124
+ [dense_request_response, sparse_request_response], limit=10
125
+ )
126
+
127
+ search_response = ""
128
+ for search_result in ranked_search_response:
129
+ search_response += search_result.payload["conversation"] + "\n"
130
+ return ranked_search_response
131
+
132
+ @staticmethod
133
+ def chat_search(query: str, chat_history):
134
+ result = HybridVectorSearch.search(query)
135
+ chat_history.append((query, "Search Results"))
136
+ for search_result in result[:3]:
137
+ text = search_result.payload["conversation"]
138
+ summary = HybridVectorSearch.summary(text) + f'\n```\n{text} \n```'
139
+ chat_history.append((None, summary))
140
+ return "", chat_history
whisper_asr.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import time
3
+ import whisperx
4
+ from scipy.signal import resample
5
+ import time
6
+ import os
7
+ import time
8
+
9
+ class WhisperAutomaticSpeechRecognizer:
10
+ device = "cuda"
11
+ compute_type = "int8" # change to if more gpu memory available
12
+ batch_size = 4
13
+ model = whisperx.load_model(
14
+ "medium", device, language="en", compute_type=compute_type
15
+ ,
16
+ asr_options={
17
+ "max_new_tokens": 448,
18
+ "clip_timestamps": True,
19
+ "hallucination_silence_threshold": 0.2,
20
+ }
21
+ #,
22
+ # ๆทปๅŠ ็ผบๅคฑ็š„ๅ‚ๆ•ฐ
23
+ #max_new_tokens=448, # ๅฏๆ นๆฎๅฎž้™…ๆƒ…ๅ†ต่ฐƒๆ•ด
24
+ #clip_timestamps=True,
25
+ #hallucination_silence_threshold=0.2 # ๅฏๆ นๆฎๅฎž้™…ๆƒ…ๅ†ต่ฐƒๆ•ด
26
+ )
27
+ diarize_model = whisperx.DiarizationPipeline(
28
+ use_auth_token=os.environ.get('HF_TOKEN'), device="cuda"
29
+ )
30
+ existing_speaker = None
31
+
32
+ @staticmethod
33
+ def downsample_audio_scipy(audio: np.ndarray, original_rate, target_rate=16000):
34
+ if original_rate == target_rate:
35
+ return audio
36
+
37
+ if audio.ndim > 1:
38
+ audio = np.mean(audio, axis=1)
39
+
40
+ # Check if audio has one channel
41
+ if len(audio.shape) != 1:
42
+ raise ValueError("Input audio must have only one channel.")
43
+
44
+ # Calculate the number of samples in the downsampled audio
45
+ num_samples = int(len(audio) * target_rate / original_rate)
46
+ downsampled_audio = resample(audio, num_samples)
47
+
48
+ return downsampled_audio
49
+
50
+ @staticmethod
51
+ def transcribe_with_diarization_file(filepath: str):
52
+ audio = whisperx.load_audio(filepath, 16000)
53
+ return WhisperAutomaticSpeechRecognizer.transcribe_with_diarization(
54
+ (16000, audio), None, "", False
55
+ )
56
+
57
+ @staticmethod
58
+ def transcribe_with_diarization(
59
+ stream, full_stream, full_transcript, streaming=True
60
+ ):
61
+ start_time = time.time()
62
+ sr, y = stream
63
+ if streaming:
64
+ sr, y = stream
65
+ y = WhisperAutomaticSpeechRecognizer.downsample_audio_scipy(y, sr)
66
+ y = y.astype(np.float32)
67
+ y /= 32768.0
68
+
69
+ if full_transcript is None:
70
+ full_transcript = ""
71
+ transcribe_result = WhisperAutomaticSpeechRecognizer.model.transcribe(
72
+ y, batch_size=WhisperAutomaticSpeechRecognizer.batch_size
73
+ )
74
+ diarize_segments = WhisperAutomaticSpeechRecognizer.diarize_model(y)
75
+
76
+ diarize_result = whisperx.assign_word_speakers(
77
+ diarize_segments, transcribe_result
78
+ )
79
+
80
+ new_transcript = ""
81
+ for segment in diarize_result["segments"]:
82
+ current_speaker = ""
83
+ default_first_speaker = "SPEAKER_00"
84
+ try:
85
+ current_speaker = segment["speaker"]
86
+ except KeyError:
87
+ current_speaker = default_first_speaker
88
+ if WhisperAutomaticSpeechRecognizer.existing_speaker == None:
89
+ try:
90
+ WhisperAutomaticSpeechRecognizer.existing_speaker = current_speaker
91
+ except KeyError:
92
+ WhisperAutomaticSpeechRecognizer.existing_speaker = default_first_speaker
93
+ new_transcript += f"\n {WhisperAutomaticSpeechRecognizer.existing_speaker} - "
94
+ if current_speaker != WhisperAutomaticSpeechRecognizer.existing_speaker and current_speaker is not default_first_speaker:
95
+ WhisperAutomaticSpeechRecognizer.existing_speaker = current_speaker
96
+ new_transcript += f"\n {WhisperAutomaticSpeechRecognizer.existing_speaker} - "
97
+ new_transcript = new_transcript + segment["text"]
98
+ full_transcript = full_transcript + new_transcript
99
+ end_time = time.time()
100
+ if streaming:
101
+ time.sleep(5 - (end_time - start_time))
102
+ return full_transcript, stream, full_transcript