Spaces:
Sleeping
Sleeping
Commit
·
46305a2
1
Parent(s):
d0c355b
refactor
Browse files- LICENSE +21 -0
- README.md +29 -13
- app.py +25 -14
- datasets/__init__.py +0 -1
- datasets/cifar10.py +0 -32
- datasets/generic.py +0 -111
- models/custom_resnet.py +2 -98
- utils/metrics.py +0 -19
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) [year] [fullname]
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,14 +1,30 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.39.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CIFAR10 demo with GradCAM
|
2 |
+
## How to Use the App
|
3 |
+
1. The app has two tabs:
|
4 |
+
- **Examples**: In this tab, you can upload your own 32x32 pixel image or choose an example image provided to classify and visualize the class activation maps using GradCAM. You can adjust the number of top predicted classes, show/hide the GradCAM overlay, select a target layer, and control the transparency of the overlay.
|
5 |
+
- **Misclassified Examples**: In this tab, the app displays a gallery of misclassified images from CIFAR10 test dataset. You can control the number of examples shown, show/hide the GradCAM overlay, select a target layer, and control the transparency of the overlay.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
2. **Examples Tab**:
|
8 |
+
- **Input Image**: Upload your own 32x32 pixel image or select one of the example images from the dropdown list.
|
9 |
+
- **Number of Top Classes**: Choose the number of top predicted classes to display along with their respective confidence scores.
|
10 |
+
- **Show GradCAM?**: Check this box to display the GradCAM overlay on the input image. Uncheck it to view only the original image.
|
11 |
+
- **Which Layer?**: Adjust the target layer for GradCAM visualization. The default value is -2.
|
12 |
+
- **Transparency**: Control the transparency of the GradCAM overlay. The default value is 0.7.
|
13 |
+
|
14 |
+
3. **Misclassified Examples Tab**:
|
15 |
+
- **No. of Missclassified Examples**: Control the number of misclassified examples displayed in the gallery. The default value is 20.
|
16 |
+
- **Show GradCAM?**: Check this box to display the GradCAM overlay on the misclassified images. Uncheck it to view only the original images.
|
17 |
+
- **Which Layer?**: Adjust the target layer for GradCAM visualization. The default value is -2.
|
18 |
+
- **Transparency**: Control the transparency of the GradCAM overlay. The default value is 0.7.
|
19 |
+
|
20 |
+
4. After adjusting the settings, click the "Submit" button to see the results.
|
21 |
+
|
22 |
+
## Credits
|
23 |
+
|
24 |
+
- This app is built using the Gradio library ([https://www.gradio.app/](https://www.gradio.app/)) for interactive model interfaces.
|
25 |
+
- The PyTorch library ([https://pytorch.org/](https://pytorch.org/)) is used for the deep learning model and GradCAM visualization.
|
26 |
+
- The CIFAR-10 dataset ([https://www.cs.toronto.edu/~kriz/cifar.html](https://www.cs.toronto.edu/~kriz/cifar.html)) is used for training and evaluation.
|
27 |
+
|
28 |
+
## License
|
29 |
+
|
30 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
app.py
CHANGED
@@ -11,18 +11,29 @@ from pytorch_grad_cam.utils.image import show_cam_on_image
|
|
11 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
12 |
|
13 |
from models.custom_resnet import Model
|
14 |
-
from datasets import CIFAR10
|
15 |
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
missed_df = pd.read_csv('S12_incorrect.csv')
|
20 |
-
missed_df['ground_truths'] = missed_df['ground_truths'].map(
|
21 |
-
missed_df['predicted_vals'] = missed_df['predicted_vals'].map(
|
22 |
missed_df = missed_df.sample(frac=1)
|
23 |
|
24 |
-
model = Model(
|
25 |
-
model.load_state_dict(torch.load('S12_model.pth', map_location=
|
26 |
model.eval()
|
27 |
|
28 |
transform = transforms.Compose([
|
@@ -34,20 +45,20 @@ inv_transform = transforms.Normalize(mean=[-2, -2, -2], std=[4, 4, 4])
|
|
34 |
|
35 |
|
36 |
def image_classifier(input_image, top_classes=3, show_cam=True, target_layer=-2, transparency=0.7):
|
37 |
-
|
38 |
|
39 |
-
output = model(
|
40 |
output = F.softmax(output.flatten(), dim=-1)
|
41 |
|
42 |
-
confidences = [(
|
43 |
confidences.sort(key=lambda x: x[1], reverse=True)
|
44 |
confidences = OrderedDict(confidences[:top_classes])
|
45 |
|
46 |
label = torch.argmax(output).item()
|
47 |
target_layer = [model.network[4 + target_layer]]
|
48 |
-
grad_cam = GradCAM(model=model, target_layers=target_layer, use_cuda=
|
49 |
targets = [ClassifierOutputTarget(label)]
|
50 |
-
grayscale_cam = grad_cam(input_tensor=
|
51 |
grayscale_cam = grayscale_cam[0, :]
|
52 |
output_image = show_cam_on_image(input_image / 255, grayscale_cam, use_rgb=True, image_weight=transparency)
|
53 |
|
@@ -63,9 +74,9 @@ demo1 = gr.Interface(
|
|
63 |
gr.Slider(-4, -1, value=-2, step=1, label="Which Layer?"),
|
64 |
gr.Slider(0, 1, value=0.7, label="Transparency", step=0.1)
|
65 |
],
|
66 |
-
outputs=[gr.Image(shape=(32, 32), label="Output Image"),
|
67 |
gr.Label(label='Top Classes')],
|
68 |
-
examples=[[f'examples/{k}.jpg'] for k in
|
69 |
)
|
70 |
|
71 |
|
|
|
11 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
12 |
|
13 |
from models.custom_resnet import Model
|
|
|
14 |
|
15 |
+
from utils import get_device
|
16 |
+
|
17 |
+
DEVICE = get_device()
|
18 |
+
|
19 |
+
classes = {0: 'airplane',
|
20 |
+
1: 'automobile',
|
21 |
+
2: 'bird',
|
22 |
+
3: 'cat',
|
23 |
+
4: 'deer',
|
24 |
+
5: 'dog',
|
25 |
+
6: 'frog',
|
26 |
+
7: 'horse',
|
27 |
+
8: 'ship',
|
28 |
+
9: 'truck'}
|
29 |
|
30 |
missed_df = pd.read_csv('S12_incorrect.csv')
|
31 |
+
missed_df['ground_truths'] = missed_df['ground_truths'].map(classes)
|
32 |
+
missed_df['predicted_vals'] = missed_df['predicted_vals'].map(classes)
|
33 |
missed_df = missed_df.sample(frac=1)
|
34 |
|
35 |
+
model = Model()
|
36 |
+
model.load_state_dict(torch.load('S12_model.pth', map_location=DEVICE), strict=False)
|
37 |
model.eval()
|
38 |
|
39 |
transform = transforms.Compose([
|
|
|
45 |
|
46 |
|
47 |
def image_classifier(input_image, top_classes=3, show_cam=True, target_layer=-2, transparency=0.7):
|
48 |
+
input_ = transform(input_image).unsqueeze(0)
|
49 |
|
50 |
+
output = model(input_)
|
51 |
output = F.softmax(output.flatten(), dim=-1)
|
52 |
|
53 |
+
confidences = [(classes[i], float(output[i])) for i in range(10)]
|
54 |
confidences.sort(key=lambda x: x[1], reverse=True)
|
55 |
confidences = OrderedDict(confidences[:top_classes])
|
56 |
|
57 |
label = torch.argmax(output).item()
|
58 |
target_layer = [model.network[4 + target_layer]]
|
59 |
+
grad_cam = GradCAM(model=model, target_layers=target_layer, use_cuda=(DEVICE == 'cuda'))
|
60 |
targets = [ClassifierOutputTarget(label)]
|
61 |
+
grayscale_cam = grad_cam(input_tensor=input_, targets=targets)
|
62 |
grayscale_cam = grayscale_cam[0, :]
|
63 |
output_image = show_cam_on_image(input_image / 255, grayscale_cam, use_rgb=True, image_weight=transparency)
|
64 |
|
|
|
74 |
gr.Slider(-4, -1, value=-2, step=1, label="Which Layer?"),
|
75 |
gr.Slider(0, 1, value=0.7, label="Transparency", step=0.1)
|
76 |
],
|
77 |
+
outputs=[gr.Image(shape=(32, 32), label="Output Image", height=256, width=256),
|
78 |
gr.Label(label='Top Classes')],
|
79 |
+
examples=[[f'examples/{k}.jpg'] for k in classes.values()]
|
80 |
)
|
81 |
|
82 |
|
datasets/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .cifar10 import CIFAR10
|
|
|
|
datasets/cifar10.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import cv2
|
3 |
-
from torchvision import datasets
|
4 |
-
import albumentations as A
|
5 |
-
|
6 |
-
from .generic import MyDataSet
|
7 |
-
|
8 |
-
|
9 |
-
class AlbCIFAR10(datasets.CIFAR10):
|
10 |
-
def __init__(self, root, alb_transform=None, **kwargs):
|
11 |
-
super(AlbCIFAR10, self).__init__(root, **kwargs)
|
12 |
-
self.alb_transform = alb_transform
|
13 |
-
|
14 |
-
def __getitem__(self, index):
|
15 |
-
image, label = super(AlbCIFAR10, self).__getitem__(index)
|
16 |
-
if self.alb_transform is not None:
|
17 |
-
image = self.alb_transform(image=np.array(image))['image']
|
18 |
-
return image, label
|
19 |
-
|
20 |
-
|
21 |
-
class CIFAR10(MyDataSet):
|
22 |
-
DataSet = AlbCIFAR10
|
23 |
-
mean = (0.49139968, 0.48215827, 0.44653124)
|
24 |
-
std = (0.24703233, 0.24348505, 0.26158768)
|
25 |
-
default_alb_transforms = [
|
26 |
-
A.ToGray(p=0.2),
|
27 |
-
A.PadIfNeeded(40, 40, p=1),
|
28 |
-
A.RandomCrop(32, 32, p=1),
|
29 |
-
A.HorizontalFlip(p=0.5),
|
30 |
-
# Since normalisation was the first step, mean is already 0, so cutout fill_value = 0
|
31 |
-
A.CoarseDropout(max_holes=1, max_height=8, max_width=8, fill_value=0, p=1)
|
32 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/generic.py
DELETED
@@ -1,111 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from abc import ABC
|
3 |
-
from functools import cached_property
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import albumentations as A
|
7 |
-
from albumentations.pytorch import ToTensorV2
|
8 |
-
|
9 |
-
try:
|
10 |
-
from epoch.utils import plot_examples
|
11 |
-
except ModuleNotFoundError:
|
12 |
-
from utils import plot_examples
|
13 |
-
|
14 |
-
|
15 |
-
class MyDataSet(ABC):
|
16 |
-
DataSet = None
|
17 |
-
mean = None
|
18 |
-
std = None
|
19 |
-
classes = None
|
20 |
-
default_alb_transforms = None
|
21 |
-
|
22 |
-
def __init__(self, batch_size=1, normalize=True, shuffle=True, augment=True, alb_transforms=None):
|
23 |
-
self.batch_size = batch_size
|
24 |
-
self.normalize = normalize
|
25 |
-
self.shuffle = shuffle
|
26 |
-
self.augment = augment
|
27 |
-
self.alb_transforms = alb_transforms or self.default_alb_transforms
|
28 |
-
|
29 |
-
self.loader_kwargs = {'batch_size': batch_size, 'num_workers': os.cpu_count(), 'pin_memory': True}
|
30 |
-
|
31 |
-
@classmethod
|
32 |
-
def set_classes(cls, data):
|
33 |
-
if cls.classes is None:
|
34 |
-
cls.classes = {i: c for i, c in enumerate(data.classes)}
|
35 |
-
|
36 |
-
@cached_property
|
37 |
-
def train_data(self):
|
38 |
-
res = self.DataSet('../data', train=True, download=True, alb_transform=self.get_train_transforms())
|
39 |
-
self.set_classes(res)
|
40 |
-
return res
|
41 |
-
|
42 |
-
@cached_property
|
43 |
-
def test_data(self):
|
44 |
-
res = self.DataSet('../data', train=False, download=True, alb_transform=self.get_test_transforms())
|
45 |
-
self.set_classes(res)
|
46 |
-
return res
|
47 |
-
|
48 |
-
@cached_property
|
49 |
-
def train_loader(self):
|
50 |
-
return torch.utils.data.DataLoader(self.train_data, shuffle=self.shuffle, **self.loader_kwargs)
|
51 |
-
|
52 |
-
@cached_property
|
53 |
-
def test_loader(self):
|
54 |
-
return torch.utils.data.DataLoader(self.test_data, shuffle=False, **self.loader_kwargs)
|
55 |
-
|
56 |
-
@cached_property
|
57 |
-
def example_iter(self):
|
58 |
-
return iter(self.train_loader)
|
59 |
-
|
60 |
-
def get_train_transforms(self):
|
61 |
-
all_transforms = list()
|
62 |
-
if self.normalize:
|
63 |
-
all_transforms.append(A.Normalize(self.mean, self.std))
|
64 |
-
if self.augment and self.alb_transforms is not None:
|
65 |
-
all_transforms.extend(self.alb_transforms)
|
66 |
-
all_transforms.append(ToTensorV2())
|
67 |
-
return A.Compose(all_transforms)
|
68 |
-
|
69 |
-
def get_test_transforms(self):
|
70 |
-
all_transforms = list()
|
71 |
-
if self.normalize:
|
72 |
-
all_transforms.append(A.Normalize(self.mean, self.std))
|
73 |
-
all_transforms.append(ToTensorV2())
|
74 |
-
return A.Compose(all_transforms)
|
75 |
-
|
76 |
-
def download(self):
|
77 |
-
self.DataSet('../data', train=True, download=True)
|
78 |
-
self.DataSet('../data', train=False, download=True)
|
79 |
-
|
80 |
-
def denormalise(self, tensor):
|
81 |
-
result = tensor.clone().detach().requires_grad_(False)
|
82 |
-
if self.normalize:
|
83 |
-
for t, m, s in zip(result, self.mean, self.std):
|
84 |
-
t.mul_(s).add_(m)
|
85 |
-
return result
|
86 |
-
|
87 |
-
def show_transform(self, img):
|
88 |
-
if self.normalize:
|
89 |
-
img = self.denormalise(img)
|
90 |
-
if len(self.mean) == 3:
|
91 |
-
return img.permute(1, 2, 0)
|
92 |
-
else:
|
93 |
-
return img.squeeze(0)
|
94 |
-
|
95 |
-
def show_examples(self, figsize=(8, 6)):
|
96 |
-
batch_data, batch_label = next(self.example_iter)
|
97 |
-
images = list()
|
98 |
-
labels = list()
|
99 |
-
|
100 |
-
for i in range(len(batch_data)):
|
101 |
-
image = batch_data[i]
|
102 |
-
image = self.show_transform(image)
|
103 |
-
|
104 |
-
label = batch_label[i].item()
|
105 |
-
if self.classes is not None:
|
106 |
-
label = f'{label}:{self.classes[label]}'
|
107 |
-
|
108 |
-
images.append(image)
|
109 |
-
labels.append(label)
|
110 |
-
|
111 |
-
plot_examples(images, labels, figsize=figsize)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/custom_resnet.py
CHANGED
@@ -1,10 +1,4 @@
|
|
1 |
from torch import nn
|
2 |
-
from torch import optim
|
3 |
-
from pytorch_lightning import LightningModule
|
4 |
-
from torchmetrics import MeanMetric
|
5 |
-
from torch_lr_finder import LRFinder
|
6 |
-
|
7 |
-
from utils.metrics import RunningAccuracy
|
8 |
|
9 |
|
10 |
class ConvLayer(nn.Module):
|
@@ -51,12 +45,10 @@ class CustomLayer(nn.Module):
|
|
51 |
return x
|
52 |
|
53 |
|
54 |
-
class Model(
|
55 |
-
def __init__(self,
|
56 |
super(Model, self).__init__()
|
57 |
|
58 |
-
self.dataset = dataset
|
59 |
-
|
60 |
self.network = nn.Sequential(
|
61 |
CustomLayer(3, 64, pool=False, residue=0, dropout=dropout),
|
62 |
CustomLayer(64, 128, pool=True, residue=2, dropout=dropout),
|
@@ -67,93 +59,5 @@ class Model(LightningModule):
|
|
67 |
nn.Linear(512, 10)
|
68 |
)
|
69 |
|
70 |
-
self.criterion = nn.CrossEntropyLoss()
|
71 |
-
self.train_accuracy = RunningAccuracy()
|
72 |
-
self.val_accuracy = RunningAccuracy()
|
73 |
-
self.train_loss = MeanMetric()
|
74 |
-
self.val_loss = MeanMetric()
|
75 |
-
|
76 |
-
self.max_epochs = max_epochs
|
77 |
-
self.epoch_counter = 1
|
78 |
-
|
79 |
def forward(self, x):
|
80 |
return self.network(x)
|
81 |
-
|
82 |
-
def common_step(self, batch, loss_metric, acc_metric):
|
83 |
-
x, y = batch
|
84 |
-
batch_len = y.numel()
|
85 |
-
logits = self.forward(x)
|
86 |
-
loss = self.criterion(logits, y)
|
87 |
-
loss_metric.update(loss, batch_len)
|
88 |
-
acc_metric.update(logits, y)
|
89 |
-
return loss
|
90 |
-
|
91 |
-
def training_step(self, batch, batch_idx):
|
92 |
-
return self.common_step(batch, self.train_loss, self.train_accuracy)
|
93 |
-
|
94 |
-
def on_train_epoch_end(self):
|
95 |
-
print(f"Epoch: {self.epoch_counter}, Train: Loss: {self.train_loss.compute():0.4f}, Accuracy: "
|
96 |
-
f"{self.train_accuracy.compute():0.2f}")
|
97 |
-
self.train_loss.reset()
|
98 |
-
self.train_accuracy.reset()
|
99 |
-
self.epoch_counter += 1
|
100 |
-
|
101 |
-
def validation_step(self, batch, batch_idx):
|
102 |
-
loss = self.common_step(batch, self.val_loss, self.val_accuracy)
|
103 |
-
self.log("val_step_loss", self.val_loss, prog_bar=True, logger=True)
|
104 |
-
self.log("val_step_acc", self.val_accuracy, prog_bar=True, logger=True)
|
105 |
-
return loss
|
106 |
-
|
107 |
-
def on_validation_epoch_end(self):
|
108 |
-
print(f"Epoch: {self.epoch_counter}, Valid: Loss: {self.val_loss.compute():0.4f}, Accuracy: "
|
109 |
-
f"{self.val_accuracy.compute():0.2f}")
|
110 |
-
self.val_loss.reset()
|
111 |
-
self.val_accuracy.reset()
|
112 |
-
|
113 |
-
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
114 |
-
if isinstance(batch, list):
|
115 |
-
x, _ = batch
|
116 |
-
else:
|
117 |
-
x = batch
|
118 |
-
return self.forward(x)
|
119 |
-
|
120 |
-
def find_lr(self, optimizer):
|
121 |
-
lr_finder = LRFinder(self, optimizer, self.criterion)
|
122 |
-
lr_finder.range_test(self.dataset.train_loader, end_lr=0.1, num_iter=100, step_mode='exp')
|
123 |
-
_, best_lr = lr_finder.plot()
|
124 |
-
lr_finder.reset()
|
125 |
-
return best_lr
|
126 |
-
|
127 |
-
def configure_optimizers(self):
|
128 |
-
optimizer = optim.Adam(self.parameters(), lr=1e-7, weight_decay=1e-2)
|
129 |
-
best_lr = self.find_lr(optimizer)
|
130 |
-
scheduler = optim.lr_scheduler.OneCycleLR(
|
131 |
-
optimizer,
|
132 |
-
max_lr=best_lr,
|
133 |
-
steps_per_epoch=len(self.dataset.train_loader),
|
134 |
-
epochs=self.max_epochs,
|
135 |
-
pct_start=5/self.max_epochs,
|
136 |
-
div_factor=100,
|
137 |
-
three_phase=False,
|
138 |
-
final_div_factor=100,
|
139 |
-
anneal_strategy='linear'
|
140 |
-
)
|
141 |
-
return {
|
142 |
-
'optimizer': optimizer,
|
143 |
-
'lr_scheduler': {
|
144 |
-
"scheduler": scheduler,
|
145 |
-
"interval": "step",
|
146 |
-
}
|
147 |
-
}
|
148 |
-
|
149 |
-
def prepare_data(self):
|
150 |
-
self.dataset.download()
|
151 |
-
|
152 |
-
def train_dataloader(self):
|
153 |
-
return self.dataset.train_loader
|
154 |
-
|
155 |
-
def val_dataloader(self):
|
156 |
-
return self.dataset.test_loader
|
157 |
-
|
158 |
-
def predict_dataloader(self):
|
159 |
-
return self.val_dataloader()
|
|
|
1 |
from torch import nn
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
class ConvLayer(nn.Module):
|
|
|
45 |
return x
|
46 |
|
47 |
|
48 |
+
class Model(nn.Module):
|
49 |
+
def __init__(self, dropout=0.05):
|
50 |
super(Model, self).__init__()
|
51 |
|
|
|
|
|
52 |
self.network = nn.Sequential(
|
53 |
CustomLayer(3, 64, pool=False, residue=0, dropout=dropout),
|
54 |
CustomLayer(64, 128, pool=True, residue=2, dropout=dropout),
|
|
|
59 |
nn.Linear(512, 10)
|
60 |
)
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
def forward(self, x):
|
63 |
return self.network(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/metrics.py
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import Tensor
|
3 |
-
from torchmetrics import Metric
|
4 |
-
|
5 |
-
|
6 |
-
class RunningAccuracy(Metric):
|
7 |
-
def __init__(self, **kwargs):
|
8 |
-
super().__init__(**kwargs)
|
9 |
-
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
|
10 |
-
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
11 |
-
|
12 |
-
def update(self, preds: Tensor, target: Tensor):
|
13 |
-
preds = preds.argmax(dim=1)
|
14 |
-
total = target.numel()
|
15 |
-
self.correct += preds.eq(target).sum()
|
16 |
-
self.total += total
|
17 |
-
|
18 |
-
def compute(self):
|
19 |
-
return 100 * self.correct.float() / self.total
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|