Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -53,17 +53,19 @@ document_texts = {} # Stores {doc_id: doc_text}
|
|
53 |
initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index
|
54 |
|
55 |
|
56 |
-
# --- Load
|
57 |
-
|
|
|
58 |
global document_texts
|
59 |
-
print("Loading
|
60 |
try:
|
61 |
-
dataset
|
62 |
-
|
|
|
63 |
document_texts[doc.doc_id] = doc.text.strip()
|
64 |
-
print(f"Loaded {len(document_texts)} documents from
|
65 |
except Exception as e:
|
66 |
-
print(f"Error loading
|
67 |
print("Please ensure 'ir_datasets' is installed and your internet connection is stable.")
|
68 |
|
69 |
|
@@ -88,8 +90,6 @@ def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
|
88 |
|
89 |
|
90 |
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
|
91 |
-
# These are your original functions, re-added.
|
92 |
-
|
93 |
def get_splade_cocondenser_representation(text):
|
94 |
if tokenizer_splade is None or model_splade is None:
|
95 |
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
|
@@ -254,8 +254,6 @@ def predict_representation_explorer(model_choice, text):
|
|
254 |
|
255 |
|
256 |
# --- Internal Core Representation Functions (Return Raw Vectors - for Retrieval Tab) ---
|
257 |
-
# These are the ones ending with _internal, as previously defined.
|
258 |
-
|
259 |
def get_splade_cocondenser_representation_internal(text, tokenizer, model):
|
260 |
if tokenizer is None or model is None: return None
|
261 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
@@ -400,7 +398,8 @@ def predict_retrieval_gradio(query_text, query_model_choice, selected_doc_model_
|
|
400 |
|
401 |
# --- Initial Load and Indexing Calls ---
|
402 |
# This part runs once when the app starts.
|
403 |
-
|
|
|
404 |
|
405 |
if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None:
|
406 |
index_documents(initial_doc_model_for_indexing)
|
@@ -440,13 +439,11 @@ with gr.Blocks(title="SPLADE Demos") as demo:
|
|
440 |
],
|
441 |
outputs=gr.Markdown(),
|
442 |
allow_flagging="never",
|
443 |
-
#
|
444 |
-
# Setting live=True might be slow for complex models on every keystroke
|
445 |
-
# live=True
|
446 |
)
|
447 |
|
448 |
with gr.TabItem("Document Retrieval Demo"):
|
449 |
-
gr.Markdown("### Retrieve Documents from
|
450 |
gr.Interface(
|
451 |
fn=predict_retrieval_gradio,
|
452 |
inputs=[
|
|
|
53 |
initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index
|
54 |
|
55 |
|
56 |
+
# --- Load Cranfield Corpus using ir_datasets ---
|
57 |
+
# Renamed function for clarity, but kept original name for call consistency
|
58 |
+
def load_cranfield_corpus_ir_datasets():
|
59 |
global document_texts
|
60 |
+
print("Loading Cranfield corpus using ir_datasets...")
|
61 |
try:
|
62 |
+
# --- IMPORTANT CHANGE: Loading 'cranfield' dataset ---
|
63 |
+
dataset = ir_datasets.load("cranfield")
|
64 |
+
for doc in tqdm(dataset.docs_iter(), desc="Loading Cranfield documents"):
|
65 |
document_texts[doc.doc_id] = doc.text.strip()
|
66 |
+
print(f"Loaded {len(document_texts)} documents from Cranfield corpus.")
|
67 |
except Exception as e:
|
68 |
+
print(f"Error loading Cranfield corpus with ir_datasets: {e}")
|
69 |
print("Please ensure 'ir_datasets' is installed and your internet connection is stable.")
|
70 |
|
71 |
|
|
|
90 |
|
91 |
|
92 |
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
|
|
|
|
|
93 |
def get_splade_cocondenser_representation(text):
|
94 |
if tokenizer_splade is None or model_splade is None:
|
95 |
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
|
|
|
254 |
|
255 |
|
256 |
# --- Internal Core Representation Functions (Return Raw Vectors - for Retrieval Tab) ---
|
|
|
|
|
257 |
def get_splade_cocondenser_representation_internal(text, tokenizer, model):
|
258 |
if tokenizer is None or model is None: return None
|
259 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
|
|
398 |
|
399 |
# --- Initial Load and Indexing Calls ---
|
400 |
# This part runs once when the app starts.
|
401 |
+
# --- IMPORTANT CHANGE: Calling the function that loads Cranfield ---
|
402 |
+
load_cranfield_corpus_ir_datasets()
|
403 |
|
404 |
if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None:
|
405 |
index_documents(initial_doc_model_for_indexing)
|
|
|
439 |
],
|
440 |
outputs=gr.Markdown(),
|
441 |
allow_flagging="never",
|
442 |
+
# live=True # Setting live=True might be slow for complex models on every keystroke
|
|
|
|
|
443 |
)
|
444 |
|
445 |
with gr.TabItem("Document Retrieval Demo"):
|
446 |
+
gr.Markdown("### Retrieve Documents from Cranfield Collection") # Changed title
|
447 |
gr.Interface(
|
448 |
fn=predict_retrieval_gradio,
|
449 |
inputs=[
|