jackkuo commited on
Commit
7bab2b1
·
verified ·
1 Parent(s): 6364266

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -43
app.py CHANGED
@@ -12,17 +12,9 @@ import doi
12
  import requests
13
  from datetime import datetime, timedelta
14
 
15
-
16
  API_URL = (
17
  "https://api-inference.huggingface.co/models/mixedbread-ai/mxbai-embed-large-v1"
18
  )
19
- summarization_API_URL = (
20
- "https://api-inference.huggingface.co/models/Falconsai/text_summarization"
21
- )
22
-
23
- LLM_API_URL = (
24
- "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
25
- )
26
 
27
  from openai import OpenAI
28
 
@@ -33,6 +25,13 @@ client_openai = OpenAI(
33
  api_key=api_key,
34
  base_url=base_url,
35
  )
 
 
 
 
 
 
 
36
 
37
  API_TOKEN = os.getenv('hf_token')
38
 
@@ -40,7 +39,6 @@ headers = {"Authorization": f"Bearer {API_TOKEN}"}
40
 
41
 
42
  def query_hf_api(text, api=API_URL, parameters=None):
43
-
44
  if not parameters:
45
  payload = {"inputs": text}
46
  else:
@@ -64,8 +62,9 @@ def query_hf_api(text, api=API_URL, parameters=None):
64
  if "error" in response_data and "loading" in response_data["error"]:
65
  estimated_time = response_data.get("estimated_time", 30) # Default wait time to 30 seconds if not provided
66
  with progress_placeholder.container():
67
- st.warning(f"Model from :hugging_face: is currently loading. Estimated wait time: {estimated_time:.1f} seconds. Please wait...")
68
-
 
69
  # Create a progress bar within the container
70
  progress_bar = st.progress(0)
71
  for i in range(int(estimated_time) + 5): # Adding a buffer time to ensure the model is loaded
@@ -76,7 +75,7 @@ def query_hf_api(text, api=API_URL, parameters=None):
76
 
77
  # Clear the placeholder once loading is complete
78
  progress_placeholder.empty()
79
-
80
  st.rerun() # Rerun the app after waiting
81
 
82
  return response_data
@@ -99,7 +98,7 @@ def normalize_embeddings(embeddings):
99
 
100
 
101
  def quantize_embeddings(
102
- embeddings, precision="ubinary", ranges=None, calibration_embeddings=None
103
  ):
104
  """
105
  Quantizes embeddings to a specified precision using PyTorch and numpy.
@@ -184,7 +183,7 @@ def process_embeddings(embeddings, precision="ubinary", calibration_embeddings=N
184
  calibration_embeddings, dtype=torch.float32
185
  )
186
  elif calibration_embeddings is not None and not isinstance(
187
- calibration_embeddings, torch.Tensor
188
  ):
189
  raise TypeError(
190
  "Calibration embeddings must be a list or a torch.Tensor if provided. "
@@ -209,7 +208,7 @@ def load_data_embeddings():
209
 
210
  new_data_directory_med = "db_update_med"
211
  updated_embeddings_directory_med = "embed_update_med"
212
-
213
  # Load existing database and embeddings
214
  df_existing = pd.read_parquet(existing_data_path)
215
  embeddings_existing = np.load(existing_embeddings_path, allow_pickle=True)
@@ -229,18 +228,19 @@ def load_data_embeddings():
229
  print(new_data_files)
230
  for data_file in new_data_files:
231
  corresponding_embedding_file = Path(updated_embeddings_directory) / (
232
- data_file.stem + ".npy"
233
  )
234
 
235
  if corresponding_embedding_file.exists():
236
  df = pd.read_parquet(data_file)
237
  new_embeddings = np.load(corresponding_embedding_file, allow_pickle=True)
238
-
239
  # Check if the number of rows in the DataFrame matches the number of rows in the embeddings
240
  if df.shape[0] != new_embeddings.shape[0]:
241
- print(f"Shape mismatch for {data_file.name}: DataFrame has {df.shape[0]} rows, embeddings have {new_embeddings.shape[0]} rows. Skipping.")
 
242
  continue
243
-
244
  # Check embedding size and adjust if necessary
245
  if new_embeddings.shape[1] != embedding_size:
246
  print(f"Skipping {data_file.name} due to embedding size mismatch.")
@@ -288,14 +288,15 @@ def load_data_embeddings():
288
 
289
  LLM_prompt = "Review the abstracts listed above and create a list and summary that captures their main themes and findings. Identify any commonalities across the abstracts and highlight these in your summary. Ensure your response is concise, avoids external links, and is formatted in markdown.\n\n"
290
 
 
291
  def summarize_abstract(abstract, llm_model="llama-3.1-70b-versatile", instructions=LLM_prompt, api_key=""):
292
  """
293
  Summarizes the provided abstract using a specified LLM model.
294
-
295
  Parameters:
296
  - abstract (str): The abstract text to be summarized.
297
  - llm_model (str): The LLM model used for summarization. Defaults to "llama-3.1-70b-versatile".
298
-
299
  Returns:
300
  - str: A summary of the abstract, condensed into one to two sentences.
301
  """
@@ -311,8 +312,90 @@ def summarize_abstract(abstract, llm_model="llama-3.1-70b-versatile", instructio
311
  except Exception as e: # Catch the exception
312
  print(f"An error occurred: {e}") # Print the error
313
  return 'LLM API not available or above the usage limit.'
314
-
315
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  # Return the summarized content
317
  return chat_completion.choices[0].message.content
318
 
@@ -381,13 +464,13 @@ def logo(db_update_date, db_size_bio, db_size_med):
381
 
382
 
383
  st.set_page_config(
384
- page_title="bMSS",
385
  page_icon=":scroll:",
386
  )
387
  define_style()
388
 
389
  df, embeddings_unique = load_data_embeddings()
390
- logo(df["date"].max(), df[df['server']=='biorxiv'].shape[0], df[df['server']=='medrxiv'].shape[0])
391
 
392
  # model = model_to_device()
393
 
@@ -471,7 +554,7 @@ if query or search_button:
471
  search_df["score"] = abs(search_df["score"] - max_score) + min_score
472
 
473
  abstracts = []
474
-
475
  # Iterate over each row in the search_df DataFrame
476
  for index, entry in search_df.iterrows():
477
  row = df.iloc[int(entry["corpus_id"])]
@@ -480,7 +563,7 @@ if query or search_button:
480
  try:
481
  doi_link = f"{doi.get_real_url_from_doi(row['doi'])}"
482
  except:
483
- doi_link = f'https://www.doi.org/'+row['doi']
484
 
485
  # Append information to plot_data for visualization
486
  plot_data["Date"].append(row["date"])
@@ -490,9 +573,7 @@ if query or search_button:
490
  plot_data["category"].append(row["category"])
491
  plot_data["server"].append(row["server"])
492
 
493
- #summary_text = summarize_abstract(row['abstract'])
494
-
495
- with st.expander(f"{index+1}\. {row['title']}"): # type: ignore
496
  col1, col2 = st.columns(2)
497
  col1.markdown(f"**Score:** {entry['score']:.1f}")
498
  col2.markdown(f"**Server:** [{row['server']}]")
@@ -500,7 +581,7 @@ if query or search_button:
500
  col1, col2 = st.columns(2)
501
  col2.markdown(f"**Category:** {row['category']}")
502
  col1.markdown(f"**Date:** {row['date']}")
503
- #st.markdown(f"**Summary:**\n{summary_text}", unsafe_allow_html=False)
504
  abstracts.append(row['abstract'])
505
  st.markdown(
506
  f"**Abstract:**\n{row['abstract']}", unsafe_allow_html=False
@@ -508,17 +589,33 @@ if query or search_button:
508
  st.markdown(
509
  f"**[Full Text Read]({doi_link})** 🔗", unsafe_allow_html=True
510
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
  if plot_data:
512
  with st.spinner("Under statistics..."):
513
  plot_df = pd.DataFrame(plot_data)
514
-
515
  # Convert 'Date' to datetime if it's not already in that format
516
  plot_df["Date"] = pd.to_datetime(plot_df["Date"])
517
-
518
  # Sort the DataFrame based on the Date to make sure it's ordered
519
  plot_df = plot_df.sort_values(by="Date")
520
-
521
-
522
  # Create a Plotly figure
523
  fig = px.scatter(
524
  plot_df,
@@ -534,14 +631,14 @@ if query or search_button:
534
  hovertemplate="<b>%{hovertext}</b>",
535
  hovertext=plot_df.apply(lambda row: f"{row['Title']}", axis=1),
536
  )
537
-
538
  # Show the figure in the Streamlit app
539
  st.plotly_chart(fig, use_container_width=True)
540
-
541
  # Generate category counts for the pie chart
542
  category_counts = plot_df["category"].value_counts().reset_index()
543
  category_counts.columns = ["category", "count"]
544
-
545
  # Create a pie chart with Plotly Express
546
  fig = px.pie(
547
  category_counts,
@@ -549,18 +646,17 @@ if query or search_button:
549
  names="category",
550
  title="Category Distribution",
551
  )
552
-
553
  # Show the pie chart in the Streamlit app
554
  st.plotly_chart(fig, use_container_width=True)
555
-
556
  if abstracts:
557
  with st.spinner("LLM is summarizing..."):
558
  prompt = st.text_area("Enter your summary prompt", value=LLM_prompt)
559
- summary_button = st.button("AI summary")
560
  if summary_button:
561
  ai_gen_start = time.time()
562
  st.markdown('**AI Summary of 10 abstracts:**')
563
  st.markdown(summarize_abstract(abstracts[:9], instructions=prompt))
564
- total_ai_time = time.time()-ai_gen_start
565
  st.markdown(f'**Time to generate summary:** {total_ai_time:.2f} s')
566
-
 
12
  import requests
13
  from datetime import datetime, timedelta
14
 
 
15
  API_URL = (
16
  "https://api-inference.huggingface.co/models/mixedbread-ai/mxbai-embed-large-v1"
17
  )
 
 
 
 
 
 
 
18
 
19
  from openai import OpenAI
20
 
 
25
  api_key=api_key,
26
  base_url=base_url,
27
  )
28
+ api_key_kimi = os.getenv('API_KEY_KIMI')
29
+ base_url_kimi = os.getenv("BASE_URL_KIMI")
30
+
31
+ client_openai_kimi = OpenAI(
32
+ api_key=api_key_kimi,
33
+ base_url=base_url_kimi,
34
+ )
35
 
36
  API_TOKEN = os.getenv('hf_token')
37
 
 
39
 
40
 
41
  def query_hf_api(text, api=API_URL, parameters=None):
 
42
  if not parameters:
43
  payload = {"inputs": text}
44
  else:
 
62
  if "error" in response_data and "loading" in response_data["error"]:
63
  estimated_time = response_data.get("estimated_time", 30) # Default wait time to 30 seconds if not provided
64
  with progress_placeholder.container():
65
+ st.warning(
66
+ f"Model from :hugging_face: is currently loading. Estimated wait time: {estimated_time:.1f} seconds. Please wait...")
67
+
68
  # Create a progress bar within the container
69
  progress_bar = st.progress(0)
70
  for i in range(int(estimated_time) + 5): # Adding a buffer time to ensure the model is loaded
 
75
 
76
  # Clear the placeholder once loading is complete
77
  progress_placeholder.empty()
78
+
79
  st.rerun() # Rerun the app after waiting
80
 
81
  return response_data
 
98
 
99
 
100
  def quantize_embeddings(
101
+ embeddings, precision="ubinary", ranges=None, calibration_embeddings=None
102
  ):
103
  """
104
  Quantizes embeddings to a specified precision using PyTorch and numpy.
 
183
  calibration_embeddings, dtype=torch.float32
184
  )
185
  elif calibration_embeddings is not None and not isinstance(
186
+ calibration_embeddings, torch.Tensor
187
  ):
188
  raise TypeError(
189
  "Calibration embeddings must be a list or a torch.Tensor if provided. "
 
208
 
209
  new_data_directory_med = "db_update_med"
210
  updated_embeddings_directory_med = "embed_update_med"
211
+
212
  # Load existing database and embeddings
213
  df_existing = pd.read_parquet(existing_data_path)
214
  embeddings_existing = np.load(existing_embeddings_path, allow_pickle=True)
 
228
  print(new_data_files)
229
  for data_file in new_data_files:
230
  corresponding_embedding_file = Path(updated_embeddings_directory) / (
231
+ data_file.stem + ".npy"
232
  )
233
 
234
  if corresponding_embedding_file.exists():
235
  df = pd.read_parquet(data_file)
236
  new_embeddings = np.load(corresponding_embedding_file, allow_pickle=True)
237
+
238
  # Check if the number of rows in the DataFrame matches the number of rows in the embeddings
239
  if df.shape[0] != new_embeddings.shape[0]:
240
+ print(
241
+ f"Shape mismatch for {data_file.name}: DataFrame has {df.shape[0]} rows, embeddings have {new_embeddings.shape[0]} rows. Skipping.")
242
  continue
243
+
244
  # Check embedding size and adjust if necessary
245
  if new_embeddings.shape[1] != embedding_size:
246
  print(f"Skipping {data_file.name} due to embedding size mismatch.")
 
288
 
289
  LLM_prompt = "Review the abstracts listed above and create a list and summary that captures their main themes and findings. Identify any commonalities across the abstracts and highlight these in your summary. Ensure your response is concise, avoids external links, and is formatted in markdown.\n\n"
290
 
291
+
292
  def summarize_abstract(abstract, llm_model="llama-3.1-70b-versatile", instructions=LLM_prompt, api_key=""):
293
  """
294
  Summarizes the provided abstract using a specified LLM model.
295
+
296
  Parameters:
297
  - abstract (str): The abstract text to be summarized.
298
  - llm_model (str): The LLM model used for summarization. Defaults to "llama-3.1-70b-versatile".
299
+
300
  Returns:
301
  - str: A summary of the abstract, condensed into one to two sentences.
302
  """
 
312
  except Exception as e: # Catch the exception
313
  print(f"An error occurred: {e}") # Print the error
314
  return 'LLM API not available or above the usage limit.'
315
+
316
+ # Return the summarized content
317
+ return chat_completion.choices[0].message.content
318
+
319
+
320
+ def summarize_abstract_kimi(title, link):
321
+ """
322
+ Summarizes the provided abstract using a specified LLM model.
323
+
324
+ Parameters:
325
+ - abstract (str): The abstract text to be summarized.
326
+ - llm_model (str): The LLM model used for summarization. Defaults to "llama-3.1-70b-versatile".
327
+
328
+ Returns:
329
+ - str: A summary of the abstract, condensed into one to two sentences.
330
+ """
331
+ print("use openai api: moonshot-v1-32k")
332
+ print(title, link)
333
+ client = client_openai_kimi
334
+ formatted_text = "The paper we are going to discuss is "+ title +". The link is"+link+""" .
335
+ Please use this as a basis to answer my questions. Please output your answers according to the following format. Please pay attention to the logic of subheading stratification and ensure that each layer includes 4-10 points.
336
+ **Q: What problem does this paper try to solve?**
337
+
338
+ A: [Use one sentence to summarize what problem this paper tries to solve]
339
+ 1. Subheading 1: [Content under subheading 1]
340
+ 2. Subheading 2: [Content under subheading 2]
341
+ 3. Subheading 3: […]
342
+ […]
343
+
344
+ ** Q: What are the related studies?**
345
+
346
+ A: [Use one sentence to summarize the relevant research]
347
+ 1. Subheading 1: [Subheading 1]
348
+ 2. Subheading 2: [Subheading 2]
349
+ 3. Subheading 3: […]
350
+ […]
351
+
352
+ ** Q: How does the paper solve this problem?**
353
+
354
+ A: [Use one sentence to summarize how the paper solves this problem]
355
+ 1. Subheading 1: [Subheading 1]
356
+ 2. Subheading 2: [Subheading 2]
357
+ 3. Subheading 3: […]
358
+ […]
359
+
360
+ ** Q: What experiments were done in the paper?**
361
+
362
+ A: [Use one sentence to summarize the experiments done in the paper]
363
+ 1. Subheading 1: [Subheading 1]
364
+ 2. Subheading 2: [Subheading 2]
365
+ 3. Subheading 3: […]
366
+ […]
367
+
368
+ ** Q: Is there anything that can be further explored?**
369
+
370
+ A: [Use one sentence here to summarize what can be further explored]
371
+ 1. Subheading 1: [Content under subheading 1]
372
+ 2. Subheading 2: [Content under subheading 2]
373
+ 3. Subheading 3: […]
374
+ […]
375
+
376
+ ** Q: Summarize the main content of the paper**
377
+
378
+ A: [Use one sentence here to summarize the main content of the paper]
379
+ 1. Research background: […]
380
+ 2. Research methods: […]
381
+ 3. Experimental design: […]
382
+ 4. Main findings: […]
383
+ 5. Research contributions: […]
384
+ 6. Future research directions: […]
385
+ 7. Methods and tools: […]
386
+ 8. Dataset: […]
387
+ 9. Conclusion: […]
388
+ """
389
+ try:
390
+ # Create a chat completion with the abstract and specified LLM model
391
+ chat_completion = client.chat.completions.create(
392
+ messages=[{"role": "user", "content": f'"{formatted_text}"'}],
393
+ model="moonshot-v1-32k",
394
+ )
395
+ except Exception as e: # Catch the exception
396
+ print(f"An error occurred: {e}") # Print the error
397
+ return 'LLM API not available or above the usage limit.'
398
+
399
  # Return the summarized content
400
  return chat_completion.choices[0].message.content
401
 
 
464
 
465
 
466
  st.set_page_config(
467
+ page_title="BioRxiv Search",
468
  page_icon=":scroll:",
469
  )
470
  define_style()
471
 
472
  df, embeddings_unique = load_data_embeddings()
473
+ logo(df["date"].max(), df[df['server'] == 'biorxiv'].shape[0], df[df['server'] == 'medrxiv'].shape[0])
474
 
475
  # model = model_to_device()
476
 
 
554
  search_df["score"] = abs(search_df["score"] - max_score) + min_score
555
 
556
  abstracts = []
557
+
558
  # Iterate over each row in the search_df DataFrame
559
  for index, entry in search_df.iterrows():
560
  row = df.iloc[int(entry["corpus_id"])]
 
563
  try:
564
  doi_link = f"{doi.get_real_url_from_doi(row['doi'])}"
565
  except:
566
+ doi_link = f'https://www.doi.org/' + row['doi']
567
 
568
  # Append information to plot_data for visualization
569
  plot_data["Date"].append(row["date"])
 
573
  plot_data["category"].append(row["category"])
574
  plot_data["server"].append(row["server"])
575
 
576
+ with st.expander(f"{index + 1}\. {row['title']}"): # type: ignore
 
 
577
  col1, col2 = st.columns(2)
578
  col1.markdown(f"**Score:** {entry['score']:.1f}")
579
  col2.markdown(f"**Server:** [{row['server']}]")
 
581
  col1, col2 = st.columns(2)
582
  col2.markdown(f"**Category:** {row['category']}")
583
  col1.markdown(f"**Date:** {row['date']}")
584
+ # st.markdown(f"**Summary:**\n{summary_text}", unsafe_allow_html=False)
585
  abstracts.append(row['abstract'])
586
  st.markdown(
587
  f"**Abstract:**\n{row['abstract']}", unsafe_allow_html=False
 
589
  st.markdown(
590
  f"**[Full Text Read]({doi_link})** 🔗", unsafe_allow_html=True
591
  )
592
+ summary_button_one_paper = st.button("AI summary of this Paper", key="b_"+str(index+1))
593
+ if summary_button_one_paper:
594
+ with st.spinner("AI summary of this Paper..."):
595
+ ai_gen_start = time.time()
596
+ st.markdown('**AI summary of this Paper:**')
597
+ summary_of_this_Paper = summarize_abstract_kimi(title=row['title'], link=doi_link)
598
+ st.markdown(summary_of_this_Paper)
599
+ new_link = f"https://kimi.moonshot.cn/_prefill_chat?prefill_prompt=The paper we are going to discuss is {row['title']}, the link is {str(doi_link)} or https://www.biorxiv.org/content/{str(row['doi'])}v1 " \
600
+ f" Please use this as a basis to continue summarize this article and answer my follow-up questions &send_immediately=true&force_search=false"
601
+
602
+ total_ai_time = time.time() - ai_gen_start
603
+ st.markdown(f'**Time to generate summary:** {total_ai_time:.2f} s')
604
+
605
+ # Make sure the HTML link is formatted correctly
606
+ st.markdown(f'<a href="{new_link}" target="_blank">**Full Text Dialogue** 🔗</a>',
607
+ unsafe_allow_html=True)
608
+
609
  if plot_data:
610
  with st.spinner("Under statistics..."):
611
  plot_df = pd.DataFrame(plot_data)
612
+
613
  # Convert 'Date' to datetime if it's not already in that format
614
  plot_df["Date"] = pd.to_datetime(plot_df["Date"])
615
+
616
  # Sort the DataFrame based on the Date to make sure it's ordered
617
  plot_df = plot_df.sort_values(by="Date")
618
+
 
619
  # Create a Plotly figure
620
  fig = px.scatter(
621
  plot_df,
 
631
  hovertemplate="<b>%{hovertext}</b>",
632
  hovertext=plot_df.apply(lambda row: f"{row['Title']}", axis=1),
633
  )
634
+
635
  # Show the figure in the Streamlit app
636
  st.plotly_chart(fig, use_container_width=True)
637
+
638
  # Generate category counts for the pie chart
639
  category_counts = plot_df["category"].value_counts().reset_index()
640
  category_counts.columns = ["category", "count"]
641
+
642
  # Create a pie chart with Plotly Express
643
  fig = px.pie(
644
  category_counts,
 
646
  names="category",
647
  title="Category Distribution",
648
  )
649
+
650
  # Show the pie chart in the Streamlit app
651
  st.plotly_chart(fig, use_container_width=True)
652
+
653
  if abstracts:
654
  with st.spinner("LLM is summarizing..."):
655
  prompt = st.text_area("Enter your summary prompt", value=LLM_prompt)
656
+ summary_button = st.button("AI summary", key="b2")
657
  if summary_button:
658
  ai_gen_start = time.time()
659
  st.markdown('**AI Summary of 10 abstracts:**')
660
  st.markdown(summarize_abstract(abstracts[:9], instructions=prompt))
661
+ total_ai_time = time.time() - ai_gen_start
662
  st.markdown(f'**Time to generate summary:** {total_ai_time:.2f} s')