Yyy0530 commited on
Commit
a302d1e
·
verified ·
1 Parent(s): 2deb0bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -13
app.py CHANGED
@@ -20,7 +20,7 @@ cli_args = get_cli_args()
20
 
21
  DEFAULT_CONFIG = {
22
  'model_path': 'BAAI/bge-base-en-v1.5',
23
- 'dataset_path': 'tool-embedding.jsonl',
24
  'vector_size': 768,
25
  'embedding_field': 'embedding',
26
  'id_field': 'id'
@@ -30,6 +30,24 @@ config = DEFAULT_CONFIG.copy()
30
  config.update(cli_args)
31
  config['vector_size'] = int(config['vector_size'])
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  @st.cache_resource
34
  def get_model(model_path: str = config['model_path']):
35
  return SentenceModel(model_path)
@@ -40,10 +58,19 @@ def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id
40
  retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
41
  return retriever
42
 
43
- if st.sidebar.checkbox("Show Configuration"):
44
- st.sidebar.write("Current Configuration:")
45
- for key, value in config.items():
46
- st.sidebar.write(f"{key}: {value}")
 
 
 
 
 
 
 
 
 
47
 
48
  model = get_model(config['model_path'])
49
  retriever = create_retriever(config['vector_size'], config['dataset_path'], config['embedding_field'], config['id_field'], _model=model)
@@ -73,7 +100,8 @@ st.markdown("""
73
  </style>
74
  """, unsafe_allow_html=True)
75
 
76
- st.title("🔍 Tool Retrieval")
 
77
 
78
  col1, col2 = st.columns([4, 1])
79
  with col1:
@@ -81,13 +109,18 @@ with col1:
81
  with col2:
82
  search_clicked = st.button("🔎 Search", use_container_width=True)
83
 
84
- top_k = st.slider("Number of Results", 1, 4453, 4453, help="Choose the number of results to display")
85
 
 
86
  if search_clicked and query:
87
  rec_ids, scores = retriever.search_return_id(query, top_k)
88
- results_df = pd.DataFrame({"tool": rec_ids, "relevance": scores})
89
-
90
- st.subheader("🗂️ 结果详情")
 
 
 
 
91
 
92
  styled_results = results_df.style.apply(
93
  lambda x: [
@@ -96,18 +129,18 @@ if search_clicked and query:
96
  ],
97
  axis=0,
98
  ).format({"relevance": "{:.4f}"})
99
-
100
  st.dataframe(
101
  styled_results,
102
  column_config={
103
- "tool": st.column_config.TextColumn("tool", help="tool help text", width="medium"),
104
  "relevance": st.column_config.ProgressColumn(
105
  "relevance",
106
  help="记录与查询的匹配程度",
107
  format="%.4f",
108
  min_value=0,
109
  max_value=float(max(scores)) if len(scores) > 0 else 1,
110
- )
 
111
  },
112
  hide_index=True,
113
  use_container_width=True,
 
20
 
21
  DEFAULT_CONFIG = {
22
  'model_path': 'BAAI/bge-base-en-v1.5',
23
+ 'dataset_path': 'extracted_tool_embedding.jsonl',
24
  'vector_size': 768,
25
  'embedding_field': 'embedding',
26
  'id_field': 'id'
 
30
  config.update(cli_args)
31
  config['vector_size'] = int(config['vector_size'])
32
 
33
+
34
+ #加载数据
35
+ from datasets import load_dataset
36
+ from datasets import concatenate_datasets
37
+ ds1 = load_dataset("mangopy/ToolRet-Tools", "code")
38
+ ds2 = load_dataset("mangopy/ToolRet-Tools", "customized")
39
+ ds3 = load_dataset("mangopy/ToolRet-Tools", "web")
40
+ ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']])
41
+ ds = ds.rename_columns({'id':'tool'})
42
+
43
+ #merge
44
+
45
+ # 随便建立一个pd.DataFrame, 有两列,一列是id,一列是text
46
+ import pandas as pd
47
+ df2 = ds.to_pandas()
48
+
49
+
50
+
51
  @st.cache_resource
52
  def get_model(model_path: str = config['model_path']):
53
  return SentenceModel(model_path)
 
58
  retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
59
  return retriever
60
 
61
+ # 在侧边栏中添加模型配置标题
62
+ st.sidebar.markdown("<div style='text-align: center;'><h3>📄 Model Configuration</h3></div>", unsafe_allow_html=True)
63
+
64
+
65
+ # 添加模型选项下拉框,目前只有一个模型可选
66
+ model_options = ["BAAI/bge-base-en-v1.5"]
67
+ selected_model = st.sidebar.selectbox("Select Model", model_options)
68
+ st.sidebar.write("Selected model:", selected_model)
69
+ st.sidebar.write("Embedding length: 768")
70
+
71
+ # 使用选中的模型加载
72
+ model = get_model(selected_model)
73
+
74
 
75
  model = get_model(config['model_path'])
76
  retriever = create_retriever(config['vector_size'], config['dataset_path'], config['embedding_field'], config['id_field'], _model=model)
 
100
  </style>
101
  """, unsafe_allow_html=True)
102
 
103
+ st.markdown("<h1 style='text-align: center;'>🔍 Tool Retrieval</h1>", unsafe_allow_html=True)
104
+
105
 
106
  col1, col2 = st.columns([4, 1])
107
  with col1:
 
109
  with col2:
110
  search_clicked = st.button("🔎 Search", use_container_width=True)
111
 
112
+ top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display")
113
 
114
+ styled_results = None
115
  if search_clicked and query:
116
  rec_ids, scores = retriever.search_return_id(query, top_k)
117
+ df1 = pd.DataFrame({ "relevance": scores, "tool": rec_ids})
118
+ # print(df1)
119
+ # merge两个DataFrame
120
+ results_df = pd.merge(df1, df2, on='tool', how = 'left')
121
+
122
+ # results_df["interface"] = "asdasdadasdasdasdasdasdasdasasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdassdasdasdasdasdasabababbabasdbabsdbasbdadabdbasdbasbdbasdbasdbasdb"
123
+ st.subheader("🗂️ Retrieval results")
124
 
125
  styled_results = results_df.style.apply(
126
  lambda x: [
 
129
  ],
130
  axis=0,
131
  ).format({"relevance": "{:.4f}"})
132
+
133
  st.dataframe(
134
  styled_results,
135
  column_config={
 
136
  "relevance": st.column_config.ProgressColumn(
137
  "relevance",
138
  help="记录与查询的匹配程度",
139
  format="%.4f",
140
  min_value=0,
141
  max_value=float(max(scores)) if len(scores) > 0 else 1,
142
+ ),
143
+ "tool": st.column_config.TextColumn("tool", help="tool help text", width="medium")
144
  },
145
  hide_index=True,
146
  use_container_width=True,