import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image from ResNet_for_CC import CC_model # Import the model # Set device (CPU/GPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the trained CC_model model_path = "CC_net.pt" model = CC_model(num_classes1=14) # Load model weights state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict, strict=False) model.to(device) model.eval() # Clothing1M Class Labels class_labels = [ "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", "Vest", "Underwear" ] # ✅ **Updated Image Preprocessing Function** def preprocess_image(image): """Applies necessary transformations to the input image.""" transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(image).unsqueeze(0).to(device) # ✅ **Classification Function** def classify_image(image): """Processes the input image and returns the predicted clothing category.""" print("\n[INFO] Received image for classification.") try: image = Image.fromarray(image) # Ensure conversion to PIL format image = preprocess_image(image) # Apply transformations print("[INFO] Image transformed and moved to device.") with torch.no_grad(): output = model(image) # ✅ Ensure output is a tensor (handle tuple case) if isinstance(output, tuple): output = output[1] # Extract the actual output tensor print(f"[DEBUG] Model output shape: {output.shape}") print(f"[DEBUG] Model output values: {output}") if output.shape[1] != 14: return f"[ERROR] Model output mismatch! Expected 14 but got {output.shape[1]}." # Convert logits to probabilities probabilities = F.softmax(output, dim=1) print(f"[DEBUG] Softmax probabilities: {probabilities}") # Get predicted class index predicted_class = torch.argmax(probabilities, dim=1).item() print(f"[INFO] Predicted class index: {predicted_class} (Class: {class_labels[predicted_class]})") # Validate and return the prediction if 0 <= predicted_class < len(class_labels): predicted_label = class_labels[predicted_class] confidence = probabilities[0][predicted_class].item() * 100 return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)" else: return "[ERROR] Model returned an invalid class index." except Exception as e: print(f"[ERROR] Exception during classification: {e}") return "Error in classification. Check console for details." # ✅ **Gradio Interface** interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="numpy"), outputs="text", title="Clothing1M Image Classifier", description="Upload a clothing image, and the model will classify it into one of the 14 categories." ) # ✅ **Run the Interface** if __name__ == "__main__": print("[INFO] Launching Gradio interface...") interface.launch()