dalybuilds commited on
Commit
a38b536
·
verified ·
1 Parent(s): ff4c65d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -83
app.py CHANGED
@@ -3,90 +3,113 @@ import gradio as gr
3
  import requests
4
  import pandas as pd
5
  from io import BytesIO
 
 
 
 
 
6
 
7
  # --- LangChain & Groq Imports ---
8
  from groq import Groq
9
  from langchain_groq import ChatGroq
10
  from langchain.agents import AgentExecutor, create_tool_calling_agent
11
- from langchain_community.tools.tavily_search import TavilySearchResults
12
  from langchain_core.prompts import ChatPromptTemplate
13
  from langchain.tools import Tool
14
 
15
 
16
  # --- Constants ---
17
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
18
 
19
 
20
- # --- Custom Tool Definition using Groq ---
21
- def transcribe_audio_from_task_id(task_id: str) -> str:
22
  """
23
- Downloads an audio file for a given task_id from the scoring server,
24
- transcribes it using the GROQ API with Whisper, and returns the text.
25
- Use this tool ONLY when a question explicitly mentions an audio file or recording.
26
- The task_id MUST be provided as the input.
27
  """
28
- print(f"Tool 'transcribe_audio_from_task_id' (using Groq) called with task_id: {task_id}")
29
  try:
30
- # Step 1: Download the file
31
  file_url = f"{DEFAULT_API_URL}/files/{task_id}"
32
- print(f"Downloading audio file from: {file_url}")
33
  audio_response = requests.get(file_url)
34
  audio_response.raise_for_status()
35
-
36
- # Step 2: Prepare the file for the Groq API
37
  audio_bytes = BytesIO(audio_response.content)
38
- audio_bytes.name = f"{task_id}.mp3" # Give the file-like object a name
39
-
40
- # Step 3: Initialize the Groq client and transcribe
41
- print("Initializing Groq client for transcription...")
42
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
 
 
 
 
43
 
44
- print("Transcribing audio with Groq's Whisper...")
45
- transcription = client.audio.transcriptions.create(
46
- file=audio_bytes,
47
- model="whisper-large-v3",
48
- response_format="text",
49
- )
 
 
 
 
 
 
 
50
 
51
- transcribed_text = str(transcription)
52
- print(f"Transcription successful. Result: {transcribed_text}")
53
- return transcribed_text
54
-
 
 
 
 
55
  except Exception as e:
56
- error_message = f"Error in Groq audio transcription tool: {e}"
57
- print(error_message)
58
- return error_message
 
59
 
60
 
61
  # --- Agent Definition ---
62
  class LangChainAgent:
63
  def __init__(self, groq_api_key: str, tavily_api_key: str):
64
- print("Initializing LangChainAgent...")
65
-
66
- # THIS IS THE CORRECTED LINE
67
  self.llm = ChatGroq(model_name="llama3-70b-8192", groq_api_key=groq_api_key, temperature=0.0)
68
 
69
- # Define all available tools
70
- audio_tool = Tool(
71
- name="audio_transcriber",
72
- func=transcribe_audio_from_task_id,
73
- description="Use this tool to transcribe an audio file. The input must be the task_id of the question.",
74
- )
75
  self.tools = [
76
- TavilySearchResults(max_results=3, tavily_api_key=tavily_api_key, name="web_search"),
77
- audio_tool,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  ]
79
 
80
- # Define the strict system prompt
81
  prompt = ChatPromptTemplate.from_messages([
82
  ("system", (
83
  "You are a powerful problem-solving agent. Your goal is to answer the user's question accurately. "
84
- "You have access to the following tools: a web search tool and an audio transcription tool.\n"
85
- "RULES:\n"
86
- "- Carefully analyze the user's question to determine if a tool is needed.\n"
87
- "- For questions requiring current information or facts, use the 'web_search' tool.\n"
88
- "- For questions that mention an audio file (.mp3, recording, voice memo, etc.), use the 'audio_transcriber' tool with the provided 'task_id'.\n"
89
- "- Once you have all the necessary information, you MUST provide ONLY THE FINAL ANSWER to the user's question. Do not include any extra conversation, explanations, apologies, or introductory phrases."
 
 
 
90
  )),
91
  ("human", "Question: {input}\nTask ID: {task_id}"),
92
  ("placeholder", "{agent_scratchpad}"),
@@ -94,59 +117,52 @@ class LangChainAgent:
94
 
95
  agent = create_tool_calling_agent(self.llm, self.tools, prompt)
96
  self.agent_executor = AgentExecutor(agent=agent, tools=self.tools, verbose=True, handle_parsing_errors=True)
97
- print("LangChainAgent initialized.")
98
 
99
  def __call__(self, question: str, task_id: str) -> str:
100
- print(f"Agent received question (ID: {task_id}): {question[:50]}...")
 
 
 
 
 
101
  try:
102
- response = self.agent_executor.invoke({"input": question, "task_id": task_id})
103
- answer = response.get("output", "Agent failed to produce an answer.")
104
  except Exception as e:
105
- answer = f"Agent execution failed with an error: {e}"
106
- print(f"Agent generated answer: {answer}")
107
- return answer
108
-
109
 
110
- # --- Main Application Logic ---
111
  def run_and_submit_all(profile: gr.OAuthProfile | None):
112
  space_id = os.getenv("SPACE_ID")
113
- if not profile:
114
- return "Please Login to Hugging Face with the button.", None
115
  username = profile.username
116
- print(f"User logged in: {username}")
117
-
118
  try:
119
  groq_api_key = os.getenv("GROQ_API_KEY")
120
  tavily_api_key = os.getenv("TAVILY_API_KEY")
121
- if not all([groq_api_key, tavily_api_key]):
122
- raise ValueError("An API key secret (GROQ or TAVILY) is missing.")
123
  agent = LangChainAgent(groq_api_key=groq_api_key, tavily_api_key=tavily_api_key)
124
- except Exception as e:
125
- return f"Error initializing agent: {e}", None
126
-
127
  questions_url = f"{DEFAULT_API_URL}/questions"
128
- print(f"Fetching questions from: {questions_url}")
129
  try:
130
  response = requests.get(questions_url, timeout=20)
131
  response.raise_for_status()
132
  questions_data = response.json()
133
- except Exception as e:
134
- return f"Error fetching questions: {e}", None
135
-
136
  results_log, answers_payload = [], []
137
  for item in questions_data:
138
- task_id, question_text = item.get("task_id"), item.get("question")
139
- if not task_id or not question_text: continue
140
- submitted_answer = agent(question=question_text, task_id=task_id)
141
- answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
142
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
143
-
144
  agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
145
  submission_data = {"username": username, "agent_code": agent_code, "answers": answers_payload}
146
  submit_url = f"{DEFAULT_API_URL}/submit"
147
- print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
148
  try:
149
- response = requests.post(submit_url, json=submission_data, timeout=90) # Increased timeout
150
  response.raise_for_status()
151
  result_data = response.json()
152
  final_status = (f"Submission Successful!\nUser: {result_data.get('username')}\n"
@@ -154,14 +170,12 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
154
  f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
155
  f"Message: {result_data.get('message', 'No message received.')}")
156
  return final_status, pd.DataFrame(results_log)
157
- except Exception as e:
158
- return f"Submission Failed: {e}", pd.DataFrame(results_log)
159
-
160
 
161
- # --- Gradio Interface ---
162
  with gr.Blocks() as demo:
163
- gr.Markdown("# Advanced Agent Evaluation Runner (Search + Groq Audio)")
164
- gr.Markdown("This agent can search the web with Tavily and transcribe audio with Groq's Whisper.")
165
  gr.LoginButton()
166
  run_button = gr.Button("Run Evaluation & Submit All Answers")
167
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
 
3
  import requests
4
  import pandas as pd
5
  from io import BytesIO
6
+ import re
7
+
8
+ # --- New Imports for Video Tool ---
9
+ from pytube import YouTube
10
+ import moviepy.editor as mp
11
 
12
  # --- LangChain & Groq Imports ---
13
  from groq import Groq
14
  from langchain_groq import ChatGroq
15
  from langchain.agents import AgentExecutor, create_tool_calling_agent
16
+ from langchain_tavily_search import TavilySearchResults
17
  from langchain_core.prompts import ChatPromptTemplate
18
  from langchain.tools import Tool
19
 
20
 
21
  # --- Constants ---
22
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
23
+ TEMP_DIR = "/tmp" # Use the /tmp directory for temporary files in HF Spaces
24
 
25
 
26
+ # --- Tool Definition: Audio File Transcription ---
27
+ def transcribe_audio_file(task_id: str) -> str:
28
  """
29
+ Downloads an audio file (.mp3) for a given task_id, transcribes it, and returns the text.
30
+ Use this tool ONLY when a question explicitly mentions an audio file, .mp3, recording, or voice memo.
 
 
31
  """
32
+ print(f"Tool 'transcribe_audio_file' called with task_id: {task_id}")
33
  try:
 
34
  file_url = f"{DEFAULT_API_URL}/files/{task_id}"
 
35
  audio_response = requests.get(file_url)
36
  audio_response.raise_for_status()
 
 
37
  audio_bytes = BytesIO(audio_response.content)
38
+ audio_bytes.name = f"{task_id}.mp3"
39
+
 
 
40
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
41
+ transcription = client.audio.transcriptions.create(file=audio_bytes, model="whisper-large-v3", response_format="text")
42
+ return str(transcription)
43
+ except Exception as e:
44
+ return f"Error during audio file transcription: {e}"
45
 
46
+ # --- Tool Definition: Video Transcription ---
47
+ def transcribe_youtube_video(video_url: str) -> str:
48
+ """
49
+ Downloads a YouTube video from a URL, extracts its audio, and transcribes it to text.
50
+ Use this tool ONLY when a question provides a youtube.com URL.
51
+ """
52
+ print(f"Tool 'transcribe_youtube_video' called with URL: {video_url}")
53
+ video_path, audio_path = None, None
54
+ try:
55
+ os.makedirs(TEMP_DIR, exist_ok=True)
56
+ yt = YouTube(video_url)
57
+ stream = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
58
+ video_path = stream.download(output_path=TEMP_DIR)
59
 
60
+ video_clip = mp.VideoFileClip(video_path)
61
+ audio_path = os.path.join(TEMP_DIR, "temp_audio.mp3")
62
+ video_clip.audio.write_audiofile(audio_path, codec='mp3', logger=None)
63
+
64
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
65
+ with open(audio_path, "rb") as audio_file:
66
+ transcription = client.audio.transcriptions.create(file=audio_file, model="whisper-large-v3", response_format="text")
67
+ return str(transcription)
68
  except Exception as e:
69
+ return f"Error during YouTube transcription: {e}"
70
+ finally:
71
+ if video_path and os.path.exists(video_path): os.remove(video_path)
72
+ if audio_path and os.path.exists(audio_path): os.remove(audio_path)
73
 
74
 
75
  # --- Agent Definition ---
76
  class LangChainAgent:
77
  def __init__(self, groq_api_key: str, tavily_api_key: str):
 
 
 
78
  self.llm = ChatGroq(model_name="llama3-70b-8192", groq_api_key=groq_api_key, temperature=0.0)
79
 
80
+ # Updated tools with much more specific descriptions
 
 
 
 
 
81
  self.tools = [
82
+ TavilySearchResults(
83
+ name="web_search",
84
+ max_results=3,
85
+ tavily_api_key=tavily_api_key,
86
+ description="A search engine for finding up-to-date information, facts, and news on the internet."
87
+ ),
88
+ Tool(
89
+ name="audio_file_transcriber",
90
+ func=transcribe_audio_file,
91
+ description="Use this ONLY for questions mentioning an audio file (.mp3, recording). Input MUST be the task_id.",
92
+ ),
93
+ Tool(
94
+ name="youtube_video_transcriber",
95
+ func=transcribe_youtube_video,
96
+ description="Use this ONLY for questions providing a youtube.com URL. Input MUST be the URL.",
97
+ ),
98
  ]
99
 
100
+ # Updated, rule-based system prompt
101
  prompt = ChatPromptTemplate.from_messages([
102
  ("system", (
103
  "You are a powerful problem-solving agent. Your goal is to answer the user's question accurately. "
104
+ "You have access to a web search tool, an audio file transcriber, and a YouTube video transcriber.\n\n"
105
+ "**REASONING PROCESS:**\n"
106
+ "1. **Analyze the question:** Is it a general knowledge question, or does it mention a file/URL?\n"
107
+ "2. **Select ONE tool:**\n"
108
+ " - If the question requires current events, facts, or general knowledge, use `web_search`.\n"
109
+ " - If the question *explicitly* mentions an audio file, .mp3, or voice memo, use `audio_file_transcriber` with the provided `task_id`.\n"
110
+ " - If the question *explicitly* provides a `youtube.com` URL, use `youtube_video_transcriber` with that URL.\n"
111
+ " - If no tool is needed (e.g., math, logic puzzles), answer directly.\n"
112
+ "3. **Execute and Answer:** After using a tool, analyze the result and provide ONLY THE FINAL ANSWER. Do not explain your actions or apologize for errors."
113
  )),
114
  ("human", "Question: {input}\nTask ID: {task_id}"),
115
  ("placeholder", "{agent_scratchpad}"),
 
117
 
118
  agent = create_tool_calling_agent(self.llm, self.tools, prompt)
119
  self.agent_executor = AgentExecutor(agent=agent, tools=self.tools, verbose=True, handle_parsing_errors=True)
 
120
 
121
  def __call__(self, question: str, task_id: str) -> str:
122
+ # The agent sometimes needs the URL directly, so let's extract it if present.
123
+ urls = re.findall(r'https?://[^\s]+', question)
124
+ input_for_agent = {"input": question, "task_id": task_id}
125
+ if urls:
126
+ input_for_agent['video_url'] = urls[0] # Pass the URL if found
127
+
128
  try:
129
+ response = self.agent_executor.invoke(input_for_agent)
130
+ return response.get("output", "Agent failed to produce an answer.")
131
  except Exception as e:
132
+ return f"Agent execution failed with an error: {e}"
 
 
 
133
 
134
+ # --- Main Application Logic (Unchanged) ---
135
  def run_and_submit_all(profile: gr.OAuthProfile | None):
136
  space_id = os.getenv("SPACE_ID")
137
+ if not profile: return "Please Login to Hugging Face with the button.", None
 
138
  username = profile.username
 
 
139
  try:
140
  groq_api_key = os.getenv("GROQ_API_KEY")
141
  tavily_api_key = os.getenv("TAVILY_API_KEY")
142
+ if not all([groq_api_key, tavily_api_key]): raise ValueError("GROQ or TAVILY API key is missing.")
 
143
  agent = LangChainAgent(groq_api_key=groq_api_key, tavily_api_key=tavily_api_key)
144
+ except Exception as e: return f"Error initializing agent: {e}", None
145
+
 
146
  questions_url = f"{DEFAULT_API_URL}/questions"
 
147
  try:
148
  response = requests.get(questions_url, timeout=20)
149
  response.raise_for_status()
150
  questions_data = response.json()
151
+ except Exception as e: return f"Error fetching questions: {e}", None
152
+
 
153
  results_log, answers_payload = [], []
154
  for item in questions_data:
155
+ task_id, q_text = item.get("task_id"), item.get("question")
156
+ if not task_id or not q_text: continue
157
+ answer = agent(question=q_text, task_id=task_id)
158
+ answers_payload.append({"task_id": task_id, "submitted_answer": answer})
159
+ results_log.append({"Task ID": task_id, "Question": q_text, "Submitted Answer": answer})
160
+
161
  agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
162
  submission_data = {"username": username, "agent_code": agent_code, "answers": answers_payload}
163
  submit_url = f"{DEFAULT_API_URL}/submit"
 
164
  try:
165
+ response = requests.post(submit_url, json=submission_data, timeout=240) # Increased timeout for video processing
166
  response.raise_for_status()
167
  result_data = response.json()
168
  final_status = (f"Submission Successful!\nUser: {result_data.get('username')}\n"
 
170
  f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
171
  f"Message: {result_data.get('message', 'No message received.')}")
172
  return final_status, pd.DataFrame(results_log)
173
+ except Exception as e: return f"Submission Failed: {e}", pd.DataFrame(results_log)
 
 
174
 
175
+ # --- Gradio Interface (Unchanged) ---
176
  with gr.Blocks() as demo:
177
+ gr.Markdown("# Ultimate Agent Runner (Search + Audio + Video)")
178
+ gr.Markdown("This agent can search, transcribe audio files, and transcribe YouTube videos.")
179
  gr.LoginButton()
180
  run_button = gr.Button("Run Evaluation & Submit All Answers")
181
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)