Yyy0530 commited on
Commit
7c04c37
·
verified ·
1 Parent(s): c267b51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -45
app.py CHANGED
@@ -7,10 +7,8 @@ import pandas as pd
7
  from text2vec import SentenceModel
8
  from src.jsonl_Indexer import JSONLIndexer
9
 
10
- # 命令行参数处理函数
11
  def get_cli_args():
12
  args = {}
13
- # 跳过第一个参数(脚本名)和第二个参数(streamlit run)
14
  argv = sys.argv[2:] if len(sys.argv) > 2 else []
15
  for arg in argv:
16
  if '=' in arg:
@@ -18,27 +16,23 @@ def get_cli_args():
18
  args[key.strip()] = value.strip()
19
  return args
20
 
21
- # 获取命令行参数
22
  cli_args = get_cli_args()
23
 
24
- # 设置默认值(适用于 JSONL 文件)
25
  DEFAULT_CONFIG = {
26
  'model_path': 'BAAI/bge-base-en-v1.5',
27
- 'dataset_path': 'tool-embedding.jsonl', # JSONL 文件路径
28
  'vector_size': 768,
29
- 'embedding_field': 'embedding', # JSON中存储embedding的字段名
30
- 'id_field': 'id' # JSON中作为待检索文本的字段
31
  }
32
 
33
- # 合并默认配置和命令行参数
34
  config = DEFAULT_CONFIG.copy()
35
  config.update(cli_args)
36
  config['vector_size'] = int(config['vector_size'])
37
 
38
  @st.cache_resource
39
  def get_model(model_path: str = config['model_path']):
40
- model = SentenceModel(model_path)
41
- return model
42
 
43
  @st.cache_resource
44
  def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
@@ -46,55 +40,55 @@ def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id
46
  retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
47
  return retriever
48
 
49
- # 在侧边栏显示当前配置
50
  if st.sidebar.checkbox("Show Configuration"):
51
  st.sidebar.write("Current Configuration:")
52
  for key, value in config.items():
53
  st.sidebar.write(f"{key}: {value}")
54
 
55
- # 初始化模型和检索器
56
  model = get_model(config['model_path'])
57
- retriever = create_retriever(
58
- config['vector_size'],
59
- config['dataset_path'],
60
- config['embedding_field'],
61
- config['id_field'],
62
- _model=model
63
- )
64
 
65
- # Streamlit 应用界面
66
- st.title("Title")
67
- # st.write("该应用基于预计算的 JSONL 文件 embedding,输入查询后将检索相似记录。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # 查询输入
70
- # 创建两列布局
71
- col1, col2 = st.columns([2.5, 1])
72
- query,topk = None, None
73
  with col1:
74
- # 搜索输入框
75
- query = st.text_input(
76
- "query", placeholder="your query", help=""
77
- )
78
  with col2:
79
- # TopK选择滑块
80
- top_k = st.slider(
81
- "Top K", 1, 100, 50, help="choose the number of results to display"
82
- )
83
- # 检索并展示结果
84
 
85
- if st.button("搜索") and query:
86
- # 调用检索方法,返回JSON中id字段和对应的相似度得分
 
87
  rec_ids, scores = retriever.search_return_id(query, top_k)
88
-
89
- # 将检索结果构造成 DataFrame
90
- results_df = pd.DataFrame({
91
- "tool": rec_ids,
92
- "relevance": scores
93
- })
94
 
95
  st.subheader("🗂️ 结果详情")
96
 
97
- # 为 DataFrame 添加样式(交替行背景色)
98
  styled_results = results_df.style.apply(
99
  lambda x: [
100
  "background-color: #F7F7F7" if i % 2 == 0 else "background-color: #FFFFFF"
@@ -103,7 +97,6 @@ if st.button("搜索") and query:
103
  axis=0,
104
  ).format({"relevance": "{:.4f}"})
105
 
106
- # 使用交互式数据表格展示结果,并配置列样式
107
  st.dataframe(
108
  styled_results,
109
  column_config={
 
7
  from text2vec import SentenceModel
8
  from src.jsonl_Indexer import JSONLIndexer
9
 
 
10
  def get_cli_args():
11
  args = {}
 
12
  argv = sys.argv[2:] if len(sys.argv) > 2 else []
13
  for arg in argv:
14
  if '=' in arg:
 
16
  args[key.strip()] = value.strip()
17
  return args
18
 
 
19
  cli_args = get_cli_args()
20
 
 
21
  DEFAULT_CONFIG = {
22
  'model_path': 'BAAI/bge-base-en-v1.5',
23
+ 'dataset_path': 'tool-embedding.jsonl',
24
  'vector_size': 768,
25
+ 'embedding_field': 'embedding',
26
+ 'id_field': 'id'
27
  }
28
 
 
29
  config = DEFAULT_CONFIG.copy()
30
  config.update(cli_args)
31
  config['vector_size'] = int(config['vector_size'])
32
 
33
  @st.cache_resource
34
  def get_model(model_path: str = config['model_path']):
35
+ return SentenceModel(model_path)
 
36
 
37
  @st.cache_resource
38
  def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
 
40
  retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
41
  return retriever
42
 
 
43
  if st.sidebar.checkbox("Show Configuration"):
44
  st.sidebar.write("Current Configuration:")
45
  for key, value in config.items():
46
  st.sidebar.write(f"{key}: {value}")
47
 
 
48
  model = get_model(config['model_path'])
49
+ retriever = create_retriever(config['vector_size'], config['dataset_path'], config['embedding_field'], config['id_field'], _model=model)
 
 
 
 
 
 
50
 
51
+ # 美化界面
52
+ st.markdown("""
53
+ <style>
54
+ .search-container {
55
+ display: flex;
56
+ justify-content: center;
57
+ align-items: center;
58
+ gap: 10px;
59
+ margin-top: 20px;
60
+ }
61
+ .search-box input {
62
+ width: 500px !important;
63
+ height: 45px;
64
+ font-size: 16px;
65
+ border-radius: 25px;
66
+ padding-left: 15px;
67
+ }
68
+ .search-btn button {
69
+ height: 45px;
70
+ font-size: 16px;
71
+ border-radius: 25px;
72
+ }
73
+ </style>
74
+ """, unsafe_allow_html=True)
75
 
76
+ st.title("🔍 Tool Search")
77
+
78
+ col1, col2 = st.columns([4, 1])
 
79
  with col1:
80
+ query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed")
 
 
 
81
  with col2:
82
+ search_clicked = st.button("🔎 Search", use_container_width=True)
 
 
 
 
83
 
84
+ top_k = st.slider("Number of Results", 1, 100, 50, help="Choose the number of results to display")
85
+
86
+ if search_clicked and query:
87
  rec_ids, scores = retriever.search_return_id(query, top_k)
88
+ results_df = pd.DataFrame({"tool": rec_ids, "relevance": scores})
 
 
 
 
 
89
 
90
  st.subheader("🗂️ 结果详情")
91
 
 
92
  styled_results = results_df.style.apply(
93
  lambda x: [
94
  "background-color: #F7F7F7" if i % 2 == 0 else "background-color: #FFFFFF"
 
97
  axis=0,
98
  ).format({"relevance": "{:.4f}"})
99
 
 
100
  st.dataframe(
101
  styled_results,
102
  column_config={