File size: 2,657 Bytes
98333ca
e2f65f6
 
 
86fab4a
98333ca
e2f65f6
 
 
 
 
 
 
98333ca
e2f65f6
 
 
 
 
fbc6758
86fab4a
fbc6758
400fc00
fbc6758
 
 
 
 
 
 
400fc00
 
 
 
 
 
 
 
86fab4a
 
 
 
e2f65f6
fbc6758
 
 
 
 
 
86fab4a
e2f65f6
 
 
 
 
fbc6758
 
 
 
 
86fab4a
e2f65f6
 
 
 
fbc6758
e2f65f6
 
fbc6758
e2f65f6
 
400fc00
e2f65f6
 
fbc6758
e2f65f6
 
 
400fc00
86fab4a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import transformers
import librosa
import torch
import numpy as np

# Load the Shuka model pipeline.
pipe = transformers.pipeline(
    model="sarvamai/shuka_v1",
    trust_remote_code=True,
    device=0 if torch.cuda.is_available() else -1,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None
)

def process_audio(audio):
    """
    Processes the input audio and returns a text response generated by the Shuka model.
    """
    if audio is None:
        return "No audio provided. Please upload or record an audio file."
    
    try:
        # Gradio returns a tuple: (sample_rate, audio_data)
        sample_rate, audio_data = audio
    except Exception as e:
        return f"Error processing audio input: {e}"
    
    if audio_data is None or len(audio_data) == 0:
        return "Audio data is empty. Please try again with a valid audio file."
    
    # Ensure audio_data is a numpy array.
    audio_data = np.asarray(audio_data)
    
    # If audio data is multi-dimensional, squeeze to 1D.
    if audio_data.ndim > 1:
        audio_data = np.squeeze(audio_data)
    
    # Convert audio data to floating-point if it's not already.
    if not np.issubdtype(audio_data.dtype, np.floating):
        audio_data = audio_data.astype(np.float32)
    
    # Resample to 16000 Hz if necessary.
    if sample_rate != 16000:
        try:
            audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
            sample_rate = 16000
        except Exception as e:
            return f"Error during resampling: {e}"
    
    # Define conversation turns for the model.
    turns = [
        {'role': 'system', 'content': 'Respond naturally and informatively.'},
        {'role': 'user', 'content': '<|audio|>'}
    ]
    
    try:
        result = pipe({'audio': audio_data, 'turns': turns, 'sampling_rate': sample_rate}, max_new_tokens=512)
    except Exception as e:
        return f"Error during model processing: {e}"
    
    # Extract the generated text response.
    if isinstance(result, list) and len(result) > 0:
        response = result[0].get('generated_text', '')
    else:
        response = str(result)
    
    return response

# Create the Gradio interface.
iface = gr.Interface(
    fn=process_audio,
    inputs=gr.Audio(type="numpy"),  # Using file upload for audio input.
    outputs="text",
    title="Sarvam AI Shuka Voice Demo",
    description="Upload an audio file and get a response using Sarvam AI's Shuka model."
)

if __name__ == "__main__":
    # Set share=True to create a public link and use a non-default port.
    iface.launch(share=True, server_port=7861)