ItsNikolor commited on
Commit
0efe602
·
verified ·
1 Parent(s): 45b90e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -24
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  import torch
3
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
@@ -8,10 +9,10 @@ def combine_title_summary(title, summary):
8
 
9
 
10
  tag2ind = {
11
- "bio": 0,
12
- "physics": 1,
13
- "math": 2,
14
- "cs": 3,
15
  }
16
 
17
 
@@ -19,12 +20,10 @@ tag2ind = {
19
  def load_model():
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
- # assert torch.cuda.is_available()
23
- save_dir = "./distilbert/distilbert-base-cased/checkpoint-738"
24
- tokenizer = AutoTokenizer.from_pretrained(save_dir)
25
- model = AutoModelForSequenceClassification.from_pretrained(
26
- save_dir
27
- ).to(device)
28
 
29
  return tokenizer, model
30
 
@@ -48,20 +47,36 @@ def run_model(model, tokenizer, title, summary):
48
  out = model(**tokens_info)
49
  probs = torch.nn.functional.softmax(out.logits, dim=-1)[0]
50
 
51
- result = f"Text: `{text}`\nPrediction (prob): \n" + "\n".join(
52
- [f"{tag}={tag_prob}" for tag, tag_prob in zip(tag2ind, probs)]
53
- )
54
- return result
 
55
 
 
 
56
 
57
- title = st.text_input(label="Title", value="")
58
- abstract = st.text_input(label="Abstract", value="")
59
- if st.button("Submit"):
60
- if title == "" and abstract == "":
61
- st.error("At least one of title or abstract must be provided")
62
- else:
63
- result = combine_title_summary(title, abstract)
64
- st.success(result)
65
 
66
- result = run_model(model, tokenizer, title, abstract)
67
- st.success(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
  import streamlit as st
3
  import torch
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
9
 
10
 
11
  tag2ind = {
12
+ "Biology": 0,
13
+ "Physics": 1,
14
+ "Math": 2,
15
+ "Computer Science": 3,
16
  }
17
 
18
 
 
20
  def load_model():
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
 
23
+ # dir_name = "./distilbert/distilbert-base-cased/checkpoint-738"
24
+ dir_name = "./microsoft/deberta-v3-small/checkpoint-4915"
25
+ tokenizer = AutoTokenizer.from_pretrained(dir_name, use_fast=False)
26
+ model = AutoModelForSequenceClassification.from_pretrained(dir_name).to(device)
 
 
27
 
28
  return tokenizer, model
29
 
 
47
  out = model(**tokens_info)
48
  probs = torch.nn.functional.softmax(out.logits, dim=-1)[0]
49
 
50
+ ids = torch.argsort(probs, descending=True)
51
+ p = 0
52
+ best_tags, best_probs = [], []
53
+ for ind in ids:
54
+ p += probs[ind]
55
 
56
+ best_tags.append(list(tag2ind.keys())[ind])
57
+ best_probs.append(probs[ind])
58
 
59
+ if p >= 0.95:
60
+ break
 
 
 
 
 
 
61
 
62
+ return best_tags, best_probs
63
+
64
+
65
+ def main():
66
+ title = st.text_input(label="Title", value="")
67
+ abstract = st.text_area(label="Abstract", value="", height=200)
68
+ if st.button("Classify"):
69
+ if title == "" and abstract == "":
70
+ st.error("At least one of title or abstract must be provided")
71
+ else:
72
+ best_tags, best_probs = run_model(model, tokenizer, title, abstract)
73
+
74
+ df = pd.DataFrame(
75
+ dict(zip(best_tags, best_probs)).items(),
76
+ columns=["Theme", "Probability"],
77
+ )
78
+ st.table(df)
79
+
80
+
81
+ if __name__ == "__main__":
82
+ main()