hysts HF Staff commited on
Commit
7ad735d
·
1 Parent(s): 77bce91
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +33 -35
  3. style.css +3 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🐠
4
  colorFrom: green
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: green
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.34.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -15,13 +15,9 @@ import PIL.Image
15
  import torch
16
  import torchvision.transforms as T
17
 
18
- TITLE = 'RF5/danbooru-pretrained'
19
- DESCRIPTION = 'This is an unofficial demo for https://github.com/RF5/danbooru-pretrained.'
20
 
21
- HF_TOKEN = os.getenv('HF_TOKEN')
22
- MODEL_REPO = 'hysts/danbooru-pretrained'
23
- MODEL_FILENAME = 'resnet50-13306192.pth'
24
- LABEL_FILENAME = 'class_names_6000.json'
25
 
26
 
27
  def load_sample_image_paths() -> list[pathlib.Path]:
@@ -30,17 +26,14 @@ def load_sample_image_paths() -> list[pathlib.Path]:
30
  dataset_repo = 'hysts/sample-images-TADNE'
31
  path = huggingface_hub.hf_hub_download(dataset_repo,
32
  'images.tar.gz',
33
- repo_type='dataset',
34
- use_auth_token=HF_TOKEN)
35
  with tarfile.open(path) as f:
36
  f.extractall()
37
  return sorted(image_dir.glob('*'))
38
 
39
 
40
  def load_model(device: torch.device) -> torch.nn.Module:
41
- path = huggingface_hub.hf_hub_download(MODEL_REPO,
42
- MODEL_FILENAME,
43
- use_auth_token=HF_TOKEN)
44
  state_dict = torch.load(path)
45
  model = torch.hub.load('RF5/danbooru-pretrained',
46
  'resnet50',
@@ -52,9 +45,7 @@ def load_model(device: torch.device) -> torch.nn.Module:
52
 
53
 
54
  def load_labels() -> list[str]:
55
- path = huggingface_hub.hf_hub_download(MODEL_REPO,
56
- LABEL_FILENAME,
57
- use_auth_token=HF_TOKEN)
58
  with open(path) as f:
59
  labels = json.load(f)
60
  return labels
@@ -91,24 +82,31 @@ transform = T.Compose([
91
  T.Normalize(mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]),
92
  ])
93
 
94
- func = functools.partial(predict,
95
- transform=transform,
96
- device=device,
97
- model=model,
98
- labels=labels)
99
-
100
- gr.Interface(
101
- fn=func,
102
- inputs=[
103
- gr.Image(label='Input', type='pil'),
104
- gr.Slider(label='Score Threshold',
105
- minimum=0,
106
- maximum=1,
107
- step=0.05,
108
- value=0.4),
109
- ],
110
- outputs=gr.Label(label='Output'),
111
- examples=examples,
112
- title=TITLE,
113
- description=DESCRIPTION,
114
- ).queue().launch(show_api=False)
 
 
 
 
 
 
 
 
15
  import torch
16
  import torchvision.transforms as T
17
 
18
+ DESCRIPTION = '# [RF5/danbooru-pretrained](https://github.com/RF5/danbooru-pretrained)'
 
19
 
20
+ MODEL_REPO = 'public-data/danbooru-pretrained'
 
 
 
21
 
22
 
23
  def load_sample_image_paths() -> list[pathlib.Path]:
 
26
  dataset_repo = 'hysts/sample-images-TADNE'
27
  path = huggingface_hub.hf_hub_download(dataset_repo,
28
  'images.tar.gz',
29
+ repo_type='dataset')
 
30
  with tarfile.open(path) as f:
31
  f.extractall()
32
  return sorted(image_dir.glob('*'))
33
 
34
 
35
  def load_model(device: torch.device) -> torch.nn.Module:
36
+ path = huggingface_hub.hf_hub_download(MODEL_REPO, 'resnet50-13306192.pth')
 
 
37
  state_dict = torch.load(path)
38
  model = torch.hub.load('RF5/danbooru-pretrained',
39
  'resnet50',
 
45
 
46
 
47
  def load_labels() -> list[str]:
48
+ path = huggingface_hub.hf_hub_download(MODEL_REPO, 'class_names_6000.json')
 
 
49
  with open(path) as f:
50
  labels = json.load(f)
51
  return labels
 
82
  T.Normalize(mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]),
83
  ])
84
 
85
+ fn = functools.partial(predict,
86
+ transform=transform,
87
+ device=device,
88
+ model=model,
89
+ labels=labels)
90
+
91
+ with gr.Blocks(css='style.css') as demo:
92
+ gr.Markdown(DESCRIPTION)
93
+ with gr.Row():
94
+ with gr.Column():
95
+ image = gr.Image(label='Input', type='pil')
96
+ threshold = gr.Slider(label='Score Threshold',
97
+ minimum=0,
98
+ maximum=1,
99
+ step=0.05,
100
+ value=0.4)
101
+ run_button = gr.Button('Run')
102
+ with gr.Column():
103
+ result = gr.Label(label='Output')
104
+
105
+ inputs = [image, threshold]
106
+ gr.Examples(examples=examples,
107
+ inputs=inputs,
108
+ outputs=result,
109
+ fn=fn,
110
+ cache_examples=os.getenv('CACHE_EXAMPLES') == '1')
111
+ run_button.click(fn=fn, inputs=inputs, outputs=result, api_name='predict')
112
+ demo.queue(max_size=15).launch()
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }