Sid26Roy commited on
Commit
83635a5
·
verified ·
1 Parent(s): c1de0ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -10,10 +10,10 @@ import nltk
10
  nltk.download('stopwords')
11
  from nltk.corpus import stopwords
12
 
13
- # Load artifacts
14
- model_path = "./best_product_model"
15
  tokenizer = AutoTokenizer.from_pretrained(model_path)
16
- product_encoder = joblib.load(f"{model_path}/product_encoder.pkl")
17
  base_model_name = "DataScienceWFSR/bert-food-product-category-cw"
18
  stop_words = set(stopwords.words('english'))
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -32,7 +32,7 @@ def clean_text(text):
32
  def template_(day, month, year, country, title, text):
33
  return f"Date: day {day}, month {month}, year {year}. Country: {country}. Title: {title}. Text: {text}"
34
 
35
- # Model class
36
  class ProductCategoryClassifier(nn.Module):
37
  def __init__(self, model_name, num_categories):
38
  super().__init__()
@@ -50,7 +50,7 @@ class ProductCategoryClassifier(nn.Module):
50
  # Load model
51
  num_categories = len(product_encoder.classes_)
52
  model = ProductCategoryClassifier(model_name=base_model_name, num_categories=num_categories).to(device)
53
- model.load_state_dict(torch.load(f"{model_path}/pytorch_model.bin", map_location=device))
54
  model.eval()
55
 
56
  # Inference function
@@ -58,19 +58,19 @@ def predict_category(day, month, year, title, text, country="Unknown"):
58
  title_clean = clean_text(title)
59
  text_clean = clean_text(text)
60
  input_text = template_(day, month, year, country, title_clean, text_clean)
61
-
62
  inputs = tokenizer([input_text], padding=True, truncation=True, max_length=512, return_tensors="pt")
63
  input_ids = inputs['input_ids'].to(device)
64
  attention_mask = inputs['attention_mask'].to(device)
65
-
66
  with torch.no_grad():
67
  logits = model(input_ids=input_ids, attention_mask=attention_mask)
68
  pred = torch.argmax(logits, dim=1).cpu().numpy()[0]
69
-
70
  category = product_encoder.inverse_transform([pred])[0]
71
  return category
72
 
73
- # Gradio Interface
74
  iface = gr.Interface(
75
  fn=predict_category,
76
  inputs=[
@@ -82,9 +82,9 @@ iface = gr.Interface(
82
  ],
83
  outputs="text",
84
  title="Product Category Predictor",
85
- description="Enter the date and hazard text details to get the predicted product category."
86
  )
87
 
88
- # Launch app
89
  if __name__ == "__main__":
90
  iface.launch()
 
10
  nltk.download('stopwords')
11
  from nltk.corpus import stopwords
12
 
13
+ # Setup
14
+ model_path = "." # All files are in the root directory
15
  tokenizer = AutoTokenizer.from_pretrained(model_path)
16
+ product_encoder = joblib.load("category_encoder.pkl")
17
  base_model_name = "DataScienceWFSR/bert-food-product-category-cw"
18
  stop_words = set(stopwords.words('english'))
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
32
  def template_(day, month, year, country, title, text):
33
  return f"Date: day {day}, month {month}, year {year}. Country: {country}. Title: {title}. Text: {text}"
34
 
35
+ # Model definition
36
  class ProductCategoryClassifier(nn.Module):
37
  def __init__(self, model_name, num_categories):
38
  super().__init__()
 
50
  # Load model
51
  num_categories = len(product_encoder.classes_)
52
  model = ProductCategoryClassifier(model_name=base_model_name, num_categories=num_categories).to(device)
53
+ model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
54
  model.eval()
55
 
56
  # Inference function
 
58
  title_clean = clean_text(title)
59
  text_clean = clean_text(text)
60
  input_text = template_(day, month, year, country, title_clean, text_clean)
61
+
62
  inputs = tokenizer([input_text], padding=True, truncation=True, max_length=512, return_tensors="pt")
63
  input_ids = inputs['input_ids'].to(device)
64
  attention_mask = inputs['attention_mask'].to(device)
65
+
66
  with torch.no_grad():
67
  logits = model(input_ids=input_ids, attention_mask=attention_mask)
68
  pred = torch.argmax(logits, dim=1).cpu().numpy()[0]
69
+
70
  category = product_encoder.inverse_transform([pred])[0]
71
  return category
72
 
73
+ # Gradio interface
74
  iface = gr.Interface(
75
  fn=predict_category,
76
  inputs=[
 
82
  ],
83
  outputs="text",
84
  title="Product Category Predictor",
85
+ description="Enter date and text details to predict the product category.",
86
  )
87
 
88
+ # Run the app
89
  if __name__ == "__main__":
90
  iface.launch()