github-actions[bot] commited on
Commit
7e103cf
Β·
1 Parent(s): d78bdc9

Sync with https://github.com/mozilla-ai/any-agent-demo

Browse files
README.md CHANGED
@@ -11,10 +11,3 @@ pinned: false
11
  short_description: Find a surf spot near you
12
  license: apache-2.0
13
  ---
14
-
15
- # Welcome to Streamlit!
16
-
17
- Edit `/src/app.py` to customize this app to your heart's desire. :heart:
18
-
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
11
  short_description: Find a surf spot near you
12
  license: apache-2.0
13
  ---
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,8 +1,9 @@
1
- from components.sidebar import ssf_sidebar
2
- from constants import DEFAULT_TOOLS
3
- import streamlit as st
4
  import asyncio
 
5
  import nest_asyncio
 
 
 
6
  from services.agent import (
7
  configure_agent,
8
  display_evaluation_results,
@@ -13,14 +14,11 @@ from services.agent import (
13
 
14
  nest_asyncio.apply()
15
 
16
- # Set page config
17
  st.set_page_config(page_title="Surf Spot Finder", page_icon="πŸ„", layout="wide")
18
 
19
- # Allow a user to resize the sidebar to take up most of the screen to make editing eval cases easier
20
  st.markdown(
21
  """
22
  <style>
23
- /* When sidebar is expanded, adjust main content */
24
  section[data-testid="stSidebar"][aria-expanded="true"] {
25
  max-width: 99% !important;
26
  }
@@ -35,18 +33,16 @@ with st.sidebar:
35
  run_button = st.button("Run Agent πŸ€–", disabled=not is_valid, type="primary")
36
 
37
 
38
- # Main content
39
  async def main():
40
- # Handle agent execution button click
41
  if run_button:
42
  agent, agent_config = await configure_agent(user_inputs)
43
  agent_trace = await run_agent(agent, agent_config)
44
 
45
  await display_output(agent_trace)
46
 
47
- evaluation_result = await evaluate_agent(agent_config, agent_trace)
48
-
49
- await display_evaluation_results(evaluation_result)
50
  else:
51
  st.title("πŸ„ Surf Spot Finder")
52
  st.markdown(
@@ -56,7 +52,6 @@ async def main():
56
  "πŸ‘ˆ Configure your search parameters in the sidebar and click Run to start!"
57
  )
58
 
59
- # Display tools in a more organized way
60
  st.markdown("### πŸ› οΈ Available Tools")
61
 
62
  st.markdown("""
@@ -92,7 +87,6 @@ async def main():
92
  with st.expander(f"🌐 {tool.__name__}"):
93
  st.markdown(tool.__doc__ or "No description available")
94
 
95
- # add a check that all tools were listed
96
  if len(weather_tools) + len(location_tools) + len(web_tools) != len(
97
  DEFAULT_TOOLS
98
  ):
@@ -100,7 +94,6 @@ async def main():
100
  "Some tools are not listed. Please check the code for more details."
101
  )
102
 
103
- # Add Custom Evaluation explanation section
104
  st.markdown("### πŸ“Š Custom Evaluation")
105
  st.markdown("""
106
  The Surf Spot Finder includes a powerful evaluation system that allows you to customize how the agent's performance is assessed.
 
 
 
 
1
  import asyncio
2
+
3
  import nest_asyncio
4
+ import streamlit as st
5
+ from components.sidebar import ssf_sidebar
6
+ from constants import DEFAULT_TOOLS
7
  from services.agent import (
8
  configure_agent,
9
  display_evaluation_results,
 
14
 
15
  nest_asyncio.apply()
16
 
 
17
  st.set_page_config(page_title="Surf Spot Finder", page_icon="πŸ„", layout="wide")
18
 
 
19
  st.markdown(
20
  """
21
  <style>
 
22
  section[data-testid="stSidebar"][aria-expanded="true"] {
23
  max-width: 99% !important;
24
  }
 
33
  run_button = st.button("Run Agent πŸ€–", disabled=not is_valid, type="primary")
34
 
35
 
 
36
  async def main():
 
37
  if run_button:
38
  agent, agent_config = await configure_agent(user_inputs)
39
  agent_trace = await run_agent(agent, agent_config)
40
 
41
  await display_output(agent_trace)
42
 
43
+ if user_inputs.run_evaluation:
44
+ evaluation_results = await evaluate_agent(agent_config, agent_trace)
45
+ await display_evaluation_results(evaluation_results)
46
  else:
47
  st.title("πŸ„ Surf Spot Finder")
48
  st.markdown(
 
52
  "πŸ‘ˆ Configure your search parameters in the sidebar and click Run to start!"
53
  )
54
 
 
55
  st.markdown("### πŸ› οΈ Available Tools")
56
 
57
  st.markdown("""
 
87
  with st.expander(f"🌐 {tool.__name__}"):
88
  st.markdown(tool.__doc__ or "No description available")
89
 
 
90
  if len(weather_tools) + len(location_tools) + len(web_tools) != len(
91
  DEFAULT_TOOLS
92
  ):
 
94
  "Some tools are not listed. Please check the code for more details."
95
  )
96
 
 
97
  st.markdown("### πŸ“Š Custom Evaluation")
98
  st.markdown("""
99
  The Surf Spot Finder includes a powerful evaluation system that allows you to customize how the agent's performance is assessed.
components/agent_status.py CHANGED
@@ -1,47 +1,82 @@
1
- from any_agent import AnyAgent
2
- from opentelemetry.sdk.trace.export import SimpleSpanProcessor
3
- from collections.abc import Sequence
4
- from typing import TYPE_CHECKING, Callable
5
 
6
- from opentelemetry.sdk.trace.export import (
7
- SpanExporter,
8
- SpanExportResult,
9
- )
10
 
11
- from any_agent import AgentFramework
12
 
13
- from any_agent.tracing import TracingProcessor
14
- from any_agent.tracing.trace import AgentSpan
15
 
16
- if TYPE_CHECKING:
17
- from opentelemetry.sdk.trace import ReadableSpan
18
 
 
 
 
 
 
19
 
20
- class StreamlitExporter(SpanExporter):
21
- """Build an `AgentTrace` and export to the different outputs."""
22
 
23
- def __init__( # noqa: D107
24
- self, agent_framework: AgentFramework, callback: Callable
25
- ):
26
- self.agent_framework = agent_framework
27
- self.processor: TracingProcessor | None = TracingProcessor.create(
28
- agent_framework
29
- )
30
- self.callback = callback
31
 
32
- def export(self, spans: Sequence["ReadableSpan"]) -> SpanExportResult: # noqa: D102
33
- if not self.processor:
34
- return SpanExportResult.SUCCESS
35
 
36
- for readable_span in spans:
37
- # Check if this span belongs to our run
38
- span = AgentSpan.from_readable_span(readable_span)
39
- self.callback(span)
 
40
 
41
- return SpanExportResult.SUCCESS
 
 
 
 
42
 
 
 
 
43
 
44
- def export_logs(agent: AnyAgent, callback: Callable) -> None:
45
- exporter = StreamlitExporter(agent.framework, callback)
46
- span_processor = SimpleSpanProcessor(exporter)
47
- agent._tracer_provider.add_span_processor(span_processor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+ from typing import Any
 
 
3
 
4
+ from any_agent.callbacks import Callback, Context
5
+ from any_agent.tracing.attributes import GenAI
 
 
6
 
 
7
 
8
+ class StreamlitStatusCallback(Callback):
9
+ """Callback to update Streamlit status with agent progress."""
10
 
11
+ def __init__(self, status_callback: Callable[[str], None]):
12
+ self.status_callback = status_callback
13
 
14
+ def after_llm_call(self, context: Context, *args, **kwargs) -> Context:
15
+ """Update status after LLM calls."""
16
+ span = context.current_span
17
+ input_value = span.attributes.get(GenAI.INPUT_MESSAGES, "")
18
+ output_value = span.attributes.get(GenAI.OUTPUT, "")
19
 
20
+ self._update_status(span.name, input_value, output_value)
21
+ return context
22
 
23
+ def after_tool_execution(self, context: Context, *args, **kwargs) -> Context:
24
+ """Update status after tool executions."""
25
+ span = context.current_span
26
+ input_value = span.attributes.get(GenAI.TOOL_ARGS, "")
27
+ output_value = span.attributes.get(GenAI.OUTPUT, "")
 
 
 
28
 
29
+ self._update_status(span.name, input_value, output_value)
30
+ return context
 
31
 
32
+ def _update_status(self, step_name: str, input_value: str, output_value: str):
33
+ """Update the Streamlit status with formatted information."""
34
+ if input_value:
35
+ try:
36
+ import json
37
 
38
+ parsed_input = json.loads(input_value)
39
+ if isinstance(parsed_input, list) and len(parsed_input) > 0:
40
+ input_value = str(parsed_input[-1])
41
+ except Exception:
42
+ pass
43
 
44
+ if output_value:
45
+ try:
46
+ import json
47
 
48
+ parsed_output = json.loads(output_value)
49
+ if isinstance(parsed_output, list) and len(parsed_output) > 0:
50
+ output_value = str(parsed_output[-1])
51
+ except Exception:
52
+ pass
53
+
54
+ max_length = 800
55
+ if len(input_value) > max_length:
56
+ input_value = f"[Truncated]...{input_value[-max_length:]}"
57
+ if len(output_value) > max_length:
58
+ output_value = f"[Truncated]...{output_value[-max_length:]}"
59
+
60
+ if input_value or output_value:
61
+ message = f"Step: {step_name}\n"
62
+ if input_value:
63
+ message += f"Input: {input_value}\n"
64
+ if output_value:
65
+ message += f"Output: {output_value}"
66
+ else:
67
+ message = f"Step: {step_name}"
68
+
69
+ self.status_callback(message)
70
+
71
+
72
+ def export_logs(agent: Any, callback: Callable[[str], None]) -> None:
73
+ """Add a Streamlit status callback to the agent.
74
+
75
+ This function adds a custom callback to the agent that will update
76
+ the Streamlit status with progress information during agent execution.
77
+ """
78
+ status_callback = StreamlitStatusCallback(callback)
79
+
80
+ if agent.config.callbacks is None:
81
+ agent.config.callbacks = []
82
+ agent.config.callbacks.append(status_callback)
components/inputs.py CHANGED
@@ -1,17 +1,19 @@
1
- from datetime import datetime, timedelta
2
  import json
 
 
 
3
  import requests
4
  import streamlit as st
5
- from any_agent import AgentFramework
6
- from any_agent.tracing.trace import _is_tracing_supported
7
- from any_agent.evaluation import EvaluationCase
8
- from any_agent.evaluation.schemas import CheckpointCriteria
9
- import pandas as pd
10
- from constants import DEFAULT_EVALUATION_CASE, MODEL_OPTIONS
11
- import copy
12
-
13
  from pydantic import BaseModel, ConfigDict
14
 
 
 
15
 
16
  class UserInputs(BaseModel):
17
  model_config = ConfigDict(extra="forbid")
@@ -20,7 +22,8 @@ class UserInputs(BaseModel):
20
  max_driving_hours: int
21
  date: datetime
22
  framework: str
23
- evaluation_case: EvaluationCase
 
24
  run_evaluation: bool
25
 
26
 
@@ -35,6 +38,7 @@ def get_area(area_name: str) -> dict:
35
 
36
  Returns:
37
  dict: The area found.
 
38
  """
39
  response = requests.get(
40
  f"https://nominatim.openstreetmap.org/search?q={area_name}&format=jsonv2",
@@ -42,8 +46,7 @@ def get_area(area_name: str) -> dict:
42
  timeout=5,
43
  )
44
  response.raise_for_status()
45
- response_json = json.loads(response.content.decode())
46
- return response_json
47
 
48
 
49
  def get_user_inputs() -> UserInputs:
@@ -65,7 +68,6 @@ def get_user_inputs() -> UserInputs:
65
  "Select a date in the future", value=datetime.now() + timedelta(days=1)
66
  )
67
  with col_time:
68
- # default to 9am
69
  time = st.selectbox(
70
  "Select a time",
71
  [datetime.strptime(f"{i:02d}:00", "%H:%M").time() for i in range(24)],
@@ -73,9 +75,7 @@ def get_user_inputs() -> UserInputs:
73
  )
74
  date = datetime.combine(date, time)
75
 
76
- supported_frameworks = [
77
- framework for framework in AgentFramework if _is_tracing_supported(framework)
78
- ]
79
 
80
  framework = st.selectbox(
81
  "Select the agent framework to use",
@@ -91,7 +91,6 @@ def get_user_inputs() -> UserInputs:
91
  format_func=lambda x: "/".join(x.split("/")[-3:]),
92
  )
93
 
94
- # Add evaluation case section
95
  with st.expander("Custom Evaluation"):
96
  evaluation_model_id = st.selectbox(
97
  "Select the model to use for LLM-as-a-Judge evaluation",
@@ -99,47 +98,35 @@ def get_user_inputs() -> UserInputs:
99
  index=2,
100
  format_func=lambda x: "/".join(x.split("/")[-3:]),
101
  )
102
- evaluation_case = copy.deepcopy(DEFAULT_EVALUATION_CASE)
103
- evaluation_case.llm_judge = evaluation_model_id
104
- # make this an editable json section
105
- # convert the checkpoints to a df series so that it can be edited
106
- checkpoints = evaluation_case.checkpoints
107
- checkpoints_df = pd.DataFrame(
108
- [checkpoint.model_dump() for checkpoint in checkpoints]
109
- )
110
- checkpoints_df = st.data_editor(
111
- checkpoints_df,
112
  column_config={
113
- "points": st.column_config.NumberColumn(label="Points"),
114
  "criteria": st.column_config.TextColumn(label="Criteria"),
115
  },
116
  hide_index=True,
117
  num_rows="dynamic",
118
  )
119
- # for each checkpoint, convert it back to a CheckpointCriteria object
120
- new_ckpts = []
121
 
122
- # don't let a user add more than 20 checkpoints
123
- if len(checkpoints_df) > 20:
124
- st.error(
125
- "You can only add up to 20 checkpoints for the purpose of this demo."
126
- )
127
- checkpoints_df = checkpoints_df[:20]
128
 
129
- for _, row in checkpoints_df.iterrows():
130
  if row["criteria"] == "":
131
  continue
132
  try:
133
- # Don't let people write essays for criteria in this demo
134
  if len(row["criteria"].split(" ")) > 100:
135
- raise ValueError("Criteria is too long")
136
- new_crit = CheckpointCriteria(
137
- criteria=row["criteria"], points=row["points"]
138
- )
139
- new_ckpts.append(new_crit)
140
  except Exception as e:
141
- st.error(f"Error creating checkpoint: {e}")
142
- evaluation_case.checkpoints = new_ckpts
143
 
144
  return UserInputs(
145
  model_id=model_id,
@@ -147,6 +134,7 @@ def get_user_inputs() -> UserInputs:
147
  max_driving_hours=max_driving_hours,
148
  date=date,
149
  framework=framework,
150
- evaluation_case=evaluation_case,
 
151
  run_evaluation=st.checkbox("Run Evaluation", value=True),
152
  )
 
1
+ import copy
2
  import json
3
+ from datetime import datetime, timedelta
4
+
5
+ import pandas as pd
6
  import requests
7
  import streamlit as st
8
+ from constants import (
9
+ DEFAULT_EVALUATION_CRITERIA,
10
+ DEFAULT_EVALUATION_MODEL,
11
+ MODEL_OPTIONS,
12
+ )
 
 
 
13
  from pydantic import BaseModel, ConfigDict
14
 
15
+ from any_agent import AgentFramework
16
+
17
 
18
  class UserInputs(BaseModel):
19
  model_config = ConfigDict(extra="forbid")
 
22
  max_driving_hours: int
23
  date: datetime
24
  framework: str
25
+ evaluation_model: str
26
+ evaluation_criteria: list[dict[str, str]]
27
  run_evaluation: bool
28
 
29
 
 
38
 
39
  Returns:
40
  dict: The area found.
41
+
42
  """
43
  response = requests.get(
44
  f"https://nominatim.openstreetmap.org/search?q={area_name}&format=jsonv2",
 
46
  timeout=5,
47
  )
48
  response.raise_for_status()
49
+ return json.loads(response.content.decode())
 
50
 
51
 
52
  def get_user_inputs() -> UserInputs:
 
68
  "Select a date in the future", value=datetime.now() + timedelta(days=1)
69
  )
70
  with col_time:
 
71
  time = st.selectbox(
72
  "Select a time",
73
  [datetime.strptime(f"{i:02d}:00", "%H:%M").time() for i in range(24)],
 
75
  )
76
  date = datetime.combine(date, time)
77
 
78
+ supported_frameworks = [framework for framework in AgentFramework]
 
 
79
 
80
  framework = st.selectbox(
81
  "Select the agent framework to use",
 
91
  format_func=lambda x: "/".join(x.split("/")[-3:]),
92
  )
93
 
 
94
  with st.expander("Custom Evaluation"):
95
  evaluation_model_id = st.selectbox(
96
  "Select the model to use for LLM-as-a-Judge evaluation",
 
98
  index=2,
99
  format_func=lambda x: "/".join(x.split("/")[-3:]),
100
  )
101
+
102
+ evaluation_criteria = copy.deepcopy(DEFAULT_EVALUATION_CRITERIA)
103
+
104
+ criteria_df = pd.DataFrame(evaluation_criteria)
105
+ criteria_df = st.data_editor(
106
+ criteria_df,
 
 
 
 
107
  column_config={
 
108
  "criteria": st.column_config.TextColumn(label="Criteria"),
109
  },
110
  hide_index=True,
111
  num_rows="dynamic",
112
  )
 
 
113
 
114
+ new_criteria = []
115
+
116
+ if len(criteria_df) > 20:
117
+ st.error("You can only add up to 20 criteria for the purpose of this demo.")
118
+ criteria_df = criteria_df[:20]
 
119
 
120
+ for _, row in criteria_df.iterrows():
121
  if row["criteria"] == "":
122
  continue
123
  try:
 
124
  if len(row["criteria"].split(" ")) > 100:
125
+ msg = "Criteria is too long"
126
+ raise ValueError(msg)
127
+ new_criteria.append({"criteria": row["criteria"]})
 
 
128
  except Exception as e:
129
+ st.error(f"Error creating criterion: {e}")
 
130
 
131
  return UserInputs(
132
  model_id=model_id,
 
134
  max_driving_hours=max_driving_hours,
135
  date=date,
136
  framework=framework,
137
+ evaluation_model=evaluation_model_id,
138
+ evaluation_criteria=new_criteria,
139
  run_evaluation=st.checkbox("Run Evaluation", value=True),
140
  )
components/sidebar.py CHANGED
@@ -1,9 +1,9 @@
1
- from components.inputs import UserInputs, get_user_inputs
2
  import streamlit as st
3
 
 
 
4
 
5
  def ssf_sidebar() -> UserInputs:
6
  st.markdown("### Configuration")
7
  st.markdown("Built using [Any-Agent](https://github.com/mozilla-ai/any-agent)")
8
- user_inputs = get_user_inputs()
9
- return user_inputs
 
 
1
  import streamlit as st
2
 
3
+ from components.inputs import UserInputs, get_user_inputs
4
+
5
 
6
  def ssf_sidebar() -> UserInputs:
7
  st.markdown("### Configuration")
8
  st.markdown("Built using [Any-Agent](https://github.com/mozilla-ai/any-agent)")
9
+ return get_user_inputs()
 
config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from datetime import datetime, timedelta
4
+ from typing import Annotated
5
+
6
+ import geocoder
7
+ from pydantic import AfterValidator, BaseModel, ConfigDict, FutureDatetime, PositiveInt
8
+ from rich.prompt import Prompt
9
+
10
+ from any_agent import AgentFramework
11
+ from any_agent.config import AgentConfig
12
+ from any_agent.logging import logger
13
+
14
+ INPUT_PROMPT_TEMPLATE = """
15
+ According to the forecast, what will be the best spot to surf around {LOCATION},
16
+ in a {MAX_DRIVING_HOURS} hour driving radius,
17
+ at {DATE}?"
18
+ """.strip()
19
+
20
+
21
+ def validate_prompt(value) -> str:
22
+ for placeholder in ("{LOCATION}", "{MAX_DRIVING_HOURS}", "{DATE}"):
23
+ if placeholder not in value:
24
+ raise ValueError(f"prompt must contain {placeholder}")
25
+ return value
26
+
27
+
28
+ class Config(BaseModel):
29
+ model_config = ConfigDict(extra="forbid")
30
+
31
+ location: str
32
+ max_driving_hours: PositiveInt
33
+ date: FutureDatetime
34
+ input_prompt_template: Annotated[str, AfterValidator(validate_prompt)] = (
35
+ INPUT_PROMPT_TEMPLATE
36
+ )
37
+
38
+ framework: AgentFramework
39
+
40
+ main_agent: AgentConfig
41
+
42
+ evaluation_model: str | None = None
43
+ evaluation_criteria: list[dict[str, str]] | None = None
constants.py CHANGED
@@ -1,65 +1,47 @@
1
  import os
2
 
3
- from any_agent.evaluation import EvaluationCase
4
- from surf_spot_finder.tools import (
5
  get_area_lat_lon,
6
  get_wave_forecast,
7
  get_wind_forecast,
8
  )
 
9
  from any_agent.logging import logger
10
- from any_agent.tools.web_browsing import search_web, visit_webpage, search_tavily
11
 
12
  MODEL_OPTIONS = [
13
- # "huggingface/novita/deepseek-ai/DeepSeek-V3",
14
- # "huggingface/novita/meta-llama/Llama-3.3-70B-Instruct",
15
  "openai/gpt-4.1-nano",
16
  "openai/gpt-4.1-mini",
17
  "openai/gpt-4o",
18
  "gemini/gemini-2.0-flash-lite",
19
  "gemini/gemini-2.0-flash",
20
- # "huggingface/Qwen/Qwen3-32B", # right now throwing an internal error, but novita qwen isn't supporting tool calling
21
  ]
22
 
23
- # Novita was the only HF based provider that worked.
24
-
25
- # Hugginface API Provider Error:
26
- # Must alternate between assistant/user, which meant that the 'tool' role made it puke
27
-
28
-
29
- DEFAULT_EVALUATION_CASE = EvaluationCase(
30
- llm_judge=MODEL_OPTIONS[0],
31
- checkpoints=[
32
- {
33
- "criteria": "Check if the agent considered at least three surf spot options",
34
- "points": 1,
35
- },
36
- {
37
- "criteria": "Check if the agent gathered wind forecasts for each surf spot being evaluated.",
38
- "points": 1,
39
- },
40
- {
41
- "criteria": "Check if the agent gathered wave forecasts for each surf spot being evaluated.",
42
- "points": 1,
43
- },
44
- {
45
- "criteria": "Check if the agent used any web search tools to explore which surf spots should be considered",
46
- "points": 1,
47
- },
48
- {
49
- "criteria": "Check if the final answer contains any description about the weather (air temp, chance of rain, etc) at the chosen location",
50
- "points": 1,
51
- },
52
- {
53
- "criteria": "Check if the final answer includes one of the surf spots evaluated by tools",
54
- "points": 1,
55
- },
56
- {
57
- "criteria": "Check if the final answer includes information about some alternative surf spots if the user is not satisfied with the chosen one",
58
- "points": 1,
59
- },
60
- ],
61
- )
62
-
63
 
64
  DEFAULT_TOOLS = [
65
  get_wind_forecast,
 
1
  import os
2
 
3
+ from tools import (
 
4
  get_area_lat_lon,
5
  get_wave_forecast,
6
  get_wind_forecast,
7
  )
8
+
9
  from any_agent.logging import logger
10
+ from any_agent.tools.web_browsing import search_tavily, search_web, visit_webpage
11
 
12
  MODEL_OPTIONS = [
 
 
13
  "openai/gpt-4.1-nano",
14
  "openai/gpt-4.1-mini",
15
  "openai/gpt-4o",
16
  "gemini/gemini-2.0-flash-lite",
17
  "gemini/gemini-2.0-flash",
 
18
  ]
19
 
20
+ DEFAULT_EVALUATION_MODEL = MODEL_OPTIONS[0]
21
+
22
+ DEFAULT_EVALUATION_CRITERIA = [
23
+ {
24
+ "criteria": "Check if the agent considered at least three surf spot options",
25
+ },
26
+ {
27
+ "criteria": "Check if the agent gathered wind forecasts for each surf spot being evaluated.",
28
+ },
29
+ {
30
+ "criteria": "Check if the agent gathered wave forecasts for each surf spot being evaluated.",
31
+ },
32
+ {
33
+ "criteria": "Check if the agent used any web search tools to explore which surf spots should be considered",
34
+ },
35
+ {
36
+ "criteria": "Check if the final answer contains any description about the weather (air temp, chance of rain, etc) at the chosen location",
37
+ },
38
+ {
39
+ "criteria": "Check if the final answer includes one of the surf spots evaluated by tools",
40
+ },
41
+ {
42
+ "criteria": "Check if the final answer includes information about some alternative surf spots if the user is not satisfied with the chosen one",
43
+ },
44
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  DEFAULT_TOOLS = [
47
  get_wind_forecast,
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
- streamlit
2
- openai-agents>=0.0.14
3
- any-agent[all]==0.15.0
4
- surf-spot-finder @ git+https://github.com/mozilla-ai/surf-spot-finder
5
  nest_asyncio
 
 
1
+ any-agent[all]>=1.4.0
2
+ geocoder
 
 
3
  nest_asyncio
4
+ streamlit
services/agent.py CHANGED
@@ -1,63 +1,65 @@
1
  import json
 
 
 
2
  from components.inputs import UserInputs
3
  from constants import DEFAULT_TOOLS
4
- from components.agent_status import export_logs
5
- import streamlit as st
6
- from surf_spot_finder.config import Config
7
- from any_agent import AgentConfig, AnyAgent, TracingConfig, AgentFramework
8
- from any_agent.tracing.trace import AgentTrace, AgentSpan
9
- from any_agent.tracing.otel_types import StatusCode
10
- from any_agent.evaluation import evaluate, TraceEvaluationResult
11
 
 
 
 
 
 
 
12
 
13
- async def display_evaluation_results(result: TraceEvaluationResult):
14
- if result.ground_truth_result is not None:
15
- all_results = [*result.checkpoint_results, result.ground_truth_result]
16
- else:
17
- all_results = result.checkpoint_results
18
 
19
- # Create columns for better layout
20
  col1, col2 = st.columns(2)
21
 
22
  with col1:
23
  st.markdown("#### Criteria Results")
24
- for checkpoint in all_results:
25
- if checkpoint.passed:
26
- st.success(f"βœ… {checkpoint.criteria}")
27
  else:
28
- st.error(f"❌ {checkpoint.criteria}")
 
29
 
30
  with col2:
31
  st.markdown("#### Overall Score")
32
- total_points = sum([result.points for result in all_results])
33
- if total_points == 0:
34
- msg = "Total points is 0, cannot calculate score."
35
- raise ValueError(msg)
36
- passed_points = sum([result.points for result in all_results if result.passed])
37
-
38
- # Create a nice score display
39
- st.markdown(f"### {passed_points}/{total_points}")
40
- percentage = (passed_points / total_points) * 100
41
  st.progress(percentage / 100)
42
  st.markdown(f"**{percentage:.1f}%**")
43
 
44
 
45
  async def evaluate_agent(
46
  config: Config, agent_trace: AgentTrace
47
- ) -> TraceEvaluationResult:
48
- assert (
49
- len(config.evaluation_cases) == 1
50
- ), "Only one evaluation case is supported in the demo"
51
  st.markdown("### πŸ“Š Evaluation Results")
52
 
53
  with st.spinner("Evaluating results..."):
54
- case = config.evaluation_cases[0]
55
- result: TraceEvaluationResult = evaluate(
56
- evaluation_case=case,
57
- trace=agent_trace,
58
- agent_framework=config.framework,
59
- )
60
- return result
 
 
 
 
 
 
 
 
61
 
62
 
63
  async def configure_agent(user_inputs: UserInputs) -> tuple[AnyAgent, Config]:
@@ -87,47 +89,52 @@ async def configure_agent(user_inputs: UserInputs) -> tuple[AnyAgent, Config]:
87
  date=user_inputs.date,
88
  framework=user_inputs.framework,
89
  main_agent=agent_config,
90
- managed_agents=[],
91
- evaluation_cases=[user_inputs.evaluation_case],
92
  )
93
 
94
  agent = await AnyAgent.create_async(
95
  agent_framework=config.framework,
96
  agent_config=config.main_agent,
97
- managed_agents=config.managed_agents,
98
- tracing=TracingConfig(console=True, cost_info=True),
99
  )
100
  return agent, config
101
 
102
 
103
  async def display_output(agent_trace: AgentTrace):
104
- # Display the agent trace in a more organized way
105
  with st.expander("### 🧩 Agent Trace"):
106
  for span in agent_trace.spans:
107
- # Header with name and status
108
  col1, col2 = st.columns([4, 1])
109
  with col1:
110
  st.markdown(f"**{span.name}**")
111
  if span.attributes:
112
- # st.json(span.attributes, expanded=False)
113
- if "input.value" in span.attributes:
114
  try:
115
- input_value = json.loads(span.attributes["input.value"])
 
 
116
  if isinstance(input_value, list) and len(input_value) > 0:
117
  st.write(f"Input: {input_value[-1]}")
118
  else:
119
  st.write(f"Input: {input_value}")
120
- except Exception: # noqa: E722
121
- st.write(f"Input: {span.attributes['input.value']}")
122
- if "output.value" in span.attributes:
 
123
  try:
124
- output_value = json.loads(span.attributes["output.value"])
 
 
 
 
 
 
 
125
  if isinstance(output_value, list) and len(output_value) > 0:
126
  st.write(f"Output: {output_value[-1]}")
127
  else:
128
  st.write(f"Output: {output_value}")
129
- except Exception: # noqa: E722
130
- st.write(f"Output: {span.attributes['output.value']}")
131
  with col2:
132
  status_color = (
133
  "green" if span.status.status_code == StatusCode.OK else "red"
@@ -145,7 +152,7 @@ async def display_output(agent_trace: AgentTrace):
145
  with cost_col:
146
  st.info(f"πŸ’° Estimated Cost: ${agent_trace.cost.total_cost:.6f}")
147
  with tokens_col:
148
- st.info(f"πŸ“¦ Total Tokens: {agent_trace.usage.total_tokens:,}")
149
  st.markdown("#### Final Output")
150
  st.info(agent_trace.final_output)
151
 
@@ -179,49 +186,10 @@ async def run_agent(agent, config) -> AgentTrace:
179
 
180
  with st.status("Agent is running...", expanded=False, state="running") as status:
181
 
182
- def update_span(span: AgentSpan):
183
- # Process input value
184
- input_value = span.attributes.get("input.value", "")
185
- if input_value:
186
- try:
187
- parsed_input = json.loads(input_value)
188
- if isinstance(parsed_input, list) and len(parsed_input) > 0:
189
- input_value = str(parsed_input[-1])
190
- except Exception:
191
- pass
192
-
193
- # Process output value
194
- output_value = span.attributes.get("output.value", "")
195
- if output_value:
196
- try:
197
- parsed_output = json.loads(output_value)
198
- if isinstance(parsed_output, list) and len(parsed_output) > 0:
199
- output_value = str(parsed_output[-1])
200
- except Exception:
201
- pass
202
-
203
- # Truncate long values
204
- max_length = 800
205
- if len(input_value) > max_length:
206
- input_value = f"[Truncated]...{input_value[-max_length:]}"
207
- if len(output_value) > max_length:
208
- output_value = f"[Truncated]...{output_value[-max_length:]}"
209
-
210
- # Create a cleaner message format
211
- if input_value or output_value:
212
- message = f"Step: {span.name}\n"
213
- if input_value:
214
- message += f"Input: {input_value}\n"
215
- if output_value:
216
- message += f"Output: {output_value}"
217
- else:
218
- message = f"Step: {span.name}\n{span}"
219
-
220
  status.update(label=message, expanded=False, state="running")
221
 
222
- export_logs(agent, update_span)
223
  agent_trace: AgentTrace = await agent.run_async(query, **kwargs)
224
  status.update(label="Finished!", expanded=False, state="complete")
225
-
226
- agent.exit()
227
  return agent_trace
 
1
  import json
2
+
3
+ import streamlit as st
4
+ from components.agent_status import export_logs
5
  from components.inputs import UserInputs
6
  from constants import DEFAULT_TOOLS
7
+ from config import Config
 
 
 
 
 
 
8
 
9
+ from any_agent import AgentConfig, AgentFramework, AnyAgent
10
+ from any_agent.evaluation import LlmJudge
11
+ from any_agent.evaluation.schemas import EvaluationOutput
12
+ from any_agent.tracing.agent_trace import AgentTrace
13
+ from any_agent.tracing.attributes import GenAI
14
+ from any_agent.tracing.otel_types import StatusCode
15
 
 
 
 
 
 
16
 
17
+ async def display_evaluation_results(results: list[EvaluationOutput]):
18
  col1, col2 = st.columns(2)
19
 
20
  with col1:
21
  st.markdown("#### Criteria Results")
22
+ for i, result in enumerate(results):
23
+ if result.passed:
24
+ st.success(f"βœ… Criterion {i + 1}")
25
  else:
26
+ st.error(f"❌ Criterion {i + 1}")
27
+ st.write(f"**Reasoning:** {result.reasoning}")
28
 
29
  with col2:
30
  st.markdown("#### Overall Score")
31
+ total_criteria = len(results)
32
+ passed_criteria = sum(1 for result in results if result.passed)
33
+
34
+ st.markdown(f"### {passed_criteria}/{total_criteria}")
35
+ percentage = (
36
+ (passed_criteria / total_criteria) * 100 if total_criteria > 0 else 0
37
+ )
 
 
38
  st.progress(percentage / 100)
39
  st.markdown(f"**{percentage:.1f}%**")
40
 
41
 
42
  async def evaluate_agent(
43
  config: Config, agent_trace: AgentTrace
44
+ ) -> list[EvaluationOutput]:
 
 
 
45
  st.markdown("### πŸ“Š Evaluation Results")
46
 
47
  with st.spinner("Evaluating results..."):
48
+ results = []
49
+
50
+ judge = LlmJudge(model_id=config.evaluation_model, framework=config.framework)
51
+
52
+ for i, criterion in enumerate(config.evaluation_criteria):
53
+ context = f"Agent Trace:\n{agent_trace.model_dump_json(indent=2)}"
54
+
55
+ result = await judge.run_async(
56
+ context=context, question=criterion["criteria"]
57
+ )
58
+ results.append(result)
59
+
60
+ st.write(f"Evaluated criterion {i + 1}/{len(config.evaluation_criteria)}")
61
+
62
+ return results
63
 
64
 
65
  async def configure_agent(user_inputs: UserInputs) -> tuple[AnyAgent, Config]:
 
89
  date=user_inputs.date,
90
  framework=user_inputs.framework,
91
  main_agent=agent_config,
92
+ evaluation_model=user_inputs.evaluation_model,
93
+ evaluation_criteria=user_inputs.evaluation_criteria,
94
  )
95
 
96
  agent = await AnyAgent.create_async(
97
  agent_framework=config.framework,
98
  agent_config=config.main_agent,
 
 
99
  )
100
  return agent, config
101
 
102
 
103
  async def display_output(agent_trace: AgentTrace):
 
104
  with st.expander("### 🧩 Agent Trace"):
105
  for span in agent_trace.spans:
 
106
  col1, col2 = st.columns([4, 1])
107
  with col1:
108
  st.markdown(f"**{span.name}**")
109
  if span.attributes:
110
+ if GenAI.INPUT_MESSAGES in span.attributes:
 
111
  try:
112
+ input_value = json.loads(
113
+ span.attributes[GenAI.INPUT_MESSAGES]
114
+ )
115
  if isinstance(input_value, list) and len(input_value) > 0:
116
  st.write(f"Input: {input_value[-1]}")
117
  else:
118
  st.write(f"Input: {input_value}")
119
+ except Exception:
120
+ st.write(f"Input: {span.attributes[GenAI.INPUT_MESSAGES]}")
121
+
122
+ if GenAI.TOOL_ARGS in span.attributes:
123
  try:
124
+ tool_args = json.loads(span.attributes[GenAI.TOOL_ARGS])
125
+ st.write(f"Tool Args: {tool_args}")
126
+ except Exception:
127
+ st.write(f"Tool Args: {span.attributes[GenAI.TOOL_ARGS]}")
128
+
129
+ if GenAI.OUTPUT in span.attributes:
130
+ try:
131
+ output_value = json.loads(span.attributes[GenAI.OUTPUT])
132
  if isinstance(output_value, list) and len(output_value) > 0:
133
  st.write(f"Output: {output_value[-1]}")
134
  else:
135
  st.write(f"Output: {output_value}")
136
+ except Exception:
137
+ st.write(f"Output: {span.attributes[GenAI.OUTPUT]}")
138
  with col2:
139
  status_color = (
140
  "green" if span.status.status_code == StatusCode.OK else "red"
 
152
  with cost_col:
153
  st.info(f"πŸ’° Estimated Cost: ${agent_trace.cost.total_cost:.6f}")
154
  with tokens_col:
155
+ st.info(f"πŸ“¦ Total Tokens: {agent_trace.tokens.total_tokens:,}")
156
  st.markdown("#### Final Output")
157
  st.info(agent_trace.final_output)
158
 
 
186
 
187
  with st.status("Agent is running...", expanded=False, state="running") as status:
188
 
189
+ def update_status(message: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  status.update(label=message, expanded=False, state="running")
191
 
192
+ export_logs(agent, update_status)
193
  agent_trace: AgentTrace = await agent.run_async(query, **kwargs)
194
  status.update(label="Finished!", expanded=False, state="complete")
 
 
195
  return agent_trace
tools/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .openmeteo import get_wave_forecast, get_wind_forecast
2
+ from .openstreetmap import driving_hours_to_meters, get_area_lat_lon
3
+
4
+ __all__ = [
5
+ "driving_hours_to_meters",
6
+ "get_area_lat_lon",
7
+ "get_wave_forecast",
8
+ "get_wind_forecast",
9
+ ]
tools/openmeteo.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from datetime import datetime, timedelta
3
+
4
+ import requests
5
+
6
+
7
+ def _extract_hourly_data(data: dict) -> list[dict]:
8
+ hourly_data = data["hourly"]
9
+ result = [
10
+ {k: v for k, v in zip(hourly_data.keys(), values, strict=False)}
11
+ for values in zip(*hourly_data.values(), strict=False)
12
+ ]
13
+ return result
14
+
15
+
16
+ def _filter_by_date(
17
+ date: datetime, hourly_data: list[dict], timedelta: timedelta = timedelta(hours=1)
18
+ ):
19
+ start_date = date - timedelta
20
+ end_date = date + timedelta
21
+ return [
22
+ item
23
+ for item in hourly_data
24
+ if start_date <= datetime.fromisoformat(item["time"]) <= end_date
25
+ ]
26
+
27
+
28
+ def get_wave_forecast(lat: float, lon: float, date: str) -> list[dict]:
29
+ """Get wave forecast for given location.
30
+
31
+ Forecast will include:
32
+
33
+ - wave_direction (degrees)
34
+ - wave_height (meters)
35
+ - wave_period (seconds)
36
+ - sea_level_height_msl (meters)
37
+
38
+ Args:
39
+ lat: Latitude of the location.
40
+ lon: Longitude of the location.
41
+ date: Date to filter by in any valid ISO 8601 format.
42
+
43
+ Returns:
44
+ Hourly data for wave forecast.
45
+ Example output:
46
+
47
+ ```json
48
+ [
49
+ {'time': '2025-03-19T09:00', 'winddirection_10m': 140, 'windspeed_10m': 24.5}, {'time': '2025-03-19T10:00', 'winddirection_10m': 140, 'windspeed_10m': 27.1},
50
+ {'time': '2025-03-19T10:00', 'winddirection_10m': 140, 'windspeed_10m': 27.1}, {'time': '2025-03-19T11:00', 'winddirection_10m': 141, 'windspeed_10m': 29.2}
51
+ ]
52
+ ```
53
+
54
+ """
55
+ url = "https://marine-api.open-meteo.com/v1/marine"
56
+ params = {
57
+ "latitude": lat,
58
+ "longitude": lon,
59
+ "hourly": [
60
+ "wave_direction",
61
+ "wave_height",
62
+ "wave_period",
63
+ "sea_level_height_msl",
64
+ ],
65
+ }
66
+ response = requests.get(url, params=params)
67
+ response.raise_for_status()
68
+ data = json.loads(response.content.decode())
69
+ hourly_data = _extract_hourly_data(data)
70
+ if date is not None:
71
+ date = datetime.fromisoformat(date)
72
+ hourly_data = _filter_by_date(date, hourly_data)
73
+ if len(hourly_data) == 0:
74
+ raise ValueError("No data found for the given date")
75
+ return hourly_data
76
+
77
+
78
+ def get_wind_forecast(lat: float, lon: float, date: str) -> list[dict]:
79
+ """Get wind forecast for given location.
80
+
81
+ Forecast will include:
82
+
83
+ - wind_direction (degrees)
84
+ - wind_speed (meters per second)
85
+
86
+ Args:
87
+ lat: Latitude of the location.
88
+ lon: Longitude of the location.
89
+ date: Date to filter by in any valid ISO 8601 format.
90
+
91
+ Returns:
92
+ Hourly data for wind forecast.
93
+ Example output:
94
+
95
+ ```json
96
+ [
97
+ {"time": "2025-03-18T22:00", "wind_direction": 196, "wind_speed": 9.6},
98
+ {"time": "2025-03-18T23:00", "wind_direction": 183, "wind_speed": 7.9},
99
+ ]
100
+ ```
101
+
102
+ """
103
+ url = "https://api.open-meteo.com/v1/forecast"
104
+ params = {
105
+ "latitude": lat,
106
+ "longitude": lon,
107
+ "hourly": ["winddirection_10m", "windspeed_10m"],
108
+ }
109
+ response = requests.get(url, params=params)
110
+ response.raise_for_status()
111
+ data = json.loads(response.content.decode())
112
+ hourly_data = _extract_hourly_data(data)
113
+ date = datetime.fromisoformat(date)
114
+ hourly_data = _filter_by_date(date, hourly_data)
115
+ if len(hourly_data) == 0:
116
+ raise ValueError("No data found for the given date")
117
+ return hourly_data
tools/openstreetmap.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import requests
4
+
5
+
6
+ def get_area_lat_lon(area_name: str) -> tuple[float, float]:
7
+ """Get the latitude and longitude of an area from Nominatim.
8
+
9
+ Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).
10
+
11
+ Args:
12
+ area_name: The name of the area.
13
+
14
+ Returns:
15
+ The area found.
16
+
17
+ """
18
+ response = requests.get(
19
+ f"https://nominatim.openstreetmap.org/search?q={area_name}&format=jsonv2",
20
+ headers={"User-Agent": "Mozilla/5.0"},
21
+ )
22
+ response.raise_for_status()
23
+ area = json.loads(response.content.decode())
24
+ return area[0]["lat"], area[0]["lon"]
25
+
26
+
27
+ def driving_hours_to_meters(driving_hours: int) -> int:
28
+ """Convert driving hours to meters assuming a 70 km/h average speed.
29
+
30
+ Args:
31
+ driving_hours: The driving hours.
32
+
33
+ Returns:
34
+ The distance in meters.
35
+
36
+ """
37
+ return driving_hours * 70 * 1000
38
+
39
+
40
+ def get_lat_lon_center(bounds: dict) -> tuple[float, float]:
41
+ """Get the latitude and longitude of the center of a bounding box.
42
+
43
+ Args:
44
+ bounds: The bounding box.
45
+
46
+ ```json
47
+ {
48
+ "minlat": float,
49
+ "minlon": float,
50
+ "maxlat": float,
51
+ "maxlon": float,
52
+ }
53
+ ```
54
+
55
+ Returns:
56
+ The latitude and longitude of the center.
57
+
58
+ """
59
+ return (
60
+ (bounds["minlat"] + bounds["maxlat"]) / 2,
61
+ (bounds["minlon"] + bounds["maxlon"]) / 2,
62
+ )