Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import faiss | |
import numpy as np | |
import streamlit as st | |
import pandas as pd | |
from text2vec import SentenceModel | |
from src.jsonl_Indexer import JSONLIndexer | |
def get_cli_args(): | |
args = {} | |
argv = sys.argv[2:] if len(sys.argv) > 2 else [] | |
for arg in argv: | |
if '=' in arg: | |
key, value = arg.split('=', 1) | |
args[key.strip()] = value.strip() | |
return args | |
cli_args = get_cli_args() | |
DEFAULT_CONFIG = { | |
'model_path': 'BAAI/bge-base-en-v1.5', | |
'dataset_path': 'tool-embedding.jsonl', | |
'vector_size': 768, | |
'embedding_field': 'embedding', | |
'id_field': 'id' | |
} | |
config = DEFAULT_CONFIG.copy() | |
config.update(cli_args) | |
config['vector_size'] = int(config['vector_size']) | |
#加载数据 | |
from datasets import load_dataset | |
from datasets import concatenate_datasets | |
ds1 = load_dataset("mangopy/ToolRet-Tools", "code") | |
ds2 = load_dataset("mangopy/ToolRet-Tools", "customized") | |
ds3 = load_dataset("mangopy/ToolRet-Tools", "web") | |
ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']]) | |
ds = ds.rename_columns({'id':'tool'}) | |
#merge | |
# 随便建立一个pd.DataFrame, 有两列,一列是id,一列是text | |
import pandas as pd | |
df2 = ds.to_pandas() | |
def get_model(model_path: str = config['model_path']): | |
return SentenceModel(model_path) | |
def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model): | |
retriever = JSONLIndexer(vector_sz=vector_sz, model=_model) | |
retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field) | |
return retriever | |
# 在侧边栏中添加模型配置标题 | |
st.sidebar.markdown("<div style='text-align: center;'><h3>📄 Model Configuration</h3></div>", unsafe_allow_html=True) | |
# 添加模型选项下拉框,目前只有一个模型可选 | |
model_options = ["BAAI/bge-base-en-v1.5"] | |
selected_model = st.sidebar.selectbox("Select Model", model_options) | |
st.sidebar.write("Selected model:", selected_model) | |
st.sidebar.write("Embedding length: 768") | |
# 使用选中的模型加载 | |
model = get_model(selected_model) | |
model = get_model(config['model_path']) | |
retriever = create_retriever(config['vector_size'], config['dataset_path'], config['embedding_field'], config['id_field'], _model=model) | |
# 美化界面 | |
st.markdown(""" | |
<style> | |
.search-container { | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
gap: 10px; | |
margin-top: 20px; | |
} | |
.search-box input { | |
width: 500px !important; | |
height: 45px; | |
font-size: 16px; | |
border-radius: 25px; | |
padding-left: 15px; | |
} | |
.search-btn button { | |
height: 45px; | |
font-size: 16px; | |
border-radius: 25px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.markdown("<h1 style='text-align: center;'>🔍 Tool Retrieval</h1>", unsafe_allow_html=True) | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed") | |
with col2: | |
search_clicked = st.button("🔎 Search", use_container_width=True) | |
top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display") | |
styled_results = None | |
if search_clicked and query: | |
rec_ids, scores = retriever.search_return_id(query, top_k) | |
df1 = pd.DataFrame({ "relevance": scores, "tool": rec_ids}) | |
# print(df1) | |
# merge两个DataFrame | |
results_df = pd.merge(df1, df2, on='tool', how = 'left') | |
# results_df["interface"] = "asdasdadasdasdasdasdasdasdasasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdassdasdasdasdasdasabababbabasdbabsdbasbdadabdbasdbasbdbasdbasdbasdb" | |
st.subheader("🗂️ Retrieval results") | |
styled_results = results_df.style.apply( | |
lambda x: [ | |
"background-color: #F7F7F7" if i % 2 == 0 else "background-color: #FFFFFF" | |
for i in range(len(x)) | |
], | |
axis=0, | |
).format({"relevance": "{:.4f}"}) | |
st.dataframe( | |
styled_results, | |
column_config={ | |
"relevance": st.column_config.ProgressColumn( | |
"relevance", | |
help="记录与查询的匹配程度", | |
format="%.4f", | |
min_value=0, | |
max_value=float(max(scores)) if len(scores) > 0 else 1, | |
), | |
"tool": st.column_config.TextColumn("tool", help="tool help text", width="medium") | |
}, | |
hide_index=True, | |
use_container_width=True, | |
) | |