SiddharthAK commited on
Commit
b0796be
·
verified ·
1 Parent(s): 44519b1

added indexing for 1-2 documents at a time from cranfield and a viewing feature

Browse files
Files changed (1) hide show
  1. app.py +202 -59
app.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
  from tqdm.auto import tqdm
6
  import os
7
  import ir_datasets
 
8
 
9
  # --- Model Loading (Keep as is) ---
10
  tokenizer_splade = None
@@ -47,49 +48,77 @@ except Exception as e:
47
  print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
48
 
49
 
50
- # --- Global Variables for Document Index ---
51
  document_representations = {} # Stores {doc_id: sparse_vector}
52
  document_texts = {} # Stores {doc_id: doc_text}
 
 
53
  initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index
54
 
55
 
56
- # --- Load Cranfield Corpus using ir_datasets ---
57
- # Renamed function for clarity, but kept original name for call consistency
58
  def load_cranfield_corpus_ir_datasets():
59
- global document_texts
60
- print("Loading Cranfield corpus using ir_datasets...")
61
  try:
62
- # --- IMPORTANT CHANGE: Loading 'cranfield' dataset ---
63
  dataset = ir_datasets.load("cranfield")
 
 
64
  for doc in tqdm(dataset.docs_iter(), desc="Loading Cranfield documents"):
65
  document_texts[doc.doc_id] = doc.text.strip()
66
  print(f"Loaded {len(document_texts)} documents from Cranfield corpus.")
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  except Exception as e:
68
  print(f"Error loading Cranfield corpus with ir_datasets: {e}")
69
  print("Please ensure 'ir_datasets' is installed and your internet connection is stable.")
70
 
71
 
72
- # --- Helper function for lexical mask (Keep as is) ---
73
- def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
74
- bow_mask = torch.zeros(vocab_size, device=input_ids.device)
75
- meaningful_token_ids = []
76
- for token_id in input_ids.squeeze().tolist():
77
- if token_id not in [
78
- tokenizer.pad_token_id,
79
- tokenizer.cls_token_id,
80
- tokenizer.sep_token_id,
81
- tokenizer.mask_token_id,
82
- tokenizer.unk_token_id
83
- ]:
84
- meaningful_token_ids.append(token_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- if meaningful_token_ids:
87
- bow_mask[list(set(meaningful_token_ids))] = 1
88
-
89
- return bow_mask.unsqueeze(0)
90
 
91
 
92
  # --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
 
93
  def get_splade_cocondenser_representation(text):
94
  if tokenizer_splade is None or model_splade is None:
95
  return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
@@ -104,7 +133,7 @@ def get_splade_cocondenser_representation(text):
104
  splade_vector = torch.max(
105
  torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
106
  dim=1
107
- )[0].squeeze()
108
  else:
109
  return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found."
110
 
@@ -151,15 +180,16 @@ def get_splade_lexical_representation(text):
151
  splade_vector = torch.max(
152
  torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
153
  dim=1
154
- )[0].squeeze()
155
  else:
156
  return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found."
157
 
158
  # Always apply lexical mask for this model's specific behavior
159
  vocab_size = tokenizer_splade_lexical.vocab_size
 
160
  bow_mask = create_lexical_bow_mask(
161
  inputs['input_ids'], vocab_size, tokenizer_splade_lexical
162
- ).squeeze()
163
  splade_vector = splade_vector * bow_mask
164
 
165
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
@@ -202,12 +232,13 @@ def get_splade_doc_representation(text):
202
  output = model_splade_doc(**inputs)
203
 
204
  if not hasattr(output, "logits"):
205
- return "SPLADE-v3-Doc model output structure not as expected. 'logits' not found."
206
 
207
  vocab_size = tokenizer_splade_doc.vocab_size
 
208
  binary_splade_vector = create_lexical_bow_mask(
209
  inputs['input_ids'], vocab_size, tokenizer_splade_doc
210
- ).squeeze()
211
 
212
  indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
213
  if not isinstance(indices, list):
@@ -253,44 +284,75 @@ def predict_representation_explorer(model_choice, text):
253
  return "Please select a model."
254
 
255
 
256
- # --- Internal Core Representation Functions (Return Raw Vectors - for Retrieval Tab) ---
257
- def get_splade_cocondenser_representation_internal(text, tokenizer, model):
 
 
 
 
 
 
 
258
  if tokenizer is None or model is None: return None
259
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
260
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
261
- with torch.no_grad(): output = model(**inputs)
 
 
 
262
  if hasattr(output, 'logits'):
263
- splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
264
- return splade_vector
 
 
 
 
 
265
  else:
266
  print("Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.")
267
  return None
268
 
269
- def get_splade_lexical_representation_internal(text, tokenizer, model):
 
 
 
 
 
 
 
270
  if tokenizer is None or model is None: return None
271
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
272
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
273
  with torch.no_grad(): output = model(**inputs)
274
  if hasattr(output, 'logits'):
275
- splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
276
  vocab_size = tokenizer.vocab_size
277
- bow_mask = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer).squeeze()
278
- splade_vector = splade_vector * bow_mask
279
- return splade_vector
 
280
  else:
281
  print("Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.")
282
  return None
283
 
284
- def get_splade_doc_representation_internal(text, tokenizer, model):
 
 
 
 
 
 
 
285
  if tokenizer is None or model is None: return None
286
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
287
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
288
  vocab_size = tokenizer.vocab_size
289
- binary_splade_vector = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer).squeeze()
290
- return binary_splade_vector
 
291
 
292
 
293
- # --- Document Indexing Function (for Retrieval Tab) ---
294
  def index_documents(doc_model_choice):
295
  global document_representations
296
  if document_representations:
@@ -328,14 +390,28 @@ def index_documents(doc_model_choice):
328
 
329
  print(f"Indexing documents using {doc_model_choice}...")
330
 
331
- doc_items = list(document_texts.items())
 
332
 
333
- for doc_id, doc_text in tqdm(doc_items, desc="Indexing Documents"):
334
- sparse_vector = representation_func_to_use(doc_text, tokenizer_to_use, model_to_use)
335
- if sparse_vector is not None:
336
- document_representations[doc_id] = sparse_vector.cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  else:
338
- print(f"Warning: Failed to get representation for doc_id {doc_id}")
339
 
340
  print(f"Finished indexing {len(document_representations)} documents.")
341
  return True
@@ -349,25 +425,27 @@ def retrieve_documents(query_text, query_model_choice, indexed_doc_model_name, t
349
  query_tokenizer = None
350
  query_model = None
351
 
 
352
  if query_model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
353
  query_tokenizer = tokenizer_splade
354
  query_model = model_splade
355
- query_vector = get_splade_cocondenser_representation_internal(query_text, query_tokenizer, query_model)
356
  elif query_model_choice == "SPLADE-v3-Lexical (weighting)":
357
  query_tokenizer = tokenizer_splade_lexical
358
  query_model = model_splade_lexical
359
- query_vector = get_splade_lexical_representation_internal(query_text, query_tokenizer, query_model)
360
  elif query_model_choice == "SPLADE-v3-Doc (binary)":
361
  query_tokenizer = tokenizer_splade_doc
362
  query_model = model_splade_doc
363
- query_vector = get_splade_doc_representation_internal(query_text, query_tokenizer, query_model)
364
  else:
365
  return "Invalid query model choice.", []
366
 
367
  if query_vector is None:
368
  return "Failed to get query representation. Check console for model loading errors.", []
369
 
370
- query_vector = query_vector.cpu()
 
371
 
372
  scores = {}
373
  for doc_id, doc_vec in document_representations.items():
@@ -396,9 +474,64 @@ def predict_retrieval_gradio(query_text, query_model_choice, selected_doc_model_
396
  formatted_output, _ = retrieve_documents(query_text, query_model_choice, initial_doc_model_for_indexing, top_k=5)
397
  return formatted_output
398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  # --- Initial Load and Indexing Calls ---
400
  # This part runs once when the app starts.
401
- # --- IMPORTANT CHANGE: Calling the function that loads Cranfield ---
402
  load_cranfield_corpus_ir_datasets()
403
 
404
  if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None:
@@ -443,7 +576,7 @@ with gr.Blocks(title="SPLADE Demos") as demo:
443
  )
444
 
445
  with gr.TabItem("Document Retrieval Demo"):
446
- gr.Markdown("### Retrieve Documents from Cranfield Collection") # Changed title
447
  gr.Interface(
448
  fn=predict_retrieval_gradio,
449
  inputs=[
@@ -476,5 +609,15 @@ with gr.Blocks(title="SPLADE Demos") as demo:
476
  allow_flagging="never",
477
  # live=True # retrieval is too heavy for live
478
  )
 
 
 
 
 
 
 
 
 
 
479
 
480
- demo.launch()
 
5
  from tqdm.auto import tqdm
6
  import os
7
  import ir_datasets
8
+ import random # Added for random selection
9
 
10
  # --- Model Loading (Keep as is) ---
11
  tokenizer_splade = None
 
48
  print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
49
 
50
 
51
+ # --- Global Variables for Document Index and Qrels ---
52
  document_representations = {} # Stores {doc_id: sparse_vector}
53
  document_texts = {} # Stores {doc_id: doc_text}
54
+ queries_texts = {} # Stores {query_id: query_text}
55
+ qrels_data = {} # Stores {query_id: [{doc_id: str, relevance: int}, ...]}
56
  initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index
57
 
58
 
59
+ # --- Load Cranfield Corpus, Queries, and Qrels using ir_datasets ---
 
60
  def load_cranfield_corpus_ir_datasets():
61
+ global document_texts, queries_texts, qrels_data
62
+ print("Loading Cranfield corpus, queries, and qrels using ir_datasets...")
63
  try:
 
64
  dataset = ir_datasets.load("cranfield")
65
+
66
+ # Load documents
67
  for doc in tqdm(dataset.docs_iter(), desc="Loading Cranfield documents"):
68
  document_texts[doc.doc_id] = doc.text.strip()
69
  print(f"Loaded {len(document_texts)} documents from Cranfield corpus.")
70
+
71
+ # Load queries
72
+ for query in tqdm(dataset.queries_iter(), desc="Loading Cranfield queries"):
73
+ queries_texts[query.query_id] = query.text.strip()
74
+ print(f"Loaded {len(queries_texts)} queries from Cranfield corpus.")
75
+
76
+ # Load qrels
77
+ for qrel in tqdm(dataset.qrels_iter(), desc="Loading Cranfield qrels"):
78
+ if qrel.query_id not in qrels_data:
79
+ qrels_data[qrel.query_id] = []
80
+ qrels_data[qrel.query_id].append({"doc_id": qrel.doc_id, "relevance": qrel.relevance})
81
+ print(f"Loaded qrels for {len(qrels_data)} queries.")
82
+
83
  except Exception as e:
84
  print(f"Error loading Cranfield corpus with ir_datasets: {e}")
85
  print("Please ensure 'ir_datasets' is installed and your internet connection is stable.")
86
 
87
 
88
+ # --- Helper function for lexical mask (now handles batches) ---
89
+ def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
90
+ """
91
+ Creates a batch of lexical BOW masks.
92
+ input_ids_batch: torch.Tensor of shape (batch_size, sequence_length)
93
+ vocab_size: int, size of the tokenizer vocabulary
94
+ tokenizer: the tokenizer object
95
+ Returns: torch.Tensor of shape (batch_size, vocab_size)
96
+ """
97
+ batch_size = input_ids_batch.shape[0]
98
+ bow_masks = torch.zeros(batch_size, vocab_size, device=input_ids_batch.device)
99
+
100
+ for i in range(batch_size):
101
+ input_ids = input_ids_batch[i] # Get input_ids for the current item in the batch
102
+ meaningful_token_ids = []
103
+ for token_id in input_ids.tolist():
104
+ if token_id not in [
105
+ tokenizer.pad_token_id,
106
+ tokenizer.cls_token_id,
107
+ tokenizer.sep_token_id,
108
+ tokenizer.mask_token_id,
109
+ tokenizer.unk_token_id
110
+ ]:
111
+ meaningful_token_ids.append(token_id)
112
+
113
+ if meaningful_token_ids:
114
+ # Apply mask to the current row in the batch
115
+ bow_masks[i, list(set(meaningful_token_ids))] = 1
116
 
117
+ return bow_masks
 
 
 
118
 
119
 
120
  # --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
121
+ # These functions still take single text input for the Explorer tab
122
  def get_splade_cocondenser_representation(text):
123
  if tokenizer_splade is None or model_splade is None:
124
  return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
 
133
  splade_vector = torch.max(
134
  torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
135
  dim=1
136
+ )[0].squeeze() # Squeeze is fine here as it's a single input
137
  else:
138
  return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found."
139
 
 
180
  splade_vector = torch.max(
181
  torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
182
  dim=1
183
+ )[0].squeeze() # Squeeze is fine here
184
  else:
185
  return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found."
186
 
187
  # Always apply lexical mask for this model's specific behavior
188
  vocab_size = tokenizer_splade_lexical.vocab_size
189
+ # Call with unsqueezed input_ids for single sample processing
190
  bow_mask = create_lexical_bow_mask(
191
  inputs['input_ids'], vocab_size, tokenizer_splade_lexical
192
+ ).squeeze() # Squeeze back for single output
193
  splade_vector = splade_vector * bow_mask
194
 
195
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
 
232
  output = model_splade_doc(**inputs)
233
 
234
  if not hasattr(output, "logits"):
235
+ return "Model output structure not as expected. 'logits' not found."
236
 
237
  vocab_size = tokenizer_splade_doc.vocab_size
238
+ # Call with unsqueezed input_ids for single sample processing
239
  binary_splade_vector = create_lexical_bow_mask(
240
  inputs['input_ids'], vocab_size, tokenizer_splade_doc
241
+ ).squeeze() # Squeeze back for single output
242
 
243
  indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
244
  if not isinstance(indices, list):
 
284
  return "Please select a model."
285
 
286
 
287
+ # --- Internal Core Representation Functions (now handle batches) ---
288
+ def get_splade_cocondenser_representation_internal(texts, tokenizer, model):
289
+ """
290
+ Generates SPLADE representations for a batch of texts.
291
+ texts: list of strings
292
+ tokenizer: the tokenizer object
293
+ model: the SPLADE model
294
+ Returns: torch.Tensor of shape (batch_size, vocab_size) or None
295
+ """
296
  if tokenizer is None or model is None: return None
297
+ inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
298
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
299
+
300
+ with torch.no_grad():
301
+ output = model(**inputs)
302
+
303
  if hasattr(output, 'logits'):
304
+ # torch.max(..., dim=1)[0] reduces along sequence_length dimension,
305
+ # resulting in (batch_size, vocab_size)
306
+ splade_vectors = torch.max(
307
+ torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
308
+ dim=1
309
+ )[0]
310
+ return splade_vectors
311
  else:
312
  print("Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.")
313
  return None
314
 
315
+ def get_splade_lexical_representation_internal(texts, tokenizer, model):
316
+ """
317
+ Generates SPLADE-Lexical representations for a batch of texts.
318
+ texts: list of strings
319
+ tokenizer: the tokenizer object
320
+ model: the SPLADE-Lexical model
321
+ Returns: torch.Tensor of shape (batch_size, vocab_size) or None
322
+ """
323
  if tokenizer is None or model is None: return None
324
+ inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
325
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
326
  with torch.no_grad(): output = model(**inputs)
327
  if hasattr(output, 'logits'):
328
+ splade_vectors = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0]
329
  vocab_size = tokenizer.vocab_size
330
+ # create_lexical_bow_mask now returns (batch_size, vocab_size)
331
+ bow_masks = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer)
332
+ splade_vectors = splade_vectors * bow_masks # Element-wise multiplication, shapes (batch_size, vocab_size)
333
+ return splade_vectors
334
  else:
335
  print("Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.")
336
  return None
337
 
338
+ def get_splade_doc_representation_internal(texts, tokenizer, model):
339
+ """
340
+ Generates SPLADE-Doc (binary) representations for a batch of texts.
341
+ texts: list of strings
342
+ tokenizer: the tokenizer object
343
+ model: the SPLADE-Doc model (not directly used for logits, but for device)
344
+ Returns: torch.Tensor of shape (batch_size, vocab_size) or None
345
+ """
346
  if tokenizer is None or model is None: return None
347
+ inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
348
+ inputs = {k: v.to(model.device) for k, v in inputs.items()} # Ensure inputs are on the correct device
349
  vocab_size = tokenizer.vocab_size
350
+ # create_lexical_bow_mask now returns (batch_size, vocab_size)
351
+ binary_splade_vectors = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer)
352
+ return binary_splade_vectors
353
 
354
 
355
+ # --- Document Indexing Function (now uses batching) ---
356
  def index_documents(doc_model_choice):
357
  global document_representations
358
  if document_representations:
 
390
 
391
  print(f"Indexing documents using {doc_model_choice}...")
392
 
393
+ doc_ids_list = list(document_texts.keys())
394
+ doc_texts_list = list(document_texts.values())
395
 
396
+ # --- BATCH SIZE FOR INDEXING ---
397
+ batch_size = 32 # You can adjust this value based on memory and performance
398
+
399
+ document_representations = {} # Ensure it's clear we're (re)building the index
400
+
401
+ # Iterate through documents in batches
402
+ for i in tqdm(range(0, len(doc_ids_list), batch_size), desc="Indexing Documents in Batches"):
403
+ batch_doc_ids = doc_ids_list[i:i + batch_size]
404
+ batch_doc_texts = doc_texts_list[i:i + batch_size]
405
+
406
+ sparse_vectors_batch = representation_func_to_use(batch_doc_texts, tokenizer_to_use, model_to_use)
407
+
408
+ if sparse_vectors_batch is not None:
409
+ # sparse_vectors_batch will have shape (batch_size, vocab_size)
410
+ for j, doc_id in enumerate(batch_doc_ids):
411
+ # Store each document's vector
412
+ document_representations[doc_id] = sparse_vectors_batch[j].cpu()
413
  else:
414
+ print(f"Warning: Failed to get representation for a batch starting with doc_id {batch_doc_ids[0]}")
415
 
416
  print(f"Finished indexing {len(document_representations)} documents.")
417
  return True
 
425
  query_tokenizer = None
426
  query_model = None
427
 
428
+ # These internal calls still use single text input for the query
429
  if query_model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
430
  query_tokenizer = tokenizer_splade
431
  query_model = model_splade
432
+ query_vector = get_splade_cocondenser_representation_internal([query_text], query_tokenizer, query_model)
433
  elif query_model_choice == "SPLADE-v3-Lexical (weighting)":
434
  query_tokenizer = tokenizer_splade_lexical
435
  query_model = model_splade_lexical
436
+ query_vector = get_splade_lexical_representation_internal([query_text], query_tokenizer, query_model)
437
  elif query_model_choice == "SPLADE-v3-Doc (binary)":
438
  query_tokenizer = tokenizer_splade_doc
439
  query_model = model_splade_doc
440
+ query_vector = get_splade_doc_representation_internal([query_text], query_tokenizer, query_model)
441
  else:
442
  return "Invalid query model choice.", []
443
 
444
  if query_vector is None:
445
  return "Failed to get query representation. Check console for model loading errors.", []
446
 
447
+ # Since internal functions now return batches, take the first (and only) item for single query
448
+ query_vector = query_vector.squeeze(0).cpu()
449
 
450
  scores = {}
451
  for doc_id, doc_vec in document_representations.items():
 
474
  formatted_output, _ = retrieve_documents(query_text, query_model_choice, initial_doc_model_for_indexing, top_k=5)
475
  return formatted_output
476
 
477
+ # --- New function to get specific retrieval examples ---
478
+ def get_specific_retrieval_examples():
479
+ if not queries_texts or not qrels_data or not document_texts:
480
+ return "Queries, qrels, or documents not loaded. Please check initial loading."
481
+
482
+ high_qrel_threshold = 3 # Relevance score of 3 or 4 for Cranfield is generally considered high
483
+ low_qrel_threshold = 1 # Relevance score of 0 or 1 for Cranfield is generally considered low
484
+
485
+ eligible_query_ids = []
486
+ for qid, qrels in qrels_data.items():
487
+ has_high_qrel = any(item['relevance'] >= high_qrel_threshold for item in qrels)
488
+ has_low_qrel = any(item['relevance'] <= low_qrel_threshold for item in qrels)
489
+ if has_high_qrel and has_low_qrel:
490
+ eligible_query_ids.append(qid)
491
+
492
+ if not eligible_query_ids:
493
+ return "Could not find a query with both high and low relevance documents in the loaded qrels."
494
+
495
+ # Pick a random eligible query
496
+ random_query_id = random.choice(eligible_query_ids)
497
+ full_query_text = queries_texts.get(random_query_id, "Query text not found.")
498
+ query_snippet = full_query_text[:300] + "..." if len(full_query_text) > 300 else full_query_text
499
+
500
+ qrels_for_query = qrels_data[random_query_id]
501
+
502
+ high_qrel_docs = [item for item in qrels_for_query if item['relevance'] >= high_qrel_threshold]
503
+ low_qrel_docs = [item for item in qrels_for_query if item['relevance'] <= low_qrel_threshold]
504
+
505
+ selected_high_doc_id = random.choice(high_qrel_docs)['doc_id'] if high_qrel_docs else None
506
+ selected_low_doc_id = random.choice(low_qrel_docs)['doc_id'] if low_qrel_docs else None
507
+
508
+ output_str = f"### Random Query Example\n\n"
509
+ output_str += f"**Query ID:** {random_query_id}\n"
510
+ output_str += f"**Query Snippet:** {query_snippet}\n\n" # Changed to snippet
511
+
512
+ if selected_high_doc_id:
513
+ full_doc_text = document_texts.get(selected_high_doc_id, "Document text not available.")
514
+ doc_snippet = full_doc_text[:500] + "..." if len(full_doc_text) > 500 else full_doc_text
515
+ output_str += f"### Highly Relevant Document (Qrel >= {high_qrel_threshold})\n"
516
+ output_str += f"**Document ID:** {selected_high_doc_id}\n"
517
+ output_str += f"**Document Snippet:** {doc_snippet}\n\n" # Changed to snippet
518
+ else:
519
+ output_str += "No highly relevant document found for this query.\n\n"
520
+
521
+ if selected_low_doc_id:
522
+ full_doc_text = document_texts.get(selected_low_doc_id, "Document text not available.")
523
+ doc_snippet = full_doc_text[:500] + "..." if len(full_doc_text) > 500 else full_doc_text
524
+ output_str += f"### Lowly Relevant Document (Qrel <= {low_qrel_threshold})\n"
525
+ output_str += f"**Document ID:** {selected_low_doc_id}\n"
526
+ output_str += f"**Document Snippet:** {doc_snippet}\n\n" # Changed to snippet
527
+ else:
528
+ output_str += "No lowly relevant document found for this query.\n\n"
529
+
530
+ return output_str
531
+
532
+
533
  # --- Initial Load and Indexing Calls ---
534
  # This part runs once when the app starts.
 
535
  load_cranfield_corpus_ir_datasets()
536
 
537
  if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None:
 
576
  )
577
 
578
  with gr.TabItem("Document Retrieval Demo"):
579
+ gr.Markdown("### Retrieve Documents from Cranfield Collection")
580
  gr.Interface(
581
  fn=predict_retrieval_gradio,
582
  inputs=[
 
609
  allow_flagging="never",
610
  # live=True # retrieval is too heavy for live
611
  )
612
+
613
+ gr.Markdown("---") # Separator
614
+ gr.Markdown("### Get Specific Retrieval Examples")
615
+ specific_example_output = gr.Markdown()
616
+ specific_example_button = gr.Button("Get Random Query with High/Low Qrel Docs")
617
+ specific_example_button.click(
618
+ fn=get_specific_retrieval_examples,
619
+ inputs=[],
620
+ outputs=specific_example_output
621
+ )
622
 
623
+ demo.launch()