sagar007 commited on
Commit
68ce9a1
·
verified ·
1 Parent(s): 80d0076

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -3,15 +3,19 @@ import torch
3
  from tqdm import tqdm
4
  import transformers
5
 
 
 
 
 
6
  # Load the model pipeline
7
  pipe = transformers.pipeline(
8
  model='sarvamai/shuka_v1',
9
  trust_remote_code=True,
10
- device='cpu', # Force CPU usage
11
- torch_dtype=torch.float32 # Use float32 instead of bfloat16
12
  )
13
 
14
- def process_audio_batched(audio_file, system_prompt, user_prompt, batch_size=2, segment_length=5): # Reduced batch_size and segment_length
15
  # Load audio
16
  audio, sr = librosa.load(audio_file, sr=16000)
17
 
@@ -32,12 +36,14 @@ def process_audio_batched(audio_file, system_prompt, user_prompt, batch_size=2,
32
  {'role': 'user', 'content': f'<|audio|>{user_prompt}'}
33
  ]
34
 
35
- batch_results = pipe([{'audio': seg, 'turns': turns, 'sampling_rate': sr} for seg in batch], max_new_tokens=512)
 
 
 
36
  full_result.extend([result[0]['generated_text'] for result in batch_results])
37
 
38
- # Clear GPU memory if using GPU
39
- if torch.cuda.is_available():
40
- torch.cuda.empty_cache()
41
 
42
  # Combine results
43
  return ' '.join(full_result)
 
3
  from tqdm import tqdm
4
  import transformers
5
 
6
+ # Check for GPU availability
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ print(f"Using device: {device}")
9
+
10
  # Load the model pipeline
11
  pipe = transformers.pipeline(
12
  model='sarvamai/shuka_v1',
13
  trust_remote_code=True,
14
+ device=device,
15
+ torch_dtype=torch.float16 if device.type == 'cuda' else torch.float32
16
  )
17
 
18
+ def process_audio_batched(audio_file, system_prompt, user_prompt, batch_size=4, segment_length=10):
19
  # Load audio
20
  audio, sr = librosa.load(audio_file, sr=16000)
21
 
 
36
  {'role': 'user', 'content': f'<|audio|>{user_prompt}'}
37
  ]
38
 
39
+ # Move batch to GPU if available
40
+ batch_gpu = [torch.tensor(seg, device=device) for seg in batch]
41
+
42
+ batch_results = pipe([{'audio': seg, 'turns': turns, 'sampling_rate': sr} for seg in batch_gpu], max_new_tokens=512)
43
  full_result.extend([result[0]['generated_text'] for result in batch_results])
44
 
45
+ # Clear GPU memory
46
+ torch.cuda.empty_cache()
 
47
 
48
  # Combine results
49
  return ' '.join(full_result)