miki5799 commited on
Commit
d8d5586
·
1 Parent(s): 3c1805a

Refactor app.py for improved readability and organization; rearranged imports, added spacing, and formatted code blocks.

Browse files
Files changed (1) hide show
  1. app.py +78 -28
app.py CHANGED
@@ -1,12 +1,15 @@
1
- from dataclasses import dataclass
2
- import pickle
3
  import os
4
- from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
5
- from nlp4web_codebase.ir.data_loaders.dm import Document
6
- from collections import Counter
7
- import tqdm
8
  import re
 
 
 
 
9
  import nltk
 
 
 
 
10
  nltk.download("stopwords", quiet=True)
11
  from nltk.corpus import stopwords as nltk_stopwords
12
 
@@ -18,22 +21,30 @@ stopwords = set(nltk_stopwords.words(LANGUAGE))
18
  def word_splitting(text: str) -> List[str]:
19
  return word_splitter(text.lower())
20
 
 
21
  def lemmatization(words: List[str]) -> List[str]:
22
  return words # We ignore lemmatization here for simplicity
23
 
 
24
  def simple_tokenize(text: str) -> List[str]:
25
  words = word_splitting(text)
26
  tokenized = list(filter(lambda w: w not in stopwords, words))
27
  tokenized = lemmatization(tokenized)
28
  return tokenized
29
 
 
30
  T = TypeVar("T", bound="InvertedIndex")
31
 
 
32
  @dataclass
33
  class PostingList:
34
  term: str # The term
35
- docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
36
- tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
 
 
 
 
37
 
38
 
39
  @dataclass
@@ -72,6 +83,7 @@ class Counting:
72
  nterms: int
73
  doc_texts: Optional[List[str]] = None
74
 
 
75
  def run_counting(
76
  documents: Iterable[Document],
77
  tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
@@ -131,22 +143,23 @@ def run_counting(
131
  doc_texts=doc_texts,
132
  )
133
 
 
134
  from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
 
135
  sciq = load_sciq()
136
  counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
137
 
138
  from __future__ import annotations
139
- from dataclasses import asdict, dataclass
140
  import math
141
- import os
142
  from typing import Iterable, List, Optional, Type
143
- import tqdm
144
  from nlp4web_codebase.ir.data_loaders.dm import Document
145
 
146
 
147
  @dataclass
148
  class BM25Index(InvertedIndex):
149
-
150
  @staticmethod
151
  def tokenize(text: str) -> List[str]:
152
  return simple_tokenize(text)
@@ -230,6 +243,7 @@ class BM25Index(InvertedIndex):
230
  )
231
  return index
232
 
 
233
  bm25_index = BM25Index.build_from_documents(
234
  documents=iter(sciq.corpus),
235
  ndocs=12160,
@@ -237,13 +251,13 @@ bm25_index = BM25Index.build_from_documents(
237
  )
238
  bm25_index.save("output/bm25_index")
239
 
240
- from nlp4web_codebase.ir.models import BaseRetriever
241
- from typing import Type
242
  from abc import abstractmethod
 
243
 
 
244
 
245
- class BaseInvertedIndexRetriever(BaseRetriever):
246
 
 
247
  @property
248
  @abstractmethod
249
  def index_class(self) -> Type[InvertedIndex]:
@@ -295,16 +309,48 @@ class BaseInvertedIndexRetriever(BaseRetriever):
295
 
296
 
297
  class BM25Retriever(BaseInvertedIndexRetriever):
298
-
299
  @property
300
  def index_class(self) -> Type[BM25Index]:
301
  return BM25Index
302
 
 
303
  bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
304
- bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")
 
 
305
 
306
- plots_b = {'X': [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 'Y': [0.694980045351474, 0.8126195011337869, 0.821528798185941, 0.8218562358276644, 0.8222244897959182, 0.8195024943310657, 0.8182163265306123, 0.8174734693877551, 0.8139020408163266, 0.8116893424036281, 0.8083002267573697]} #TODO: Replace
307
- plots_k1 = {'X': [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 'Y': [0.7345419501133786, 0.7668607709750567, 0.779508843537415, 0.7900947845804988, 0.8015931972789115, 0.8103560090702948, 0.812374149659864, 0.8156743764172336, 0.8194036281179138, 0.8222244897959182, 0.8221800453514739]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
  best_b = plots_b["X"][np.argmax(plots_b["Y"])]
310
  best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])]
@@ -313,23 +359,26 @@ bm25_index = BM25Index.build_from_documents(
313
  ndocs=12160,
314
  show_progress_bar=True,
315
  k1=best_k1,
316
- b=best_b
317
  )
318
 
319
- import gradio as gr
320
  from typing import TypedDict
321
- import pandas as pd
 
 
322
 
323
  class Hit(TypedDict):
324
- cid: str
325
- score: float
326
- text: str
 
327
 
328
  demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
329
  return_type = List[Hit]
330
 
 
331
  ## YOUR_CODE_STARTS_HERE
332
- def retrieve(query: str, topk: int=10) -> return_type:
333
  ranking = bm25_retriever.retrieve(query=query, topk=3)
334
  hits = []
335
  for cid, score in ranking.items():
@@ -337,6 +386,7 @@ def retrieve(query: str, topk: int=10) -> return_type:
337
  hits.append({"cid": cid, "score": score, "text": text})
338
  return hits
339
 
 
340
  demo = gr.Interface(
341
  fn=retrieve,
342
  inputs=gr.Textbox(lines=3, placeholder="Enter your query here..."),
@@ -347,7 +397,7 @@ demo = gr.Interface(
347
  ["What are the differences between immunodeficiency and autoimmune diseases?"],
348
  ["What are the causes of immunodeficiency?"],
349
  ["What are the symptoms of immunodeficiency?"],
350
- ]
351
  )
352
  ## YOUR_CODE_ENDS_HERE
353
- demo.launch()
 
 
 
1
  import os
2
+ import pickle
 
 
 
3
  import re
4
+ from collections import Counter
5
+ from dataclasses import dataclass
6
+ from typing import Callable, Dict, Iterable, List, Optional, Type, TypeVar
7
+
8
  import nltk
9
+ import tqdm
10
+
11
+ from nlp4web_codebase.ir.data_loaders.dm import Document
12
+
13
  nltk.download("stopwords", quiet=True)
14
  from nltk.corpus import stopwords as nltk_stopwords
15
 
 
21
  def word_splitting(text: str) -> List[str]:
22
  return word_splitter(text.lower())
23
 
24
+
25
  def lemmatization(words: List[str]) -> List[str]:
26
  return words # We ignore lemmatization here for simplicity
27
 
28
+
29
  def simple_tokenize(text: str) -> List[str]:
30
  words = word_splitting(text)
31
  tokenized = list(filter(lambda w: w not in stopwords, words))
32
  tokenized = lemmatization(tokenized)
33
  return tokenized
34
 
35
+
36
  T = TypeVar("T", bound="InvertedIndex")
37
 
38
+
39
  @dataclass
40
  class PostingList:
41
  term: str # The term
42
+ docid_postings: List[
43
+ int
44
+ ] # docid_postings[i] means the docid (int) of the i-th associated posting
45
+ tweight_postings: List[
46
+ float
47
+ ] # tweight_postings[i] means the term weight (float) of the i-th associated posting
48
 
49
 
50
  @dataclass
 
83
  nterms: int
84
  doc_texts: Optional[List[str]] = None
85
 
86
+
87
  def run_counting(
88
  documents: Iterable[Document],
89
  tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
 
143
  doc_texts=doc_texts,
144
  )
145
 
146
+
147
  from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
148
+
149
  sciq = load_sciq()
150
  counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
151
 
152
  from __future__ import annotations
153
+
154
  import math
155
+ from dataclasses import dataclass
156
  from typing import Iterable, List, Optional, Type
157
+
158
  from nlp4web_codebase.ir.data_loaders.dm import Document
159
 
160
 
161
  @dataclass
162
  class BM25Index(InvertedIndex):
 
163
  @staticmethod
164
  def tokenize(text: str) -> List[str]:
165
  return simple_tokenize(text)
 
243
  )
244
  return index
245
 
246
+
247
  bm25_index = BM25Index.build_from_documents(
248
  documents=iter(sciq.corpus),
249
  ndocs=12160,
 
251
  )
252
  bm25_index.save("output/bm25_index")
253
 
 
 
254
  from abc import abstractmethod
255
+ from typing import Type
256
 
257
+ from nlp4web_codebase.ir.models import BaseRetriever
258
 
 
259
 
260
+ class BaseInvertedIndexRetriever(BaseRetriever):
261
  @property
262
  @abstractmethod
263
  def index_class(self) -> Type[InvertedIndex]:
 
309
 
310
 
311
  class BM25Retriever(BaseInvertedIndexRetriever):
 
312
  @property
313
  def index_class(self) -> Type[BM25Index]:
314
  return BM25Index
315
 
316
+
317
  bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
318
+ bm25_retriever.retrieve(
319
+ "What type of diseases occur when the immune system attacks normal body cells?"
320
+ )
321
 
322
+ plots_b = {
323
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
324
+ "Y": [
325
+ 0.694980045351474,
326
+ 0.8126195011337869,
327
+ 0.821528798185941,
328
+ 0.8218562358276644,
329
+ 0.8222244897959182,
330
+ 0.8195024943310657,
331
+ 0.8182163265306123,
332
+ 0.8174734693877551,
333
+ 0.8139020408163266,
334
+ 0.8116893424036281,
335
+ 0.8083002267573697,
336
+ ],
337
+ } # TODO: Replace
338
+ plots_k1 = {
339
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
340
+ "Y": [
341
+ 0.7345419501133786,
342
+ 0.7668607709750567,
343
+ 0.779508843537415,
344
+ 0.7900947845804988,
345
+ 0.8015931972789115,
346
+ 0.8103560090702948,
347
+ 0.812374149659864,
348
+ 0.8156743764172336,
349
+ 0.8194036281179138,
350
+ 0.8222244897959182,
351
+ 0.8221800453514739,
352
+ ],
353
+ }
354
 
355
  best_b = plots_b["X"][np.argmax(plots_b["Y"])]
356
  best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])]
 
359
  ndocs=12160,
360
  show_progress_bar=True,
361
  k1=best_k1,
362
+ b=best_b,
363
  )
364
 
 
365
  from typing import TypedDict
366
+
367
+ import gradio as gr
368
+
369
 
370
  class Hit(TypedDict):
371
+ cid: str
372
+ score: float
373
+ text: str
374
+
375
 
376
  demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
377
  return_type = List[Hit]
378
 
379
+
380
  ## YOUR_CODE_STARTS_HERE
381
+ def retrieve(query: str, topk: int = 10) -> return_type:
382
  ranking = bm25_retriever.retrieve(query=query, topk=3)
383
  hits = []
384
  for cid, score in ranking.items():
 
386
  hits.append({"cid": cid, "score": score, "text": text})
387
  return hits
388
 
389
+
390
  demo = gr.Interface(
391
  fn=retrieve,
392
  inputs=gr.Textbox(lines=3, placeholder="Enter your query here..."),
 
397
  ["What are the differences between immunodeficiency and autoimmune diseases?"],
398
  ["What are the causes of immunodeficiency?"],
399
  ["What are the symptoms of immunodeficiency?"],
400
+ ],
401
  )
402
  ## YOUR_CODE_ENDS_HERE
403
+ demo.launch()