File size: 6,242 Bytes
0671449
e7ccf3d
 
36133ea
0671449
 
 
 
 
88efd72
77e4b34
0671449
 
1b9b43e
0671449
1184919
b5675ba
1184919
0671449
 
 
 
 
 
 
 
 
77e4b34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0671449
 
77e4b34
 
 
 
 
 
0671449
 
77e4b34
 
 
 
 
0671449
 
 
a2b3c02
 
0671449
 
 
 
 
 
 
 
 
 
 
34a23c4
fbaa611
34a23c4
 
 
 
 
 
0671449
 
 
 
fbaa611
 
 
 
 
0671449
fbaa611
 
d6b85c0
0671449
fbaa611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0671449
 
26ee315
0671449
88efd72
 
 
 
 
 
 
 
 
 
 
 
 
 
0671449
 
 
 
 
 
15063a9
0671449
 
36133ea
 
0671449
36133ea
0671449
 
36133ea
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
# Disable hf_transfer for safer downloading
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
import gradio as gr
import requests
from sentence_transformers import SentenceTransformer, util
import torch
import json
import urllib.parse
import soundfile as sf
import time 

# Fetch Hugging Face API Token securely from environment variables
HF_API_TOKEN = os.getenv("HF")  # This fetches the token securely

# Updated model URLs for Whisper and LLaMA
WHISPER_API_URL = "https://api-inference.huggingface.co/models/openai/whisper-small"
LLAMA_API_URL = "https://api-inference.huggingface.co/models/abhinand/tamil-llama-7b-instruct-v0.2"

# Load SentenceTransformer model for retrieval
retriever_model = SentenceTransformer("distiluse-base-multilingual-cased-v2")

# Load dataset
with open("qa_dataset.json", "r", encoding="utf-8") as f:
    qa_data = json.load(f)

# Function to transcribe audio using Whisper
def wait_for_model_ready(model_url, headers, timeout=300):
    start_time = time.time()
    while time.time() - start_time < timeout:
        # Send a "dummy" GET request to check status
        response = requests.get(model_url, headers=headers)
        result = response.json()

        if not ("error" in result and "loading" in result["error"].lower()):
            print("✅ Model is ready!")
            return True

        print("⏳ Model is still loading, waiting 10 seconds...")
        time.sleep(10)

    print("❌ Model did not become ready in time.")
    return False  # timeout

def transcribe_audio(audio_file):
    headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}

    # Wait for Whisper model to be ready
    if not wait_for_model_ready(WHISPER_API_URL, headers):
        return "Error: Whisper model did not load in time. Please try again later."

    # Now send the audio after model is ready
    with open(audio_file, "rb") as f:
        response = requests.post(WHISPER_API_URL, headers=headers, data=f)

    result = response.json()
    print(result)  # log response

    return result.get("text", "Error: No transcription text returned.")

# Function to generate TTS audio URL (Google Translate API for Tamil Voice)
def get_tts_audio_url(text, lang="ta"):
    # URL encode the text to ensure special characters are handled
    safe_text = urllib.parse.quote(text)
    return f"https://translate.google.com/translate_tts?ie=UTF-8&q={safe_text}&tl={lang}&client=tw-ob"

# Function to retrieve a relevant response from the Q&A dataset using SentenceTransformer
def get_bot_response(query):
    query_embedding = retriever_model.encode(query, convert_to_tensor=True)
    qa_embeddings = retriever_model.encode([qa["question"] for qa in qa_data], convert_to_tensor=True)

    scores = util.pytorch_cos_sim(query_embedding, qa_embeddings)
    best_idx = torch.argmax(scores)

    top_qa = qa_data[best_idx]

    prompt = f"""நீ ஒரு அறிவாளியான தமிழ் உதவியாளர்.
    தகவல்கள்:
    கேள்வி: {top_qa['question']}
    பதில்: {top_qa['answer']}
    மேலே உள்ள தகவல்களைப் பயன்படுத்தி, தெளிவான மற்றும் சுருக்கமான பதிலை வழங்கவும்.
    உயர்கட்ட கேள்வி: {query}
    பதில்:"""

    headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
    payload = {
        "inputs": prompt,
        "parameters": {
            "temperature": 0.7,
            "max_new_tokens": 150,
            "return_full_text": False
        },
    }

    # Post request
    response = requests.post(LLAMA_API_URL, headers=headers, json=payload, timeout=300)

    # Sometimes inference is slow ➔ Wait for result
    start_time = time.time()
    max_wait_seconds = 180  # 💬 wait up to 3 minutes if necessary
    while True:
        try:
            result = response.json()

            if isinstance(result, list) and "generated_text" in result[0]:
                return result[0]["generated_text"]
            elif "error" in result and "loading" in result["error"].lower():
                print("⏳ Model is loading, waiting 10 seconds...")
                time.sleep(10)
            else:
                return "மன்னிக்கவும், நான் இந்த கேள்விக்கு பதில் தர முடியவில்லை."

        except Exception as e:
            if time.time() - start_time > max_wait_seconds:
                return f"Error: Timeout while waiting for model prediction after {max_wait_seconds} seconds."
            print(f"Waiting for model to respond... {str(e)}")
            time.sleep(5)  # wait 5 seconds before retry

# Gradio interface function
def chatbot(audio, message, system_message, max_tokens, temperature, top_p):
    if audio is not None:
        sample_rate, audio_data = audio  # ✅ Correct order
        sf.write("temp.wav", audio_data, sample_rate)  # Save audio
        try:
            transcript = transcribe_audio("temp.wav")
            message = transcript  # Use transcribed text
        except Exception as e:
            return f"Audio transcription failed: {str(e)}", None

    try:
        response = get_bot_response(message)
        audio_url = get_tts_audio_url(response)
        return response, audio_url
    except Exception as e:
        return f"Error in generating response: {str(e)}", None


# Define Gradio interface
demo = gr.Interface(
    fn=chatbot,
    inputs=[
        gr.Audio(type="numpy", label="Speak to the Bot"),  # Adjusted for microphone input
        gr.Textbox(value="How can I help you?", label="Text Input (optional)"),
        gr.Textbox(value="You are a friendly chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
    outputs=[gr.Textbox(label="Response"), gr.Audio(label="Bot's Voice Response (Tamil)")],
    live=True,
)

if __name__ == "__main__":
    demo.launch()