#!/usr/bin/env python # coding: utf-8 # In[1]: import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import numpy as np import torch.nn.functional as F from torchvision.models import resnet50, ResNet50_Weights # In[4]: class MalariaResNet50(nn.Module): def __init__(self, num_classes=2): super(MalariaResNet50, self).__init__() # Load pretrained ResNet50 self.backbone = models.resnet50(weights=ResNet50_Weights.DEFAULT) # Replace final fully connected layer for binary classification num_ftrs = self.backbone.fc.in_features self.backbone.fc = nn.Linear(num_ftrs, num_classes) def forward(self, x): return self.backbone(x) def predict(self, image_path, device='cpu', show_image=False): """ Predict class of a single image. Args: image_path (str): Path to input image device (torch.device): 'cuda' or 'cpu' show_image (bool): Whether to display the image Returns: pred_label (str): "Infected" or "Uninfected" confidence (float): Confidence score (softmax output) """ from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load and preprocess image img = Image.open(image_path).convert('RGB') img_tensor = transform(img).unsqueeze(0).to(device) # Inference self.eval() with torch.no_grad(): output = self(img_tensor) probs = F.softmax(output, dim=1) _, preds = torch.max(output, 1) pred_idx = preds.item() confidence = probs[0][pred_idx].item() classes = ['Uninfected', 'Infected'] pred_label = classes[pred_idx] if show_image: plt.imshow(img) plt.title(f"Predicted: {pred_label} ({confidence:.2%})") plt.axis("off") plt.show() return pred_label, confidence def save(self, path): """Save model state dict""" torch.save(self.state_dict(), path) print(f"Model saved to {path}") def load(self, path): """Load model state dict from file""" state_dict = torch.load(path, map_location=torch.device('cpu')) self.load_state_dict(state_dict) print(f"Model loaded from {path}") # In[ ]: