Spaces:
Runtime error
Runtime error
Commit
·
bbf7aff
1
Parent(s):
fd97dee
inference code
Browse files
app.py
CHANGED
@@ -1,7 +1,11 @@
|
|
1 |
import streamlit as st
|
2 |
-
|
|
|
3 |
from datasets import load_dataset
|
4 |
|
|
|
|
|
|
|
5 |
# load the dataset and
|
6 |
# use the patent number, abstract and claim columns for UI
|
7 |
with st.spinner("Setting up the app..."):
|
@@ -16,11 +20,6 @@ with st.spinner("Setting up the app..."):
|
|
16 |
val_filing_end_date="2016-01-31",
|
17 |
)
|
18 |
|
19 |
-
# widget for selecting our finetuned langugae model
|
20 |
-
language_model_path = "juliaannjose/finetuned_model"
|
21 |
-
|
22 |
-
# pass the model to transformers pipeline - model selection component.
|
23 |
-
classifier_model = pipeline(model=language_model_path)
|
24 |
|
25 |
# drop down menu with patent numbers
|
26 |
_patent_id = st.selectbox(
|
@@ -28,19 +27,39 @@ _patent_id = st.selectbox(
|
|
28 |
dataset_dict["train"]["patent_number"],
|
29 |
)
|
30 |
|
|
|
31 |
# display abstract and claim
|
32 |
@st.cache(persist=True)
|
33 |
def get_abs_claim(_patent_id):
|
34 |
# get abstract and claim corresponding to this patent id
|
35 |
_abstract = dataset_dict["train"][["patent_number"] == _patent_id]["abstract"]
|
36 |
_claim = dataset_dict["train"][["patent_number"] == _patent_id]["claims"]
|
37 |
-
return _abstract,_claim
|
|
|
38 |
|
39 |
-
_abstract,_claim = get_abs_claim(_patent_id)
|
40 |
st.write(_abstract)
|
41 |
st.write(_claim)
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
# when submit button clicked, run the model and get result
|
44 |
if st.button("Submit"):
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
4 |
from datasets import load_dataset
|
5 |
|
6 |
+
# finetuned model
|
7 |
+
language_model_path = "juliaannjose/finetuned_model"
|
8 |
+
|
9 |
# load the dataset and
|
10 |
# use the patent number, abstract and claim columns for UI
|
11 |
with st.spinner("Setting up the app..."):
|
|
|
20 |
val_filing_end_date="2016-01-31",
|
21 |
)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
# drop down menu with patent numbers
|
25 |
_patent_id = st.selectbox(
|
|
|
27 |
dataset_dict["train"]["patent_number"],
|
28 |
)
|
29 |
|
30 |
+
|
31 |
# display abstract and claim
|
32 |
@st.cache(persist=True)
|
33 |
def get_abs_claim(_patent_id):
|
34 |
# get abstract and claim corresponding to this patent id
|
35 |
_abstract = dataset_dict["train"][["patent_number"] == _patent_id]["abstract"]
|
36 |
_claim = dataset_dict["train"][["patent_number"] == _patent_id]["claims"]
|
37 |
+
return _abstract, _claim
|
38 |
+
|
39 |
|
40 |
+
_abstract, _claim = get_abs_claim(_patent_id)
|
41 |
st.write(_abstract)
|
42 |
st.write(_claim)
|
43 |
|
44 |
+
input_text = _abstract + _claim
|
45 |
+
|
46 |
+
# model and tokenizer initialization
|
47 |
+
tokenizer = AutoTokenizer.from_pretrained(language_model_path)
|
48 |
+
inputs = tokenizer(
|
49 |
+
input_text,
|
50 |
+
truncation=True,
|
51 |
+
padding=True,
|
52 |
+
return_tensors="pt",
|
53 |
+
)
|
54 |
+
model = AutoModelForSequenceClassification.from_pretrained(language_model_path)
|
55 |
+
|
56 |
+
# get predictions
|
57 |
+
id2label = {0: "REJECTED", 1: "ACCEPTED"}
|
58 |
# when submit button clicked, run the model and get result
|
59 |
if st.button("Submit"):
|
60 |
+
with torch.no_grad():
|
61 |
+
logits = model(**inputs).logits
|
62 |
+
|
63 |
+
predicted_class_id = logits.argmax().item()
|
64 |
+
pred_label = id2label[predicted_class_id]
|
65 |
+
st.write(pred_label)
|