SiddharthAK commited on
Commit
44519b1
·
verified ·
1 Parent(s): c3ddf27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -16
app.py CHANGED
@@ -53,17 +53,19 @@ document_texts = {} # Stores {doc_id: doc_text}
53
  initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index
54
 
55
 
56
- # --- Load SciFact Corpus using ir_datasets ---
57
- def load_scifact_corpus_ir_datasets():
 
58
  global document_texts
59
- print("Loading SciFact corpus using ir_datasets...")
60
  try:
61
- dataset = ir_datasets.load("scifact")
62
- for doc in tqdm(dataset.docs_iter(), desc="Loading SciFact documents"):
 
63
  document_texts[doc.doc_id] = doc.text.strip()
64
- print(f"Loaded {len(document_texts)} documents from SciFact corpus.")
65
  except Exception as e:
66
- print(f"Error loading SciFact corpus with ir_datasets: {e}")
67
  print("Please ensure 'ir_datasets' is installed and your internet connection is stable.")
68
 
69
 
@@ -88,8 +90,6 @@ def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
88
 
89
 
90
  # --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
91
- # These are your original functions, re-added.
92
-
93
  def get_splade_cocondenser_representation(text):
94
  if tokenizer_splade is None or model_splade is None:
95
  return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
@@ -254,8 +254,6 @@ def predict_representation_explorer(model_choice, text):
254
 
255
 
256
  # --- Internal Core Representation Functions (Return Raw Vectors - for Retrieval Tab) ---
257
- # These are the ones ending with _internal, as previously defined.
258
-
259
  def get_splade_cocondenser_representation_internal(text, tokenizer, model):
260
  if tokenizer is None or model is None: return None
261
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
@@ -400,7 +398,8 @@ def predict_retrieval_gradio(query_text, query_model_choice, selected_doc_model_
400
 
401
  # --- Initial Load and Indexing Calls ---
402
  # This part runs once when the app starts.
403
- load_scifact_corpus_ir_datasets() # Or load_cranfield_corpus_ir_datasets() if you switch back
 
404
 
405
  if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None:
406
  index_documents(initial_doc_model_for_indexing)
@@ -440,13 +439,11 @@ with gr.Blocks(title="SPLADE Demos") as demo:
440
  ],
441
  outputs=gr.Markdown(),
442
  allow_flagging="never",
443
- # Don't show redundant title/description within the tab, as it's above
444
- # Setting live=True might be slow for complex models on every keystroke
445
- # live=True
446
  )
447
 
448
  with gr.TabItem("Document Retrieval Demo"):
449
- gr.Markdown("### Retrieve Documents from SciFact Collection")
450
  gr.Interface(
451
  fn=predict_retrieval_gradio,
452
  inputs=[
 
53
  initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index
54
 
55
 
56
+ # --- Load Cranfield Corpus using ir_datasets ---
57
+ # Renamed function for clarity, but kept original name for call consistency
58
+ def load_cranfield_corpus_ir_datasets():
59
  global document_texts
60
+ print("Loading Cranfield corpus using ir_datasets...")
61
  try:
62
+ # --- IMPORTANT CHANGE: Loading 'cranfield' dataset ---
63
+ dataset = ir_datasets.load("cranfield")
64
+ for doc in tqdm(dataset.docs_iter(), desc="Loading Cranfield documents"):
65
  document_texts[doc.doc_id] = doc.text.strip()
66
+ print(f"Loaded {len(document_texts)} documents from Cranfield corpus.")
67
  except Exception as e:
68
+ print(f"Error loading Cranfield corpus with ir_datasets: {e}")
69
  print("Please ensure 'ir_datasets' is installed and your internet connection is stable.")
70
 
71
 
 
90
 
91
 
92
  # --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
 
 
93
  def get_splade_cocondenser_representation(text):
94
  if tokenizer_splade is None or model_splade is None:
95
  return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
 
254
 
255
 
256
  # --- Internal Core Representation Functions (Return Raw Vectors - for Retrieval Tab) ---
 
 
257
  def get_splade_cocondenser_representation_internal(text, tokenizer, model):
258
  if tokenizer is None or model is None: return None
259
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
 
398
 
399
  # --- Initial Load and Indexing Calls ---
400
  # This part runs once when the app starts.
401
+ # --- IMPORTANT CHANGE: Calling the function that loads Cranfield ---
402
+ load_cranfield_corpus_ir_datasets()
403
 
404
  if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None:
405
  index_documents(initial_doc_model_for_indexing)
 
439
  ],
440
  outputs=gr.Markdown(),
441
  allow_flagging="never",
442
+ # live=True # Setting live=True might be slow for complex models on every keystroke
 
 
443
  )
444
 
445
  with gr.TabItem("Document Retrieval Demo"):
446
+ gr.Markdown("### Retrieve Documents from Cranfield Collection") # Changed title
447
  gr.Interface(
448
  fn=predict_retrieval_gradio,
449
  inputs=[