Spaces:
Sleeping
Sleeping
File size: 4,567 Bytes
9ab9c77 8911626 9ab9c77 b97944f 9ab9c77 7c04c37 9ab9c77 a302d1e 9ab9c77 7c04c37 9ab9c77 a302d1e 9ab9c77 7c04c37 9ab9c77 7c04c37 9ab9c77 a302d1e 7c04c37 8911626 7c04c37 8911626 7c04c37 ce422d8 a302d1e 7c04c37 a302d1e 7c04c37 9ab9c77 a302d1e 8911626 a302d1e 8911626 a302d1e 8911626 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import sys
import faiss
import numpy as np
import streamlit as st
import pandas as pd
from text2vec import SentenceModel
from src.jsonl_Indexer import JSONLIndexer
def get_cli_args():
args = {}
argv = sys.argv[2:] if len(sys.argv) > 2 else []
for arg in argv:
if '=' in arg:
key, value = arg.split('=', 1)
args[key.strip()] = value.strip()
return args
cli_args = get_cli_args()
DEFAULT_CONFIG = {
'model_path': 'BAAI/bge-base-en-v1.5',
'dataset_path': 'tool-embedding.jsonl',
'vector_size': 768,
'embedding_field': 'embedding',
'id_field': 'id'
}
config = DEFAULT_CONFIG.copy()
config.update(cli_args)
config['vector_size'] = int(config['vector_size'])
#加载数据
from datasets import load_dataset
from datasets import concatenate_datasets
ds1 = load_dataset("mangopy/ToolRet-Tools", "code")
ds2 = load_dataset("mangopy/ToolRet-Tools", "customized")
ds3 = load_dataset("mangopy/ToolRet-Tools", "web")
ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']])
ds = ds.rename_columns({'id':'tool'})
#merge
# 随便建立一个pd.DataFrame, 有两列,一列是id,一列是text
import pandas as pd
df2 = ds.to_pandas()
@st.cache_resource
def get_model(model_path: str = config['model_path']):
return SentenceModel(model_path)
@st.cache_resource
def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
retriever = JSONLIndexer(vector_sz=vector_sz, model=_model)
retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
return retriever
# 在侧边栏中添加模型配置标题
st.sidebar.markdown("<div style='text-align: center;'><h3>📄 Model Configuration</h3></div>", unsafe_allow_html=True)
# 添加模型选项下拉框,目前只有一个模型可选
model_options = ["BAAI/bge-base-en-v1.5"]
selected_model = st.sidebar.selectbox("Select Model", model_options)
st.sidebar.write("Selected model:", selected_model)
st.sidebar.write("Embedding length: 768")
# 使用选中的模型加载
model = get_model(selected_model)
model = get_model(config['model_path'])
retriever = create_retriever(config['vector_size'], config['dataset_path'], config['embedding_field'], config['id_field'], _model=model)
# 美化界面
st.markdown("""
<style>
.search-container {
display: flex;
justify-content: center;
align-items: center;
gap: 10px;
margin-top: 20px;
}
.search-box input {
width: 500px !important;
height: 45px;
font-size: 16px;
border-radius: 25px;
padding-left: 15px;
}
.search-btn button {
height: 45px;
font-size: 16px;
border-radius: 25px;
}
</style>
""", unsafe_allow_html=True)
st.markdown("<h1 style='text-align: center;'>🔍 Tool Retrieval</h1>", unsafe_allow_html=True)
col1, col2 = st.columns([4, 1])
with col1:
query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed")
with col2:
search_clicked = st.button("🔎 Search", use_container_width=True)
top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display")
styled_results = None
if search_clicked and query:
rec_ids, scores = retriever.search_return_id(query, top_k)
df1 = pd.DataFrame({ "relevance": scores, "tool": rec_ids})
# print(df1)
# merge两个DataFrame
results_df = pd.merge(df1, df2, on='tool', how = 'left')
# results_df["interface"] = "asdasdadasdasdasdasdasdasdasasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdassdasdasdasdasdasabababbabasdbabsdbasbdadabdbasdbasbdbasdbasdbasdb"
st.subheader("🗂️ Retrieval results")
styled_results = results_df.style.apply(
lambda x: [
"background-color: #F7F7F7" if i % 2 == 0 else "background-color: #FFFFFF"
for i in range(len(x))
],
axis=0,
).format({"relevance": "{:.4f}"})
st.dataframe(
styled_results,
column_config={
"relevance": st.column_config.ProgressColumn(
"relevance",
help="记录与查询的匹配程度",
format="%.4f",
min_value=0,
max_value=float(max(scores)) if len(scores) > 0 else 1,
),
"tool": st.column_config.TextColumn("tool", help="tool help text", width="medium")
},
hide_index=True,
use_container_width=True,
)
|