Spaces:
Sleeping
Sleeping
import os | |
os.environ['CUDA_VISIBLE_DEVICES'] = '3' | |
import os | |
import sys | |
import faiss | |
import numpy as np | |
import streamlit as st | |
from text2vec import SentenceModel | |
# 请确保 JSONLIndexer 在 src 目录下或者已正确安装 | |
from src.jsonl_Indexer import JSONLIndexer | |
# 命令行参数处理函数 | |
def get_cli_args(): | |
args = {} | |
# 跳过第一个参数(脚本名)和第二个参数(streamlit run) | |
argv = sys.argv[2:] if len(sys.argv) > 2 else [] | |
for arg in argv: | |
if '=' in arg: | |
key, value = arg.split('=', 1) | |
args[key.strip()] = value.strip() | |
return args | |
# 获取命令行参数 | |
cli_args = get_cli_args() | |
# 设置默认值(适用于 JSONL 文件) | |
DEFAULT_CONFIG = { | |
'model_path': 'BAAI/bge-base-en-v1.5', | |
'dataset_path': 'src/tool-embedding.jsonl', # JSONL 文件路径 | |
'vector_size': 768, | |
'embedding_field': 'embedding', # JSON中存储embedding的字段名 | |
'id_field': 'id' # JSON中作为待检索文本的字段 | |
} | |
# 合并默认配置和命令行参数 | |
config = DEFAULT_CONFIG.copy() | |
config.update(cli_args) | |
# 将 vector_size 转换为整数 | |
config['vector_size'] = int(config['vector_size']) | |
def get_model(model_path: str = config['model_path']): | |
model = SentenceModel(model_path) | |
return model | |
def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model): | |
retriever = JSONLIndexer(vector_sz=vector_sz, model=_model) | |
retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field) | |
return retriever | |
# 在侧边栏显示当前配置 | |
if st.sidebar.checkbox("Show Configuration"): | |
st.sidebar.write("Current Configuration:") | |
for key, value in config.items(): | |
st.sidebar.write(f"{key}: {value}") | |
# 初始化模型和检索器 | |
model = get_model(config['model_path']) | |
retriever = create_retriever( | |
config['vector_size'], | |
config['dataset_path'], | |
config['embedding_field'], | |
config['id_field'], | |
_model=model | |
) | |
# Streamlit 应用界面 | |
st.title("JSONL Data Retrieval Visualization") | |
st.write("该应用基于预计算的 JSONL 文件 embedding,输入查询后将检索相似记录。") | |
# 查询输入 | |
query = st.text_input("Enter a search query:") | |
top_k = st.slider("Select number of results to display", min_value=1, max_value=100, value=5) | |
# 检索并展示结果 | |
if st.button("Search") and query: | |
# 注意:JSONLIndexer 提供的是 search_return_id 方法,返回的是 JSON 中 id 字段 | |
rec_ids, scores = retriever.search_return_id(query, top_k) | |
st.write("### Results:") | |
with st.expander("Retrieval Results (click to expand)"): | |
for j, rec_id in enumerate(rec_ids): | |
st.markdown( | |
f""" | |
<div style="border:1px solid #ccc; padding:10px; border-radius:5px; margin-bottom:10px; background-color:#f9f9f9;"> | |
<p><b>Record {j+1} ID:</b> {rec_id}</p> | |
<p><b>Score:</b> {scores[j]:.4f}</p> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |