File size: 4,567 Bytes
9ab9c77
 
 
 
 
8911626
9ab9c77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b97944f
9ab9c77
7c04c37
 
9ab9c77
 
 
 
 
 
a302d1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ab9c77
 
7c04c37
9ab9c77
 
 
 
 
 
 
a302d1e
 
 
 
 
 
 
 
 
 
 
 
 
9ab9c77
 
7c04c37
9ab9c77
7c04c37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ab9c77
a302d1e
 
7c04c37
 
8911626
7c04c37
8911626
7c04c37
ce422d8
a302d1e
7c04c37
a302d1e
7c04c37
9ab9c77
a302d1e
 
 
 
 
 
 
8911626
 
 
 
 
 
 
 
a302d1e
8911626
 
 
 
 
 
 
 
 
a302d1e
 
8911626
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import sys
import faiss
import numpy as np
import streamlit as st
import pandas as pd
from text2vec import SentenceModel
from src.jsonl_Indexer import JSONLIndexer

def get_cli_args():
    args = {}
    argv = sys.argv[2:] if len(sys.argv) > 2 else []
    for arg in argv:
        if '=' in arg:
            key, value = arg.split('=', 1)
            args[key.strip()] = value.strip()
    return args

cli_args = get_cli_args()

DEFAULT_CONFIG = {
    'model_path': 'BAAI/bge-base-en-v1.5',
    'dataset_path': 'tool-embedding.jsonl',
    'vector_size': 768,
    'embedding_field': 'embedding',
    'id_field': 'id'
}

config = DEFAULT_CONFIG.copy()
config.update(cli_args)
config['vector_size'] = int(config['vector_size'])


#加载数据
from datasets import load_dataset
from datasets import  concatenate_datasets
ds1 = load_dataset("mangopy/ToolRet-Tools", "code")
ds2 = load_dataset("mangopy/ToolRet-Tools", "customized")
ds3 = load_dataset("mangopy/ToolRet-Tools", "web")
ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']])
ds = ds.rename_columns({'id':'tool'})

#merge

# 随便建立一个pd.DataFrame, 有两列,一列是id,一列是text
import pandas as pd
df2 = ds.to_pandas()



@st.cache_resource
def get_model(model_path: str = config['model_path']):
    return SentenceModel(model_path)

@st.cache_resource
def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
    retriever = JSONLIndexer(vector_sz=vector_sz, model=_model)
    retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
    return retriever

# 在侧边栏中添加模型配置标题
st.sidebar.markdown("<div style='text-align: center;'><h3>📄 Model Configuration</h3></div>", unsafe_allow_html=True)


# 添加模型选项下拉框,目前只有一个模型可选
model_options = ["BAAI/bge-base-en-v1.5"]
selected_model = st.sidebar.selectbox("Select Model", model_options)
st.sidebar.write("Selected model:", selected_model)
st.sidebar.write("Embedding length: 768")

# 使用选中的模型加载
model = get_model(selected_model)


model = get_model(config['model_path'])
retriever = create_retriever(config['vector_size'], config['dataset_path'], config['embedding_field'], config['id_field'], _model=model)

# 美化界面
st.markdown("""
    <style>
    .search-container {
        display: flex;
        justify-content: center;
        align-items: center;
        gap: 10px;
        margin-top: 20px;
    }
    .search-box input {
        width: 500px !important;
        height: 45px;
        font-size: 16px;
        border-radius: 25px;
        padding-left: 15px;
    }
    .search-btn button {
        height: 45px;
        font-size: 16px;
        border-radius: 25px;
    }
    </style>
""", unsafe_allow_html=True)

st.markdown("<h1 style='text-align: center;'>🔍 Tool Retrieval</h1>", unsafe_allow_html=True)


col1, col2 = st.columns([4, 1])
with col1:
    query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed")
with col2:
    search_clicked = st.button("🔎 Search", use_container_width=True)

top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display")

styled_results = None
if search_clicked and query:
    rec_ids, scores = retriever.search_return_id(query, top_k)
    df1 = pd.DataFrame({ "relevance": scores, "tool": rec_ids})
    # print(df1)
    # merge两个DataFrame
    results_df = pd.merge(df1, df2, on='tool', how = 'left')

    # results_df["interface"] = "asdasdadasdasdasdasdasdasdasasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdassdasdasdasdasdasabababbabasdbabsdbasbdadabdbasdbasbdbasdbasdbasdb"
    st.subheader("🗂️ Retrieval results")
    
    styled_results = results_df.style.apply(
        lambda x: [
            "background-color: #F7F7F7" if i % 2 == 0 else "background-color: #FFFFFF"
            for i in range(len(x))
        ],
        axis=0,
    ).format({"relevance": "{:.4f}"})

    st.dataframe(
        styled_results,
        column_config={
            "relevance": st.column_config.ProgressColumn(
                "relevance",
                help="记录与查询的匹配程度",
                format="%.4f",
                min_value=0,
                max_value=float(max(scores)) if len(scores) > 0 else 1,
            ),
            "tool": st.column_config.TextColumn("tool", help="tool help text", width="medium")
        },
        hide_index=True,
        use_container_width=True,
    )