Spaces:
Running
on
Zero
Running
on
Zero
Upload app.py
Browse files
app.py
CHANGED
@@ -1,333 +1,32 @@
|
|
1 |
-
# --- Imports ---
|
2 |
import gradio as gr
|
3 |
-
import
|
4 |
-
from PIL import Image, ImageDraw, ImageFont
|
5 |
-
import json
|
6 |
-
import os
|
7 |
-
import io
|
8 |
-
import requests
|
9 |
-
import matplotlib.pyplot as plt
|
10 |
-
import matplotlib
|
11 |
-
from huggingface_hub import hf_hub_download
|
12 |
-
from dataclasses import dataclass
|
13 |
-
from typing import List, Dict, Optional, Tuple
|
14 |
import time
|
15 |
-
import
|
16 |
-
|
17 |
-
import torch
|
18 |
-
import timm
|
19 |
-
from safetensors.torch import load_file as safe_load_file
|
20 |
-
|
21 |
-
# MatplotlibのバックエンドをAggに設定
|
22 |
-
matplotlib.use('Agg')
|
23 |
-
|
24 |
-
# --- Data Classes and Helper Functions ---
|
25 |
-
@dataclass
|
26 |
-
class LabelData:
|
27 |
-
names: list[str]
|
28 |
-
rating: list[np.int64]
|
29 |
-
general: list[np.int64]
|
30 |
-
artist: list[np.int64]
|
31 |
-
character: list[np.int64]
|
32 |
-
copyright: list[np.int64]
|
33 |
-
meta: list[np.int64]
|
34 |
-
quality: list[np.int64]
|
35 |
-
|
36 |
-
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
|
37 |
-
if image.mode not in ["RGB", "RGBA"]:
|
38 |
-
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
|
39 |
-
if image.mode == "RGBA":
|
40 |
-
background = Image.new("RGB", image.size, (255, 255, 255))
|
41 |
-
background.paste(image, mask=image.split()[3])
|
42 |
-
image = background
|
43 |
-
return image
|
44 |
-
|
45 |
-
def pil_pad_square(image: Image.Image) -> Image.Image:
|
46 |
-
width, height = image.size
|
47 |
-
if width == height: return image
|
48 |
-
new_size = max(width, height)
|
49 |
-
new_image = Image.new("RGB", (new_size, new_size), (255, 255, 255))
|
50 |
-
paste_position = ((new_size - width) // 2, (new_size - height) // 2)
|
51 |
-
new_image.paste(image, paste_position)
|
52 |
-
return new_image
|
53 |
-
|
54 |
-
def load_tag_mapping(mapping_path):
|
55 |
-
with open(mapping_path, 'r', encoding='utf-8') as f: tag_mapping_data = json.load(f)
|
56 |
-
if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
|
57 |
-
idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
|
58 |
-
tag_to_category = tag_mapping_data["tag_to_category"]
|
59 |
-
elif isinstance(tag_mapping_data, dict):
|
60 |
-
tag_mapping_data = {int(k): v for k, v in tag_mapping_data.items()}
|
61 |
-
idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data.items()}
|
62 |
-
tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data.values()}
|
63 |
-
else: raise ValueError("Unsupported tag mapping format")
|
64 |
-
names = [None] * (max(idx_to_tag.keys()) + 1)
|
65 |
-
rating, general, artist, character, copyright, meta, quality = [], [], [], [], [], [], []
|
66 |
-
for idx, tag in idx_to_tag.items():
|
67 |
-
if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
|
68 |
-
names[idx] = tag
|
69 |
-
category = tag_to_category.get(tag, 'Unknown')
|
70 |
-
idx_int = int(idx)
|
71 |
-
if category == 'Rating': rating.append(idx_int)
|
72 |
-
elif category == 'General': general.append(idx_int)
|
73 |
-
elif category == 'Artist': artist.append(idx_int)
|
74 |
-
elif category == 'Character': character.append(idx_int)
|
75 |
-
elif category == 'Copyright': copyright.append(idx_int)
|
76 |
-
elif category == 'Meta': meta.append(idx_int)
|
77 |
-
elif category == 'Quality': quality.append(idx_int)
|
78 |
-
return LabelData(names=names, rating=np.array(rating), general=np.array(general), artist=np.array(artist),
|
79 |
-
character=np.array(character), copyright=np.array(copyright), meta=np.array(meta), quality=np.array(quality)), tag_to_category
|
80 |
-
|
81 |
-
def get_tags(probs, labels: LabelData, gen_threshold, char_threshold):
|
82 |
-
result = {"rating": [], "general": [], "character": [], "copyright": [], "artist": [], "meta": [], "quality": []}
|
83 |
-
if labels.rating.size > 0:
|
84 |
-
valid_indices = labels.rating[labels.rating < len(probs)]
|
85 |
-
if valid_indices.size > 0:
|
86 |
-
rating_probs = probs[valid_indices]
|
87 |
-
if rating_probs.size > 0:
|
88 |
-
rating_idx = np.argmax(rating_probs); original_idx = valid_indices[rating_idx]
|
89 |
-
if original_idx < len(labels.names): result["rating"].append((labels.names[original_idx], float(rating_probs[rating_idx])))
|
90 |
-
if labels.quality.size > 0:
|
91 |
-
valid_indices = labels.quality[labels.quality < len(probs)]
|
92 |
-
if valid_indices.size > 0:
|
93 |
-
quality_probs = probs[valid_indices]
|
94 |
-
if quality_probs.size > 0:
|
95 |
-
quality_idx = np.argmax(quality_probs); original_idx = valid_indices[quality_idx]
|
96 |
-
if original_idx < len(labels.names): result["quality"].append((labels.names[original_idx], float(quality_probs[quality_idx])))
|
97 |
-
category_map = {"general": (labels.general, gen_threshold), "character": (labels.character, char_threshold),
|
98 |
-
"copyright": (labels.copyright, char_threshold), "artist": (labels.artist, char_threshold), "meta": (labels.meta, gen_threshold)}
|
99 |
-
for category, (indices, threshold) in category_map.items():
|
100 |
-
if indices.size > 0:
|
101 |
-
valid_indices = indices[(indices < len(probs)) & (indices < len(labels.names))]
|
102 |
-
if valid_indices.size > 0:
|
103 |
-
category_probs = probs[valid_indices]; mask = category_probs >= threshold
|
104 |
-
selected_indices = valid_indices[mask]; selected_probs = category_probs[mask]
|
105 |
-
for idx, prob in zip(selected_indices, selected_probs): result[category].append((labels.names[idx], float(prob)))
|
106 |
-
for k in result: result[k] = sorted(result[k], key=lambda x: x[1], reverse=True)
|
107 |
-
return result
|
108 |
-
|
109 |
-
def visualize_predictions(image: Image.Image, predictions, threshold=0.45):
|
110 |
-
filtered_meta = [(tag, prob) for tag, prob in predictions.get("meta", []) if not any(p in tag.lower() for p in ['id', 'commentary', 'request', 'mismatch'])]
|
111 |
-
predictions["meta"] = filtered_meta
|
112 |
-
fig = plt.figure(figsize=(20, 12), dpi=100); gs = fig.add_gridspec(1, 2, width_ratios=[1.2, 1])
|
113 |
-
ax_img = fig.add_subplot(gs[0, 0]); ax_img.imshow(image); ax_img.set_title("Original Image"); ax_img.axis('off')
|
114 |
-
ax_tags = fig.add_subplot(gs[0, 1])
|
115 |
-
all_tags, all_probs, all_colors = [], [], []
|
116 |
-
color_map = {'rating': 'red', 'character': 'blue', 'copyright': 'purple', 'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow'}
|
117 |
-
for cat, prefix, color in [('rating', 'R', 'red'), ('character', 'C', 'blue'), ('copyright', '©', 'purple'), ('artist', 'A', 'orange'), ('general', 'G', 'green'), ('meta', 'M', 'gray'), ('quality', 'Q', 'yellow')]:
|
118 |
-
for tag, prob in predictions.get(cat, []): all_tags.append(f"[{prefix}] {tag}"); all_probs.append(prob); all_colors.append(color)
|
119 |
-
if not all_tags:
|
120 |
-
ax_tags.text(0.5, 0.5, "No tags found", ha='center', va='center'); ax_tags.set_title(f"Tags (threshold={threshold})"); ax_tags.axis('off')
|
121 |
-
else:
|
122 |
-
sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i]); all_tags = [all_tags[i] for i in sorted_indices]; all_probs = [all_probs[i] for i in sorted_indices]; all_colors = [all_colors[i] for i in sorted_indices]
|
123 |
-
num_tags = len(all_tags); bar_height = min(0.8, 0.8 * (30 / num_tags) if num_tags > 30 else 0.8); y_positions = np.arange(num_tags)
|
124 |
-
bars = ax_tags.barh(y_positions, all_probs, height=bar_height, color=all_colors); ax_tags.set_yticks(y_positions); ax_tags.set_yticklabels(all_tags)
|
125 |
-
fontsize = 10 if num_tags <= 40 else 8 if num_tags <= 60 else 6; [label.set_fontsize(fontsize) for label in ax_tags.get_yticklabels()]
|
126 |
-
for i, (bar, prob) in enumerate(zip(bars, all_probs)): ax_tags.text(min(prob + 0.02, 0.98), y_positions[i], f"{prob:.3f}", va='center', fontsize=fontsize)
|
127 |
-
ax_tags.set_xlim(0, 1); ax_tags.set_title(f"Tags (threshold={threshold})")
|
128 |
-
from matplotlib.patches import Patch; legend_elements = [Patch(facecolor=color, label=cat.capitalize()) for cat, color in color_map.items() if cat in predictions and predictions[cat]]; ax_tags.legend(handles=legend_elements, loc='lower right', fontsize=8)
|
129 |
-
plt.tight_layout(); plt.subplots_adjust(bottom=0.05); buf = io.BytesIO(); plt.savefig(buf, format='png', dpi=100); plt.close(fig); buf.seek(0)
|
130 |
-
return Image.open(buf)
|
131 |
-
|
132 |
-
def preprocess_image(image: Image.Image, target_size=(448, 448)):
|
133 |
-
image = pil_ensure_rgb(image)
|
134 |
-
image = pil_pad_square(image)
|
135 |
-
image_resized = image.resize(target_size, Image.BICUBIC)
|
136 |
-
img_array = np.array(image_resized, dtype=np.float32) / 255.0
|
137 |
-
img_array = img_array.transpose(2, 0, 1)
|
138 |
-
img_array = img_array[::-1, :, :] # BGR
|
139 |
-
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
|
140 |
-
std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
|
141 |
-
img_array = (img_array - mean) / std
|
142 |
-
img_tensor = torch.from_numpy(img_array).unsqueeze(0)
|
143 |
-
return image, img_tensor
|
144 |
-
|
145 |
-
# --- Constants ---
|
146 |
-
REPO_ID = "cella110n/cl_tagger"
|
147 |
-
SAFETENSORS_FILENAME = "lora_model_0426/checkpoint_epoch_4.safetensors"
|
148 |
-
METADATA_FILENAME = "lora_model_0426/checkpoint_epoch_4_metadata.json"
|
149 |
-
TAG_MAPPING_FILENAME = "lora_model_0426/tag_mapping.json"
|
150 |
-
CACHE_DIR = "./model_cache"
|
151 |
-
BASE_MODEL_NAME = 'eva02_large_patch14_448.mim_m38m_ft_in1k' # Define base model name
|
152 |
-
|
153 |
-
# --- Tagger Class ---
|
154 |
-
class Tagger:
|
155 |
-
def __init__(self):
|
156 |
-
print("Initializing Tagger...")
|
157 |
-
self.safetensors_path = None
|
158 |
-
self.metadata_path = None
|
159 |
-
self.tag_mapping_path = None
|
160 |
-
self.labels_data = None
|
161 |
-
self.tag_to_category = None
|
162 |
-
self.model = None # Model will be loaded on first predict call
|
163 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
164 |
-
self._initialize_paths_and_labels()
|
165 |
-
|
166 |
-
def _download_files(self):
|
167 |
-
if self.safetensors_path and self.tag_mapping_path and os.path.exists(self.safetensors_path) and os.path.exists(self.tag_mapping_path):
|
168 |
-
print("Files seem already downloaded.")
|
169 |
-
return
|
170 |
-
print("Downloading model files...")
|
171 |
-
hf_token = os.environ.get("HF_TOKEN")
|
172 |
-
try:
|
173 |
-
self.safetensors_path = hf_hub_download(repo_id=REPO_ID, filename=SAFETENSORS_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True)
|
174 |
-
self.tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True)
|
175 |
-
print(f"Safetensors: {self.safetensors_path}")
|
176 |
-
print(f"Tag mapping: {self.tag_mapping_path}")
|
177 |
-
try:
|
178 |
-
self.metadata_path = hf_hub_download(repo_id=REPO_ID, filename=METADATA_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True)
|
179 |
-
print(f"Metadata: {self.metadata_path}")
|
180 |
-
except Exception: print(f"Metadata ({METADATA_FILENAME}) not found/download failed."); self.metadata_path = None
|
181 |
-
except Exception as e:
|
182 |
-
print(f"Error downloading files: {e}")
|
183 |
-
if "401 Client Error" in str(e) or "Repository not found" in str(e): raise gr.Error(f"Could not download files from {REPO_ID}. Check HF_TOKEN.")
|
184 |
-
else: raise gr.Error(f"Error downloading files: {e}")
|
185 |
-
|
186 |
-
def _initialize_paths_and_labels(self):
|
187 |
-
if self.labels_data is None:
|
188 |
-
self._download_files()
|
189 |
-
print("Loading labels...")
|
190 |
-
if self.tag_mapping_path and os.path.exists(self.tag_mapping_path):
|
191 |
-
try:
|
192 |
-
self.labels_data, self.tag_to_category = load_tag_mapping(self.tag_mapping_path)
|
193 |
-
print(f"Labels loaded. Count: {len(self.labels_data.names)}")
|
194 |
-
except Exception as e: raise gr.Error(f"Error loading tag mapping: {e}")
|
195 |
-
else: raise gr.Error("Tag mapping file not found.")
|
196 |
-
|
197 |
-
def _load_model_on_gpu(self):
|
198 |
-
print("Loading PyTorch model for GPU worker...")
|
199 |
-
if not self.safetensors_path or not self.labels_data:
|
200 |
-
raise gr.Error("Model paths or labels not initialized before loading.")
|
201 |
-
try:
|
202 |
-
num_classes = len(self.labels_data.names)
|
203 |
-
if num_classes <= 0: raise ValueError(f"Invalid num_classes: {num_classes}")
|
204 |
-
print(f"Creating base model: {BASE_MODEL_NAME} with {num_classes} classes")
|
205 |
-
model = timm.create_model(BASE_MODEL_NAME, pretrained=True, num_classes=num_classes)
|
206 |
-
|
207 |
-
print(f"Loading state dict from: {self.safetensors_path}")
|
208 |
-
if not os.path.exists(self.safetensors_path): raise FileNotFoundError(f"File not found: {self.safetensors_path}")
|
209 |
-
state_dict = safe_load_file(self.safetensors_path)
|
210 |
-
# --- Key Adaptation Logic (Important!) ---
|
211 |
-
# This needs to match how the safetensors were saved by lora.py merge
|
212 |
-
# If the merge script saved the full model state_dict directly, no adaptation needed.
|
213 |
-
# If it saved only LoRA weights or prefixed keys, adaptation is required.
|
214 |
-
# Assuming direct match for now:
|
215 |
-
adapted_state_dict = state_dict
|
216 |
-
# Example if keys were prefixed with 'base_model.':
|
217 |
-
# adapted_state_dict = {k.replace('base_model.', ''): v for k, v in state_dict.items()}
|
218 |
-
# -----------------------------------------
|
219 |
-
missing_keys, unexpected_keys = model.load_state_dict(adapted_state_dict, strict=False)
|
220 |
-
print(f"State dict loaded. Missing: {missing_keys}")
|
221 |
-
print(f"State dict loaded. Unexpected: {unexpected_keys}")
|
222 |
-
if any(k.startswith('head.') for k in missing_keys): print("Warning: Head weights missing/mismatched!")
|
223 |
-
|
224 |
-
print(f"Moving model to device: {self.device}")
|
225 |
-
model.to(self.device)
|
226 |
-
model.eval()
|
227 |
-
self.model = model # Store loaded model
|
228 |
-
print("Model loaded successfully on GPU worker.")
|
229 |
-
except Exception as e:
|
230 |
-
print(f"(Worker) Error loading PyTorch model: {e}")
|
231 |
-
import traceback; print(traceback.format_exc())
|
232 |
-
raise gr.Error(f"Error loading PyTorch model: {e}")
|
233 |
-
|
234 |
-
@spaces.GPU()
|
235 |
-
def predict_on_gpu(self, image_input, gen_threshold, char_threshold, output_mode):
|
236 |
-
print("--- predict_on_gpu function started (GPU worker) ---")
|
237 |
-
# Ensure model is loaded on the *current* worker process/device
|
238 |
-
if self.model is None or next(self.model.parameters()).device != self.device:
|
239 |
-
# This check might be needed if workers don't share memory
|
240 |
-
self._load_model_on_gpu()
|
241 |
-
|
242 |
-
if self.model is None:
|
243 |
-
return "Error: Model could not be loaded on GPU worker.", None
|
244 |
-
|
245 |
-
if image_input is None: return "Please upload an image.", None
|
246 |
-
print(f"(Worker) Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
|
247 |
-
|
248 |
-
# Image Loading (same as before)
|
249 |
-
if not isinstance(image_input, Image.Image):
|
250 |
-
try:
|
251 |
-
if isinstance(image_input, str) and image_input.startswith("http"): response = requests.get(image_input); response.raise_for_status(); image = Image.open(io.BytesIO(response.content))
|
252 |
-
elif isinstance(image_input, str) and os.path.exists(image_input): image = Image.open(image_input)
|
253 |
-
elif isinstance(image_input, np.ndarray): image = Image.fromarray(image_input)
|
254 |
-
else: raise ValueError("Unsupported image input type")
|
255 |
-
except Exception as e: print(f"(Worker) Error loading image: {e}"); return f"Error loading image: {e}", None
|
256 |
-
else: image = image_input
|
257 |
-
|
258 |
-
original_pil_image, input_tensor = preprocess_image(image)
|
259 |
-
input_tensor = input_tensor.to(self.device)
|
260 |
-
|
261 |
-
# Inference
|
262 |
-
try:
|
263 |
-
print("(Worker) Running inference...")
|
264 |
-
start_time = time.time()
|
265 |
-
with torch.no_grad(): outputs = self.model(input_tensor)
|
266 |
-
inference_time = time.time() - start_time
|
267 |
-
print(f"(Worker) Inference completed in {inference_time:.3f} seconds")
|
268 |
-
probs = torch.sigmoid(outputs)[0].cpu().numpy()
|
269 |
-
except Exception as e:
|
270 |
-
print(f"(Worker) Error during PyTorch inference: {e}"); import traceback; print(traceback.format_exc()); return f"Error during inference: {e}", None
|
271 |
-
|
272 |
-
# Post-processing (same as before)
|
273 |
-
predictions = get_tags(probs, self.labels_data, gen_threshold, char_threshold)
|
274 |
-
output_tags = []
|
275 |
-
if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
|
276 |
-
if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
|
277 |
-
for category in ["artist", "character", "copyright", "general", "meta"]:
|
278 |
-
tags = [tag.replace("_", " ") for tag, prob in predictions.get(category, []) if not (category == "meta" and any(p in tag.lower() for p in ['id', 'commentary','mismatch']))]
|
279 |
-
output_tags.extend(tags)
|
280 |
-
output_text = ", ".join(output_tags)
|
281 |
-
|
282 |
-
if output_mode == "Tags Only": return output_text, None
|
283 |
-
else: viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold); return output_text, viz_image
|
284 |
-
|
285 |
-
# --- Gradio Interface Definition ---
|
286 |
-
css = """
|
287 |
-
.gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
|
288 |
-
footer { display: none !important; }
|
289 |
-
.gr-prose { max-width: 100% !important; }
|
290 |
-
"""
|
291 |
-
js = """ /* Keep existing JS */ """
|
292 |
-
|
293 |
-
# Instantiate the tagger class (this will download files/load labels)
|
294 |
-
tagger = Tagger()
|
295 |
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
],
|
318 |
-
inputs=[image_input, gen_threshold, char_threshold, output_mode],
|
319 |
-
outputs=[output_tags, output_visualization],
|
320 |
-
fn=tagger.predict_on_gpu, # Call class method
|
321 |
-
cache_examples=False
|
322 |
-
)
|
323 |
-
predict_button.click(
|
324 |
-
tagger.predict_on_gpu, # Call class method
|
325 |
-
inputs=[image_input, gen_threshold, char_threshold, output_mode],
|
326 |
-
outputs=[output_tags, output_visualization]
|
327 |
)
|
328 |
|
329 |
# --- Main Block ---
|
330 |
if __name__ == "__main__":
|
331 |
if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
|
332 |
-
# Tagger instance is created above, no need to call anything here explicitly
|
333 |
demo.launch(share=True)
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import time
|
4 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
# --- Simple Test Function ---
|
7 |
+
@spaces.GPU()
|
8 |
+
def test_button_click():
|
9 |
+
current_time = time.time()
|
10 |
+
print(f"--- Test button clicked on GPU worker at {current_time} ---")
|
11 |
+
return f"Test button clicked at {current_time}"
|
12 |
+
|
13 |
+
# --- Gradio Interface Definition (Minimal) ---
|
14 |
+
with gr.Blocks() as demo:
|
15 |
+
gr.Markdown("""
|
16 |
+
# Minimal Button Test for ZeroGPU Environment
|
17 |
+
Click the button below to check if the `@spaces.GPU` decorated function is triggered.
|
18 |
+
""")
|
19 |
+
with gr.Column():
|
20 |
+
test_button = gr.Button("Test GPU Button")
|
21 |
+
output_text = gr.Textbox(label="Output")
|
22 |
+
|
23 |
+
test_button.click(
|
24 |
+
fn=test_button_click,
|
25 |
+
inputs=[],
|
26 |
+
outputs=[output_text]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
)
|
28 |
|
29 |
# --- Main Block ---
|
30 |
if __name__ == "__main__":
|
31 |
if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
|
|
|
32 |
demo.launch(share=True)
|