File size: 2,448 Bytes
10e9b7d
 
eccf8e4
3c4371f
8d9e499
10e9b7d
8d9e499
 
c0a6618
 
 
 
8d9e499
c0a6618
 
e80aab9
3db6293
e80aab9
c0a6618
8d9e499
 
31243f4
8d9e499
 
 
 
31243f4
8d9e499
31243f4
8d9e499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e80aab9
8d9e499
 
 
 
 
7d65c66
8d9e499
 
 
e80aab9
 
8d9e499
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import gradio as gr
import requests
import pandas as pd
from io import BytesIO

# --- LangChain & Groq Imports ---
from groq import Groq
from langchain_groq import ChatGroq
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.prompts import ChatPromptTemplate
from langchain.tools import Tool


# --- Constants ---
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"


# --- Custom Tool Definition using Groq ---
def transcribe_audio_from_task_id(task_id: str) -> str:
    """
    Downloads an audio file for a given task_id from the scoring server,
    transcribes it using the GROQ API with Whisper, and returns the text.
    Use this tool ONLY when a question explicitly mentions an audio file or recording.
    The task_id MUST be provided as the input.
    """
    print(f"Tool 'transcribe_audio_from_task_id' (using Groq) called with task_id: {task_id}")
    try:
        # Step 1: Download the file
        file_url = f"{DEFAULT_API_URL}/files/{task_id}"
        print(f"Downloading audio file from: {file_url}")
        audio_response = requests.get(file_url)
        audio_response.raise_for_status()

        # Step 2: Prepare the file for the Groq API
        # The API expects a file-like object with a name.
        audio_bytes = BytesIO(audio_response.content)
        audio_bytes.name = f"{task_id}.mp3" # Give the file-like object a name

        # Step 3: Initialize the Groq client and transcribe
        print("Initializing Groq client for transcription...")
        client = Groq(api_key=os.getenv("GROQ_API_KEY"))

        print("Transcribing audio with Groq's Whisper...")
        transcription = client.audio.transcriptions.create(
          file=audio_bytes,
          model="whisper-large-v3",
          response_format="text",
        )
        
        transcribed_text = str(transcription)
        print(f"Transcription successful. Result: {transcribed_text}")
        return transcribed_text

    except Exception as e:
        error_message = f"Error in Groq audio transcription tool: {e}"
        print(error_message)
        return error_message


# --- Agent Definition ---
class LangChainAgent:
    def __init__(self, groq_api_key: str, tavily_api_key: str):
        print("Initializing LangChainAgent...")
        self.llm = ChatGroq(model_name="llama3-70b-8192", groq_api_key=groq