Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import torch | |
from torch import nn | |
from transformers import AutoTokenizer, AutoModel | |
import joblib | |
from bs4 import BeautifulSoup | |
import re | |
import nltk | |
nltk.download('stopwords') | |
from nltk.corpus import stopwords | |
# Setup | |
model_path = "." # All files are in the root directory | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
product_encoder = joblib.load("category_encoder.pkl") | |
base_model_name = "DataScienceWFSR/bert-food-product-category-cw" | |
stop_words = set(stopwords.words('english')) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Clean text function | |
def clean_text(text): | |
text = text.lower() | |
text = BeautifulSoup(text, "html.parser").get_text() | |
text = re.sub(r"http\S+", "", text) | |
text = re.sub(r"[^\w\s]", "", text) | |
tokens = text.split() | |
tokens = [w for w in tokens if w not in stop_words] | |
return " ".join(tokens) | |
# Template function | |
def template_(day, month, year, country, title, text): | |
return f"Date: day {day}, month {month}, year {year}. Country: {country}. Title: {title}. Text: {text}" | |
# Model definition | |
class ProductCategoryClassifier(nn.Module): | |
def __init__(self, model_name, num_categories): | |
super().__init__() | |
self.bert = AutoModel.from_pretrained(model_name) | |
self.dropout = nn.Dropout(0.4) | |
hidden_size = self.bert.config.hidden_size | |
self.classifier = nn.Linear(hidden_size, num_categories) | |
def forward(self, input_ids, attention_mask=None): | |
output = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
cls_token = self.dropout(output.last_hidden_state[:, 0, :]) | |
logits = self.classifier(cls_token) | |
return logits | |
# Load model | |
num_categories = len(product_encoder.classes_) | |
model = ProductCategoryClassifier(model_name=base_model_name, num_categories=num_categories).to(device) | |
model.load_state_dict(torch.load("pytorch_model.bin", map_location=device)) | |
model.eval() | |
# Inference function | |
def predict_category(day, month, year, title, text, country="Unknown"): | |
title_clean = clean_text(title) | |
text_clean = clean_text(text) | |
input_text = template_(day, month, year, country, title_clean, text_clean) | |
inputs = tokenizer([input_text], padding=True, truncation=True, max_length=512, return_tensors="pt") | |
input_ids = inputs['input_ids'].to(device) | |
attention_mask = inputs['attention_mask'].to(device) | |
with torch.no_grad(): | |
logits = model(input_ids=input_ids, attention_mask=attention_mask) | |
pred = torch.argmax(logits, dim=1).cpu().numpy()[0] | |
category = product_encoder.inverse_transform([pred])[0] | |
return category | |
# Gradio interface | |
iface = gr.Interface( | |
fn=predict_category, | |
inputs=[ | |
gr.Number(label="Day"), | |
gr.Number(label="Month"), | |
gr.Number(label="Year"), | |
gr.Textbox(label="Title"), | |
gr.Textbox(label="Text", lines=5), | |
], | |
outputs="text", | |
title="Product Category Predictor", | |
description="Enter date and text details to predict the product category.", | |
) | |
# Run the app | |
if __name__ == "__main__": | |
iface.launch() | |