File size: 3,159 Bytes
9ab9c77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import os
import sys
import faiss
import numpy as np
import streamlit as st
from text2vec import SentenceModel
# 请确保 JSONLIndexer 在 src 目录下或者已正确安装
from src.jsonl_Indexer import JSONLIndexer

# 命令行参数处理函数
def get_cli_args():
    args = {}
    # 跳过第一个参数(脚本名)和第二个参数(streamlit run)
    argv = sys.argv[2:] if len(sys.argv) > 2 else []
    for arg in argv:
        if '=' in arg:
            key, value = arg.split('=', 1)
            args[key.strip()] = value.strip()
    return args

# 获取命令行参数
cli_args = get_cli_args()

# 设置默认值(适用于 JSONL 文件)
DEFAULT_CONFIG = {
    'model_path': 'BAAI/bge-base-en-v1.5',
    'dataset_path': 'src/tool-embedding.jsonl',  # JSONL 文件路径
    'vector_size': 768,
    'embedding_field': 'embedding',   # JSON中存储embedding的字段名
    'id_field': 'id'                  # JSON中作为待检索文本的字段
}

# 合并默认配置和命令行参数
config = DEFAULT_CONFIG.copy()
config.update(cli_args)

# 将 vector_size 转换为整数
config['vector_size'] = int(config['vector_size'])

@st.cache_resource
def get_model(model_path: str = config['model_path']):
    model = SentenceModel(model_path)
    return model

@st.cache_resource
def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
    retriever = JSONLIndexer(vector_sz=vector_sz, model=_model)
    retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
    return retriever

# 在侧边栏显示当前配置
if st.sidebar.checkbox("Show Configuration"):
    st.sidebar.write("Current Configuration:")
    for key, value in config.items():
        st.sidebar.write(f"{key}: {value}")

# 初始化模型和检索器
model = get_model(config['model_path'])
retriever = create_retriever(
    config['vector_size'],
    config['dataset_path'],
    config['embedding_field'],
    config['id_field'],
    _model=model
)

# Streamlit 应用界面
st.title("JSONL Data Retrieval Visualization")
st.write("该应用基于预计算的 JSONL 文件 embedding,输入查询后将检索相似记录。")

# 查询输入
query = st.text_input("Enter a search query:")
top_k = st.slider("Select number of results to display", min_value=1, max_value=100, value=5)

# 检索并展示结果
if st.button("Search") and query:
    # 注意:JSONLIndexer 提供的是 search_return_id 方法,返回的是 JSON 中 id 字段
    rec_ids, scores = retriever.search_return_id(query, top_k)

    st.write("### Results:")

    with st.expander("Retrieval Results (click to expand)"):
        for j, rec_id in enumerate(rec_ids):
            st.markdown(
                f"""
                <div style="border:1px solid #ccc; padding:10px; border-radius:5px; margin-bottom:10px; background-color:#f9f9f9;">
                    <p><b>Record {j+1} ID:</b> {rec_id}</p>
                    <p><b>Score:</b> {scores[j]:.4f}</p>
                </div>
                """,
                unsafe_allow_html=True
            )