Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- vector_search.py +140 -0
- whisper_asr.py +102 -0
vector_search.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
3 |
+
from langchain_text_splitters import SpacyTextSplitter
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
from typing import Dict, List
|
6 |
+
import torch
|
7 |
+
from qdrant_client import http, models, QdrantClient
|
8 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
9 |
+
|
10 |
+
|
11 |
+
class HybridVectorSearch:
|
12 |
+
cuda_device = torch.device("cpu")
|
13 |
+
sparse_model = "naver/splade-v3"
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained(sparse_model)
|
15 |
+
model = AutoModelForMaskedLM.from_pretrained(sparse_model).to(cuda_device)
|
16 |
+
|
17 |
+
text_splitter = SpacyTextSplitter(chunk_size=1000)
|
18 |
+
dense_encoder = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
|
19 |
+
|
20 |
+
model_name_t5 = "Falconsai/text_summarization" # "t5-small"
|
21 |
+
tokenizer_t5 = T5Tokenizer.from_pretrained(model_name_t5)
|
22 |
+
model_t5 = T5ForConditionalGeneration.from_pretrained(model_name_t5).to("cuda")
|
23 |
+
|
24 |
+
client = QdrantClient(url="http://localhost:6333")
|
25 |
+
earnings_collection = "earnings_calls"
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def reciprocal_rank_fusion(
|
29 |
+
responses: List[List[http.models.ScoredPoint]], limit: int = 10
|
30 |
+
) -> List[http.models.ScoredPoint]:
|
31 |
+
def compute_score(pos: int) -> float:
|
32 |
+
ranking_constant = 2 # the constant mitigates the impact of high rankings by outlier systems
|
33 |
+
return 1 / (ranking_constant + pos)
|
34 |
+
|
35 |
+
scores: Dict[http.models.ExtendedPointId, float] = {}
|
36 |
+
point_pile = {}
|
37 |
+
for response in responses:
|
38 |
+
for i, scored_point in enumerate(response):
|
39 |
+
if scored_point.id in scores:
|
40 |
+
scores[scored_point.id] += compute_score(i)
|
41 |
+
else:
|
42 |
+
point_pile[scored_point.id] = scored_point
|
43 |
+
scores[scored_point.id] = compute_score(i)
|
44 |
+
|
45 |
+
sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
|
46 |
+
sorted_points = []
|
47 |
+
for point_id, score in sorted_scores[:limit]:
|
48 |
+
point = point_pile[point_id]
|
49 |
+
point.score = score
|
50 |
+
sorted_points.append(point)
|
51 |
+
return sorted_points
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def summary(text: str):
|
55 |
+
inputs = HybridVectorSearch.tokenizer_t5.encode(
|
56 |
+
f"summarize: {text}", return_tensors="pt", max_length=1024, truncation=True
|
57 |
+
).to("cuda")
|
58 |
+
summary_ids = HybridVectorSearch.model_t5.generate(
|
59 |
+
inputs,
|
60 |
+
max_length=512,
|
61 |
+
min_length=100,
|
62 |
+
length_penalty=2.0,
|
63 |
+
num_beams=4,
|
64 |
+
early_stopping=True,
|
65 |
+
)
|
66 |
+
summary = HybridVectorSearch.tokenizer_t5.decode(
|
67 |
+
summary_ids[0], skip_special_tokens=True
|
68 |
+
)
|
69 |
+
return summary
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def compute_vector(text):
|
73 |
+
tokens = HybridVectorSearch.tokenizer(text, return_tensors="pt").to(
|
74 |
+
HybridVectorSearch.cuda_device
|
75 |
+
)
|
76 |
+
split_texts = []
|
77 |
+
if len(tokens["input_ids"][0]) >= 512:
|
78 |
+
summary = HybridVectorSearch.summary(text)
|
79 |
+
split_texts = HybridVectorSearch.text_splitter.split_text(text)
|
80 |
+
tokens = HybridVectorSearch.tokenizer(summary, return_tensors="pt").to(
|
81 |
+
HybridVectorSearch.cuda_device
|
82 |
+
)
|
83 |
+
|
84 |
+
output = HybridVectorSearch.model(**tokens)
|
85 |
+
logits, attention_mask = output.logits, tokens.attention_mask
|
86 |
+
relu_log = torch.log(1 + torch.relu(logits))
|
87 |
+
weighted_log = relu_log * attention_mask.unsqueeze(-1)
|
88 |
+
max_val, _ = torch.max(weighted_log, dim=1)
|
89 |
+
vec = max_val.squeeze()
|
90 |
+
|
91 |
+
return vec, tokens, split_texts
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def search(query_text: str, symbol="AMD"):
|
95 |
+
vectors, tokens, split_texts = HybridVectorSearch.compute_vector(query_text)
|
96 |
+
indices = vectors.cpu().nonzero().numpy().flatten()
|
97 |
+
values = vectors.cpu().detach().numpy()[indices]
|
98 |
+
|
99 |
+
sparse_query_vector = models.SparseVector(indices=indices, values=values)
|
100 |
+
|
101 |
+
query_vector = HybridVectorSearch.dense_encoder.encode(query_text).tolist()
|
102 |
+
limit = 3
|
103 |
+
|
104 |
+
dense_request = models.SearchRequest(
|
105 |
+
vector=models.NamedVector(name="dense_vector", vector=query_vector),
|
106 |
+
limit=limit,
|
107 |
+
with_payload=True,
|
108 |
+
)
|
109 |
+
sparse_request = models.SearchRequest(
|
110 |
+
vector=models.NamedSparseVector(
|
111 |
+
name="sparse_vector", vector=sparse_query_vector
|
112 |
+
),
|
113 |
+
limit=limit,
|
114 |
+
with_payload=True,
|
115 |
+
)
|
116 |
+
|
117 |
+
(dense_request_response, sparse_request_response) = (
|
118 |
+
HybridVectorSearch.client.search_batch(
|
119 |
+
collection_name=HybridVectorSearch.earnings_collection,
|
120 |
+
requests=[dense_request, sparse_request],
|
121 |
+
)
|
122 |
+
)
|
123 |
+
ranked_search_response = HybridVectorSearch.reciprocal_rank_fusion(
|
124 |
+
[dense_request_response, sparse_request_response], limit=10
|
125 |
+
)
|
126 |
+
|
127 |
+
search_response = ""
|
128 |
+
for search_result in ranked_search_response:
|
129 |
+
search_response += search_result.payload["conversation"] + "\n"
|
130 |
+
return ranked_search_response
|
131 |
+
|
132 |
+
@staticmethod
|
133 |
+
def chat_search(query: str, chat_history):
|
134 |
+
result = HybridVectorSearch.search(query)
|
135 |
+
chat_history.append((query, "Search Results"))
|
136 |
+
for search_result in result[:3]:
|
137 |
+
text = search_result.payload["conversation"]
|
138 |
+
summary = HybridVectorSearch.summary(text) + f'\n```\n{text} \n```'
|
139 |
+
chat_history.append((None, summary))
|
140 |
+
return "", chat_history
|
whisper_asr.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import time
|
3 |
+
import whisperx
|
4 |
+
from scipy.signal import resample
|
5 |
+
import time
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
|
9 |
+
class WhisperAutomaticSpeechRecognizer:
|
10 |
+
device = "cuda"
|
11 |
+
compute_type = "int8" # change to if more gpu memory available
|
12 |
+
batch_size = 4
|
13 |
+
model = whisperx.load_model(
|
14 |
+
"medium", device, language="en", compute_type=compute_type
|
15 |
+
,
|
16 |
+
asr_options={
|
17 |
+
"max_new_tokens": 448,
|
18 |
+
"clip_timestamps": True,
|
19 |
+
"hallucination_silence_threshold": 0.2,
|
20 |
+
}
|
21 |
+
#,
|
22 |
+
# ๆทปๅ ็ผบๅคฑ็ๅๆฐ
|
23 |
+
#max_new_tokens=448, # ๅฏๆ นๆฎๅฎ้
ๆ
ๅต่ฐๆด
|
24 |
+
#clip_timestamps=True,
|
25 |
+
#hallucination_silence_threshold=0.2 # ๅฏๆ นๆฎๅฎ้
ๆ
ๅต่ฐๆด
|
26 |
+
)
|
27 |
+
diarize_model = whisperx.DiarizationPipeline(
|
28 |
+
use_auth_token=os.environ.get('HF_TOKEN'), device="cuda"
|
29 |
+
)
|
30 |
+
existing_speaker = None
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def downsample_audio_scipy(audio: np.ndarray, original_rate, target_rate=16000):
|
34 |
+
if original_rate == target_rate:
|
35 |
+
return audio
|
36 |
+
|
37 |
+
if audio.ndim > 1:
|
38 |
+
audio = np.mean(audio, axis=1)
|
39 |
+
|
40 |
+
# Check if audio has one channel
|
41 |
+
if len(audio.shape) != 1:
|
42 |
+
raise ValueError("Input audio must have only one channel.")
|
43 |
+
|
44 |
+
# Calculate the number of samples in the downsampled audio
|
45 |
+
num_samples = int(len(audio) * target_rate / original_rate)
|
46 |
+
downsampled_audio = resample(audio, num_samples)
|
47 |
+
|
48 |
+
return downsampled_audio
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def transcribe_with_diarization_file(filepath: str):
|
52 |
+
audio = whisperx.load_audio(filepath, 16000)
|
53 |
+
return WhisperAutomaticSpeechRecognizer.transcribe_with_diarization(
|
54 |
+
(16000, audio), None, "", False
|
55 |
+
)
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def transcribe_with_diarization(
|
59 |
+
stream, full_stream, full_transcript, streaming=True
|
60 |
+
):
|
61 |
+
start_time = time.time()
|
62 |
+
sr, y = stream
|
63 |
+
if streaming:
|
64 |
+
sr, y = stream
|
65 |
+
y = WhisperAutomaticSpeechRecognizer.downsample_audio_scipy(y, sr)
|
66 |
+
y = y.astype(np.float32)
|
67 |
+
y /= 32768.0
|
68 |
+
|
69 |
+
if full_transcript is None:
|
70 |
+
full_transcript = ""
|
71 |
+
transcribe_result = WhisperAutomaticSpeechRecognizer.model.transcribe(
|
72 |
+
y, batch_size=WhisperAutomaticSpeechRecognizer.batch_size
|
73 |
+
)
|
74 |
+
diarize_segments = WhisperAutomaticSpeechRecognizer.diarize_model(y)
|
75 |
+
|
76 |
+
diarize_result = whisperx.assign_word_speakers(
|
77 |
+
diarize_segments, transcribe_result
|
78 |
+
)
|
79 |
+
|
80 |
+
new_transcript = ""
|
81 |
+
for segment in diarize_result["segments"]:
|
82 |
+
current_speaker = ""
|
83 |
+
default_first_speaker = "SPEAKER_00"
|
84 |
+
try:
|
85 |
+
current_speaker = segment["speaker"]
|
86 |
+
except KeyError:
|
87 |
+
current_speaker = default_first_speaker
|
88 |
+
if WhisperAutomaticSpeechRecognizer.existing_speaker == None:
|
89 |
+
try:
|
90 |
+
WhisperAutomaticSpeechRecognizer.existing_speaker = current_speaker
|
91 |
+
except KeyError:
|
92 |
+
WhisperAutomaticSpeechRecognizer.existing_speaker = default_first_speaker
|
93 |
+
new_transcript += f"\n {WhisperAutomaticSpeechRecognizer.existing_speaker} - "
|
94 |
+
if current_speaker != WhisperAutomaticSpeechRecognizer.existing_speaker and current_speaker is not default_first_speaker:
|
95 |
+
WhisperAutomaticSpeechRecognizer.existing_speaker = current_speaker
|
96 |
+
new_transcript += f"\n {WhisperAutomaticSpeechRecognizer.existing_speaker} - "
|
97 |
+
new_transcript = new_transcript + segment["text"]
|
98 |
+
full_transcript = full_transcript + new_transcript
|
99 |
+
end_time = time.time()
|
100 |
+
if streaming:
|
101 |
+
time.sleep(5 - (end_time - start_time))
|
102 |
+
return full_transcript, stream, full_transcript
|