Upload app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
import onnxruntime as ort
|
4 |
import numpy as np
|
5 |
from PIL import Image, ImageDraw, ImageFont
|
6 |
import json
|
@@ -13,7 +12,10 @@ from huggingface_hub import hf_hub_download
|
|
13 |
from dataclasses import dataclass
|
14 |
from typing import List, Dict, Optional, Tuple
|
15 |
import time
|
16 |
-
|
|
|
|
|
|
|
17 |
|
18 |
# MatplotlibのバックエンドをAggに設定 (GUIなし環境用)
|
19 |
matplotlib.use('Agg')
|
@@ -289,169 +291,165 @@ def visualize_predictions(image: Image.Image, predictions, threshold=0.45):
|
|
289 |
|
290 |
# 定数
|
291 |
REPO_ID = "cella110n/cl_tagger"
|
292 |
-
|
293 |
-
|
294 |
-
TAG_MAPPING_FILENAME = "
|
295 |
CACHE_DIR = "./model_cache"
|
296 |
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
labels_data = None
|
301 |
tag_to_category_map = None
|
302 |
|
303 |
def download_model_files():
|
304 |
-
"""Hugging Face Hub
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
print("Downloading model files...")
|
306 |
-
# 環境変数からHFトークンを取得 (プライベートリポジトリ用)
|
307 |
hf_token = os.environ.get("HF_TOKEN")
|
308 |
try:
|
309 |
-
|
310 |
-
|
311 |
-
print(f"
|
312 |
-
print(f"Tag mapping downloaded to: {
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
314 |
except Exception as e:
|
315 |
print(f"Error downloading files: {e}")
|
316 |
-
# トークンがない場合のエラーメッセージを改善
|
317 |
if "401 Client Error" in str(e) or "Repository not found" in str(e):
|
318 |
-
raise gr.Error(f"Could not download files from {REPO_ID}. "
|
319 |
-
f"If this is a private repository, make sure to set the HF_TOKEN secret in your Space settings.")
|
320 |
else:
|
321 |
raise gr.Error(f"Error downloading files: {e}")
|
322 |
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
global model_path_global, labels_data, tag_to_category_map
|
327 |
-
# Only initialize once
|
328 |
if labels_data is None:
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
336 |
|
337 |
@spaces.GPU()
|
338 |
def predict(image_input, gen_threshold, char_threshold, output_mode):
|
339 |
print("--- predict function started (GPU worker) ---")
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
# Attempt initialization again if model path is missing (e.g., after restart)
|
348 |
-
initialize_model()
|
349 |
-
if model_path_global is None:
|
350 |
-
return "Error: Model path could not be initialized.", None
|
351 |
-
|
352 |
-
available_providers = ort.get_available_providers()
|
353 |
-
print(f"(Worker) Available ONNX Runtime providers: {available_providers}")
|
354 |
-
providers = []
|
355 |
-
if 'CUDAExecutionProvider' in available_providers:
|
356 |
-
providers.append('CUDAExecutionProvider')
|
357 |
-
providers.append('CPUExecutionProvider') # Always include CPU as fallback
|
358 |
-
|
359 |
try:
|
360 |
-
|
361 |
-
|
362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
except Exception as e:
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
session = ort.InferenceSession(model_path_global, providers=['CPUExecutionProvider'])
|
369 |
-
except Exception as e_cpu:
|
370 |
-
print(f"(Worker) Error initializing ONNX session even with CPU: {e_cpu}")
|
371 |
-
return f"Error initializing ONNX session: {e_cpu}", None
|
372 |
-
# --- Session created ---
|
373 |
|
374 |
if image_input is None:
|
375 |
return "Please upload an image.", None
|
376 |
-
|
377 |
print(f"(Worker) Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
|
378 |
-
|
379 |
-
# PIL Imageオブジェクトであることを確認
|
380 |
if not isinstance(image_input, Image.Image):
|
381 |
try:
|
382 |
-
# URLの場合
|
383 |
if isinstance(image_input, str) and image_input.startswith("http"):
|
384 |
-
response = requests.get(image_input)
|
385 |
-
response.raise_for_status()
|
386 |
image = Image.open(io.BytesIO(response.content))
|
387 |
-
# ファイルパスの場合 (Gradioでは通常発生しないが念のため)
|
388 |
elif isinstance(image_input, str) and os.path.exists(image_input):
|
389 |
image = Image.open(image_input)
|
390 |
-
# Numpy配列の場合 (Gradio Imageコンポーネントからの入力)
|
391 |
elif isinstance(image_input, np.ndarray):
|
392 |
image = Image.fromarray(image_input)
|
393 |
-
else:
|
394 |
-
raise ValueError("Unsupported image input type")
|
395 |
except Exception as e:
|
396 |
-
print(f"(Worker) Error loading image: {e}")
|
397 |
-
|
398 |
-
else:
|
399 |
-
image = image_input
|
400 |
-
|
401 |
-
# 前処理
|
402 |
-
original_pil_image, input_data = preprocess_image(image)
|
403 |
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
inference_time = time.time() - start_time
|
416 |
-
print(f"(Worker) Inference completed in {inference_time:.3f} seconds")
|
417 |
-
|
418 |
-
# シグモイド関数で確率に変換
|
419 |
-
probs = 1 / (1 + np.exp(-outputs[0])) # Apply sigmoid to the first batch item
|
420 |
|
421 |
-
# タグ取得
|
422 |
predictions = get_tags(probs, labels_data, gen_threshold, char_threshold)
|
423 |
-
|
424 |
-
# タグを整形
|
425 |
output_tags = []
|
426 |
-
|
427 |
-
if predictions["
|
428 |
-
output_tags.append(predictions["rating"][0][0].replace("_", " "))
|
429 |
-
if predictions["quality"]:
|
430 |
-
output_tags.append(predictions["quality"][0][0].replace("_", " "))
|
431 |
-
|
432 |
-
# 残りのカテゴリをアルファベット順に追加(オプション)
|
433 |
for category in ["artist", "character", "copyright", "general", "meta"]:
|
434 |
-
tags = [tag.replace("_", " ") for tag, prob in predictions[
|
435 |
-
|
436 |
output_tags.extend(tags)
|
437 |
-
|
438 |
output_text = ", ".join(output_tags)
|
439 |
|
440 |
-
if output_mode == "Tags Only":
|
441 |
-
|
442 |
-
else: # Visualization
|
443 |
-
viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold)
|
444 |
-
return output_text, viz_image
|
445 |
|
446 |
# --- Gradio Interface Definition ---
|
447 |
-
|
448 |
-
# CSS for styling
|
449 |
css = """
|
450 |
.gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
|
451 |
footer { display: none !important; }
|
452 |
.gr-prose { max-width: 100% !important; }
|
453 |
"""
|
454 |
-
# Custom JS for image pasting and URL handling
|
455 |
js = """
|
456 |
async function paste_image(blob, gen_thresh, char_thresh, out_mode) {
|
457 |
const data = await fetch(blob)
|
@@ -530,17 +528,14 @@ document.addEventListener('paste', paste_update);
|
|
530 |
"""
|
531 |
|
532 |
with gr.Blocks(css=css, js=js) as demo:
|
533 |
-
gr.Markdown("# WD EVA02 LoRA
|
534 |
-
gr.Markdown("Upload an image or paste an image URL to predict tags using the fine-tuned WD EVA02 Tagger model (
|
535 |
gr.Markdown(f"Model Repository: [{REPO_ID}](https://huggingface.co/{REPO_ID})")
|
536 |
|
537 |
with gr.Row():
|
538 |
with gr.Column(scale=1):
|
539 |
-
# Use elem_id for JS targeting
|
540 |
image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
|
541 |
-
# Container for URL paste message
|
542 |
gr.HTML("<div id='url-input-container'></div>")
|
543 |
-
|
544 |
gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General Tag Threshold")
|
545 |
char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
|
546 |
output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
|
@@ -550,7 +545,6 @@ with gr.Blocks(css=css, js=js) as demo:
|
|
550 |
output_tags = gr.Textbox(label="Predicted Tags", lines=10)
|
551 |
output_visualization = gr.Image(type="pil", label="Prediction Visualization")
|
552 |
|
553 |
-
# Examples
|
554 |
gr.Examples(
|
555 |
examples=[
|
556 |
["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", 0.55, 0.5, "Tags + Visualization"],
|
@@ -561,7 +555,7 @@ with gr.Blocks(css=css, js=js) as demo:
|
|
561 |
inputs=[image_input, gen_threshold, char_threshold, output_mode],
|
562 |
outputs=[output_tags, output_visualization],
|
563 |
fn=predict,
|
564 |
-
cache_examples=False
|
565 |
)
|
566 |
|
567 |
predict_button.click(
|
@@ -570,18 +564,7 @@ with gr.Blocks(css=css, js=js) as demo:
|
|
570 |
outputs=[output_tags, output_visualization]
|
571 |
)
|
572 |
|
573 |
-
# Add listener for image input changes (e.g., from pasting)
|
574 |
-
# This might trigger prediction automatically or require the button click
|
575 |
-
# image_input.change(
|
576 |
-
# fn=predict,
|
577 |
-
# inputs=[image_input, gen_threshold, char_threshold, output_mode],
|
578 |
-
# outputs=[output_tags, output_visualization]
|
579 |
-
# )
|
580 |
-
|
581 |
-
|
582 |
if __name__ == "__main__":
|
583 |
-
# 環境変数HF_TOKENがない場合に警告(プライベートリポジトリ用)
|
584 |
if not os.environ.get("HF_TOKEN"):
|
585 |
-
print("Warning: HF_TOKEN environment variable not set.
|
586 |
-
# initialize_model() # Removed startup initialization (model loaded in predict)
|
587 |
demo.launch(share=True)
|
|
|
1 |
import gradio as gr
|
2 |
+
import spaces
|
|
|
3 |
import numpy as np
|
4 |
from PIL import Image, ImageDraw, ImageFont
|
5 |
import json
|
|
|
12 |
from dataclasses import dataclass
|
13 |
from typing import List, Dict, Optional, Tuple
|
14 |
import time
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import timm
|
18 |
+
from safetensors.torch import load_file as safe_load_file
|
19 |
|
20 |
# MatplotlibのバックエンドをAggに設定 (GUIなし環境用)
|
21 |
matplotlib.use('Agg')
|
|
|
291 |
|
292 |
# 定数
|
293 |
REPO_ID = "cella110n/cl_tagger"
|
294 |
+
SAFETENSORS_FILENAME = "lora_model_0426/checkpoint_epoch_4.safetensors"
|
295 |
+
METADATA_FILENAME = "lora_model_0426/checkpoint_epoch_4_metadata.json"
|
296 |
+
TAG_MAPPING_FILENAME = "lora_model_0426/tag_mapping.json"
|
297 |
CACHE_DIR = "./model_cache"
|
298 |
|
299 |
+
safetensors_path_global = None
|
300 |
+
metadata_path_global = None
|
301 |
+
tag_mapping_path_global = None
|
302 |
labels_data = None
|
303 |
tag_to_category_map = None
|
304 |
|
305 |
def download_model_files():
|
306 |
+
"""Hugging Face Hubからモデル、メタデータ、タグマッピングをダウンロード"""
|
307 |
+
global safetensors_path_global, metadata_path_global, tag_mapping_path_global
|
308 |
+
# Check if files seem to be downloaded already
|
309 |
+
if safetensors_path_global and tag_mapping_path_global and os.path.exists(safetensors_path_global) and os.path.exists(tag_mapping_path_global):
|
310 |
+
print("Files seem already downloaded.")
|
311 |
+
return
|
312 |
+
|
313 |
print("Downloading model files...")
|
|
|
314 |
hf_token = os.environ.get("HF_TOKEN")
|
315 |
try:
|
316 |
+
safetensors_path_global = hf_hub_download(repo_id=REPO_ID, filename=SAFETENSORS_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True) # Force download to ensure latest
|
317 |
+
tag_mapping_path_global = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True)
|
318 |
+
print(f"Safetensors downloaded to: {safetensors_path_global}")
|
319 |
+
print(f"Tag mapping downloaded to: {tag_mapping_path_global}")
|
320 |
+
try:
|
321 |
+
metadata_path_global = hf_hub_download(repo_id=REPO_ID, filename=METADATA_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True)
|
322 |
+
print(f"Metadata downloaded to: {metadata_path_global}")
|
323 |
+
except Exception:
|
324 |
+
print(f"Metadata file ({METADATA_FILENAME}) not found or download failed. Proceeding without it.")
|
325 |
+
metadata_path_global = None
|
326 |
except Exception as e:
|
327 |
print(f"Error downloading files: {e}")
|
|
|
328 |
if "401 Client Error" in str(e) or "Repository not found" in str(e):
|
329 |
+
raise gr.Error(f"Could not download files from {REPO_ID}. Check HF_TOKEN secret.")
|
|
|
330 |
else:
|
331 |
raise gr.Error(f"Error downloading files: {e}")
|
332 |
|
333 |
+
def initialize_labels_and_paths():
|
334 |
+
"""ラベルデータとファイルパスを準備(キャッシュ)"""
|
335 |
+
global labels_data, tag_to_category_map, tag_mapping_path_global
|
|
|
|
|
336 |
if labels_data is None:
|
337 |
+
download_model_files() # Ensure files are downloaded
|
338 |
+
print("Loading labels from tag_mapping.json...")
|
339 |
+
if tag_mapping_path_global and os.path.exists(tag_mapping_path_global):
|
340 |
+
try:
|
341 |
+
labels_data, _, tag_to_category_map = load_tag_mapping(tag_mapping_path_global)
|
342 |
+
print(f"Labels loaded successfully. Number of labels: {len(labels_data.names)}")
|
343 |
+
except Exception as e:
|
344 |
+
print(f"Error loading tag mapping from {tag_mapping_path_global}: {e}")
|
345 |
+
raise gr.Error(f"Error loading tag mapping file: {e}")
|
346 |
+
else:
|
347 |
+
print(f"Tag mapping file not found after download attempt: {tag_mapping_path_global}")
|
348 |
+
raise gr.Error("Tag mapping file could not be downloaded or found.")
|
349 |
|
350 |
@spaces.GPU()
|
351 |
def predict(image_input, gen_threshold, char_threshold, output_mode):
|
352 |
print("--- predict function started (GPU worker) ---")
|
353 |
+
initialize_labels_and_paths()
|
354 |
+
print("Loading PyTorch model...")
|
355 |
+
global safetensors_path_global, labels_data
|
356 |
+
if safetensors_path_global is None or labels_data is None:
|
357 |
+
initialize_labels_and_paths()
|
358 |
+
if safetensors_path_global is None or labels_data is None:
|
359 |
+
return "Error: Model/Labels paths could not be initialized.", None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
try:
|
361 |
+
print(f"Creating base model: eva02_large_patch14_448.mim_m38m_ft_in1k")
|
362 |
+
num_classes = len(labels_data.names)
|
363 |
+
# Validate num_classes (should be > 0)
|
364 |
+
if num_classes <= 0:
|
365 |
+
raise ValueError(f"Invalid number of classes loaded from tag mapping: {num_classes}")
|
366 |
+
print(f"Setting num_classes: {num_classes}")
|
367 |
+
model = timm.create_model(
|
368 |
+
'eva02_large_patch14_448.mim_m38m_ft_in1k',
|
369 |
+
pretrained=True,
|
370 |
+
num_classes=num_classes
|
371 |
+
)
|
372 |
+
print(f"Loading state dict from: {safetensors_path_global}")
|
373 |
+
if not os.path.exists(safetensors_path_global):
|
374 |
+
raise FileNotFoundError(f"Safetensors file not found at: {safetensors_path_global}")
|
375 |
+
state_dict = safe_load_file(safetensors_path_global)
|
376 |
+
adapted_state_dict = {}
|
377 |
+
for k, v in state_dict.items():
|
378 |
+
# Adjust key names if needed based on how lora.py saved the merged model
|
379 |
+
# Example: If saved with 'base_model.' prefix
|
380 |
+
# if k.startswith('base_model.'):
|
381 |
+
# adapted_state_dict[k[len('base_model.'):]] = v
|
382 |
+
# else:
|
383 |
+
adapted_state_dict[k] = v # Assuming direct key match for now
|
384 |
+
|
385 |
+
missing_keys, unexpected_keys = model.load_state_dict(adapted_state_dict, strict=False)
|
386 |
+
print(f"State dict loaded. Missing keys: {missing_keys}")
|
387 |
+
print(f"State dict loaded. Unexpected keys: {unexpected_keys}")
|
388 |
+
# Handle critical missing keys (like the head) if necessary
|
389 |
+
if any(k.startswith('head.') for k in missing_keys):
|
390 |
+
print("Warning: Classification head weights might be missing or mismatched!")
|
391 |
+
# if unexpected_keys:
|
392 |
+
# print("Warning: Unexpected keys found in state_dict.")
|
393 |
+
|
394 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
395 |
+
print(f"Moving model to device: {device}")
|
396 |
+
model.to(device)
|
397 |
+
model.eval()
|
398 |
+
print("Model loaded and moved to device.")
|
399 |
except Exception as e:
|
400 |
+
print(f"(Worker) Error loading PyTorch model: {e}")
|
401 |
+
import traceback
|
402 |
+
print(traceback.format_exc())
|
403 |
+
return f"Error loading PyTorch model: {e}", None
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
if image_input is None:
|
406 |
return "Please upload an image.", None
|
|
|
407 |
print(f"(Worker) Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
|
|
|
|
|
408 |
if not isinstance(image_input, Image.Image):
|
409 |
try:
|
|
|
410 |
if isinstance(image_input, str) and image_input.startswith("http"):
|
411 |
+
response = requests.get(image_input); response.raise_for_status()
|
|
|
412 |
image = Image.open(io.BytesIO(response.content))
|
|
|
413 |
elif isinstance(image_input, str) and os.path.exists(image_input):
|
414 |
image = Image.open(image_input)
|
|
|
415 |
elif isinstance(image_input, np.ndarray):
|
416 |
image = Image.fromarray(image_input)
|
417 |
+
else: raise ValueError("Unsupported image input type")
|
|
|
418 |
except Exception as e:
|
419 |
+
print(f"(Worker) Error loading image: {e}"); return f"Error loading image: {e}", None
|
420 |
+
else: image = image_input
|
|
|
|
|
|
|
|
|
|
|
421 |
|
422 |
+
original_pil_image, input_tensor = preprocess_image(image)
|
423 |
+
input_tensor = input_tensor.to(device)
|
424 |
+
try:
|
425 |
+
print("(Worker) Running inference...")
|
426 |
+
start_time = time.time()
|
427 |
+
with torch.no_grad(): outputs = model(input_tensor)
|
428 |
+
inference_time = time.time() - start_time
|
429 |
+
print(f"(Worker) Inference completed in {inference_time:.3f} seconds")
|
430 |
+
probs = torch.sigmoid(outputs)[0].cpu().numpy()
|
431 |
+
except Exception as e:
|
432 |
+
print(f"(Worker) Error during PyTorch inference: {e}"); import traceback; print(traceback.format_exc()); return f"Error during inference: {e}", None
|
|
|
|
|
|
|
|
|
|
|
433 |
|
|
|
434 |
predictions = get_tags(probs, labels_data, gen_threshold, char_threshold)
|
|
|
|
|
435 |
output_tags = []
|
436 |
+
if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
|
437 |
+
if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
|
|
|
|
|
|
|
|
|
|
|
438 |
for category in ["artist", "character", "copyright", "general", "meta"]:
|
439 |
+
tags = [tag.replace("_", " ") for tag, prob in predictions.get(category, [])
|
440 |
+
if not (category == "meta" and any(p in tag.lower() for p in ['id', 'commentary','mismatch']))]
|
441 |
output_tags.extend(tags)
|
|
|
442 |
output_text = ", ".join(output_tags)
|
443 |
|
444 |
+
if output_mode == "Tags Only": return output_text, None
|
445 |
+
else: viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold); return output_text, viz_image
|
|
|
|
|
|
|
446 |
|
447 |
# --- Gradio Interface Definition ---
|
|
|
|
|
448 |
css = """
|
449 |
.gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
|
450 |
footer { display: none !important; }
|
451 |
.gr-prose { max-width: 100% !important; }
|
452 |
"""
|
|
|
453 |
js = """
|
454 |
async function paste_image(blob, gen_thresh, char_thresh, out_mode) {
|
455 |
const data = await fetch(blob)
|
|
|
528 |
"""
|
529 |
|
530 |
with gr.Blocks(css=css, js=js) as demo:
|
531 |
+
gr.Markdown("# WD EVA02 LoRA PyTorch Tagger")
|
532 |
+
gr.Markdown("Upload an image or paste an image URL to predict tags using the fine-tuned WD EVA02 Tagger model (PyTorch/Safetensors).")
|
533 |
gr.Markdown(f"Model Repository: [{REPO_ID}](https://huggingface.co/{REPO_ID})")
|
534 |
|
535 |
with gr.Row():
|
536 |
with gr.Column(scale=1):
|
|
|
537 |
image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
|
|
|
538 |
gr.HTML("<div id='url-input-container'></div>")
|
|
|
539 |
gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General Tag Threshold")
|
540 |
char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
|
541 |
output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
|
|
|
545 |
output_tags = gr.Textbox(label="Predicted Tags", lines=10)
|
546 |
output_visualization = gr.Image(type="pil", label="Prediction Visualization")
|
547 |
|
|
|
548 |
gr.Examples(
|
549 |
examples=[
|
550 |
["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", 0.55, 0.5, "Tags + Visualization"],
|
|
|
555 |
inputs=[image_input, gen_threshold, char_threshold, output_mode],
|
556 |
outputs=[output_tags, output_visualization],
|
557 |
fn=predict,
|
558 |
+
cache_examples=False
|
559 |
)
|
560 |
|
561 |
predict_button.click(
|
|
|
564 |
outputs=[output_tags, output_visualization]
|
565 |
)
|
566 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
567 |
if __name__ == "__main__":
|
|
|
568 |
if not os.environ.get("HF_TOKEN"):
|
569 |
+
print("Warning: HF_TOKEN environment variable not set.")
|
|
|
570 |
demo.launch(share=True)
|