tool_retriever / app.py
Yyy0530's picture
重构 Streamlit 应用,将 JSONL 文件检索器集成到 app.py 中
6ac301d
raw
history blame
3.16 kB
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import os
import sys
import faiss
import numpy as np
import streamlit as st
from text2vec import SentenceModel
# 请确保 JSONLIndexer 在 src 目录下或者已正确安装
from src.jsonl_Indexer import JSONLIndexer
# 命令行参数处理函数
def get_cli_args():
args = {}
# 跳过第一个参数(脚本名)和第二个参数(streamlit run)
argv = sys.argv[2:] if len(sys.argv) > 2 else []
for arg in argv:
if '=' in arg:
key, value = arg.split('=', 1)
args[key.strip()] = value.strip()
return args
# 获取命令行参数
cli_args = get_cli_args()
# 设置默认值(适用于 JSONL 文件)
DEFAULT_CONFIG = {
'model_path': 'BAAI/bge-base-en-v1.5',
'dataset_path': 'src/tool-embedding.jsonl', # JSONL 文件路径
'vector_size': 768,
'embedding_field': 'embedding', # JSON中存储embedding的字段名
'id_field': 'id' # JSON中作为待检索文本的字段
}
# 合并默认配置和命令行参数
config = DEFAULT_CONFIG.copy()
config.update(cli_args)
# 将 vector_size 转换为整数
config['vector_size'] = int(config['vector_size'])
@st.cache_resource
def get_model(model_path: str = config['model_path']):
model = SentenceModel(model_path)
return model
@st.cache_resource
def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
retriever = JSONLIndexer(vector_sz=vector_sz, model=_model)
retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
return retriever
# 在侧边栏显示当前配置
if st.sidebar.checkbox("Show Configuration"):
st.sidebar.write("Current Configuration:")
for key, value in config.items():
st.sidebar.write(f"{key}: {value}")
# 初始化模型和检索器
model = get_model(config['model_path'])
retriever = create_retriever(
config['vector_size'],
config['dataset_path'],
config['embedding_field'],
config['id_field'],
_model=model
)
# Streamlit 应用界面
st.title("JSONL Data Retrieval Visualization")
st.write("该应用基于预计算的 JSONL 文件 embedding,输入查询后将检索相似记录。")
# 查询输入
query = st.text_input("Enter a search query:")
top_k = st.slider("Select number of results to display", min_value=1, max_value=100, value=5)
# 检索并展示结果
if st.button("Search") and query:
# 注意:JSONLIndexer 提供的是 search_return_id 方法,返回的是 JSON 中 id 字段
rec_ids, scores = retriever.search_return_id(query, top_k)
st.write("### Results:")
with st.expander("Retrieval Results (click to expand)"):
for j, rec_id in enumerate(rec_ids):
st.markdown(
f"""
<div style="border:1px solid #ccc; padding:10px; border-radius:5px; margin-bottom:10px; background-color:#f9f9f9;">
<p><b>Record {j+1} ID:</b> {rec_id}</p>
<p><b>Score:</b> {scores[j]:.4f}</p>
</div>
""",
unsafe_allow_html=True
)