dalybuilds's picture
Update app.py
8d9e499 verified
raw
history blame
2.45 kB
import os
import gradio as gr
import requests
import pandas as pd
from io import BytesIO
# --- LangChain & Groq Imports ---
from groq import Groq
from langchain_groq import ChatGroq
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.prompts import ChatPromptTemplate
from langchain.tools import Tool
# --- Constants ---
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
# --- Custom Tool Definition using Groq ---
def transcribe_audio_from_task_id(task_id: str) -> str:
"""
Downloads an audio file for a given task_id from the scoring server,
transcribes it using the GROQ API with Whisper, and returns the text.
Use this tool ONLY when a question explicitly mentions an audio file or recording.
The task_id MUST be provided as the input.
"""
print(f"Tool 'transcribe_audio_from_task_id' (using Groq) called with task_id: {task_id}")
try:
# Step 1: Download the file
file_url = f"{DEFAULT_API_URL}/files/{task_id}"
print(f"Downloading audio file from: {file_url}")
audio_response = requests.get(file_url)
audio_response.raise_for_status()
# Step 2: Prepare the file for the Groq API
# The API expects a file-like object with a name.
audio_bytes = BytesIO(audio_response.content)
audio_bytes.name = f"{task_id}.mp3" # Give the file-like object a name
# Step 3: Initialize the Groq client and transcribe
print("Initializing Groq client for transcription...")
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
print("Transcribing audio with Groq's Whisper...")
transcription = client.audio.transcriptions.create(
file=audio_bytes,
model="whisper-large-v3",
response_format="text",
)
transcribed_text = str(transcription)
print(f"Transcription successful. Result: {transcribed_text}")
return transcribed_text
except Exception as e:
error_message = f"Error in Groq audio transcription tool: {e}"
print(error_message)
return error_message
# --- Agent Definition ---
class LangChainAgent:
def __init__(self, groq_api_key: str, tavily_api_key: str):
print("Initializing LangChainAgent...")
self.llm = ChatGroq(model_name="llama3-70b-8192", groq_api_key=groq