bobs24 commited on
Commit
5303063
·
verified ·
1 Parent(s): ef9b1e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pickle
4
+ import joblib
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ import gradio as gr
8
+ from transformers import AutoModelForImageClassification
9
+ from torch import nn
10
+ from torchvision import transforms
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ # Paths in Hugging Face model repository
14
+ MODEL_PATH = "DeiT_Model_Parameter.pth"
15
+ ENCODER_PATH = "label_encoder.pkl"
16
+
17
+ # Ensure device is set
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ def load_label_encoder():
21
+ # Load label encoder from Hugging Face repository
22
+ label_encoder_path = hf_hub_download(repo_id="bobs24/DeiT-Classification-Apparel", filename=ENCODER_PATH)
23
+ label_encoder = joblib.load(label_encoder_path)
24
+ return label_encoder
25
+
26
+ # Define the model class
27
+ class CustomModel(nn.Module):
28
+ def __init__(self, num_classes):
29
+ super(CustomModel, self).__init__()
30
+ self.base_model = AutoModelForImageClassification.from_pretrained(
31
+ "facebook/deit-base-patch16-224",
32
+ num_labels=num_classes,
33
+ ignore_mismatched_sizes=True
34
+ )
35
+
36
+ def forward(self, x):
37
+ return self.base_model(x).logits
38
+
39
+ def load_model():
40
+ # Load the model from Hugging Face repository
41
+ model_path = hf_hub_download(repo_id="bobs24/DeiT-Classification-Apparel", filename=MODEL_PATH)
42
+ label_encoder = load_label_encoder()
43
+ model = CustomModel(num_classes=len(label_encoder.classes_)).to(device)
44
+ model.load_state_dict(torch.load(model_path, map_location=device))
45
+ model.device = device
46
+ model.eval()
47
+
48
+ return model, label_encoder
49
+
50
+ # Load the model and label encoder
51
+ model, label_encoder = load_model()
52
+
53
+ # Preprocessing as per your training setup
54
+ preprocess = transforms.Compose([
55
+ transforms.Resize(256), # Resize to 256x256 (a bit larger than 224)
56
+ transforms.CenterCrop(224), # Crop the center to 224x224
57
+ transforms.ToTensor(), # Convert to tensor
58
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) # Normalize as per DeiT
59
+ ])
60
+
61
+ # Function to perform predictions and show probabilities
62
+ def predict(image):
63
+ # Apply preprocessing to the input image
64
+ image = Image.fromarray(image).convert("RGB")
65
+ input_tensor = preprocess(image).unsqueeze(0).to(device)
66
+
67
+ # Perform inference
68
+ with torch.no_grad():
69
+ output = model(input_tensor)
70
+
71
+ # Apply softmax to get probabilities
72
+ probabilities = F.softmax(output, dim=1)
73
+
74
+ # Get the predicted label and confidence
75
+ predicted_label = torch.argmax(probabilities, dim=1).item()
76
+ confidence = probabilities[0, predicted_label].item()
77
+
78
+ # Get the class name using label encoder
79
+ class_name = label_encoder.inverse_transform([predicted_label])[0]
80
+
81
+ return f"Predicted class: {class_name}, Confidence: {confidence:.4f}"
82
+
83
+ # Create Gradio interface
84
+ iface = gr.Interface(fn=predict, inputs=gr.Image(type="numpy"), outputs="text", live=True)
85
+
86
+ # Launch the interface
87
+ iface.launch()