import gradio as gr from huggingface_hub import hf_hub_download import torch import torch.nn as nn from torchvision import transforms class SimpleResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, set_stride=False): super().__init__() stride = 2 if in_channels != out_channels and set_stride else 1 self.conv1 = nn.LazyConv2d( out_channels, kernel_size=3, padding="same" if stride == 1 else 1, stride=stride, ) self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same") self.bn1 = nn.LazyBatchNorm2d() self.bn2 = nn.LazyBatchNorm2d() self.relu = nn.ReLU() if in_channels != out_channels: self.residual = nn.Sequential( nn.LazyConv2d(out_channels, kernel_size=1, stride=stride), nn.LazyBatchNorm2d(), ) else: self.residual = nn.Identity() def forward(self, x): out = self.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.residual(x) out = self.relu(out) return out class BottleneckResidualBlock(nn.Module): def __init__( self, in_channels, out_channels, identity_mapping=False, set_stride=False ): super().__init__() stride = 2 if in_channels != out_channels and set_stride else 1 self.conv1 = nn.LazyConv2d( out_channels, kernel_size=1, padding="same" if stride == 1 else 0, stride=stride, ) self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same") self.conv3 = nn.LazyConv2d(out_channels * 4, kernel_size=1, padding="same") self.bn1 = nn.LazyBatchNorm2d() self.bn2 = nn.LazyBatchNorm2d() self.bn3 = nn.LazyBatchNorm2d() self.relu = nn.ReLU() if in_channels != out_channels or not identity_mapping: self.residual = nn.Sequential( nn.LazyConv2d(out_channels * 4, kernel_size=1, stride=stride), nn.LazyBatchNorm2d(), ) else: self.residual = nn.Identity() def forward(self, x): out = self.relu(self.bn1(self.conv1(x))) out = self.relu(self.bn2(self.conv2(out))) out = self.bn3(self.conv3(out)) out += self.residual(x) out = self.relu(out) return out RESNET_18 = [2, 2, 2, 2] RESNET_34 = [3, 4, 6, 3] RESNET_50 = [3, 4, 6, 3] RESNET_101 = [3, 4, 23, 3] RESNET_152 = [3, 8, 36, 3] class ResNet(nn.Module): def __init__(self, arch=RESNET_18, block="simple", num_classes=256): super().__init__() self.conv1 = nn.Sequential( nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3), nn.LazyBatchNorm2d(), nn.ReLU(), ) self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) self.conv2 = self._make_layer(64, 64, arch[0], set_stride=False, block=block) self.conv3 = self._make_layer(64, 128, arch[1], block=block) self.conv4 = self._make_layer(128, 256, arch[2], block=block) self.conv5 = self._make_layer(256, 512, arch[3], block=block) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.flatten = nn.Flatten() self.fc = nn.LazyLinear(num_classes) def _make_layer( self, in_channels, out_channels, num_blocks, set_stride=True, block="simple" ): """Block is either 'simple' or 'bottleneck'""" layers = [] for i in range(num_blocks): layers.append( SimpleResidualBlock(in_channels, out_channels, set_stride=set_stride) if block == "simple" else BottleneckResidualBlock( in_channels if i == 0 else out_channels * 4, out_channels, set_stride=set_stride, ) ) set_stride = False return nn.Sequential(*layers) def forward(self, x): out = self.conv1(x) out = self.maxpool(self.conv2(out)) out = self.conv3(out) out = self.conv4(out) out = self.conv5(out) out = self.avgpool(out) out = self.flatten(out) out = self.fc(out) return out def _init_weights(module): # Initlize weights with glorot uniform if isinstance(module, nn.Conv2d): nn.init.xavier_uniform_(module.weight) nn.init.zeros_(module.bias) elif isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) nn.init.zeros_(module.bias) class ImageClassifier: def __init__(self, checkpoint_path): self.checkpoint_path = checkpoint_path self.model = self.load_model(checkpoint_path) self.transform = self.get_transform((244, 244)) self.labels = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", ] def load_model(self, checkpoint_path): classifier = ResNet( arch=RESNET_18, block="simple", num_classes=10, ) classifier.load_state_dict(torch.load(checkpoint_path)) classifier = classifier.cpu() classifier.eval() return classifier def get_transform(self, img_shape): preprocess_transform = transforms.Compose( [ transforms.Resize(img_shape), transforms.ToTensor(), ] ) return preprocess_transform def predict(self, image): image_tensor = self.transform(image).unsqueeze(0) with torch.no_grad(): logits = self.model(image_tensor) probs = logits.softmax(dim=1)[0] return {label: prob.item() for label, prob in zip(self.labels, probs)} def classify(self, input_image): return self.predict(input_image) def classify(input_image): return classifier.classify(input_image) checkpoint_path = hf_hub_download( repo_id="SatwikKambham/resnet18-cifar10", filename="model.pt", ) classifier = ImageClassifier(checkpoint_path) iface = gr.Interface( classify, inputs=[ gr.Image(label="Input Image", type="pil"), ], outputs=gr.Label(num_top_classes=3), ) iface.launch()