Spaces:
Running
Running
File size: 3,147 Bytes
3a0cc4c 83635a5 3a0cc4c 83635a5 3a0cc4c 83635a5 3a0cc4c 83635a5 3a0cc4c 83635a5 3a0cc4c 83635a5 3a0cc4c 83635a5 3a0cc4c 83635a5 3a0cc4c 83635a5 3a0cc4c 83635a5 3a0cc4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
import pandas as pd
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModel
import joblib
from bs4 import BeautifulSoup
import re
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
# Setup
model_path = "." # All files are in the root directory
tokenizer = AutoTokenizer.from_pretrained(model_path)
product_encoder = joblib.load("category_encoder.pkl")
base_model_name = "DataScienceWFSR/bert-food-product-category-cw"
stop_words = set(stopwords.words('english'))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Clean text function
def clean_text(text):
text = text.lower()
text = BeautifulSoup(text, "html.parser").get_text()
text = re.sub(r"http\S+", "", text)
text = re.sub(r"[^\w\s]", "", text)
tokens = text.split()
tokens = [w for w in tokens if w not in stop_words]
return " ".join(tokens)
# Template function
def template_(day, month, year, country, title, text):
return f"Date: day {day}, month {month}, year {year}. Country: {country}. Title: {title}. Text: {text}"
# Model definition
class ProductCategoryClassifier(nn.Module):
def __init__(self, model_name, num_categories):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
self.dropout = nn.Dropout(0.4)
hidden_size = self.bert.config.hidden_size
self.classifier = nn.Linear(hidden_size, num_categories)
def forward(self, input_ids, attention_mask=None):
output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_token = self.dropout(output.last_hidden_state[:, 0, :])
logits = self.classifier(cls_token)
return logits
# Load model
num_categories = len(product_encoder.classes_)
model = ProductCategoryClassifier(model_name=base_model_name, num_categories=num_categories).to(device)
model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
model.eval()
# Inference function
def predict_category(day, month, year, title, text, country="Unknown"):
title_clean = clean_text(title)
text_clean = clean_text(text)
input_text = template_(day, month, year, country, title_clean, text_clean)
inputs = tokenizer([input_text], padding=True, truncation=True, max_length=512, return_tensors="pt")
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask)
pred = torch.argmax(logits, dim=1).cpu().numpy()[0]
category = product_encoder.inverse_transform([pred])[0]
return category
# Gradio interface
iface = gr.Interface(
fn=predict_category,
inputs=[
gr.Number(label="Day"),
gr.Number(label="Month"),
gr.Number(label="Year"),
gr.Textbox(label="Title"),
gr.Textbox(label="Text", lines=5),
],
outputs="text",
title="Product Category Predictor",
description="Enter date and text details to predict the product category.",
)
# Run the app
if __name__ == "__main__":
iface.launch()
|