davidgturner commited on
Commit
e305927
·
1 Parent(s): e7f4f55

some changes to add llm and cleaning changes too

Browse files
Files changed (2) hide show
  1. app.py +117 -3
  2. requirements.txt +2 -1
app.py CHANGED
@@ -3,6 +3,9 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
 
 
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
@@ -13,11 +16,122 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
13
  class BasicAgent:
14
  def __init__(self):
15
  print("BasicAgent initialized.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def __call__(self, question: str) -> str:
17
  print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ import time
7
+ import json
8
+ from typing import Dict, List, Union, Optional
9
 
10
  # (Keep Constants as is)
11
  # --- Constants ---
 
16
  class BasicAgent:
17
  def __init__(self):
18
  print("BasicAgent initialized.")
19
+ # Initialize the Hugging Face API client
20
+ self.hf_api_url = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
21
+ self.hf_api_token = os.getenv("HF_API_TOKEN")
22
+ if not self.hf_api_token:
23
+ print("WARNING: HF_API_TOKEN not found. Using default fallback methods.")
24
+ self.headers = {"Authorization": f"Bearer {self.hf_api_token}"} if self.hf_api_token else {}
25
+ self.max_retries = 3
26
+ self.retry_delay = 2 # seconds
27
+
28
+ def query_llm(self, prompt):
29
+ """Send a prompt to the LLM API and return the response."""
30
+ if not self.hf_api_token:
31
+ # Fallback to a rule-based approach if no API token
32
+ return self.rule_based_answer(prompt)
33
+
34
+ payload = {
35
+ "inputs": prompt,
36
+ "parameters": {
37
+ "max_new_tokens": 512,
38
+ "temperature": 0.7,
39
+ "top_p": 0.9,
40
+ "do_sample": True
41
+ }
42
+ }
43
+
44
+ for attempt in range(self.max_retries):
45
+ try:
46
+ response = requests.post(self.hf_api_url, headers=self.headers, json=payload, timeout=30)
47
+ response.raise_for_status()
48
+ result = response.json()
49
+
50
+ # Extract the generated text from the response
51
+ if isinstance(result, list) and len(result) > 0:
52
+ generated_text = result[0].get("generated_text", "")
53
+ # Clean up the response to get just the answer
54
+ return self.clean_response(generated_text, prompt)
55
+ return "I couldn't generate a proper response."
56
+
57
+ except Exception as e:
58
+ print(f"Attempt {attempt+1}/{self.max_retries} failed: {str(e)}")
59
+ if attempt < self.max_retries - 1:
60
+ time.sleep(self.retry_delay)
61
+ else:
62
+ # Fall back to rule-based method on failure
63
+ return self.rule_based_answer(prompt)
64
+
65
+ def clean_response(self, response, prompt):
66
+ """Clean up the LLM response to extract the answer."""
67
+ # Remove the prompt from the beginning if it's included
68
+ if response.startswith(prompt):
69
+ response = response[len(prompt):]
70
+
71
+ # Try to find where the model's actual answer begins
72
+ # This is model-specific and may need adjustments
73
+ markers = ["<answer>", "<response>", "Answer:", "Response:"]
74
+ for marker in markers:
75
+ if marker.lower() in response.lower():
76
+ parts = response.lower().split(marker.lower(), 1)
77
+ if len(parts) > 1:
78
+ response = parts[1].strip()
79
+
80
+ # Remove any closing tags if they exist
81
+ end_markers = ["</answer>", "</response>"]
82
+ for marker in end_markers:
83
+ if marker.lower() in response.lower():
84
+ response = response.lower().split(marker.lower())[0].strip()
85
+
86
+ return response.strip()
87
+
88
+ def rule_based_answer(self, question):
89
+ """Fallback method using rule-based answers for common question types."""
90
+ question_lower = question.lower()
91
+
92
+ # Simple pattern matching for common question types
93
+ if "what is" in question_lower or "define" in question_lower:
94
+ if "agent" in question_lower:
95
+ return "An agent is an autonomous entity that observes and acts upon an environment using sensors and actuators, usually to achieve specific goals."
96
+ if "gaia" in question_lower:
97
+ return "GAIA (General AI Assistant) is a framework for creating and evaluating AI assistants that can perform a wide range of tasks."
98
+
99
+ if "how to" in question_lower:
100
+ return "To accomplish this task, you should first understand the requirements, then implement a solution step by step, and finally test your implementation."
101
+
102
+ if "example" in question_lower:
103
+ return "Here's an example implementation that demonstrates the concept in a practical manner."
104
+
105
+ # Default response for unmatched questions
106
+ return "Based on my understanding, the answer involves analyzing the context carefully and applying the relevant principles to arrive at a solution."
107
+
108
+ def format_prompt(self, question):
109
+ """Format the question into a proper prompt for the LLM."""
110
+ return f"""You are an intelligent AI assistant. Please answer the following question accurately and concisely:
111
+
112
+ Question: {question}
113
+
114
+ Answer:"""
115
+
116
  def __call__(self, question: str) -> str:
117
  print(f"Agent received question (first 50 chars): {question[:50]}...")
118
+
119
+ try:
120
+ # Format the question as a prompt
121
+ prompt = self.format_prompt(question)
122
+
123
+ # Query the LLM
124
+ answer = self.query_llm(prompt)
125
+
126
+ print(f"Agent returning answer (first 50 chars): {answer[:50]}...")
127
+ return answer
128
+
129
+ except Exception as e:
130
+ print(f"Error in agent: {e}")
131
+ # Fallback to the rule-based method if anything goes wrong
132
+ fallback_answer = self.rule_based_answer(question)
133
+ print(f"Agent returning fallback answer: {fallback_answer[:50]}...")
134
+ return fallback_answer
135
 
136
  def run_and_submit_all( profile: gr.OAuthProfile | None):
137
  """
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  gradio
2
- requests
 
 
1
  gradio
2
+ requests
3
+ pandas