Spaces:
Running
Running
File size: 4,223 Bytes
faf90bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "72678f69-46b9-4908-b301-85ad5d4a6055",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torchvision import models, transforms\n",
"from PIL import Image\n",
"import numpy as np\n",
"import torch.nn.functional as F\n",
"from torchvision.models import resnet50, ResNet50_Weights"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "88eb14fe-a198-4378-8817-13924bb328e3",
"metadata": {},
"outputs": [],
"source": [
"class MalariaResNet50(nn.Module):\n",
" def __init__(self, num_classes=2):\n",
" super(MalariaResNet50, self).__init__()\n",
" # Load pretrained ResNet50\n",
" self.backbone = models.resnet50(weights=ResNet50_Weights.DEFAULT)\n",
"\n",
" # Replace final fully connected layer for binary classification\n",
" num_ftrs = self.backbone.fc.in_features\n",
" self.backbone.fc = nn.Linear(num_ftrs, num_classes)\n",
"\n",
" def forward(self, x):\n",
" return self.backbone(x)\n",
"\n",
" def predict(self, image_path, device='cpu', show_image=False):\n",
" \"\"\"\n",
" Predict class of a single image.\n",
"\n",
" Args:\n",
" image_path (str): Path to input image\n",
" device (torch.device): 'cuda' or 'cpu'\n",
" show_image (bool): Whether to display the image\n",
"\n",
" Returns:\n",
" pred_label (str): \"Infected\" or \"Uninfected\"\n",
" confidence (float): Confidence score (softmax output)\n",
" \"\"\"\n",
" from torchvision import transforms\n",
" from PIL import Image\n",
" import matplotlib.pyplot as plt\n",
"\n",
" transform = transforms.Compose([\n",
" transforms.Resize((224, 224)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
" ])\n",
"\n",
" # Load and preprocess image\n",
" img = Image.open(image_path).convert('RGB')\n",
" img_tensor = transform(img).unsqueeze(0).to(device)\n",
"\n",
" # Inference\n",
" self.eval()\n",
" with torch.no_grad():\n",
" output = self(img_tensor)\n",
" probs = F.softmax(output, dim=1)\n",
" _, preds = torch.max(output, 1)\n",
"\n",
" pred_idx = preds.item()\n",
" confidence = probs[0][pred_idx].item()\n",
"\n",
" classes = ['Uninfected', 'Infected']\n",
" pred_label = classes[pred_idx]\n",
"\n",
" if show_image:\n",
" plt.imshow(img)\n",
" plt.title(f\"Predicted: {pred_label} ({confidence:.2%})\")\n",
" plt.axis(\"off\")\n",
" plt.show()\n",
"\n",
" return pred_label, confidence\n",
"\n",
" def save(self, path):\n",
" \"\"\"Save model state dict\"\"\"\n",
" torch.save(self.state_dict(), path)\n",
" print(f\"Model saved to {path}\")\n",
"\n",
" def load(self, path):\n",
" \"\"\"Load model state dict from file\"\"\"\n",
" state_dict = torch.load(path, map_location=torch.device('cpu'))\n",
" self.load_state_dict(state_dict)\n",
" print(f\"Model loaded from {path}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "70b8f814-f126-4a12-afe8-051b9b9d4c2a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.17"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|