File size: 2,085 Bytes
32e703d
 
 
448703c
32e703d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6a5948
 
 
 
32e703d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448703c
32e703d
 
 
 
 
 
 
 
 
448703c
32e703d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import numpy as np
import onnxruntime as ort
import os
import torchvision as tv
from huggingface_hub import hf_hub_download

CATEGORIES = [
    "agricultural",
    "airplane",
    "baseballdiamond",
    "beach",
    "buildings",
    "chaparral",
    "denseresidential",
    "forest",
    "freeway",
    "golfcourse",
    "harbor",
    "intersection",
    "mediumresidential",
    "mobilehomepark",
    "overpass",
    "parkinglot",
    "river",
    "runway",
    "sparseresidential",
    "storagetanks",
    "tenniscourt",
]


class Classifier:
    def __init__(self, model_path):
        self.model_path = model_path
        self.session = ort.InferenceSession(
            self.model_path,
            providers=["AzureExecutionProvider", "CPUExecutionProvider"],
        )

        self.img_transforms = tv.transforms.Compose(
            [
                tv.transforms.Resize((256, 256)),
                tv.transforms.ToTensor(),
                tv.transforms.Normalize(
                    (0.48422758, 0.49005175, 0.45050276),
                    (0.17348297, 0.16352356, 0.15547496),
                ),
            ]
        )

    def predict(self, image):
        inp = self.img_transforms(image).unsqueeze(0).numpy()
        logits = self.session.run(
            None,
            {self.session.get_inputs()[0].name: inp},
        )[0]
        probs = np.exp(logits) / np.sum(np.exp(logits))
        return {
            category: float(prob)
            for category, prob in zip(
                CATEGORIES,
                probs[0],
            )
        }


test_images = os.listdir("UCMercedTestImages")
model_path = hf_hub_download(
    repo_id="SatwikKambham/land_use_classifier",
    filename="model.onnx",
)
classifier = Classifier(model_path)
interface = gr.Interface(
    fn=classifier.predict,
    inputs=gr.components.Image(label="Input image", type="pil"),
    outputs=gr.components.Label(label="Predicted class", num_top_classes=3),
    examples=[["UCMercedTestImages/" + test_image] for test_image in test_images],
)
interface.launch()