Spaces:
Running
Running
github-actions[bot]
commited on
Commit
Β·
7e103cf
1
Parent(s):
d78bdc9
Sync with https://github.com/mozilla-ai/any-agent-demo
Browse files- README.md +0 -7
- app.py +7 -14
- components/agent_status.py +70 -35
- components/inputs.py +34 -46
- components/sidebar.py +3 -3
- config.py +43 -0
- constants.py +28 -46
- requirements.txt +3 -4
- services/agent.py +62 -94
- tools/__init__.py +9 -0
- tools/openmeteo.py +117 -0
- tools/openstreetmap.py +62 -0
README.md
CHANGED
@@ -11,10 +11,3 @@ pinned: false
|
|
11 |
short_description: Find a surf spot near you
|
12 |
license: apache-2.0
|
13 |
---
|
14 |
-
|
15 |
-
# Welcome to Streamlit!
|
16 |
-
|
17 |
-
Edit `/src/app.py` to customize this app to your heart's desire. :heart:
|
18 |
-
|
19 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
20 |
-
forums](https://discuss.streamlit.io).
|
|
|
11 |
short_description: Find a surf spot near you
|
12 |
license: apache-2.0
|
13 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
-
from components.sidebar import ssf_sidebar
|
2 |
-
from constants import DEFAULT_TOOLS
|
3 |
-
import streamlit as st
|
4 |
import asyncio
|
|
|
5 |
import nest_asyncio
|
|
|
|
|
|
|
6 |
from services.agent import (
|
7 |
configure_agent,
|
8 |
display_evaluation_results,
|
@@ -13,14 +14,11 @@ from services.agent import (
|
|
13 |
|
14 |
nest_asyncio.apply()
|
15 |
|
16 |
-
# Set page config
|
17 |
st.set_page_config(page_title="Surf Spot Finder", page_icon="π", layout="wide")
|
18 |
|
19 |
-
# Allow a user to resize the sidebar to take up most of the screen to make editing eval cases easier
|
20 |
st.markdown(
|
21 |
"""
|
22 |
<style>
|
23 |
-
/* When sidebar is expanded, adjust main content */
|
24 |
section[data-testid="stSidebar"][aria-expanded="true"] {
|
25 |
max-width: 99% !important;
|
26 |
}
|
@@ -35,18 +33,16 @@ with st.sidebar:
|
|
35 |
run_button = st.button("Run Agent π€", disabled=not is_valid, type="primary")
|
36 |
|
37 |
|
38 |
-
# Main content
|
39 |
async def main():
|
40 |
-
# Handle agent execution button click
|
41 |
if run_button:
|
42 |
agent, agent_config = await configure_agent(user_inputs)
|
43 |
agent_trace = await run_agent(agent, agent_config)
|
44 |
|
45 |
await display_output(agent_trace)
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
else:
|
51 |
st.title("π Surf Spot Finder")
|
52 |
st.markdown(
|
@@ -56,7 +52,6 @@ async def main():
|
|
56 |
"π Configure your search parameters in the sidebar and click Run to start!"
|
57 |
)
|
58 |
|
59 |
-
# Display tools in a more organized way
|
60 |
st.markdown("### π οΈ Available Tools")
|
61 |
|
62 |
st.markdown("""
|
@@ -92,7 +87,6 @@ async def main():
|
|
92 |
with st.expander(f"π {tool.__name__}"):
|
93 |
st.markdown(tool.__doc__ or "No description available")
|
94 |
|
95 |
-
# add a check that all tools were listed
|
96 |
if len(weather_tools) + len(location_tools) + len(web_tools) != len(
|
97 |
DEFAULT_TOOLS
|
98 |
):
|
@@ -100,7 +94,6 @@ async def main():
|
|
100 |
"Some tools are not listed. Please check the code for more details."
|
101 |
)
|
102 |
|
103 |
-
# Add Custom Evaluation explanation section
|
104 |
st.markdown("### π Custom Evaluation")
|
105 |
st.markdown("""
|
106 |
The Surf Spot Finder includes a powerful evaluation system that allows you to customize how the agent's performance is assessed.
|
|
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
+
|
3 |
import nest_asyncio
|
4 |
+
import streamlit as st
|
5 |
+
from components.sidebar import ssf_sidebar
|
6 |
+
from constants import DEFAULT_TOOLS
|
7 |
from services.agent import (
|
8 |
configure_agent,
|
9 |
display_evaluation_results,
|
|
|
14 |
|
15 |
nest_asyncio.apply()
|
16 |
|
|
|
17 |
st.set_page_config(page_title="Surf Spot Finder", page_icon="π", layout="wide")
|
18 |
|
|
|
19 |
st.markdown(
|
20 |
"""
|
21 |
<style>
|
|
|
22 |
section[data-testid="stSidebar"][aria-expanded="true"] {
|
23 |
max-width: 99% !important;
|
24 |
}
|
|
|
33 |
run_button = st.button("Run Agent π€", disabled=not is_valid, type="primary")
|
34 |
|
35 |
|
|
|
36 |
async def main():
|
|
|
37 |
if run_button:
|
38 |
agent, agent_config = await configure_agent(user_inputs)
|
39 |
agent_trace = await run_agent(agent, agent_config)
|
40 |
|
41 |
await display_output(agent_trace)
|
42 |
|
43 |
+
if user_inputs.run_evaluation:
|
44 |
+
evaluation_results = await evaluate_agent(agent_config, agent_trace)
|
45 |
+
await display_evaluation_results(evaluation_results)
|
46 |
else:
|
47 |
st.title("π Surf Spot Finder")
|
48 |
st.markdown(
|
|
|
52 |
"π Configure your search parameters in the sidebar and click Run to start!"
|
53 |
)
|
54 |
|
|
|
55 |
st.markdown("### π οΈ Available Tools")
|
56 |
|
57 |
st.markdown("""
|
|
|
87 |
with st.expander(f"π {tool.__name__}"):
|
88 |
st.markdown(tool.__doc__ or "No description available")
|
89 |
|
|
|
90 |
if len(weather_tools) + len(location_tools) + len(web_tools) != len(
|
91 |
DEFAULT_TOOLS
|
92 |
):
|
|
|
94 |
"Some tools are not listed. Please check the code for more details."
|
95 |
)
|
96 |
|
|
|
97 |
st.markdown("### π Custom Evaluation")
|
98 |
st.markdown("""
|
99 |
The Surf Spot Finder includes a powerful evaluation system that allows you to customize how the agent's performance is assessed.
|
components/agent_status.py
CHANGED
@@ -1,47 +1,82 @@
|
|
1 |
-
from
|
2 |
-
from
|
3 |
-
from collections.abc import Sequence
|
4 |
-
from typing import TYPE_CHECKING, Callable
|
5 |
|
6 |
-
from
|
7 |
-
|
8 |
-
SpanExportResult,
|
9 |
-
)
|
10 |
|
11 |
-
from any_agent import AgentFramework
|
12 |
|
13 |
-
|
14 |
-
|
15 |
|
16 |
-
|
17 |
-
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
|
22 |
|
23 |
-
def
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
agent_framework
|
29 |
-
)
|
30 |
-
self.callback = callback
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
return SpanExportResult.SUCCESS
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
40 |
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Callable
|
2 |
+
from typing import Any
|
|
|
|
|
3 |
|
4 |
+
from any_agent.callbacks import Callback, Context
|
5 |
+
from any_agent.tracing.attributes import GenAI
|
|
|
|
|
6 |
|
|
|
7 |
|
8 |
+
class StreamlitStatusCallback(Callback):
|
9 |
+
"""Callback to update Streamlit status with agent progress."""
|
10 |
|
11 |
+
def __init__(self, status_callback: Callable[[str], None]):
|
12 |
+
self.status_callback = status_callback
|
13 |
|
14 |
+
def after_llm_call(self, context: Context, *args, **kwargs) -> Context:
|
15 |
+
"""Update status after LLM calls."""
|
16 |
+
span = context.current_span
|
17 |
+
input_value = span.attributes.get(GenAI.INPUT_MESSAGES, "")
|
18 |
+
output_value = span.attributes.get(GenAI.OUTPUT, "")
|
19 |
|
20 |
+
self._update_status(span.name, input_value, output_value)
|
21 |
+
return context
|
22 |
|
23 |
+
def after_tool_execution(self, context: Context, *args, **kwargs) -> Context:
|
24 |
+
"""Update status after tool executions."""
|
25 |
+
span = context.current_span
|
26 |
+
input_value = span.attributes.get(GenAI.TOOL_ARGS, "")
|
27 |
+
output_value = span.attributes.get(GenAI.OUTPUT, "")
|
|
|
|
|
|
|
28 |
|
29 |
+
self._update_status(span.name, input_value, output_value)
|
30 |
+
return context
|
|
|
31 |
|
32 |
+
def _update_status(self, step_name: str, input_value: str, output_value: str):
|
33 |
+
"""Update the Streamlit status with formatted information."""
|
34 |
+
if input_value:
|
35 |
+
try:
|
36 |
+
import json
|
37 |
|
38 |
+
parsed_input = json.loads(input_value)
|
39 |
+
if isinstance(parsed_input, list) and len(parsed_input) > 0:
|
40 |
+
input_value = str(parsed_input[-1])
|
41 |
+
except Exception:
|
42 |
+
pass
|
43 |
|
44 |
+
if output_value:
|
45 |
+
try:
|
46 |
+
import json
|
47 |
|
48 |
+
parsed_output = json.loads(output_value)
|
49 |
+
if isinstance(parsed_output, list) and len(parsed_output) > 0:
|
50 |
+
output_value = str(parsed_output[-1])
|
51 |
+
except Exception:
|
52 |
+
pass
|
53 |
+
|
54 |
+
max_length = 800
|
55 |
+
if len(input_value) > max_length:
|
56 |
+
input_value = f"[Truncated]...{input_value[-max_length:]}"
|
57 |
+
if len(output_value) > max_length:
|
58 |
+
output_value = f"[Truncated]...{output_value[-max_length:]}"
|
59 |
+
|
60 |
+
if input_value or output_value:
|
61 |
+
message = f"Step: {step_name}\n"
|
62 |
+
if input_value:
|
63 |
+
message += f"Input: {input_value}\n"
|
64 |
+
if output_value:
|
65 |
+
message += f"Output: {output_value}"
|
66 |
+
else:
|
67 |
+
message = f"Step: {step_name}"
|
68 |
+
|
69 |
+
self.status_callback(message)
|
70 |
+
|
71 |
+
|
72 |
+
def export_logs(agent: Any, callback: Callable[[str], None]) -> None:
|
73 |
+
"""Add a Streamlit status callback to the agent.
|
74 |
+
|
75 |
+
This function adds a custom callback to the agent that will update
|
76 |
+
the Streamlit status with progress information during agent execution.
|
77 |
+
"""
|
78 |
+
status_callback = StreamlitStatusCallback(callback)
|
79 |
+
|
80 |
+
if agent.config.callbacks is None:
|
81 |
+
agent.config.callbacks = []
|
82 |
+
agent.config.callbacks.append(status_callback)
|
components/inputs.py
CHANGED
@@ -1,17 +1,19 @@
|
|
1 |
-
|
2 |
import json
|
|
|
|
|
|
|
3 |
import requests
|
4 |
import streamlit as st
|
5 |
-
from
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
from constants import DEFAULT_EVALUATION_CASE, MODEL_OPTIONS
|
11 |
-
import copy
|
12 |
-
|
13 |
from pydantic import BaseModel, ConfigDict
|
14 |
|
|
|
|
|
15 |
|
16 |
class UserInputs(BaseModel):
|
17 |
model_config = ConfigDict(extra="forbid")
|
@@ -20,7 +22,8 @@ class UserInputs(BaseModel):
|
|
20 |
max_driving_hours: int
|
21 |
date: datetime
|
22 |
framework: str
|
23 |
-
|
|
|
24 |
run_evaluation: bool
|
25 |
|
26 |
|
@@ -35,6 +38,7 @@ def get_area(area_name: str) -> dict:
|
|
35 |
|
36 |
Returns:
|
37 |
dict: The area found.
|
|
|
38 |
"""
|
39 |
response = requests.get(
|
40 |
f"https://nominatim.openstreetmap.org/search?q={area_name}&format=jsonv2",
|
@@ -42,8 +46,7 @@ def get_area(area_name: str) -> dict:
|
|
42 |
timeout=5,
|
43 |
)
|
44 |
response.raise_for_status()
|
45 |
-
|
46 |
-
return response_json
|
47 |
|
48 |
|
49 |
def get_user_inputs() -> UserInputs:
|
@@ -65,7 +68,6 @@ def get_user_inputs() -> UserInputs:
|
|
65 |
"Select a date in the future", value=datetime.now() + timedelta(days=1)
|
66 |
)
|
67 |
with col_time:
|
68 |
-
# default to 9am
|
69 |
time = st.selectbox(
|
70 |
"Select a time",
|
71 |
[datetime.strptime(f"{i:02d}:00", "%H:%M").time() for i in range(24)],
|
@@ -73,9 +75,7 @@ def get_user_inputs() -> UserInputs:
|
|
73 |
)
|
74 |
date = datetime.combine(date, time)
|
75 |
|
76 |
-
supported_frameworks = [
|
77 |
-
framework for framework in AgentFramework if _is_tracing_supported(framework)
|
78 |
-
]
|
79 |
|
80 |
framework = st.selectbox(
|
81 |
"Select the agent framework to use",
|
@@ -91,7 +91,6 @@ def get_user_inputs() -> UserInputs:
|
|
91 |
format_func=lambda x: "/".join(x.split("/")[-3:]),
|
92 |
)
|
93 |
|
94 |
-
# Add evaluation case section
|
95 |
with st.expander("Custom Evaluation"):
|
96 |
evaluation_model_id = st.selectbox(
|
97 |
"Select the model to use for LLM-as-a-Judge evaluation",
|
@@ -99,47 +98,35 @@ def get_user_inputs() -> UserInputs:
|
|
99 |
index=2,
|
100 |
format_func=lambda x: "/".join(x.split("/")[-3:]),
|
101 |
)
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
[checkpoint.model_dump() for checkpoint in checkpoints]
|
109 |
-
)
|
110 |
-
checkpoints_df = st.data_editor(
|
111 |
-
checkpoints_df,
|
112 |
column_config={
|
113 |
-
"points": st.column_config.NumberColumn(label="Points"),
|
114 |
"criteria": st.column_config.TextColumn(label="Criteria"),
|
115 |
},
|
116 |
hide_index=True,
|
117 |
num_rows="dynamic",
|
118 |
)
|
119 |
-
# for each checkpoint, convert it back to a CheckpointCriteria object
|
120 |
-
new_ckpts = []
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
checkpoints_df = checkpoints_df[:20]
|
128 |
|
129 |
-
for _, row in
|
130 |
if row["criteria"] == "":
|
131 |
continue
|
132 |
try:
|
133 |
-
# Don't let people write essays for criteria in this demo
|
134 |
if len(row["criteria"].split(" ")) > 100:
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
)
|
139 |
-
new_ckpts.append(new_crit)
|
140 |
except Exception as e:
|
141 |
-
st.error(f"Error creating
|
142 |
-
evaluation_case.checkpoints = new_ckpts
|
143 |
|
144 |
return UserInputs(
|
145 |
model_id=model_id,
|
@@ -147,6 +134,7 @@ def get_user_inputs() -> UserInputs:
|
|
147 |
max_driving_hours=max_driving_hours,
|
148 |
date=date,
|
149 |
framework=framework,
|
150 |
-
|
|
|
151 |
run_evaluation=st.checkbox("Run Evaluation", value=True),
|
152 |
)
|
|
|
1 |
+
import copy
|
2 |
import json
|
3 |
+
from datetime import datetime, timedelta
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
import requests
|
7 |
import streamlit as st
|
8 |
+
from constants import (
|
9 |
+
DEFAULT_EVALUATION_CRITERIA,
|
10 |
+
DEFAULT_EVALUATION_MODEL,
|
11 |
+
MODEL_OPTIONS,
|
12 |
+
)
|
|
|
|
|
|
|
13 |
from pydantic import BaseModel, ConfigDict
|
14 |
|
15 |
+
from any_agent import AgentFramework
|
16 |
+
|
17 |
|
18 |
class UserInputs(BaseModel):
|
19 |
model_config = ConfigDict(extra="forbid")
|
|
|
22 |
max_driving_hours: int
|
23 |
date: datetime
|
24 |
framework: str
|
25 |
+
evaluation_model: str
|
26 |
+
evaluation_criteria: list[dict[str, str]]
|
27 |
run_evaluation: bool
|
28 |
|
29 |
|
|
|
38 |
|
39 |
Returns:
|
40 |
dict: The area found.
|
41 |
+
|
42 |
"""
|
43 |
response = requests.get(
|
44 |
f"https://nominatim.openstreetmap.org/search?q={area_name}&format=jsonv2",
|
|
|
46 |
timeout=5,
|
47 |
)
|
48 |
response.raise_for_status()
|
49 |
+
return json.loads(response.content.decode())
|
|
|
50 |
|
51 |
|
52 |
def get_user_inputs() -> UserInputs:
|
|
|
68 |
"Select a date in the future", value=datetime.now() + timedelta(days=1)
|
69 |
)
|
70 |
with col_time:
|
|
|
71 |
time = st.selectbox(
|
72 |
"Select a time",
|
73 |
[datetime.strptime(f"{i:02d}:00", "%H:%M").time() for i in range(24)],
|
|
|
75 |
)
|
76 |
date = datetime.combine(date, time)
|
77 |
|
78 |
+
supported_frameworks = [framework for framework in AgentFramework]
|
|
|
|
|
79 |
|
80 |
framework = st.selectbox(
|
81 |
"Select the agent framework to use",
|
|
|
91 |
format_func=lambda x: "/".join(x.split("/")[-3:]),
|
92 |
)
|
93 |
|
|
|
94 |
with st.expander("Custom Evaluation"):
|
95 |
evaluation_model_id = st.selectbox(
|
96 |
"Select the model to use for LLM-as-a-Judge evaluation",
|
|
|
98 |
index=2,
|
99 |
format_func=lambda x: "/".join(x.split("/")[-3:]),
|
100 |
)
|
101 |
+
|
102 |
+
evaluation_criteria = copy.deepcopy(DEFAULT_EVALUATION_CRITERIA)
|
103 |
+
|
104 |
+
criteria_df = pd.DataFrame(evaluation_criteria)
|
105 |
+
criteria_df = st.data_editor(
|
106 |
+
criteria_df,
|
|
|
|
|
|
|
|
|
107 |
column_config={
|
|
|
108 |
"criteria": st.column_config.TextColumn(label="Criteria"),
|
109 |
},
|
110 |
hide_index=True,
|
111 |
num_rows="dynamic",
|
112 |
)
|
|
|
|
|
113 |
|
114 |
+
new_criteria = []
|
115 |
+
|
116 |
+
if len(criteria_df) > 20:
|
117 |
+
st.error("You can only add up to 20 criteria for the purpose of this demo.")
|
118 |
+
criteria_df = criteria_df[:20]
|
|
|
119 |
|
120 |
+
for _, row in criteria_df.iterrows():
|
121 |
if row["criteria"] == "":
|
122 |
continue
|
123 |
try:
|
|
|
124 |
if len(row["criteria"].split(" ")) > 100:
|
125 |
+
msg = "Criteria is too long"
|
126 |
+
raise ValueError(msg)
|
127 |
+
new_criteria.append({"criteria": row["criteria"]})
|
|
|
|
|
128 |
except Exception as e:
|
129 |
+
st.error(f"Error creating criterion: {e}")
|
|
|
130 |
|
131 |
return UserInputs(
|
132 |
model_id=model_id,
|
|
|
134 |
max_driving_hours=max_driving_hours,
|
135 |
date=date,
|
136 |
framework=framework,
|
137 |
+
evaluation_model=evaluation_model_id,
|
138 |
+
evaluation_criteria=new_criteria,
|
139 |
run_evaluation=st.checkbox("Run Evaluation", value=True),
|
140 |
)
|
components/sidebar.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
-
from components.inputs import UserInputs, get_user_inputs
|
2 |
import streamlit as st
|
3 |
|
|
|
|
|
4 |
|
5 |
def ssf_sidebar() -> UserInputs:
|
6 |
st.markdown("### Configuration")
|
7 |
st.markdown("Built using [Any-Agent](https://github.com/mozilla-ai/any-agent)")
|
8 |
-
|
9 |
-
return user_inputs
|
|
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
+
from components.inputs import UserInputs, get_user_inputs
|
4 |
+
|
5 |
|
6 |
def ssf_sidebar() -> UserInputs:
|
7 |
st.markdown("### Configuration")
|
8 |
st.markdown("Built using [Any-Agent](https://github.com/mozilla-ai/any-agent)")
|
9 |
+
return get_user_inputs()
|
|
config.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
from datetime import datetime, timedelta
|
4 |
+
from typing import Annotated
|
5 |
+
|
6 |
+
import geocoder
|
7 |
+
from pydantic import AfterValidator, BaseModel, ConfigDict, FutureDatetime, PositiveInt
|
8 |
+
from rich.prompt import Prompt
|
9 |
+
|
10 |
+
from any_agent import AgentFramework
|
11 |
+
from any_agent.config import AgentConfig
|
12 |
+
from any_agent.logging import logger
|
13 |
+
|
14 |
+
INPUT_PROMPT_TEMPLATE = """
|
15 |
+
According to the forecast, what will be the best spot to surf around {LOCATION},
|
16 |
+
in a {MAX_DRIVING_HOURS} hour driving radius,
|
17 |
+
at {DATE}?"
|
18 |
+
""".strip()
|
19 |
+
|
20 |
+
|
21 |
+
def validate_prompt(value) -> str:
|
22 |
+
for placeholder in ("{LOCATION}", "{MAX_DRIVING_HOURS}", "{DATE}"):
|
23 |
+
if placeholder not in value:
|
24 |
+
raise ValueError(f"prompt must contain {placeholder}")
|
25 |
+
return value
|
26 |
+
|
27 |
+
|
28 |
+
class Config(BaseModel):
|
29 |
+
model_config = ConfigDict(extra="forbid")
|
30 |
+
|
31 |
+
location: str
|
32 |
+
max_driving_hours: PositiveInt
|
33 |
+
date: FutureDatetime
|
34 |
+
input_prompt_template: Annotated[str, AfterValidator(validate_prompt)] = (
|
35 |
+
INPUT_PROMPT_TEMPLATE
|
36 |
+
)
|
37 |
+
|
38 |
+
framework: AgentFramework
|
39 |
+
|
40 |
+
main_agent: AgentConfig
|
41 |
+
|
42 |
+
evaluation_model: str | None = None
|
43 |
+
evaluation_criteria: list[dict[str, str]] | None = None
|
constants.py
CHANGED
@@ -1,65 +1,47 @@
|
|
1 |
import os
|
2 |
|
3 |
-
from
|
4 |
-
from surf_spot_finder.tools import (
|
5 |
get_area_lat_lon,
|
6 |
get_wave_forecast,
|
7 |
get_wind_forecast,
|
8 |
)
|
|
|
9 |
from any_agent.logging import logger
|
10 |
-
from any_agent.tools.web_browsing import search_web, visit_webpage
|
11 |
|
12 |
MODEL_OPTIONS = [
|
13 |
-
# "huggingface/novita/deepseek-ai/DeepSeek-V3",
|
14 |
-
# "huggingface/novita/meta-llama/Llama-3.3-70B-Instruct",
|
15 |
"openai/gpt-4.1-nano",
|
16 |
"openai/gpt-4.1-mini",
|
17 |
"openai/gpt-4o",
|
18 |
"gemini/gemini-2.0-flash-lite",
|
19 |
"gemini/gemini-2.0-flash",
|
20 |
-
# "huggingface/Qwen/Qwen3-32B", # right now throwing an internal error, but novita qwen isn't supporting tool calling
|
21 |
]
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
{
|
49 |
-
"criteria": "Check if the final answer contains any description about the weather (air temp, chance of rain, etc) at the chosen location",
|
50 |
-
"points": 1,
|
51 |
-
},
|
52 |
-
{
|
53 |
-
"criteria": "Check if the final answer includes one of the surf spots evaluated by tools",
|
54 |
-
"points": 1,
|
55 |
-
},
|
56 |
-
{
|
57 |
-
"criteria": "Check if the final answer includes information about some alternative surf spots if the user is not satisfied with the chosen one",
|
58 |
-
"points": 1,
|
59 |
-
},
|
60 |
-
],
|
61 |
-
)
|
62 |
-
|
63 |
|
64 |
DEFAULT_TOOLS = [
|
65 |
get_wind_forecast,
|
|
|
1 |
import os
|
2 |
|
3 |
+
from tools import (
|
|
|
4 |
get_area_lat_lon,
|
5 |
get_wave_forecast,
|
6 |
get_wind_forecast,
|
7 |
)
|
8 |
+
|
9 |
from any_agent.logging import logger
|
10 |
+
from any_agent.tools.web_browsing import search_tavily, search_web, visit_webpage
|
11 |
|
12 |
MODEL_OPTIONS = [
|
|
|
|
|
13 |
"openai/gpt-4.1-nano",
|
14 |
"openai/gpt-4.1-mini",
|
15 |
"openai/gpt-4o",
|
16 |
"gemini/gemini-2.0-flash-lite",
|
17 |
"gemini/gemini-2.0-flash",
|
|
|
18 |
]
|
19 |
|
20 |
+
DEFAULT_EVALUATION_MODEL = MODEL_OPTIONS[0]
|
21 |
+
|
22 |
+
DEFAULT_EVALUATION_CRITERIA = [
|
23 |
+
{
|
24 |
+
"criteria": "Check if the agent considered at least three surf spot options",
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"criteria": "Check if the agent gathered wind forecasts for each surf spot being evaluated.",
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"criteria": "Check if the agent gathered wave forecasts for each surf spot being evaluated.",
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"criteria": "Check if the agent used any web search tools to explore which surf spots should be considered",
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"criteria": "Check if the final answer contains any description about the weather (air temp, chance of rain, etc) at the chosen location",
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"criteria": "Check if the final answer includes one of the surf spots evaluated by tools",
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"criteria": "Check if the final answer includes information about some alternative surf spots if the user is not satisfied with the chosen one",
|
43 |
+
},
|
44 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
DEFAULT_TOOLS = [
|
47 |
get_wind_forecast,
|
requirements.txt
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
any-agent[all]==0.15.0
|
4 |
-
surf-spot-finder @ git+https://github.com/mozilla-ai/surf-spot-finder
|
5 |
nest_asyncio
|
|
|
|
1 |
+
any-agent[all]>=1.4.0
|
2 |
+
geocoder
|
|
|
|
|
3 |
nest_asyncio
|
4 |
+
streamlit
|
services/agent.py
CHANGED
@@ -1,63 +1,65 @@
|
|
1 |
import json
|
|
|
|
|
|
|
2 |
from components.inputs import UserInputs
|
3 |
from constants import DEFAULT_TOOLS
|
4 |
-
from
|
5 |
-
import streamlit as st
|
6 |
-
from surf_spot_finder.config import Config
|
7 |
-
from any_agent import AgentConfig, AnyAgent, TracingConfig, AgentFramework
|
8 |
-
from any_agent.tracing.trace import AgentTrace, AgentSpan
|
9 |
-
from any_agent.tracing.otel_types import StatusCode
|
10 |
-
from any_agent.evaluation import evaluate, TraceEvaluationResult
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
async def display_evaluation_results(result: TraceEvaluationResult):
|
14 |
-
if result.ground_truth_result is not None:
|
15 |
-
all_results = [*result.checkpoint_results, result.ground_truth_result]
|
16 |
-
else:
|
17 |
-
all_results = result.checkpoint_results
|
18 |
|
19 |
-
|
20 |
col1, col2 = st.columns(2)
|
21 |
|
22 |
with col1:
|
23 |
st.markdown("#### Criteria Results")
|
24 |
-
for
|
25 |
-
if
|
26 |
-
st.success(f"β
{
|
27 |
else:
|
28 |
-
st.error(f"β {
|
|
|
29 |
|
30 |
with col2:
|
31 |
st.markdown("#### Overall Score")
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
st.markdown(f"### {passed_points}/{total_points}")
|
40 |
-
percentage = (passed_points / total_points) * 100
|
41 |
st.progress(percentage / 100)
|
42 |
st.markdown(f"**{percentage:.1f}%**")
|
43 |
|
44 |
|
45 |
async def evaluate_agent(
|
46 |
config: Config, agent_trace: AgentTrace
|
47 |
-
) ->
|
48 |
-
assert (
|
49 |
-
len(config.evaluation_cases) == 1
|
50 |
-
), "Only one evaluation case is supported in the demo"
|
51 |
st.markdown("### π Evaluation Results")
|
52 |
|
53 |
with st.spinner("Evaluating results..."):
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
async def configure_agent(user_inputs: UserInputs) -> tuple[AnyAgent, Config]:
|
@@ -87,47 +89,52 @@ async def configure_agent(user_inputs: UserInputs) -> tuple[AnyAgent, Config]:
|
|
87 |
date=user_inputs.date,
|
88 |
framework=user_inputs.framework,
|
89 |
main_agent=agent_config,
|
90 |
-
|
91 |
-
|
92 |
)
|
93 |
|
94 |
agent = await AnyAgent.create_async(
|
95 |
agent_framework=config.framework,
|
96 |
agent_config=config.main_agent,
|
97 |
-
managed_agents=config.managed_agents,
|
98 |
-
tracing=TracingConfig(console=True, cost_info=True),
|
99 |
)
|
100 |
return agent, config
|
101 |
|
102 |
|
103 |
async def display_output(agent_trace: AgentTrace):
|
104 |
-
# Display the agent trace in a more organized way
|
105 |
with st.expander("### π§© Agent Trace"):
|
106 |
for span in agent_trace.spans:
|
107 |
-
# Header with name and status
|
108 |
col1, col2 = st.columns([4, 1])
|
109 |
with col1:
|
110 |
st.markdown(f"**{span.name}**")
|
111 |
if span.attributes:
|
112 |
-
|
113 |
-
if "input.value" in span.attributes:
|
114 |
try:
|
115 |
-
input_value = json.loads(
|
|
|
|
|
116 |
if isinstance(input_value, list) and len(input_value) > 0:
|
117 |
st.write(f"Input: {input_value[-1]}")
|
118 |
else:
|
119 |
st.write(f"Input: {input_value}")
|
120 |
-
except Exception:
|
121 |
-
st.write(f"Input: {span.attributes[
|
122 |
-
|
|
|
123 |
try:
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
if isinstance(output_value, list) and len(output_value) > 0:
|
126 |
st.write(f"Output: {output_value[-1]}")
|
127 |
else:
|
128 |
st.write(f"Output: {output_value}")
|
129 |
-
except Exception:
|
130 |
-
st.write(f"Output: {span.attributes[
|
131 |
with col2:
|
132 |
status_color = (
|
133 |
"green" if span.status.status_code == StatusCode.OK else "red"
|
@@ -145,7 +152,7 @@ async def display_output(agent_trace: AgentTrace):
|
|
145 |
with cost_col:
|
146 |
st.info(f"π° Estimated Cost: ${agent_trace.cost.total_cost:.6f}")
|
147 |
with tokens_col:
|
148 |
-
st.info(f"π¦ Total Tokens: {agent_trace.
|
149 |
st.markdown("#### Final Output")
|
150 |
st.info(agent_trace.final_output)
|
151 |
|
@@ -179,49 +186,10 @@ async def run_agent(agent, config) -> AgentTrace:
|
|
179 |
|
180 |
with st.status("Agent is running...", expanded=False, state="running") as status:
|
181 |
|
182 |
-
def
|
183 |
-
# Process input value
|
184 |
-
input_value = span.attributes.get("input.value", "")
|
185 |
-
if input_value:
|
186 |
-
try:
|
187 |
-
parsed_input = json.loads(input_value)
|
188 |
-
if isinstance(parsed_input, list) and len(parsed_input) > 0:
|
189 |
-
input_value = str(parsed_input[-1])
|
190 |
-
except Exception:
|
191 |
-
pass
|
192 |
-
|
193 |
-
# Process output value
|
194 |
-
output_value = span.attributes.get("output.value", "")
|
195 |
-
if output_value:
|
196 |
-
try:
|
197 |
-
parsed_output = json.loads(output_value)
|
198 |
-
if isinstance(parsed_output, list) and len(parsed_output) > 0:
|
199 |
-
output_value = str(parsed_output[-1])
|
200 |
-
except Exception:
|
201 |
-
pass
|
202 |
-
|
203 |
-
# Truncate long values
|
204 |
-
max_length = 800
|
205 |
-
if len(input_value) > max_length:
|
206 |
-
input_value = f"[Truncated]...{input_value[-max_length:]}"
|
207 |
-
if len(output_value) > max_length:
|
208 |
-
output_value = f"[Truncated]...{output_value[-max_length:]}"
|
209 |
-
|
210 |
-
# Create a cleaner message format
|
211 |
-
if input_value or output_value:
|
212 |
-
message = f"Step: {span.name}\n"
|
213 |
-
if input_value:
|
214 |
-
message += f"Input: {input_value}\n"
|
215 |
-
if output_value:
|
216 |
-
message += f"Output: {output_value}"
|
217 |
-
else:
|
218 |
-
message = f"Step: {span.name}\n{span}"
|
219 |
-
|
220 |
status.update(label=message, expanded=False, state="running")
|
221 |
|
222 |
-
export_logs(agent,
|
223 |
agent_trace: AgentTrace = await agent.run_async(query, **kwargs)
|
224 |
status.update(label="Finished!", expanded=False, state="complete")
|
225 |
-
|
226 |
-
agent.exit()
|
227 |
return agent_trace
|
|
|
1 |
import json
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
from components.agent_status import export_logs
|
5 |
from components.inputs import UserInputs
|
6 |
from constants import DEFAULT_TOOLS
|
7 |
+
from config import Config
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
from any_agent import AgentConfig, AgentFramework, AnyAgent
|
10 |
+
from any_agent.evaluation import LlmJudge
|
11 |
+
from any_agent.evaluation.schemas import EvaluationOutput
|
12 |
+
from any_agent.tracing.agent_trace import AgentTrace
|
13 |
+
from any_agent.tracing.attributes import GenAI
|
14 |
+
from any_agent.tracing.otel_types import StatusCode
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
async def display_evaluation_results(results: list[EvaluationOutput]):
|
18 |
col1, col2 = st.columns(2)
|
19 |
|
20 |
with col1:
|
21 |
st.markdown("#### Criteria Results")
|
22 |
+
for i, result in enumerate(results):
|
23 |
+
if result.passed:
|
24 |
+
st.success(f"β
Criterion {i + 1}")
|
25 |
else:
|
26 |
+
st.error(f"β Criterion {i + 1}")
|
27 |
+
st.write(f"**Reasoning:** {result.reasoning}")
|
28 |
|
29 |
with col2:
|
30 |
st.markdown("#### Overall Score")
|
31 |
+
total_criteria = len(results)
|
32 |
+
passed_criteria = sum(1 for result in results if result.passed)
|
33 |
+
|
34 |
+
st.markdown(f"### {passed_criteria}/{total_criteria}")
|
35 |
+
percentage = (
|
36 |
+
(passed_criteria / total_criteria) * 100 if total_criteria > 0 else 0
|
37 |
+
)
|
|
|
|
|
38 |
st.progress(percentage / 100)
|
39 |
st.markdown(f"**{percentage:.1f}%**")
|
40 |
|
41 |
|
42 |
async def evaluate_agent(
|
43 |
config: Config, agent_trace: AgentTrace
|
44 |
+
) -> list[EvaluationOutput]:
|
|
|
|
|
|
|
45 |
st.markdown("### π Evaluation Results")
|
46 |
|
47 |
with st.spinner("Evaluating results..."):
|
48 |
+
results = []
|
49 |
+
|
50 |
+
judge = LlmJudge(model_id=config.evaluation_model, framework=config.framework)
|
51 |
+
|
52 |
+
for i, criterion in enumerate(config.evaluation_criteria):
|
53 |
+
context = f"Agent Trace:\n{agent_trace.model_dump_json(indent=2)}"
|
54 |
+
|
55 |
+
result = await judge.run_async(
|
56 |
+
context=context, question=criterion["criteria"]
|
57 |
+
)
|
58 |
+
results.append(result)
|
59 |
+
|
60 |
+
st.write(f"Evaluated criterion {i + 1}/{len(config.evaluation_criteria)}")
|
61 |
+
|
62 |
+
return results
|
63 |
|
64 |
|
65 |
async def configure_agent(user_inputs: UserInputs) -> tuple[AnyAgent, Config]:
|
|
|
89 |
date=user_inputs.date,
|
90 |
framework=user_inputs.framework,
|
91 |
main_agent=agent_config,
|
92 |
+
evaluation_model=user_inputs.evaluation_model,
|
93 |
+
evaluation_criteria=user_inputs.evaluation_criteria,
|
94 |
)
|
95 |
|
96 |
agent = await AnyAgent.create_async(
|
97 |
agent_framework=config.framework,
|
98 |
agent_config=config.main_agent,
|
|
|
|
|
99 |
)
|
100 |
return agent, config
|
101 |
|
102 |
|
103 |
async def display_output(agent_trace: AgentTrace):
|
|
|
104 |
with st.expander("### π§© Agent Trace"):
|
105 |
for span in agent_trace.spans:
|
|
|
106 |
col1, col2 = st.columns([4, 1])
|
107 |
with col1:
|
108 |
st.markdown(f"**{span.name}**")
|
109 |
if span.attributes:
|
110 |
+
if GenAI.INPUT_MESSAGES in span.attributes:
|
|
|
111 |
try:
|
112 |
+
input_value = json.loads(
|
113 |
+
span.attributes[GenAI.INPUT_MESSAGES]
|
114 |
+
)
|
115 |
if isinstance(input_value, list) and len(input_value) > 0:
|
116 |
st.write(f"Input: {input_value[-1]}")
|
117 |
else:
|
118 |
st.write(f"Input: {input_value}")
|
119 |
+
except Exception:
|
120 |
+
st.write(f"Input: {span.attributes[GenAI.INPUT_MESSAGES]}")
|
121 |
+
|
122 |
+
if GenAI.TOOL_ARGS in span.attributes:
|
123 |
try:
|
124 |
+
tool_args = json.loads(span.attributes[GenAI.TOOL_ARGS])
|
125 |
+
st.write(f"Tool Args: {tool_args}")
|
126 |
+
except Exception:
|
127 |
+
st.write(f"Tool Args: {span.attributes[GenAI.TOOL_ARGS]}")
|
128 |
+
|
129 |
+
if GenAI.OUTPUT in span.attributes:
|
130 |
+
try:
|
131 |
+
output_value = json.loads(span.attributes[GenAI.OUTPUT])
|
132 |
if isinstance(output_value, list) and len(output_value) > 0:
|
133 |
st.write(f"Output: {output_value[-1]}")
|
134 |
else:
|
135 |
st.write(f"Output: {output_value}")
|
136 |
+
except Exception:
|
137 |
+
st.write(f"Output: {span.attributes[GenAI.OUTPUT]}")
|
138 |
with col2:
|
139 |
status_color = (
|
140 |
"green" if span.status.status_code == StatusCode.OK else "red"
|
|
|
152 |
with cost_col:
|
153 |
st.info(f"π° Estimated Cost: ${agent_trace.cost.total_cost:.6f}")
|
154 |
with tokens_col:
|
155 |
+
st.info(f"π¦ Total Tokens: {agent_trace.tokens.total_tokens:,}")
|
156 |
st.markdown("#### Final Output")
|
157 |
st.info(agent_trace.final_output)
|
158 |
|
|
|
186 |
|
187 |
with st.status("Agent is running...", expanded=False, state="running") as status:
|
188 |
|
189 |
+
def update_status(message: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
status.update(label=message, expanded=False, state="running")
|
191 |
|
192 |
+
export_logs(agent, update_status)
|
193 |
agent_trace: AgentTrace = await agent.run_async(query, **kwargs)
|
194 |
status.update(label="Finished!", expanded=False, state="complete")
|
|
|
|
|
195 |
return agent_trace
|
tools/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .openmeteo import get_wave_forecast, get_wind_forecast
|
2 |
+
from .openstreetmap import driving_hours_to_meters, get_area_lat_lon
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"driving_hours_to_meters",
|
6 |
+
"get_area_lat_lon",
|
7 |
+
"get_wave_forecast",
|
8 |
+
"get_wind_forecast",
|
9 |
+
]
|
tools/openmeteo.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from datetime import datetime, timedelta
|
3 |
+
|
4 |
+
import requests
|
5 |
+
|
6 |
+
|
7 |
+
def _extract_hourly_data(data: dict) -> list[dict]:
|
8 |
+
hourly_data = data["hourly"]
|
9 |
+
result = [
|
10 |
+
{k: v for k, v in zip(hourly_data.keys(), values, strict=False)}
|
11 |
+
for values in zip(*hourly_data.values(), strict=False)
|
12 |
+
]
|
13 |
+
return result
|
14 |
+
|
15 |
+
|
16 |
+
def _filter_by_date(
|
17 |
+
date: datetime, hourly_data: list[dict], timedelta: timedelta = timedelta(hours=1)
|
18 |
+
):
|
19 |
+
start_date = date - timedelta
|
20 |
+
end_date = date + timedelta
|
21 |
+
return [
|
22 |
+
item
|
23 |
+
for item in hourly_data
|
24 |
+
if start_date <= datetime.fromisoformat(item["time"]) <= end_date
|
25 |
+
]
|
26 |
+
|
27 |
+
|
28 |
+
def get_wave_forecast(lat: float, lon: float, date: str) -> list[dict]:
|
29 |
+
"""Get wave forecast for given location.
|
30 |
+
|
31 |
+
Forecast will include:
|
32 |
+
|
33 |
+
- wave_direction (degrees)
|
34 |
+
- wave_height (meters)
|
35 |
+
- wave_period (seconds)
|
36 |
+
- sea_level_height_msl (meters)
|
37 |
+
|
38 |
+
Args:
|
39 |
+
lat: Latitude of the location.
|
40 |
+
lon: Longitude of the location.
|
41 |
+
date: Date to filter by in any valid ISO 8601 format.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Hourly data for wave forecast.
|
45 |
+
Example output:
|
46 |
+
|
47 |
+
```json
|
48 |
+
[
|
49 |
+
{'time': '2025-03-19T09:00', 'winddirection_10m': 140, 'windspeed_10m': 24.5}, {'time': '2025-03-19T10:00', 'winddirection_10m': 140, 'windspeed_10m': 27.1},
|
50 |
+
{'time': '2025-03-19T10:00', 'winddirection_10m': 140, 'windspeed_10m': 27.1}, {'time': '2025-03-19T11:00', 'winddirection_10m': 141, 'windspeed_10m': 29.2}
|
51 |
+
]
|
52 |
+
```
|
53 |
+
|
54 |
+
"""
|
55 |
+
url = "https://marine-api.open-meteo.com/v1/marine"
|
56 |
+
params = {
|
57 |
+
"latitude": lat,
|
58 |
+
"longitude": lon,
|
59 |
+
"hourly": [
|
60 |
+
"wave_direction",
|
61 |
+
"wave_height",
|
62 |
+
"wave_period",
|
63 |
+
"sea_level_height_msl",
|
64 |
+
],
|
65 |
+
}
|
66 |
+
response = requests.get(url, params=params)
|
67 |
+
response.raise_for_status()
|
68 |
+
data = json.loads(response.content.decode())
|
69 |
+
hourly_data = _extract_hourly_data(data)
|
70 |
+
if date is not None:
|
71 |
+
date = datetime.fromisoformat(date)
|
72 |
+
hourly_data = _filter_by_date(date, hourly_data)
|
73 |
+
if len(hourly_data) == 0:
|
74 |
+
raise ValueError("No data found for the given date")
|
75 |
+
return hourly_data
|
76 |
+
|
77 |
+
|
78 |
+
def get_wind_forecast(lat: float, lon: float, date: str) -> list[dict]:
|
79 |
+
"""Get wind forecast for given location.
|
80 |
+
|
81 |
+
Forecast will include:
|
82 |
+
|
83 |
+
- wind_direction (degrees)
|
84 |
+
- wind_speed (meters per second)
|
85 |
+
|
86 |
+
Args:
|
87 |
+
lat: Latitude of the location.
|
88 |
+
lon: Longitude of the location.
|
89 |
+
date: Date to filter by in any valid ISO 8601 format.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
Hourly data for wind forecast.
|
93 |
+
Example output:
|
94 |
+
|
95 |
+
```json
|
96 |
+
[
|
97 |
+
{"time": "2025-03-18T22:00", "wind_direction": 196, "wind_speed": 9.6},
|
98 |
+
{"time": "2025-03-18T23:00", "wind_direction": 183, "wind_speed": 7.9},
|
99 |
+
]
|
100 |
+
```
|
101 |
+
|
102 |
+
"""
|
103 |
+
url = "https://api.open-meteo.com/v1/forecast"
|
104 |
+
params = {
|
105 |
+
"latitude": lat,
|
106 |
+
"longitude": lon,
|
107 |
+
"hourly": ["winddirection_10m", "windspeed_10m"],
|
108 |
+
}
|
109 |
+
response = requests.get(url, params=params)
|
110 |
+
response.raise_for_status()
|
111 |
+
data = json.loads(response.content.decode())
|
112 |
+
hourly_data = _extract_hourly_data(data)
|
113 |
+
date = datetime.fromisoformat(date)
|
114 |
+
hourly_data = _filter_by_date(date, hourly_data)
|
115 |
+
if len(hourly_data) == 0:
|
116 |
+
raise ValueError("No data found for the given date")
|
117 |
+
return hourly_data
|
tools/openstreetmap.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import requests
|
4 |
+
|
5 |
+
|
6 |
+
def get_area_lat_lon(area_name: str) -> tuple[float, float]:
|
7 |
+
"""Get the latitude and longitude of an area from Nominatim.
|
8 |
+
|
9 |
+
Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).
|
10 |
+
|
11 |
+
Args:
|
12 |
+
area_name: The name of the area.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
The area found.
|
16 |
+
|
17 |
+
"""
|
18 |
+
response = requests.get(
|
19 |
+
f"https://nominatim.openstreetmap.org/search?q={area_name}&format=jsonv2",
|
20 |
+
headers={"User-Agent": "Mozilla/5.0"},
|
21 |
+
)
|
22 |
+
response.raise_for_status()
|
23 |
+
area = json.loads(response.content.decode())
|
24 |
+
return area[0]["lat"], area[0]["lon"]
|
25 |
+
|
26 |
+
|
27 |
+
def driving_hours_to_meters(driving_hours: int) -> int:
|
28 |
+
"""Convert driving hours to meters assuming a 70 km/h average speed.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
driving_hours: The driving hours.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
The distance in meters.
|
35 |
+
|
36 |
+
"""
|
37 |
+
return driving_hours * 70 * 1000
|
38 |
+
|
39 |
+
|
40 |
+
def get_lat_lon_center(bounds: dict) -> tuple[float, float]:
|
41 |
+
"""Get the latitude and longitude of the center of a bounding box.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
bounds: The bounding box.
|
45 |
+
|
46 |
+
```json
|
47 |
+
{
|
48 |
+
"minlat": float,
|
49 |
+
"minlon": float,
|
50 |
+
"maxlat": float,
|
51 |
+
"maxlon": float,
|
52 |
+
}
|
53 |
+
```
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
The latitude and longitude of the center.
|
57 |
+
|
58 |
+
"""
|
59 |
+
return (
|
60 |
+
(bounds["minlat"] + bounds["maxlat"]) / 2,
|
61 |
+
(bounds["minlon"] + bounds["maxlon"]) / 2,
|
62 |
+
)
|