import gradio as gr import torch import torchvision.transforms as transforms from PIL import Image from resnet import SupCEResNet # Ensure the correct import path # ✅ Define class labels (from Clothing1M) class_labels = [ "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", "Vest", "Underwear" ] # ✅ Function to load the model def create_model_selfsup(net='resnet50', num_class=14, checkpoint_path='/content/ckpt_clothing_resnet50.pth'): """Loads a self-supervised pretrained model for Clothing1M classification""" print(f"🔄 Loading model from: {checkpoint_path}") # Load the checkpoint safely checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu", weights_only=False) # Remove 'module.' prefix if using DataParallel state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()} # Initialize and load model model = SupCEResNet(net, num_classes=num_class, pool=True) model.load_state_dict(state_dict, strict=False) # Move model to GPU if available model = model.to("cuda" if torch.cuda.is_available() else "cpu") model.eval() # Set model to evaluation mode print("✅ Model loaded successfully!") return model # ✅ Load the model once model = create_model_selfsup() # ✅ Define image preprocessing function def preprocess_image(image): """Transforms input image for the model""" transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") # ✅ Define inference function def predict_clothing(image): """Runs inference on an uploaded image""" image = Image.fromarray(image) # Convert numpy array to PIL Image image = preprocess_image(image) # Preprocess image with torch.no_grad(): output = model(image) predicted_class = torch.argmax(output, dim=1).item() # Get class index return class_labels[predicted_class] # Return class name # ✅ Create Gradio Interface gr.Interface( fn=predict_clothing, inputs=gr.Image(type="numpy"), outputs=gr.Textbox(label="Predicted Clothing Type"), title="Clothing1M Classification", description="Upload an image to classify clothing into one of 14 categories." ).launch()