{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "72678f69-46b9-4908-b301-85ad5d4a6055", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torchvision import models, transforms\n", "from PIL import Image\n", "import numpy as np\n", "import torch.nn.functional as F\n", "from torchvision.models import resnet50, ResNet50_Weights" ] }, { "cell_type": "code", "execution_count": 4, "id": "88eb14fe-a198-4378-8817-13924bb328e3", "metadata": {}, "outputs": [], "source": [ "class MalariaResNet50(nn.Module):\n", " def __init__(self, num_classes=2):\n", " super(MalariaResNet50, self).__init__()\n", " # Load pretrained ResNet50\n", " self.backbone = models.resnet50(weights=ResNet50_Weights.DEFAULT)\n", "\n", " # Replace final fully connected layer for binary classification\n", " num_ftrs = self.backbone.fc.in_features\n", " self.backbone.fc = nn.Linear(num_ftrs, num_classes)\n", "\n", " def forward(self, x):\n", " return self.backbone(x)\n", "\n", " def predict(self, image_path, device='cpu', show_image=False):\n", " \"\"\"\n", " Predict class of a single image.\n", "\n", " Args:\n", " image_path (str): Path to input image\n", " device (torch.device): 'cuda' or 'cpu'\n", " show_image (bool): Whether to display the image\n", "\n", " Returns:\n", " pred_label (str): \"Infected\" or \"Uninfected\"\n", " confidence (float): Confidence score (softmax output)\n", " \"\"\"\n", " from torchvision import transforms\n", " from PIL import Image\n", " import matplotlib.pyplot as plt\n", "\n", " transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", " ])\n", "\n", " # Load and preprocess image\n", " img = Image.open(image_path).convert('RGB')\n", " img_tensor = transform(img).unsqueeze(0).to(device)\n", "\n", " # Inference\n", " self.eval()\n", " with torch.no_grad():\n", " output = self(img_tensor)\n", " probs = F.softmax(output, dim=1)\n", " _, preds = torch.max(output, 1)\n", "\n", " pred_idx = preds.item()\n", " confidence = probs[0][pred_idx].item()\n", "\n", " classes = ['Uninfected', 'Infected']\n", " pred_label = classes[pred_idx]\n", "\n", " if show_image:\n", " plt.imshow(img)\n", " plt.title(f\"Predicted: {pred_label} ({confidence:.2%})\")\n", " plt.axis(\"off\")\n", " plt.show()\n", "\n", " return pred_label, confidence\n", "\n", " def save(self, path):\n", " \"\"\"Save model state dict\"\"\"\n", " torch.save(self.state_dict(), path)\n", " print(f\"Model saved to {path}\")\n", "\n", " def load(self, path):\n", " \"\"\"Load model state dict from file\"\"\"\n", " state_dict = torch.load(path, map_location=torch.device('cpu'))\n", " self.load_state_dict(state_dict)\n", " print(f\"Model loaded from {path}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "70b8f814-f126-4a12-afe8-051b9b9d4c2a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.17" } }, "nbformat": 4, "nbformat_minor": 5 }