cella110n commited on
Commit
e0492a1
·
verified ·
1 Parent(s): 1acf660

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -324
app.py CHANGED
@@ -1,333 +1,32 @@
1
- # --- Imports ---
2
  import gradio as gr
3
- import numpy as np
4
- from PIL import Image, ImageDraw, ImageFont
5
- import json
6
- import os
7
- import io
8
- import requests
9
- import matplotlib.pyplot as plt
10
- import matplotlib
11
- from huggingface_hub import hf_hub_download
12
- from dataclasses import dataclass
13
- from typing import List, Dict, Optional, Tuple
14
  import time
15
- import spaces # Required for @spaces.GPU
16
-
17
- import torch
18
- import timm
19
- from safetensors.torch import load_file as safe_load_file
20
-
21
- # MatplotlibのバックエンドをAggに設定
22
- matplotlib.use('Agg')
23
-
24
- # --- Data Classes and Helper Functions ---
25
- @dataclass
26
- class LabelData:
27
- names: list[str]
28
- rating: list[np.int64]
29
- general: list[np.int64]
30
- artist: list[np.int64]
31
- character: list[np.int64]
32
- copyright: list[np.int64]
33
- meta: list[np.int64]
34
- quality: list[np.int64]
35
-
36
- def pil_ensure_rgb(image: Image.Image) -> Image.Image:
37
- if image.mode not in ["RGB", "RGBA"]:
38
- image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
39
- if image.mode == "RGBA":
40
- background = Image.new("RGB", image.size, (255, 255, 255))
41
- background.paste(image, mask=image.split()[3])
42
- image = background
43
- return image
44
-
45
- def pil_pad_square(image: Image.Image) -> Image.Image:
46
- width, height = image.size
47
- if width == height: return image
48
- new_size = max(width, height)
49
- new_image = Image.new("RGB", (new_size, new_size), (255, 255, 255))
50
- paste_position = ((new_size - width) // 2, (new_size - height) // 2)
51
- new_image.paste(image, paste_position)
52
- return new_image
53
-
54
- def load_tag_mapping(mapping_path):
55
- with open(mapping_path, 'r', encoding='utf-8') as f: tag_mapping_data = json.load(f)
56
- if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
57
- idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
58
- tag_to_category = tag_mapping_data["tag_to_category"]
59
- elif isinstance(tag_mapping_data, dict):
60
- tag_mapping_data = {int(k): v for k, v in tag_mapping_data.items()}
61
- idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data.items()}
62
- tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data.values()}
63
- else: raise ValueError("Unsupported tag mapping format")
64
- names = [None] * (max(idx_to_tag.keys()) + 1)
65
- rating, general, artist, character, copyright, meta, quality = [], [], [], [], [], [], []
66
- for idx, tag in idx_to_tag.items():
67
- if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
68
- names[idx] = tag
69
- category = tag_to_category.get(tag, 'Unknown')
70
- idx_int = int(idx)
71
- if category == 'Rating': rating.append(idx_int)
72
- elif category == 'General': general.append(idx_int)
73
- elif category == 'Artist': artist.append(idx_int)
74
- elif category == 'Character': character.append(idx_int)
75
- elif category == 'Copyright': copyright.append(idx_int)
76
- elif category == 'Meta': meta.append(idx_int)
77
- elif category == 'Quality': quality.append(idx_int)
78
- return LabelData(names=names, rating=np.array(rating), general=np.array(general), artist=np.array(artist),
79
- character=np.array(character), copyright=np.array(copyright), meta=np.array(meta), quality=np.array(quality)), tag_to_category
80
-
81
- def get_tags(probs, labels: LabelData, gen_threshold, char_threshold):
82
- result = {"rating": [], "general": [], "character": [], "copyright": [], "artist": [], "meta": [], "quality": []}
83
- if labels.rating.size > 0:
84
- valid_indices = labels.rating[labels.rating < len(probs)]
85
- if valid_indices.size > 0:
86
- rating_probs = probs[valid_indices]
87
- if rating_probs.size > 0:
88
- rating_idx = np.argmax(rating_probs); original_idx = valid_indices[rating_idx]
89
- if original_idx < len(labels.names): result["rating"].append((labels.names[original_idx], float(rating_probs[rating_idx])))
90
- if labels.quality.size > 0:
91
- valid_indices = labels.quality[labels.quality < len(probs)]
92
- if valid_indices.size > 0:
93
- quality_probs = probs[valid_indices]
94
- if quality_probs.size > 0:
95
- quality_idx = np.argmax(quality_probs); original_idx = valid_indices[quality_idx]
96
- if original_idx < len(labels.names): result["quality"].append((labels.names[original_idx], float(quality_probs[quality_idx])))
97
- category_map = {"general": (labels.general, gen_threshold), "character": (labels.character, char_threshold),
98
- "copyright": (labels.copyright, char_threshold), "artist": (labels.artist, char_threshold), "meta": (labels.meta, gen_threshold)}
99
- for category, (indices, threshold) in category_map.items():
100
- if indices.size > 0:
101
- valid_indices = indices[(indices < len(probs)) & (indices < len(labels.names))]
102
- if valid_indices.size > 0:
103
- category_probs = probs[valid_indices]; mask = category_probs >= threshold
104
- selected_indices = valid_indices[mask]; selected_probs = category_probs[mask]
105
- for idx, prob in zip(selected_indices, selected_probs): result[category].append((labels.names[idx], float(prob)))
106
- for k in result: result[k] = sorted(result[k], key=lambda x: x[1], reverse=True)
107
- return result
108
-
109
- def visualize_predictions(image: Image.Image, predictions, threshold=0.45):
110
- filtered_meta = [(tag, prob) for tag, prob in predictions.get("meta", []) if not any(p in tag.lower() for p in ['id', 'commentary', 'request', 'mismatch'])]
111
- predictions["meta"] = filtered_meta
112
- fig = plt.figure(figsize=(20, 12), dpi=100); gs = fig.add_gridspec(1, 2, width_ratios=[1.2, 1])
113
- ax_img = fig.add_subplot(gs[0, 0]); ax_img.imshow(image); ax_img.set_title("Original Image"); ax_img.axis('off')
114
- ax_tags = fig.add_subplot(gs[0, 1])
115
- all_tags, all_probs, all_colors = [], [], []
116
- color_map = {'rating': 'red', 'character': 'blue', 'copyright': 'purple', 'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow'}
117
- for cat, prefix, color in [('rating', 'R', 'red'), ('character', 'C', 'blue'), ('copyright', '©', 'purple'), ('artist', 'A', 'orange'), ('general', 'G', 'green'), ('meta', 'M', 'gray'), ('quality', 'Q', 'yellow')]:
118
- for tag, prob in predictions.get(cat, []): all_tags.append(f"[{prefix}] {tag}"); all_probs.append(prob); all_colors.append(color)
119
- if not all_tags:
120
- ax_tags.text(0.5, 0.5, "No tags found", ha='center', va='center'); ax_tags.set_title(f"Tags (threshold={threshold})"); ax_tags.axis('off')
121
- else:
122
- sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i]); all_tags = [all_tags[i] for i in sorted_indices]; all_probs = [all_probs[i] for i in sorted_indices]; all_colors = [all_colors[i] for i in sorted_indices]
123
- num_tags = len(all_tags); bar_height = min(0.8, 0.8 * (30 / num_tags) if num_tags > 30 else 0.8); y_positions = np.arange(num_tags)
124
- bars = ax_tags.barh(y_positions, all_probs, height=bar_height, color=all_colors); ax_tags.set_yticks(y_positions); ax_tags.set_yticklabels(all_tags)
125
- fontsize = 10 if num_tags <= 40 else 8 if num_tags <= 60 else 6; [label.set_fontsize(fontsize) for label in ax_tags.get_yticklabels()]
126
- for i, (bar, prob) in enumerate(zip(bars, all_probs)): ax_tags.text(min(prob + 0.02, 0.98), y_positions[i], f"{prob:.3f}", va='center', fontsize=fontsize)
127
- ax_tags.set_xlim(0, 1); ax_tags.set_title(f"Tags (threshold={threshold})")
128
- from matplotlib.patches import Patch; legend_elements = [Patch(facecolor=color, label=cat.capitalize()) for cat, color in color_map.items() if cat in predictions and predictions[cat]]; ax_tags.legend(handles=legend_elements, loc='lower right', fontsize=8)
129
- plt.tight_layout(); plt.subplots_adjust(bottom=0.05); buf = io.BytesIO(); plt.savefig(buf, format='png', dpi=100); plt.close(fig); buf.seek(0)
130
- return Image.open(buf)
131
-
132
- def preprocess_image(image: Image.Image, target_size=(448, 448)):
133
- image = pil_ensure_rgb(image)
134
- image = pil_pad_square(image)
135
- image_resized = image.resize(target_size, Image.BICUBIC)
136
- img_array = np.array(image_resized, dtype=np.float32) / 255.0
137
- img_array = img_array.transpose(2, 0, 1)
138
- img_array = img_array[::-1, :, :] # BGR
139
- mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
140
- std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
141
- img_array = (img_array - mean) / std
142
- img_tensor = torch.from_numpy(img_array).unsqueeze(0)
143
- return image, img_tensor
144
-
145
- # --- Constants ---
146
- REPO_ID = "cella110n/cl_tagger"
147
- SAFETENSORS_FILENAME = "lora_model_0426/checkpoint_epoch_4.safetensors"
148
- METADATA_FILENAME = "lora_model_0426/checkpoint_epoch_4_metadata.json"
149
- TAG_MAPPING_FILENAME = "lora_model_0426/tag_mapping.json"
150
- CACHE_DIR = "./model_cache"
151
- BASE_MODEL_NAME = 'eva02_large_patch14_448.mim_m38m_ft_in1k' # Define base model name
152
-
153
- # --- Tagger Class ---
154
- class Tagger:
155
- def __init__(self):
156
- print("Initializing Tagger...")
157
- self.safetensors_path = None
158
- self.metadata_path = None
159
- self.tag_mapping_path = None
160
- self.labels_data = None
161
- self.tag_to_category = None
162
- self.model = None # Model will be loaded on first predict call
163
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
164
- self._initialize_paths_and_labels()
165
-
166
- def _download_files(self):
167
- if self.safetensors_path and self.tag_mapping_path and os.path.exists(self.safetensors_path) and os.path.exists(self.tag_mapping_path):
168
- print("Files seem already downloaded.")
169
- return
170
- print("Downloading model files...")
171
- hf_token = os.environ.get("HF_TOKEN")
172
- try:
173
- self.safetensors_path = hf_hub_download(repo_id=REPO_ID, filename=SAFETENSORS_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True)
174
- self.tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True)
175
- print(f"Safetensors: {self.safetensors_path}")
176
- print(f"Tag mapping: {self.tag_mapping_path}")
177
- try:
178
- self.metadata_path = hf_hub_download(repo_id=REPO_ID, filename=METADATA_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True)
179
- print(f"Metadata: {self.metadata_path}")
180
- except Exception: print(f"Metadata ({METADATA_FILENAME}) not found/download failed."); self.metadata_path = None
181
- except Exception as e:
182
- print(f"Error downloading files: {e}")
183
- if "401 Client Error" in str(e) or "Repository not found" in str(e): raise gr.Error(f"Could not download files from {REPO_ID}. Check HF_TOKEN.")
184
- else: raise gr.Error(f"Error downloading files: {e}")
185
-
186
- def _initialize_paths_and_labels(self):
187
- if self.labels_data is None:
188
- self._download_files()
189
- print("Loading labels...")
190
- if self.tag_mapping_path and os.path.exists(self.tag_mapping_path):
191
- try:
192
- self.labels_data, self.tag_to_category = load_tag_mapping(self.tag_mapping_path)
193
- print(f"Labels loaded. Count: {len(self.labels_data.names)}")
194
- except Exception as e: raise gr.Error(f"Error loading tag mapping: {e}")
195
- else: raise gr.Error("Tag mapping file not found.")
196
-
197
- def _load_model_on_gpu(self):
198
- print("Loading PyTorch model for GPU worker...")
199
- if not self.safetensors_path or not self.labels_data:
200
- raise gr.Error("Model paths or labels not initialized before loading.")
201
- try:
202
- num_classes = len(self.labels_data.names)
203
- if num_classes <= 0: raise ValueError(f"Invalid num_classes: {num_classes}")
204
- print(f"Creating base model: {BASE_MODEL_NAME} with {num_classes} classes")
205
- model = timm.create_model(BASE_MODEL_NAME, pretrained=True, num_classes=num_classes)
206
-
207
- print(f"Loading state dict from: {self.safetensors_path}")
208
- if not os.path.exists(self.safetensors_path): raise FileNotFoundError(f"File not found: {self.safetensors_path}")
209
- state_dict = safe_load_file(self.safetensors_path)
210
- # --- Key Adaptation Logic (Important!) ---
211
- # This needs to match how the safetensors were saved by lora.py merge
212
- # If the merge script saved the full model state_dict directly, no adaptation needed.
213
- # If it saved only LoRA weights or prefixed keys, adaptation is required.
214
- # Assuming direct match for now:
215
- adapted_state_dict = state_dict
216
- # Example if keys were prefixed with 'base_model.':
217
- # adapted_state_dict = {k.replace('base_model.', ''): v for k, v in state_dict.items()}
218
- # -----------------------------------------
219
- missing_keys, unexpected_keys = model.load_state_dict(adapted_state_dict, strict=False)
220
- print(f"State dict loaded. Missing: {missing_keys}")
221
- print(f"State dict loaded. Unexpected: {unexpected_keys}")
222
- if any(k.startswith('head.') for k in missing_keys): print("Warning: Head weights missing/mismatched!")
223
-
224
- print(f"Moving model to device: {self.device}")
225
- model.to(self.device)
226
- model.eval()
227
- self.model = model # Store loaded model
228
- print("Model loaded successfully on GPU worker.")
229
- except Exception as e:
230
- print(f"(Worker) Error loading PyTorch model: {e}")
231
- import traceback; print(traceback.format_exc())
232
- raise gr.Error(f"Error loading PyTorch model: {e}")
233
-
234
- @spaces.GPU()
235
- def predict_on_gpu(self, image_input, gen_threshold, char_threshold, output_mode):
236
- print("--- predict_on_gpu function started (GPU worker) ---")
237
- # Ensure model is loaded on the *current* worker process/device
238
- if self.model is None or next(self.model.parameters()).device != self.device:
239
- # This check might be needed if workers don't share memory
240
- self._load_model_on_gpu()
241
-
242
- if self.model is None:
243
- return "Error: Model could not be loaded on GPU worker.", None
244
-
245
- if image_input is None: return "Please upload an image.", None
246
- print(f"(Worker) Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
247
-
248
- # Image Loading (same as before)
249
- if not isinstance(image_input, Image.Image):
250
- try:
251
- if isinstance(image_input, str) and image_input.startswith("http"): response = requests.get(image_input); response.raise_for_status(); image = Image.open(io.BytesIO(response.content))
252
- elif isinstance(image_input, str) and os.path.exists(image_input): image = Image.open(image_input)
253
- elif isinstance(image_input, np.ndarray): image = Image.fromarray(image_input)
254
- else: raise ValueError("Unsupported image input type")
255
- except Exception as e: print(f"(Worker) Error loading image: {e}"); return f"Error loading image: {e}", None
256
- else: image = image_input
257
-
258
- original_pil_image, input_tensor = preprocess_image(image)
259
- input_tensor = input_tensor.to(self.device)
260
-
261
- # Inference
262
- try:
263
- print("(Worker) Running inference...")
264
- start_time = time.time()
265
- with torch.no_grad(): outputs = self.model(input_tensor)
266
- inference_time = time.time() - start_time
267
- print(f"(Worker) Inference completed in {inference_time:.3f} seconds")
268
- probs = torch.sigmoid(outputs)[0].cpu().numpy()
269
- except Exception as e:
270
- print(f"(Worker) Error during PyTorch inference: {e}"); import traceback; print(traceback.format_exc()); return f"Error during inference: {e}", None
271
-
272
- # Post-processing (same as before)
273
- predictions = get_tags(probs, self.labels_data, gen_threshold, char_threshold)
274
- output_tags = []
275
- if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
276
- if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
277
- for category in ["artist", "character", "copyright", "general", "meta"]:
278
- tags = [tag.replace("_", " ") for tag, prob in predictions.get(category, []) if not (category == "meta" and any(p in tag.lower() for p in ['id', 'commentary','mismatch']))]
279
- output_tags.extend(tags)
280
- output_text = ", ".join(output_tags)
281
-
282
- if output_mode == "Tags Only": return output_text, None
283
- else: viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold); return output_text, viz_image
284
-
285
- # --- Gradio Interface Definition ---
286
- css = """
287
- .gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
288
- footer { display: none !important; }
289
- .gr-prose { max-width: 100% !important; }
290
- """
291
- js = """ /* Keep existing JS */ """
292
-
293
- # Instantiate the tagger class (this will download files/load labels)
294
- tagger = Tagger()
295
 
296
- with gr.Blocks(css=css, js=js) as demo:
297
- gr.Markdown("# WD EVA02 LoRA PyTorch Tagger")
298
- gr.Markdown("Upload an image or paste an image URL to predict tags using the fine-tuned WD EVA02 Tagger model (PyTorch/Safetensors).")
299
- gr.Markdown(f"Model Repository: [{REPO_ID}](https://huggingface.co/{REPO_ID})")
300
- with gr.Row():
301
- with gr.Column(scale=1):
302
- image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
303
- gr.HTML("<div id='url-input-container'></div>")
304
- gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General Tag Threshold")
305
- char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
306
- output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
307
- predict_button = gr.Button("Predict", variant="primary")
308
- with gr.Column(scale=1):
309
- output_tags = gr.Textbox(label="Predicted Tags", lines=10)
310
- output_visualization = gr.Image(type="pil", label="Prediction Visualization")
311
- gr.Examples(
312
- examples=[
313
- ["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", 0.55, 0.5, "Tags + Visualization"],
314
- ["https://pbs.twimg.com/media/GjlX0gibcAA4EJ4.jpg", 0.5, 0.5, "Tags Only"],
315
- ["https://pbs.twimg.com/media/Gj4nQbjbEAATeoH.jpg", 0.55, 0.5, "Tags + Visualization"],
316
- ["https://pbs.twimg.com/media/GkbtX0GaoAMlUZt.jpg", 0.45, 0.45, "Tags + Visualization"]
317
- ],
318
- inputs=[image_input, gen_threshold, char_threshold, output_mode],
319
- outputs=[output_tags, output_visualization],
320
- fn=tagger.predict_on_gpu, # Call class method
321
- cache_examples=False
322
- )
323
- predict_button.click(
324
- tagger.predict_on_gpu, # Call class method
325
- inputs=[image_input, gen_threshold, char_threshold, output_mode],
326
- outputs=[output_tags, output_visualization]
327
  )
328
 
329
  # --- Main Block ---
330
  if __name__ == "__main__":
331
  if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
332
- # Tagger instance is created above, no need to call anything here explicitly
333
  demo.launch(share=True)
 
 
1
  import gradio as gr
2
+ import spaces
 
 
 
 
 
 
 
 
 
 
3
  import time
4
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # --- Simple Test Function ---
7
+ @spaces.GPU()
8
+ def test_button_click():
9
+ current_time = time.time()
10
+ print(f"--- Test button clicked on GPU worker at {current_time} ---")
11
+ return f"Test button clicked at {current_time}"
12
+
13
+ # --- Gradio Interface Definition (Minimal) ---
14
+ with gr.Blocks() as demo:
15
+ gr.Markdown("""
16
+ # Minimal Button Test for ZeroGPU Environment
17
+ Click the button below to check if the `@spaces.GPU` decorated function is triggered.
18
+ """)
19
+ with gr.Column():
20
+ test_button = gr.Button("Test GPU Button")
21
+ output_text = gr.Textbox(label="Output")
22
+
23
+ test_button.click(
24
+ fn=test_button_click,
25
+ inputs=[],
26
+ outputs=[output_text]
 
 
 
 
 
 
 
 
 
 
27
  )
28
 
29
  # --- Main Block ---
30
  if __name__ == "__main__":
31
  if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
 
32
  demo.launch(share=True)