File size: 3,231 Bytes
9bdd97c
 
 
 
 
 
 
 
f57e36f
9bdd97c
 
 
 
 
 
 
 
 
 
 
2278162
8f553c7
9bdd97c
ed7463d
9bdd97c
f57e36f
 
9bdd97c
 
f57e36f
 
 
 
9bdd97c
f57e36f
9bdd97c
ed7463d
f57e36f
 
 
 
 
 
 
 
ed7463d
f57e36f
 
 
 
 
 
 
 
 
 
 
ed7463d
f57e36f
 
 
9bdd97c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76221dd
9bdd97c
 
 
 
 
 
ed7463d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python

from __future__ import annotations

import functools
import os
import pathlib
import sys
import tarfile

import gradio as gr
import huggingface_hub
import PIL.Image
import torch
import torchvision

sys.path.insert(0, 'bizarre-pose-estimator')

from _util.twodee_v0 import I as ImageWrapper

TITLE = 'ShuhongChen/bizarre-pose-estimator (tagger)'
DESCRIPTION = 'This is an unofficial demo for https://github.com/ShuhongChen/bizarre-pose-estimator.'

HF_TOKEN = os.getenv('HF_TOKEN')
MODEL_REPO = 'hysts/bizarre-pose-estimator-models'
MODEL_FILENAME = 'tagger.pth'
LABEL_FILENAME = 'tags.txt'


def load_sample_image_paths() -> list[pathlib.Path]:
    image_dir = pathlib.Path('images')
    if not image_dir.exists():
        dataset_repo = 'hysts/sample-images-TADNE'
        path = huggingface_hub.hf_hub_download(dataset_repo,
                                               'images.tar.gz',
                                               repo_type='dataset',
                                               use_auth_token=HF_TOKEN)
        with tarfile.open(path) as f:
            f.extractall()
    return sorted(image_dir.glob('*'))


def load_model(device: torch.device) -> torch.nn.Module:
    path = huggingface_hub.hf_hub_download(MODEL_REPO,
                                           MODEL_FILENAME,
                                           use_auth_token=HF_TOKEN)
    state_dict = torch.load(path)
    model = torchvision.models.resnet50(num_classes=1062)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model


def load_labels() -> list[str]:
    label_path = huggingface_hub.hf_hub_download(MODEL_REPO,
                                                 LABEL_FILENAME,
                                                 use_auth_token=HF_TOKEN)
    with open(label_path) as f:
        labels = [line.strip() for line in f.readlines()]
    return labels


@torch.inference_mode()
def predict(image: PIL.Image.Image, score_threshold: float,
            device: torch.device, model: torch.nn.Module,
            labels: list[str]) -> dict[str, float]:
    data = ImageWrapper(image).resize_square(256).alpha_bg(
        c='w').convert('RGB').tensor()
    data = data.to(device).unsqueeze(0)

    preds = model(data)[0]
    preds = torch.sigmoid(preds)
    preds = preds.cpu().numpy().astype(float)

    res = dict()
    for prob, label in zip(preds.tolist(), labels):
        if prob < score_threshold:
            continue
        res[label] = prob
    return res


image_paths = load_sample_image_paths()
examples = [[path.as_posix(), 0.5] for path in image_paths]

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = load_model(device)
labels = load_labels()

func = functools.partial(predict, device=device, model=model, labels=labels)

gr.Interface(
    fn=func,
    inputs=[
        gr.Image(label='Input', type='pil'),
        gr.Slider(label='Score Threshold',
                  minimum=0,
                  maximum=1,
                  step=0.05,
                  value=0.5),
    ],
    outputs=gr.Label(label='Output'),
    examples=examples,
    title=TITLE,
    description=DESCRIPTION,
).queue().launch(show_api=False)