File size: 2,629 Bytes
faf90bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/env python
# coding: utf-8

# In[1]:


import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights


# In[4]:


class MalariaResNet50(nn.Module):
    def __init__(self, num_classes=2):
        super(MalariaResNet50, self).__init__()
        # Load pretrained ResNet50
        self.backbone = models.resnet50(weights=ResNet50_Weights.DEFAULT)

        # Replace final fully connected layer for binary classification
        num_ftrs = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.backbone(x)

    def predict(self, image_path, device='cpu', show_image=False):
        """
        Predict class of a single image.

        Args:
            image_path (str): Path to input image
            device (torch.device): 'cuda' or 'cpu'
            show_image (bool): Whether to display the image

        Returns:
            pred_label (str): "Infected" or "Uninfected"
            confidence (float): Confidence score (softmax output)
        """
        from torchvision import transforms
        from PIL import Image
        import matplotlib.pyplot as plt

        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        # Load and preprocess image
        img = Image.open(image_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0).to(device)

        # Inference
        self.eval()
        with torch.no_grad():
            output = self(img_tensor)
            probs = F.softmax(output, dim=1)
            _, preds = torch.max(output, 1)

        pred_idx = preds.item()
        confidence = probs[0][pred_idx].item()

        classes = ['Uninfected', 'Infected']
        pred_label = classes[pred_idx]

        if show_image:
            plt.imshow(img)
            plt.title(f"Predicted: {pred_label} ({confidence:.2%})")
            plt.axis("off")
            plt.show()

        return pred_label, confidence

    def save(self, path):
        """Save model state dict"""
        torch.save(self.state_dict(), path)
        print(f"Model saved to {path}")

    def load(self, path):
        """Load model state dict from file"""
        state_dict = torch.load(path, map_location=torch.device('cpu'))
        self.load_state_dict(state_dict)
        print(f"Model loaded from {path}")


# In[ ]: