Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -10,10 +10,10 @@ import nltk
|
|
10 |
nltk.download('stopwords')
|
11 |
from nltk.corpus import stopwords
|
12 |
|
13 |
-
#
|
14 |
-
model_path = "
|
15 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
16 |
-
product_encoder = joblib.load(
|
17 |
base_model_name = "DataScienceWFSR/bert-food-product-category-cw"
|
18 |
stop_words = set(stopwords.words('english'))
|
19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -32,7 +32,7 @@ def clean_text(text):
|
|
32 |
def template_(day, month, year, country, title, text):
|
33 |
return f"Date: day {day}, month {month}, year {year}. Country: {country}. Title: {title}. Text: {text}"
|
34 |
|
35 |
-
# Model
|
36 |
class ProductCategoryClassifier(nn.Module):
|
37 |
def __init__(self, model_name, num_categories):
|
38 |
super().__init__()
|
@@ -50,7 +50,7 @@ class ProductCategoryClassifier(nn.Module):
|
|
50 |
# Load model
|
51 |
num_categories = len(product_encoder.classes_)
|
52 |
model = ProductCategoryClassifier(model_name=base_model_name, num_categories=num_categories).to(device)
|
53 |
-
model.load_state_dict(torch.load(
|
54 |
model.eval()
|
55 |
|
56 |
# Inference function
|
@@ -58,19 +58,19 @@ def predict_category(day, month, year, title, text, country="Unknown"):
|
|
58 |
title_clean = clean_text(title)
|
59 |
text_clean = clean_text(text)
|
60 |
input_text = template_(day, month, year, country, title_clean, text_clean)
|
61 |
-
|
62 |
inputs = tokenizer([input_text], padding=True, truncation=True, max_length=512, return_tensors="pt")
|
63 |
input_ids = inputs['input_ids'].to(device)
|
64 |
attention_mask = inputs['attention_mask'].to(device)
|
65 |
-
|
66 |
with torch.no_grad():
|
67 |
logits = model(input_ids=input_ids, attention_mask=attention_mask)
|
68 |
pred = torch.argmax(logits, dim=1).cpu().numpy()[0]
|
69 |
-
|
70 |
category = product_encoder.inverse_transform([pred])[0]
|
71 |
return category
|
72 |
|
73 |
-
# Gradio
|
74 |
iface = gr.Interface(
|
75 |
fn=predict_category,
|
76 |
inputs=[
|
@@ -82,9 +82,9 @@ iface = gr.Interface(
|
|
82 |
],
|
83 |
outputs="text",
|
84 |
title="Product Category Predictor",
|
85 |
-
description="Enter
|
86 |
)
|
87 |
|
88 |
-
#
|
89 |
if __name__ == "__main__":
|
90 |
iface.launch()
|
|
|
10 |
nltk.download('stopwords')
|
11 |
from nltk.corpus import stopwords
|
12 |
|
13 |
+
# Setup
|
14 |
+
model_path = "." # All files are in the root directory
|
15 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
16 |
+
product_encoder = joblib.load("category_encoder.pkl")
|
17 |
base_model_name = "DataScienceWFSR/bert-food-product-category-cw"
|
18 |
stop_words = set(stopwords.words('english'))
|
19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
32 |
def template_(day, month, year, country, title, text):
|
33 |
return f"Date: day {day}, month {month}, year {year}. Country: {country}. Title: {title}. Text: {text}"
|
34 |
|
35 |
+
# Model definition
|
36 |
class ProductCategoryClassifier(nn.Module):
|
37 |
def __init__(self, model_name, num_categories):
|
38 |
super().__init__()
|
|
|
50 |
# Load model
|
51 |
num_categories = len(product_encoder.classes_)
|
52 |
model = ProductCategoryClassifier(model_name=base_model_name, num_categories=num_categories).to(device)
|
53 |
+
model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
|
54 |
model.eval()
|
55 |
|
56 |
# Inference function
|
|
|
58 |
title_clean = clean_text(title)
|
59 |
text_clean = clean_text(text)
|
60 |
input_text = template_(day, month, year, country, title_clean, text_clean)
|
61 |
+
|
62 |
inputs = tokenizer([input_text], padding=True, truncation=True, max_length=512, return_tensors="pt")
|
63 |
input_ids = inputs['input_ids'].to(device)
|
64 |
attention_mask = inputs['attention_mask'].to(device)
|
65 |
+
|
66 |
with torch.no_grad():
|
67 |
logits = model(input_ids=input_ids, attention_mask=attention_mask)
|
68 |
pred = torch.argmax(logits, dim=1).cpu().numpy()[0]
|
69 |
+
|
70 |
category = product_encoder.inverse_transform([pred])[0]
|
71 |
return category
|
72 |
|
73 |
+
# Gradio interface
|
74 |
iface = gr.Interface(
|
75 |
fn=predict_category,
|
76 |
inputs=[
|
|
|
82 |
],
|
83 |
outputs="text",
|
84 |
title="Product Category Predictor",
|
85 |
+
description="Enter date and text details to predict the product category.",
|
86 |
)
|
87 |
|
88 |
+
# Run the app
|
89 |
if __name__ == "__main__":
|
90 |
iface.launch()
|