Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,15 +3,19 @@ import torch
|
|
3 |
from tqdm import tqdm
|
4 |
import transformers
|
5 |
|
|
|
|
|
|
|
|
|
6 |
# Load the model pipeline
|
7 |
pipe = transformers.pipeline(
|
8 |
model='sarvamai/shuka_v1',
|
9 |
trust_remote_code=True,
|
10 |
-
device=
|
11 |
-
torch_dtype=torch.
|
12 |
)
|
13 |
|
14 |
-
def process_audio_batched(audio_file, system_prompt, user_prompt, batch_size=
|
15 |
# Load audio
|
16 |
audio, sr = librosa.load(audio_file, sr=16000)
|
17 |
|
@@ -32,12 +36,14 @@ def process_audio_batched(audio_file, system_prompt, user_prompt, batch_size=2,
|
|
32 |
{'role': 'user', 'content': f'<|audio|>{user_prompt}'}
|
33 |
]
|
34 |
|
35 |
-
|
|
|
|
|
|
|
36 |
full_result.extend([result[0]['generated_text'] for result in batch_results])
|
37 |
|
38 |
-
# Clear GPU memory
|
39 |
-
|
40 |
-
torch.cuda.empty_cache()
|
41 |
|
42 |
# Combine results
|
43 |
return ' '.join(full_result)
|
|
|
3 |
from tqdm import tqdm
|
4 |
import transformers
|
5 |
|
6 |
+
# Check for GPU availability
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
print(f"Using device: {device}")
|
9 |
+
|
10 |
# Load the model pipeline
|
11 |
pipe = transformers.pipeline(
|
12 |
model='sarvamai/shuka_v1',
|
13 |
trust_remote_code=True,
|
14 |
+
device=device,
|
15 |
+
torch_dtype=torch.float16 if device.type == 'cuda' else torch.float32
|
16 |
)
|
17 |
|
18 |
+
def process_audio_batched(audio_file, system_prompt, user_prompt, batch_size=4, segment_length=10):
|
19 |
# Load audio
|
20 |
audio, sr = librosa.load(audio_file, sr=16000)
|
21 |
|
|
|
36 |
{'role': 'user', 'content': f'<|audio|>{user_prompt}'}
|
37 |
]
|
38 |
|
39 |
+
# Move batch to GPU if available
|
40 |
+
batch_gpu = [torch.tensor(seg, device=device) for seg in batch]
|
41 |
+
|
42 |
+
batch_results = pipe([{'audio': seg, 'turns': turns, 'sampling_rate': sr} for seg in batch_gpu], max_new_tokens=512)
|
43 |
full_result.extend([result[0]['generated_text'] for result in batch_results])
|
44 |
|
45 |
+
# Clear GPU memory
|
46 |
+
torch.cuda.empty_cache()
|
|
|
47 |
|
48 |
# Combine results
|
49 |
return ' '.join(full_result)
|