File size: 4,002 Bytes
7e103cf
763ec84
7e103cf
 
 
763ec84
 
7e103cf
 
 
 
 
763ec84
 
7e103cf
 
763ec84
 
 
 
 
 
 
 
7e103cf
 
763ec84
 
 
 
 
 
 
 
 
 
 
 
 
 
7e103cf
763ec84
 
9204b9a
763ec84
 
 
 
7e103cf
763ec84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e103cf
763ec84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e103cf
 
 
 
 
 
763ec84
 
 
 
 
 
 
7e103cf
 
 
 
 
763ec84
7e103cf
763ec84
 
 
 
7e103cf
 
 
763ec84
7e103cf
763ec84
 
 
 
 
 
 
7e103cf
 
763ec84
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import copy
import json
from datetime import datetime, timedelta

import pandas as pd
import requests
import streamlit as st
from constants import (
    DEFAULT_EVALUATION_CRITERIA,
    DEFAULT_EVALUATION_MODEL,
    MODEL_OPTIONS,
)
from pydantic import BaseModel, ConfigDict

from any_agent import AgentFramework


class UserInputs(BaseModel):
    model_config = ConfigDict(extra="forbid")
    model_id: str
    location: str
    max_driving_hours: int
    date: datetime
    framework: str
    evaluation_model: str
    evaluation_criteria: list[dict[str, str]]
    run_evaluation: bool


@st.cache_resource
def get_area(area_name: str) -> dict:
    """Get the area from Nominatim.

    Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).

    Args:
        area_name (str): The name of the area.

    Returns:
        dict: The area found.

    """
    response = requests.get(
        f"https://nominatim.openstreetmap.org/search?q={area_name}&format=jsonv2",
        headers={"User-Agent": "Mozilla/5.0"},
        timeout=5,
    )
    response.raise_for_status()
    return json.loads(response.content.decode())


def get_user_inputs() -> UserInputs:
    default_val = "Los Angeles California, US"

    location = st.text_input("Enter a location", value=default_val)
    if location:
        location_check = get_area(location)
        if not location_check:
            st.error("❌ Invalid location")

    max_driving_hours = st.number_input(
        "Enter the maximum driving hours", min_value=1, value=2
    )

    col_date, col_time = st.columns([2, 1])
    with col_date:
        date = st.date_input(
            "Select a date in the future", value=datetime.now() + timedelta(days=1)
        )
    with col_time:
        time = st.selectbox(
            "Select a time",
            [datetime.strptime(f"{i:02d}:00", "%H:%M").time() for i in range(24)],
            index=9,
        )
    date = datetime.combine(date, time)

    supported_frameworks = [framework for framework in AgentFramework]

    framework = st.selectbox(
        "Select the agent framework to use",
        supported_frameworks,
        index=2,
        format_func=lambda x: x.name,
    )

    model_id = st.selectbox(
        "Select the model to use",
        MODEL_OPTIONS,
        index=1,
        format_func=lambda x: "/".join(x.split("/")[-3:]),
    )

    with st.expander("Custom Evaluation"):
        evaluation_model_id = st.selectbox(
            "Select the model to use for LLM-as-a-Judge evaluation",
            MODEL_OPTIONS,
            index=2,
            format_func=lambda x: "/".join(x.split("/")[-3:]),
        )

        evaluation_criteria = copy.deepcopy(DEFAULT_EVALUATION_CRITERIA)

        criteria_df = pd.DataFrame(evaluation_criteria)
        criteria_df = st.data_editor(
            criteria_df,
            column_config={
                "criteria": st.column_config.TextColumn(label="Criteria"),
            },
            hide_index=True,
            num_rows="dynamic",
        )

        new_criteria = []

        if len(criteria_df) > 20:
            st.error("You can only add up to 20 criteria for the purpose of this demo.")
            criteria_df = criteria_df[:20]

        for _, row in criteria_df.iterrows():
            if row["criteria"] == "":
                continue
            try:
                if len(row["criteria"].split(" ")) > 100:
                    msg = "Criteria is too long"
                    raise ValueError(msg)
                new_criteria.append({"criteria": row["criteria"]})
            except Exception as e:
                st.error(f"Error creating criterion: {e}")

    return UserInputs(
        model_id=model_id,
        location=location,
        max_driving_hours=max_driving_hours,
        date=date,
        framework=framework,
        evaluation_model=evaluation_model_id,
        evaluation_criteria=new_criteria,
        run_evaluation=st.checkbox("Run Evaluation", value=True),
    )