File size: 5,140 Bytes
0671449
e7ccf3d
 
36133ea
0671449
 
 
 
 
88efd72
77e4b34
0671449
 
1b9b43e
0671449
1184919
b5675ba
1184919
0671449
 
 
 
 
 
 
 
 
77e4b34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0671449
 
77e4b34
 
 
 
 
 
0671449
 
77e4b34
 
 
 
 
0671449
 
 
a2b3c02
 
0671449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26ee315
0671449
88efd72
 
 
 
 
 
 
 
 
 
 
 
 
 
0671449
 
 
 
 
 
15063a9
0671449
 
36133ea
 
0671449
36133ea
0671449
 
36133ea
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
# Disable hf_transfer for safer downloading
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
import gradio as gr
import requests
from sentence_transformers import SentenceTransformer, util
import torch
import json
import urllib.parse
import soundfile as sf
import time 

# Fetch Hugging Face API Token securely from environment variables
HF_API_TOKEN = os.getenv("HF")  # This fetches the token securely

# Updated model URLs for Whisper and LLaMA
WHISPER_API_URL = "https://api-inference.huggingface.co/models/openai/whisper-small"
LLAMA_API_URL = "https://api-inference.huggingface.co/models/abhinand/tamil-llama-7b-instruct-v0.2"

# Load SentenceTransformer model for retrieval
retriever_model = SentenceTransformer("distiluse-base-multilingual-cased-v2")

# Load dataset
with open("qa_dataset.json", "r", encoding="utf-8") as f:
    qa_data = json.load(f)

# Function to transcribe audio using Whisper
def wait_for_model_ready(model_url, headers, timeout=300):
    start_time = time.time()
    while time.time() - start_time < timeout:
        # Send a "dummy" GET request to check status
        response = requests.get(model_url, headers=headers)
        result = response.json()

        if not ("error" in result and "loading" in result["error"].lower()):
            print("✅ Model is ready!")
            return True

        print("⏳ Model is still loading, waiting 10 seconds...")
        time.sleep(10)

    print("❌ Model did not become ready in time.")
    return False  # timeout

def transcribe_audio(audio_file):
    headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}

    # Wait for Whisper model to be ready
    if not wait_for_model_ready(WHISPER_API_URL, headers):
        return "Error: Whisper model did not load in time. Please try again later."

    # Now send the audio after model is ready
    with open(audio_file, "rb") as f:
        response = requests.post(WHISPER_API_URL, headers=headers, data=f)

    result = response.json()
    print(result)  # log response

    return result.get("text", "Error: No transcription text returned.")

# Function to generate TTS audio URL (Google Translate API for Tamil Voice)
def get_tts_audio_url(text, lang="ta"):
    # URL encode the text to ensure special characters are handled
    safe_text = urllib.parse.quote(text)
    return f"https://translate.google.com/translate_tts?ie=UTF-8&q={safe_text}&tl={lang}&client=tw-ob"

# Function to retrieve a relevant response from the Q&A dataset using SentenceTransformer
def get_bot_response(query):
    query_embedding = retriever_model.encode(query, convert_to_tensor=True)
    qa_embeddings = retriever_model.encode([qa["question"] for qa in qa_data], convert_to_tensor=True)

    scores = util.pytorch_cos_sim(query_embedding, qa_embeddings)
    best_idx = torch.argmax(scores)

    top_qa = qa_data[best_idx]
    prompt = f"User asked: {query}\nRelevant FAQ: {top_qa['question']}\nAnswer: {top_qa['answer']}\nNow generate a helpful and fluent Tamil response to the user query."

    # Use LLaMA for generating the refined response
    headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
    payload = {
        "inputs": prompt,
        "parameters": {"temperature": 0.7, "max_new_tokens": 150, "return_full_text": False},
    }
    response = requests.post(LLAMA_API_URL, headers=headers, json=payload)
    result = response.json()

    if isinstance(result, list) and "generated_text" in result[0]:
        return result[0]["generated_text"]
    else:
        return "மன்னிக்கவும், நான் இந்த கேள்விக்கு பதில் தர முடியவில்லை."

# Gradio interface function
def chatbot(audio, message, system_message, max_tokens, temperature, top_p):
    if audio is not None:
        sample_rate, audio_data = audio  # ✅ Correct order
        sf.write("temp.wav", audio_data, sample_rate)  # Save audio
        try:
            transcript = transcribe_audio("temp.wav")
            message = transcript  # Use transcribed text
        except Exception as e:
            return f"Audio transcription failed: {str(e)}", None

    try:
        response = get_bot_response(message)
        audio_url = get_tts_audio_url(response)
        return response, audio_url
    except Exception as e:
        return f"Error in generating response: {str(e)}", None


# Define Gradio interface
demo = gr.Interface(
    fn=chatbot,
    inputs=[
        gr.Audio(type="numpy", label="Speak to the Bot"),  # Adjusted for microphone input
        gr.Textbox(value="How can I help you?", label="Text Input (optional)"),
        gr.Textbox(value="You are a friendly chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
    outputs=[gr.Textbox(label="Response"), gr.Audio(label="Bot's Voice Response (Tamil)")],
    live=True,
)

if __name__ == "__main__":
    demo.launch()