cella110n commited on
Commit
e0ed6cc
·
verified ·
1 Parent(s): e0492a1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -12
app.py CHANGED
@@ -1,27 +1,192 @@
1
  import gradio as gr
2
- import spaces
3
- import time
 
 
4
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # --- Simple Test Function ---
7
- @spaces.GPU()
8
- def test_button_click():
9
- current_time = time.time()
10
- print(f"--- Test button clicked on GPU worker at {current_time} ---")
11
- return f"Test button clicked at {current_time}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # --- Gradio Interface Definition (Minimal) ---
14
  with gr.Blocks() as demo:
15
  gr.Markdown("""
16
- # Minimal Button Test for ZeroGPU Environment
17
- Click the button below to check if the `@spaces.GPU` decorated function is triggered.
 
18
  """)
19
  with gr.Column():
20
- test_button = gr.Button("Test GPU Button")
21
  output_text = gr.Textbox(label="Output")
22
 
23
  test_button.click(
24
- fn=test_button_click,
25
  inputs=[],
26
  outputs=[output_text]
27
  )
@@ -29,4 +194,43 @@ with gr.Blocks() as demo:
29
  # --- Main Block ---
30
  if __name__ == "__main__":
31
  if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
 
32
  demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image # Keep PIL for now, might be needed by helpers implicitly
4
+ # from PIL import Image, ImageDraw, ImageFont # No drawing yet
5
+ import json
6
  import os
7
+ import io
8
+ import requests
9
+ # import matplotlib.pyplot as plt # No plotting yet
10
+ # import matplotlib # No plotting yet
11
+ from huggingface_hub import hf_hub_download
12
+ from dataclasses import dataclass
13
+ from typing import List, Dict, Optional, Tuple
14
+ import time
15
+ import spaces # Required for @spaces.GPU
16
+
17
+ import torch # Keep torch for device check in Tagger
18
+ # import timm # No model loading yet
19
+ # from safetensors.torch import load_file as safe_load_file # No model loading yet
20
+
21
+ # MatplotlibのバックエンドをAggに設定 (Keep commented out for now)
22
+ # matplotlib.use('Agg')
23
+
24
+ # --- Data Classes and Helper Functions ---
25
+ @dataclass
26
+ class LabelData:
27
+ names: list[str]
28
+ rating: list[np.int64]
29
+ general: list[np.int64]
30
+ artist: list[np.int64]
31
+ character: list[np.int64]
32
+ copyright: list[np.int64]
33
+ meta: list[np.int64]
34
+ quality: list[np.int64]
35
+
36
+ # Keep helpers needed for initialization
37
+ def load_tag_mapping(mapping_path):
38
+ with open(mapping_path, 'r', encoding='utf-8') as f: tag_mapping_data = json.load(f)
39
+ if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
40
+ idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
41
+ tag_to_category = tag_mapping_data["tag_to_category"]
42
+ elif isinstance(tag_mapping_data, dict):
43
+ tag_mapping_data = {int(k): v for k, v in tag_mapping_data.items()}
44
+ idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data.items()}
45
+ tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data.values()}
46
+ else: raise ValueError("Unsupported tag mapping format")
47
+ names = [None] * (max(idx_to_tag.keys()) + 1)
48
+ rating, general, artist, character, copyright, meta, quality = [], [], [], [], [], [], []
49
+ for idx, tag in idx_to_tag.items():
50
+ if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
51
+ names[idx] = tag
52
+ category = tag_to_category.get(tag, 'Unknown')
53
+ idx_int = int(idx)
54
+ if category == 'Rating': rating.append(idx_int)
55
+ elif category == 'General': general.append(idx_int)
56
+ elif category == 'Artist': artist.append(idx_int)
57
+ elif category == 'Character': character.append(idx_int)
58
+ elif category == 'Copyright': copyright.append(idx_int)
59
+ elif category == 'Meta': meta.append(idx_int)
60
+ elif category == 'Quality': quality.append(idx_int)
61
+ return LabelData(names=names, rating=np.array(rating), general=np.array(general), artist=np.array(artist),
62
+ character=np.array(character), copyright=np.array(copyright), meta=np.array(meta), quality=np.array(quality)), tag_to_category
63
+
64
+ # --- Constants ---
65
+ REPO_ID = "cella110n/cl_tagger"
66
+ SAFETENSORS_FILENAME = "lora_model_0426/checkpoint_epoch_4.safetensors"
67
+ METADATA_FILENAME = "lora_model_0426/checkpoint_epoch_4_metadata.json"
68
+ TAG_MAPPING_FILENAME = "lora_model_0426/tag_mapping.json"
69
+ CACHE_DIR = "./model_cache"
70
+ # BASE_MODEL_NAME = 'eva02_large_patch14_448.mim_m38m_ft_in1k' # No model loading yet
71
+
72
+ # --- Tagger Class ---
73
+ class Tagger:
74
+ def __init__(self):
75
+ print("Initializing Tagger...")
76
+ self.safetensors_path = None
77
+ self.metadata_path = None
78
+ self.tag_mapping_path = None
79
+ self.labels_data = None
80
+ self.tag_to_category = None
81
+ self.model = None # Model will be loaded later
82
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
+ self._initialize_paths_and_labels()
84
+ print("Tagger Initialized.") # Add confirmation
85
+
86
+ def _download_files(self):
87
+ # Check if paths are already set and files exist (useful for restarts)
88
+ local_safetensors = os.path.join(CACHE_DIR, 'models--cella110n--cl_tagger', 'snapshots', '21e237f0ae461b8d9ebf7472ae8de003e5effe5b', SAFETENSORS_FILENAME)
89
+ local_tag_mapping = os.path.join(CACHE_DIR, 'models--cella110n--cl_tagger', 'snapshots', '21e237f0ae461b8d9ebf7472ae8de003e5effe5b', TAG_MAPPING_FILENAME)
90
+ local_metadata = os.path.join(CACHE_DIR, 'models--cella110n--cl_tagger', 'snapshots', '21e237f0ae461b8d9ebf7472ae8de003e5effe5b', METADATA_FILENAME)
91
+
92
+ needs_download = False
93
+ if not (self.safetensors_path and os.path.exists(self.safetensors_path)):
94
+ if os.path.exists(local_safetensors):
95
+ self.safetensors_path = local_safetensors
96
+ print(f"Found existing safetensors: {self.safetensors_path}")
97
+ else:
98
+ needs_download = True
99
+ if not (self.tag_mapping_path and os.path.exists(self.tag_mapping_path)):
100
+ if os.path.exists(local_tag_mapping):
101
+ self.tag_mapping_path = local_tag_mapping
102
+ print(f"Found existing tag mapping: {self.tag_mapping_path}")
103
+ else:
104
+ needs_download = True
105
+ # Metadata is optional, check similarly
106
+ if not (self.metadata_path and os.path.exists(self.metadata_path)):
107
+ if os.path.exists(local_metadata):
108
+ self.metadata_path = local_metadata
109
+ print(f"Found existing metadata: {self.metadata_path}")
110
+ # Don't trigger download just for metadata if others exist
111
+
112
+ if not needs_download and self.safetensors_path and self.tag_mapping_path:
113
+ print("Required files already exist or paths set.")
114
+ return
115
 
116
+ print("Downloading model files...")
117
+ hf_token = os.environ.get("HF_TOKEN")
118
+ try:
119
+ # Only download if not found locally
120
+ if not self.safetensors_path:
121
+ self.safetensors_path = hf_hub_download(repo_id=REPO_ID, filename=SAFETENSORS_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False) # Use force_download=False
122
+ if not self.tag_mapping_path:
123
+ self.tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
124
+ print(f"Safetensors: {self.safetensors_path}")
125
+ print(f"Tag mapping: {self.tag_mapping_path}")
126
+ try:
127
+ # Only download if not found locally
128
+ if not self.metadata_path:
129
+ self.metadata_path = hf_hub_download(repo_id=REPO_ID, filename=METADATA_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
130
+ print(f"Metadata: {self.metadata_path}")
131
+ except Exception as e_meta:
132
+ # Handle case where metadata genuinely doesn't exist or download fails
133
+ print(f"Metadata ({METADATA_FILENAME}) not found/download failed. Error: {e_meta}")
134
+ self.metadata_path = None
135
+
136
+ except Exception as e:
137
+ print(f"Error downloading files: {e}")
138
+ if "401 Client Error" in str(e) or "Repository not found" in str(e): raise gr.Error(f"Could not download files from {REPO_ID}. Check HF_TOKEN or repository status.")
139
+ else: raise gr.Error(f"Error downloading files: {e}")
140
+
141
+ def _initialize_paths_and_labels(self):
142
+ # Call download first (it now checks existence)
143
+ self._download_files()
144
+ # Only load labels if not already loaded
145
+ if self.labels_data is None:
146
+ print("Loading labels...")
147
+ if self.tag_mapping_path and os.path.exists(self.tag_mapping_path):
148
+ try:
149
+ self.labels_data, self.tag_to_category = load_tag_mapping(self.tag_mapping_path)
150
+ print(f"Labels loaded. Count: {len(self.labels_data.names)}")
151
+ except Exception as e: raise gr.Error(f"Error loading tag mapping: {e}")
152
+ else:
153
+ # This should ideally not happen if download worked
154
+ raise gr.Error(f"Tag mapping file not found at expected path: {self.tag_mapping_path}")
155
+ else:
156
+ print("Labels already loaded.")
157
+
158
+ # Add a simple test method decorated with GPU
159
+ @spaces.GPU()
160
+ def test_gpu_method(self):
161
+ current_time = time.time()
162
+ print(f"--- Tagger.test_gpu_method called on GPU worker at {current_time} ---")
163
+ # Check if labels are accessible from the GPU worker context
164
+ label_count = len(self.labels_data.names) if self.labels_data else -1
165
+ print(f"--- (Worker) Label count: {label_count} ---")
166
+ return f"Tagger method called at {current_time}. Label count: {label_count}"
167
+
168
+ # --- Original predict_on_gpu (Keep commented out for this test) ---
169
+ # @spaces.GPU()
170
+ # def predict_on_gpu(self, image_input, gen_threshold, char_threshold, output_mode):
171
+ # # ... (original prediction logic including model loading) ...
172
+ # pass
173
+
174
+ # Instantiate the tagger class (this will download files/load labels)
175
+ tagger = Tagger()
176
 
177
  # --- Gradio Interface Definition (Minimal) ---
178
  with gr.Blocks() as demo:
179
  gr.Markdown("""
180
+ # Tagger Initialization + Minimal Button Test
181
+ Instantiates Tagger, then click the button below to check if a simple `@spaces.GPU` decorated *method* is triggered.
182
+ Check logs for Tagger initialization messages.
183
  """)
184
  with gr.Column():
185
+ test_button = gr.Button("Test Tagger GPU Method")
186
  output_text = gr.Textbox(label="Output")
187
 
188
  test_button.click(
189
+ fn=tagger.test_gpu_method, # Call the simple method on the instance
190
  inputs=[],
191
  outputs=[output_text]
192
  )
 
194
  # --- Main Block ---
195
  if __name__ == "__main__":
196
  if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
197
+ # Tagger instance is created above
198
  demo.launch(share=True)
199
+
200
+ # --- Commented out original UI and helpers/constants not needed for init/simple test ---
201
+ """
202
+ # import matplotlib.pyplot as plt
203
+ # import matplotlib
204
+ # matplotlib.use('Agg')
205
+ # from PIL import Image, ImageDraw, ImageFont
206
+ # import timm
207
+ # from safetensors.torch import load_file as safe_load_file
208
+
209
+ # def pil_ensure_rgb(...)
210
+ # def pil_pad_square(...)
211
+ # def get_tags(...)
212
+ # def visualize_predictions(...)
213
+ # def preprocess_image(...)
214
+
215
+ # BASE_MODEL_NAME = 'eva02_large_patch14_448.mim_m38m_ft_in1k'
216
+
217
+ # class Tagger:
218
+ # ... (methods related to prediction, model loading) ...
219
+ # def _load_model_on_gpu(self):
220
+ # ...
221
+ # @spaces.GPU()
222
+ # def predict_on_gpu(self, image_input, gen_threshold, char_threshold, output_mode):
223
+ # ...
224
+
225
+ # --- Original Full Gradio UI ---
226
+ # css = ...
227
+ # js = ...
228
+ # with gr.Blocks(css=css, js=js) as demo:
229
+ # gr.Markdown("# WD EVA02 LoRA PyTorch Tagger")
230
+ # ...
231
+ # predict_button.click(
232
+ # fn=tagger.predict_on_gpu,
233
+ # ...
234
+ # )
235
+ # gr.Examples(...)
236
+ """