Spaces:
Running
on
Zero
Running
on
Zero
Upload app.py
Browse files
app.py
CHANGED
@@ -1,27 +1,192 @@
|
|
1 |
import gradio as gr
|
2 |
-
import
|
3 |
-
import
|
|
|
|
|
4 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# --- Gradio Interface Definition (Minimal) ---
|
14 |
with gr.Blocks() as demo:
|
15 |
gr.Markdown("""
|
16 |
-
# Minimal Button Test
|
17 |
-
|
|
|
18 |
""")
|
19 |
with gr.Column():
|
20 |
-
test_button = gr.Button("Test GPU
|
21 |
output_text = gr.Textbox(label="Output")
|
22 |
|
23 |
test_button.click(
|
24 |
-
fn=
|
25 |
inputs=[],
|
26 |
outputs=[output_text]
|
27 |
)
|
@@ -29,4 +194,43 @@ with gr.Blocks() as demo:
|
|
29 |
# --- Main Block ---
|
30 |
if __name__ == "__main__":
|
31 |
if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
|
|
|
32 |
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image # Keep PIL for now, might be needed by helpers implicitly
|
4 |
+
# from PIL import Image, ImageDraw, ImageFont # No drawing yet
|
5 |
+
import json
|
6 |
import os
|
7 |
+
import io
|
8 |
+
import requests
|
9 |
+
# import matplotlib.pyplot as plt # No plotting yet
|
10 |
+
# import matplotlib # No plotting yet
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from typing import List, Dict, Optional, Tuple
|
14 |
+
import time
|
15 |
+
import spaces # Required for @spaces.GPU
|
16 |
+
|
17 |
+
import torch # Keep torch for device check in Tagger
|
18 |
+
# import timm # No model loading yet
|
19 |
+
# from safetensors.torch import load_file as safe_load_file # No model loading yet
|
20 |
+
|
21 |
+
# MatplotlibのバックエンドをAggに設定 (Keep commented out for now)
|
22 |
+
# matplotlib.use('Agg')
|
23 |
+
|
24 |
+
# --- Data Classes and Helper Functions ---
|
25 |
+
@dataclass
|
26 |
+
class LabelData:
|
27 |
+
names: list[str]
|
28 |
+
rating: list[np.int64]
|
29 |
+
general: list[np.int64]
|
30 |
+
artist: list[np.int64]
|
31 |
+
character: list[np.int64]
|
32 |
+
copyright: list[np.int64]
|
33 |
+
meta: list[np.int64]
|
34 |
+
quality: list[np.int64]
|
35 |
+
|
36 |
+
# Keep helpers needed for initialization
|
37 |
+
def load_tag_mapping(mapping_path):
|
38 |
+
with open(mapping_path, 'r', encoding='utf-8') as f: tag_mapping_data = json.load(f)
|
39 |
+
if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
|
40 |
+
idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
|
41 |
+
tag_to_category = tag_mapping_data["tag_to_category"]
|
42 |
+
elif isinstance(tag_mapping_data, dict):
|
43 |
+
tag_mapping_data = {int(k): v for k, v in tag_mapping_data.items()}
|
44 |
+
idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data.items()}
|
45 |
+
tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data.values()}
|
46 |
+
else: raise ValueError("Unsupported tag mapping format")
|
47 |
+
names = [None] * (max(idx_to_tag.keys()) + 1)
|
48 |
+
rating, general, artist, character, copyright, meta, quality = [], [], [], [], [], [], []
|
49 |
+
for idx, tag in idx_to_tag.items():
|
50 |
+
if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
|
51 |
+
names[idx] = tag
|
52 |
+
category = tag_to_category.get(tag, 'Unknown')
|
53 |
+
idx_int = int(idx)
|
54 |
+
if category == 'Rating': rating.append(idx_int)
|
55 |
+
elif category == 'General': general.append(idx_int)
|
56 |
+
elif category == 'Artist': artist.append(idx_int)
|
57 |
+
elif category == 'Character': character.append(idx_int)
|
58 |
+
elif category == 'Copyright': copyright.append(idx_int)
|
59 |
+
elif category == 'Meta': meta.append(idx_int)
|
60 |
+
elif category == 'Quality': quality.append(idx_int)
|
61 |
+
return LabelData(names=names, rating=np.array(rating), general=np.array(general), artist=np.array(artist),
|
62 |
+
character=np.array(character), copyright=np.array(copyright), meta=np.array(meta), quality=np.array(quality)), tag_to_category
|
63 |
+
|
64 |
+
# --- Constants ---
|
65 |
+
REPO_ID = "cella110n/cl_tagger"
|
66 |
+
SAFETENSORS_FILENAME = "lora_model_0426/checkpoint_epoch_4.safetensors"
|
67 |
+
METADATA_FILENAME = "lora_model_0426/checkpoint_epoch_4_metadata.json"
|
68 |
+
TAG_MAPPING_FILENAME = "lora_model_0426/tag_mapping.json"
|
69 |
+
CACHE_DIR = "./model_cache"
|
70 |
+
# BASE_MODEL_NAME = 'eva02_large_patch14_448.mim_m38m_ft_in1k' # No model loading yet
|
71 |
+
|
72 |
+
# --- Tagger Class ---
|
73 |
+
class Tagger:
|
74 |
+
def __init__(self):
|
75 |
+
print("Initializing Tagger...")
|
76 |
+
self.safetensors_path = None
|
77 |
+
self.metadata_path = None
|
78 |
+
self.tag_mapping_path = None
|
79 |
+
self.labels_data = None
|
80 |
+
self.tag_to_category = None
|
81 |
+
self.model = None # Model will be loaded later
|
82 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
83 |
+
self._initialize_paths_and_labels()
|
84 |
+
print("Tagger Initialized.") # Add confirmation
|
85 |
+
|
86 |
+
def _download_files(self):
|
87 |
+
# Check if paths are already set and files exist (useful for restarts)
|
88 |
+
local_safetensors = os.path.join(CACHE_DIR, 'models--cella110n--cl_tagger', 'snapshots', '21e237f0ae461b8d9ebf7472ae8de003e5effe5b', SAFETENSORS_FILENAME)
|
89 |
+
local_tag_mapping = os.path.join(CACHE_DIR, 'models--cella110n--cl_tagger', 'snapshots', '21e237f0ae461b8d9ebf7472ae8de003e5effe5b', TAG_MAPPING_FILENAME)
|
90 |
+
local_metadata = os.path.join(CACHE_DIR, 'models--cella110n--cl_tagger', 'snapshots', '21e237f0ae461b8d9ebf7472ae8de003e5effe5b', METADATA_FILENAME)
|
91 |
+
|
92 |
+
needs_download = False
|
93 |
+
if not (self.safetensors_path and os.path.exists(self.safetensors_path)):
|
94 |
+
if os.path.exists(local_safetensors):
|
95 |
+
self.safetensors_path = local_safetensors
|
96 |
+
print(f"Found existing safetensors: {self.safetensors_path}")
|
97 |
+
else:
|
98 |
+
needs_download = True
|
99 |
+
if not (self.tag_mapping_path and os.path.exists(self.tag_mapping_path)):
|
100 |
+
if os.path.exists(local_tag_mapping):
|
101 |
+
self.tag_mapping_path = local_tag_mapping
|
102 |
+
print(f"Found existing tag mapping: {self.tag_mapping_path}")
|
103 |
+
else:
|
104 |
+
needs_download = True
|
105 |
+
# Metadata is optional, check similarly
|
106 |
+
if not (self.metadata_path and os.path.exists(self.metadata_path)):
|
107 |
+
if os.path.exists(local_metadata):
|
108 |
+
self.metadata_path = local_metadata
|
109 |
+
print(f"Found existing metadata: {self.metadata_path}")
|
110 |
+
# Don't trigger download just for metadata if others exist
|
111 |
+
|
112 |
+
if not needs_download and self.safetensors_path and self.tag_mapping_path:
|
113 |
+
print("Required files already exist or paths set.")
|
114 |
+
return
|
115 |
|
116 |
+
print("Downloading model files...")
|
117 |
+
hf_token = os.environ.get("HF_TOKEN")
|
118 |
+
try:
|
119 |
+
# Only download if not found locally
|
120 |
+
if not self.safetensors_path:
|
121 |
+
self.safetensors_path = hf_hub_download(repo_id=REPO_ID, filename=SAFETENSORS_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False) # Use force_download=False
|
122 |
+
if not self.tag_mapping_path:
|
123 |
+
self.tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
|
124 |
+
print(f"Safetensors: {self.safetensors_path}")
|
125 |
+
print(f"Tag mapping: {self.tag_mapping_path}")
|
126 |
+
try:
|
127 |
+
# Only download if not found locally
|
128 |
+
if not self.metadata_path:
|
129 |
+
self.metadata_path = hf_hub_download(repo_id=REPO_ID, filename=METADATA_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
|
130 |
+
print(f"Metadata: {self.metadata_path}")
|
131 |
+
except Exception as e_meta:
|
132 |
+
# Handle case where metadata genuinely doesn't exist or download fails
|
133 |
+
print(f"Metadata ({METADATA_FILENAME}) not found/download failed. Error: {e_meta}")
|
134 |
+
self.metadata_path = None
|
135 |
+
|
136 |
+
except Exception as e:
|
137 |
+
print(f"Error downloading files: {e}")
|
138 |
+
if "401 Client Error" in str(e) or "Repository not found" in str(e): raise gr.Error(f"Could not download files from {REPO_ID}. Check HF_TOKEN or repository status.")
|
139 |
+
else: raise gr.Error(f"Error downloading files: {e}")
|
140 |
+
|
141 |
+
def _initialize_paths_and_labels(self):
|
142 |
+
# Call download first (it now checks existence)
|
143 |
+
self._download_files()
|
144 |
+
# Only load labels if not already loaded
|
145 |
+
if self.labels_data is None:
|
146 |
+
print("Loading labels...")
|
147 |
+
if self.tag_mapping_path and os.path.exists(self.tag_mapping_path):
|
148 |
+
try:
|
149 |
+
self.labels_data, self.tag_to_category = load_tag_mapping(self.tag_mapping_path)
|
150 |
+
print(f"Labels loaded. Count: {len(self.labels_data.names)}")
|
151 |
+
except Exception as e: raise gr.Error(f"Error loading tag mapping: {e}")
|
152 |
+
else:
|
153 |
+
# This should ideally not happen if download worked
|
154 |
+
raise gr.Error(f"Tag mapping file not found at expected path: {self.tag_mapping_path}")
|
155 |
+
else:
|
156 |
+
print("Labels already loaded.")
|
157 |
+
|
158 |
+
# Add a simple test method decorated with GPU
|
159 |
+
@spaces.GPU()
|
160 |
+
def test_gpu_method(self):
|
161 |
+
current_time = time.time()
|
162 |
+
print(f"--- Tagger.test_gpu_method called on GPU worker at {current_time} ---")
|
163 |
+
# Check if labels are accessible from the GPU worker context
|
164 |
+
label_count = len(self.labels_data.names) if self.labels_data else -1
|
165 |
+
print(f"--- (Worker) Label count: {label_count} ---")
|
166 |
+
return f"Tagger method called at {current_time}. Label count: {label_count}"
|
167 |
+
|
168 |
+
# --- Original predict_on_gpu (Keep commented out for this test) ---
|
169 |
+
# @spaces.GPU()
|
170 |
+
# def predict_on_gpu(self, image_input, gen_threshold, char_threshold, output_mode):
|
171 |
+
# # ... (original prediction logic including model loading) ...
|
172 |
+
# pass
|
173 |
+
|
174 |
+
# Instantiate the tagger class (this will download files/load labels)
|
175 |
+
tagger = Tagger()
|
176 |
|
177 |
# --- Gradio Interface Definition (Minimal) ---
|
178 |
with gr.Blocks() as demo:
|
179 |
gr.Markdown("""
|
180 |
+
# Tagger Initialization + Minimal Button Test
|
181 |
+
Instantiates Tagger, then click the button below to check if a simple `@spaces.GPU` decorated *method* is triggered.
|
182 |
+
Check logs for Tagger initialization messages.
|
183 |
""")
|
184 |
with gr.Column():
|
185 |
+
test_button = gr.Button("Test Tagger GPU Method")
|
186 |
output_text = gr.Textbox(label="Output")
|
187 |
|
188 |
test_button.click(
|
189 |
+
fn=tagger.test_gpu_method, # Call the simple method on the instance
|
190 |
inputs=[],
|
191 |
outputs=[output_text]
|
192 |
)
|
|
|
194 |
# --- Main Block ---
|
195 |
if __name__ == "__main__":
|
196 |
if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
|
197 |
+
# Tagger instance is created above
|
198 |
demo.launch(share=True)
|
199 |
+
|
200 |
+
# --- Commented out original UI and helpers/constants not needed for init/simple test ---
|
201 |
+
"""
|
202 |
+
# import matplotlib.pyplot as plt
|
203 |
+
# import matplotlib
|
204 |
+
# matplotlib.use('Agg')
|
205 |
+
# from PIL import Image, ImageDraw, ImageFont
|
206 |
+
# import timm
|
207 |
+
# from safetensors.torch import load_file as safe_load_file
|
208 |
+
|
209 |
+
# def pil_ensure_rgb(...)
|
210 |
+
# def pil_pad_square(...)
|
211 |
+
# def get_tags(...)
|
212 |
+
# def visualize_predictions(...)
|
213 |
+
# def preprocess_image(...)
|
214 |
+
|
215 |
+
# BASE_MODEL_NAME = 'eva02_large_patch14_448.mim_m38m_ft_in1k'
|
216 |
+
|
217 |
+
# class Tagger:
|
218 |
+
# ... (methods related to prediction, model loading) ...
|
219 |
+
# def _load_model_on_gpu(self):
|
220 |
+
# ...
|
221 |
+
# @spaces.GPU()
|
222 |
+
# def predict_on_gpu(self, image_input, gen_threshold, char_threshold, output_mode):
|
223 |
+
# ...
|
224 |
+
|
225 |
+
# --- Original Full Gradio UI ---
|
226 |
+
# css = ...
|
227 |
+
# js = ...
|
228 |
+
# with gr.Blocks(css=css, js=js) as demo:
|
229 |
+
# gr.Markdown("# WD EVA02 LoRA PyTorch Tagger")
|
230 |
+
# ...
|
231 |
+
# predict_button.click(
|
232 |
+
# fn=tagger.predict_on_gpu,
|
233 |
+
# ...
|
234 |
+
# )
|
235 |
+
# gr.Examples(...)
|
236 |
+
"""
|