tool_retriever / app.py
Yyy0530's picture
Update app.py
b97944f verified
raw
history blame
4.57 kB
import os
import sys
import faiss
import numpy as np
import streamlit as st
import pandas as pd
from text2vec import SentenceModel
from src.jsonl_Indexer import JSONLIndexer
def get_cli_args():
args = {}
argv = sys.argv[2:] if len(sys.argv) > 2 else []
for arg in argv:
if '=' in arg:
key, value = arg.split('=', 1)
args[key.strip()] = value.strip()
return args
cli_args = get_cli_args()
DEFAULT_CONFIG = {
'model_path': 'BAAI/bge-base-en-v1.5',
'dataset_path': 'tool-embedding.jsonl',
'vector_size': 768,
'embedding_field': 'embedding',
'id_field': 'id'
}
config = DEFAULT_CONFIG.copy()
config.update(cli_args)
config['vector_size'] = int(config['vector_size'])
#加载数据
from datasets import load_dataset
from datasets import concatenate_datasets
ds1 = load_dataset("mangopy/ToolRet-Tools", "code")
ds2 = load_dataset("mangopy/ToolRet-Tools", "customized")
ds3 = load_dataset("mangopy/ToolRet-Tools", "web")
ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']])
ds = ds.rename_columns({'id':'tool'})
#merge
# 随便建立一个pd.DataFrame, 有两列,一列是id,一列是text
import pandas as pd
df2 = ds.to_pandas()
@st.cache_resource
def get_model(model_path: str = config['model_path']):
return SentenceModel(model_path)
@st.cache_resource
def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
retriever = JSONLIndexer(vector_sz=vector_sz, model=_model)
retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
return retriever
# 在侧边栏中添加模型配置标题
st.sidebar.markdown("<div style='text-align: center;'><h3>📄 Model Configuration</h3></div>", unsafe_allow_html=True)
# 添加模型选项下拉框,目前只有一个模型可选
model_options = ["BAAI/bge-base-en-v1.5"]
selected_model = st.sidebar.selectbox("Select Model", model_options)
st.sidebar.write("Selected model:", selected_model)
st.sidebar.write("Embedding length: 768")
# 使用选中的模型加载
model = get_model(selected_model)
model = get_model(config['model_path'])
retriever = create_retriever(config['vector_size'], config['dataset_path'], config['embedding_field'], config['id_field'], _model=model)
# 美化界面
st.markdown("""
<style>
.search-container {
display: flex;
justify-content: center;
align-items: center;
gap: 10px;
margin-top: 20px;
}
.search-box input {
width: 500px !important;
height: 45px;
font-size: 16px;
border-radius: 25px;
padding-left: 15px;
}
.search-btn button {
height: 45px;
font-size: 16px;
border-radius: 25px;
}
</style>
""", unsafe_allow_html=True)
st.markdown("<h1 style='text-align: center;'>🔍 Tool Retrieval</h1>", unsafe_allow_html=True)
col1, col2 = st.columns([4, 1])
with col1:
query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed")
with col2:
search_clicked = st.button("🔎 Search", use_container_width=True)
top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display")
styled_results = None
if search_clicked and query:
rec_ids, scores = retriever.search_return_id(query, top_k)
df1 = pd.DataFrame({ "relevance": scores, "tool": rec_ids})
# print(df1)
# merge两个DataFrame
results_df = pd.merge(df1, df2, on='tool', how = 'left')
# results_df["interface"] = "asdasdadasdasdasdasdasdasdasasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdassdasdasdasdasdasabababbabasdbabsdbasbdadabdbasdbasbdbasdbasdbasdb"
st.subheader("🗂️ Retrieval results")
styled_results = results_df.style.apply(
lambda x: [
"background-color: #F7F7F7" if i % 2 == 0 else "background-color: #FFFFFF"
for i in range(len(x))
],
axis=0,
).format({"relevance": "{:.4f}"})
st.dataframe(
styled_results,
column_config={
"relevance": st.column_config.ProgressColumn(
"relevance",
help="记录与查询的匹配程度",
format="%.4f",
min_value=0,
max_value=float(max(scores)) if len(scores) > 0 else 1,
),
"tool": st.column_config.TextColumn("tool", help="tool help text", width="medium")
},
hide_index=True,
use_container_width=True,
)