|
import gradio as gr |
|
import spaces |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
import json |
|
import os |
|
import io |
|
import requests |
|
import matplotlib.pyplot as plt |
|
import matplotlib |
|
from huggingface_hub import hf_hub_download |
|
from dataclasses import dataclass |
|
from typing import List, Dict, Optional, Tuple |
|
import time |
|
|
|
import torch |
|
import timm |
|
from safetensors.torch import load_file as safe_load_file |
|
|
|
|
|
matplotlib.use('Agg') |
|
|
|
|
|
|
|
@dataclass |
|
class LabelData: |
|
names: list[str] |
|
rating: list[np.int64] |
|
general: list[np.int64] |
|
artist: list[np.int64] |
|
character: list[np.int64] |
|
copyright: list[np.int64] |
|
meta: list[np.int64] |
|
quality: list[np.int64] |
|
|
|
def pil_ensure_rgb(image: Image.Image) -> Image.Image: |
|
if image.mode not in ["RGB", "RGBA"]: |
|
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") |
|
if image.mode == "RGBA": |
|
background = Image.new("RGB", image.size, (255, 255, 255)) |
|
background.paste(image, mask=image.split()[3]) |
|
image = background |
|
return image |
|
|
|
def pil_pad_square(image: Image.Image) -> Image.Image: |
|
width, height = image.size |
|
if width == height: |
|
return image |
|
new_size = max(width, height) |
|
new_image = Image.new("RGB", (new_size, new_size), (255, 255, 255)) |
|
paste_position = ((new_size - width) // 2, (new_size - height) // 2) |
|
new_image.paste(image, paste_position) |
|
return new_image |
|
|
|
def load_tag_mapping(mapping_path): |
|
with open(mapping_path, 'r', encoding='utf-8') as f: |
|
tag_mapping_data = json.load(f) |
|
|
|
|
|
if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data: |
|
|
|
idx_to_tag_dict = tag_mapping_data["idx_to_tag"] |
|
tag_to_category_dict = tag_mapping_data["tag_to_category"] |
|
|
|
idx_to_tag = {int(k): v for k, v in idx_to_tag_dict.items()} |
|
tag_to_category = tag_to_category_dict |
|
elif isinstance(tag_mapping_data, dict): |
|
|
|
tag_mapping_data = {int(k): v for k, v in tag_mapping_data.items()} |
|
idx_to_tag = {} |
|
tag_to_category = {} |
|
for idx, data in tag_mapping_data.items(): |
|
tag = data['tag'] |
|
category = data['category'] |
|
idx_to_tag[idx] = tag |
|
tag_to_category[tag] = category |
|
else: |
|
raise ValueError("Unsupported tag mapping format") |
|
|
|
|
|
names = [None] * (max(idx_to_tag.keys()) + 1) |
|
rating = [] |
|
general = [] |
|
artist = [] |
|
character = [] |
|
copyright = [] |
|
meta = [] |
|
quality = [] |
|
|
|
for idx, tag in idx_to_tag.items(): |
|
if idx >= len(names): |
|
names.extend([None] * (idx - len(names) + 1)) |
|
names[idx] = tag |
|
category = tag_to_category.get(tag, 'Unknown') |
|
|
|
if category == 'Rating': |
|
rating.append(idx) |
|
elif category == 'General': |
|
general.append(idx) |
|
elif category == 'Artist': |
|
artist.append(idx) |
|
elif category == 'Character': |
|
character.append(idx) |
|
elif category == 'Copyright': |
|
copyright.append(idx) |
|
elif category == 'Meta': |
|
meta.append(idx) |
|
elif category == 'Quality': |
|
quality.append(idx) |
|
|
|
|
|
label_data = LabelData( |
|
names=names, |
|
rating=np.array(rating, dtype=np.int64), |
|
general=np.array(general, dtype=np.int64), |
|
artist=np.array(artist, dtype=np.int64), |
|
character=np.array(character, dtype=np.int64), |
|
copyright=np.array(copyright, dtype=np.int64), |
|
meta=np.array(meta, dtype=np.int64), |
|
quality=np.array(quality, dtype=np.int64) |
|
) |
|
|
|
return label_data, idx_to_tag, tag_to_category |
|
|
|
|
|
def preprocess_image(image: Image.Image, target_size=(448, 448)): |
|
image = pil_ensure_rgb(image) |
|
image = pil_pad_square(image) |
|
image_resized = image.resize(target_size, Image.BICUBIC) |
|
img_array = np.array(image_resized, dtype=np.float32) / 255.0 |
|
img_array = img_array.transpose(2, 0, 1) |
|
|
|
|
|
img_array = img_array[::-1, :, :] |
|
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1) |
|
std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1) |
|
img_array = (img_array - mean) / std |
|
img_array = np.expand_dims(img_array, axis=0) |
|
return image, img_array |
|
|
|
def get_tags(probs, labels: LabelData, gen_threshold, char_threshold): |
|
result = { |
|
"rating": [], "general": [], "character": [], |
|
"copyright": [], "artist": [], "meta": [], "quality": [] |
|
} |
|
|
|
|
|
if labels.rating.size > 0: |
|
rating_probs = probs[labels.rating] |
|
if rating_probs.size > 0: |
|
rating_idx = np.argmax(rating_probs) |
|
|
|
if labels.rating[rating_idx] < len(labels.names): |
|
rating_name = labels.names[labels.rating[rating_idx]] |
|
rating_conf = float(rating_probs[rating_idx]) |
|
result["rating"].append((rating_name, rating_conf)) |
|
else: |
|
print(f"Warning: Rating index {labels.rating[rating_idx]} out of bounds for names list (size {len(labels.names)}).") |
|
|
|
|
|
|
|
if labels.quality.size > 0: |
|
quality_probs = probs[labels.quality] |
|
if quality_probs.size > 0: |
|
quality_idx = np.argmax(quality_probs) |
|
if labels.quality[quality_idx] < len(labels.names): |
|
quality_name = labels.names[labels.quality[quality_idx]] |
|
quality_conf = float(quality_probs[quality_idx]) |
|
result["quality"].append((quality_name, quality_conf)) |
|
else: |
|
print(f"Warning: Quality index {labels.quality[quality_idx]} out of bounds for names list (size {len(labels.names)}).") |
|
|
|
|
|
category_map = { |
|
"general": (labels.general, gen_threshold), |
|
"character": (labels.character, char_threshold), |
|
"copyright": (labels.copyright, char_threshold), |
|
"artist": (labels.artist, char_threshold), |
|
"meta": (labels.meta, gen_threshold) |
|
} |
|
|
|
for category, (indices, threshold) in category_map.items(): |
|
if indices.size > 0: |
|
|
|
valid_indices = indices[(indices < len(probs)) & (indices < len(labels.names))] |
|
if valid_indices.size > 0: |
|
category_probs = probs[valid_indices] |
|
mask = category_probs >= threshold |
|
selected_indices = valid_indices[mask] |
|
selected_probs = category_probs[mask] |
|
for idx, prob in zip(selected_indices, selected_probs): |
|
result[category].append((labels.names[idx], float(prob))) |
|
|
|
|
|
|
|
for k in result: |
|
result[k] = sorted(result[k], key=lambda x: x[1], reverse=True) |
|
|
|
return result |
|
|
|
def visualize_predictions(image: Image.Image, predictions, threshold=0.45): |
|
|
|
filtered_meta = [] |
|
excluded_meta_patterns = ['id', 'commentary', 'request', 'mismatch'] |
|
for tag, prob in predictions["meta"]: |
|
if not any(pattern in tag.lower() for pattern in excluded_meta_patterns): |
|
filtered_meta.append((tag, prob)) |
|
predictions["meta"] = filtered_meta |
|
|
|
|
|
fig = plt.figure(figsize=(20, 12), dpi=100) |
|
gs = fig.add_gridspec(1, 2, width_ratios=[1.2, 1]) |
|
ax_img = fig.add_subplot(gs[0, 0]) |
|
ax_img.imshow(image) |
|
ax_img.set_title("Original Image") |
|
ax_img.axis('off') |
|
ax_tags = fig.add_subplot(gs[0, 1]) |
|
|
|
all_tags = [] |
|
all_probs = [] |
|
all_colors = [] |
|
color_map = {'rating': 'red', 'character': 'blue', 'copyright': 'purple', |
|
'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow'} |
|
|
|
for cat, prefix, color in [('rating', 'R', 'red'), ('character', 'C', 'blue'), |
|
('copyright', '©', 'purple'), ('artist', 'A', 'orange'), |
|
('general', 'G', 'green'), ('meta', 'M', 'gray'), ('quality', 'Q', 'yellow')]: |
|
for tag, prob in predictions[cat]: |
|
all_tags.append(f"[{prefix}] {tag}") |
|
all_probs.append(prob) |
|
all_colors.append(color) |
|
|
|
if not all_tags: |
|
ax_tags.text(0.5, 0.5, "No tags found above threshold", ha='center', va='center') |
|
ax_tags.set_title(f"Tags (threshold={threshold})") |
|
ax_tags.axis('off') |
|
plt.tight_layout() |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=100) |
|
plt.close(fig) |
|
buf.seek(0) |
|
return Image.open(buf) |
|
|
|
|
|
sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i], reverse=True) |
|
all_tags = [all_tags[i] for i in sorted_indices] |
|
all_probs = [all_probs[i] for i in sorted_indices] |
|
all_colors = [all_colors[i] for i in sorted_indices] |
|
|
|
all_tags.reverse() |
|
all_probs.reverse() |
|
all_colors.reverse() |
|
|
|
num_tags = len(all_tags) |
|
bar_height = 0.8 |
|
if num_tags > 30: bar_height = 0.8 * (30 / num_tags) |
|
y_positions = np.arange(num_tags) |
|
|
|
bars = ax_tags.barh(y_positions, all_probs, height=bar_height, color=all_colors) |
|
ax_tags.set_yticks(y_positions) |
|
ax_tags.set_yticklabels(all_tags) |
|
|
|
fontsize = 10 |
|
if num_tags > 40: fontsize = 8 |
|
elif num_tags > 60: fontsize = 6 |
|
for label in ax_tags.get_yticklabels(): label.set_fontsize(fontsize) |
|
|
|
for i, (bar, prob) in enumerate(zip(bars, all_probs)): |
|
ax_tags.text(min(prob + 0.02, 0.98), y_positions[i], f"{prob:.3f}", |
|
va='center', fontsize=fontsize) |
|
|
|
ax_tags.set_xlim(0, 1) |
|
ax_tags.set_title(f"Tags (threshold={threshold})") |
|
|
|
from matplotlib.patches import Patch |
|
legend_elements = [Patch(facecolor=color, label=cat.capitalize()) for cat, color in color_map.items()] |
|
ax_tags.legend(handles=legend_elements, loc='lower right', fontsize=8) |
|
|
|
plt.tight_layout() |
|
plt.subplots_adjust(bottom=0.05) |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=100) |
|
plt.close(fig) |
|
buf.seek(0) |
|
return Image.open(buf) |
|
|
|
|
|
|
|
|
|
REPO_ID = "cella110n/cl_tagger" |
|
SAFETENSORS_FILENAME = "lora_model_0426/checkpoint_epoch_4.safetensors" |
|
METADATA_FILENAME = "lora_model_0426/checkpoint_epoch_4_metadata.json" |
|
TAG_MAPPING_FILENAME = "lora_model_0426/tag_mapping.json" |
|
CACHE_DIR = "./model_cache" |
|
|
|
safetensors_path_global = None |
|
metadata_path_global = None |
|
tag_mapping_path_global = None |
|
labels_data = None |
|
tag_to_category_map = None |
|
|
|
def download_model_files(): |
|
"""Hugging Face Hubからモデル、メタデータ、タグマッピングをダウンロード""" |
|
global safetensors_path_global, metadata_path_global, tag_mapping_path_global |
|
|
|
if safetensors_path_global and tag_mapping_path_global and os.path.exists(safetensors_path_global) and os.path.exists(tag_mapping_path_global): |
|
print("Files seem already downloaded.") |
|
return |
|
|
|
print("Downloading model files...") |
|
hf_token = os.environ.get("HF_TOKEN") |
|
try: |
|
safetensors_path_global = hf_hub_download(repo_id=REPO_ID, filename=SAFETENSORS_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True) |
|
tag_mapping_path_global = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True) |
|
print(f"Safetensors downloaded to: {safetensors_path_global}") |
|
print(f"Tag mapping downloaded to: {tag_mapping_path_global}") |
|
try: |
|
metadata_path_global = hf_hub_download(repo_id=REPO_ID, filename=METADATA_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True) |
|
print(f"Metadata downloaded to: {metadata_path_global}") |
|
except Exception: |
|
print(f"Metadata file ({METADATA_FILENAME}) not found or download failed. Proceeding without it.") |
|
metadata_path_global = None |
|
except Exception as e: |
|
print(f"Error downloading files: {e}") |
|
if "401 Client Error" in str(e) or "Repository not found" in str(e): |
|
raise gr.Error(f"Could not download files from {REPO_ID}. Check HF_TOKEN secret.") |
|
else: |
|
raise gr.Error(f"Error downloading files: {e}") |
|
|
|
def initialize_labels_and_paths(): |
|
"""ラベルデータとファイルパスを準備(キャッシュ)""" |
|
global labels_data, tag_to_category_map, tag_mapping_path_global |
|
if labels_data is None: |
|
download_model_files() |
|
print("Loading labels from tag_mapping.json...") |
|
if tag_mapping_path_global and os.path.exists(tag_mapping_path_global): |
|
try: |
|
labels_data, _, tag_to_category_map = load_tag_mapping(tag_mapping_path_global) |
|
print(f"Labels loaded successfully. Number of labels: {len(labels_data.names)}") |
|
except Exception as e: |
|
print(f"Error loading tag mapping from {tag_mapping_path_global}: {e}") |
|
raise gr.Error(f"Error loading tag mapping file: {e}") |
|
else: |
|
print(f"Tag mapping file not found after download attempt: {tag_mapping_path_global}") |
|
raise gr.Error("Tag mapping file could not be downloaded or found.") |
|
|
|
@spaces.GPU() |
|
def predict(image_input, gen_threshold, char_threshold, output_mode): |
|
print("--- predict function started (GPU worker) ---") |
|
initialize_labels_and_paths() |
|
print("Loading PyTorch model...") |
|
global safetensors_path_global, labels_data |
|
if safetensors_path_global is None or labels_data is None: |
|
initialize_labels_and_paths() |
|
if safetensors_path_global is None or labels_data is None: |
|
return "Error: Model/Labels paths could not be initialized.", None |
|
try: |
|
print(f"Creating base model: eva02_large_patch14_448.mim_m38m_ft_in1k") |
|
num_classes = len(labels_data.names) |
|
|
|
if num_classes <= 0: |
|
raise ValueError(f"Invalid number of classes loaded from tag mapping: {num_classes}") |
|
print(f"Setting num_classes: {num_classes}") |
|
model = timm.create_model( |
|
'eva02_large_patch14_448.mim_m38m_ft_in1k', |
|
pretrained=True, |
|
num_classes=num_classes |
|
) |
|
print(f"Loading state dict from: {safetensors_path_global}") |
|
if not os.path.exists(safetensors_path_global): |
|
raise FileNotFoundError(f"Safetensors file not found at: {safetensors_path_global}") |
|
state_dict = safe_load_file(safetensors_path_global) |
|
adapted_state_dict = {} |
|
for k, v in state_dict.items(): |
|
|
|
|
|
|
|
|
|
|
|
adapted_state_dict[k] = v |
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(adapted_state_dict, strict=False) |
|
print(f"State dict loaded. Missing keys: {missing_keys}") |
|
print(f"State dict loaded. Unexpected keys: {unexpected_keys}") |
|
|
|
if any(k.startswith('head.') for k in missing_keys): |
|
print("Warning: Classification head weights might be missing or mismatched!") |
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Moving model to device: {device}") |
|
model.to(device) |
|
model.eval() |
|
print("Model loaded and moved to device.") |
|
except Exception as e: |
|
print(f"(Worker) Error loading PyTorch model: {e}") |
|
import traceback |
|
print(traceback.format_exc()) |
|
return f"Error loading PyTorch model: {e}", None |
|
|
|
if image_input is None: |
|
return "Please upload an image.", None |
|
print(f"(Worker) Processing image with thresholds: gen={gen_threshold}, char={char_threshold}") |
|
if not isinstance(image_input, Image.Image): |
|
try: |
|
if isinstance(image_input, str) and image_input.startswith("http"): |
|
response = requests.get(image_input); response.raise_for_status() |
|
image = Image.open(io.BytesIO(response.content)) |
|
elif isinstance(image_input, str) and os.path.exists(image_input): |
|
image = Image.open(image_input) |
|
elif isinstance(image_input, np.ndarray): |
|
image = Image.fromarray(image_input) |
|
else: raise ValueError("Unsupported image input type") |
|
except Exception as e: |
|
print(f"(Worker) Error loading image: {e}"); return f"Error loading image: {e}", None |
|
else: image = image_input |
|
|
|
original_pil_image, input_tensor = preprocess_image(image) |
|
input_tensor = input_tensor.to(device) |
|
try: |
|
print("(Worker) Running inference...") |
|
start_time = time.time() |
|
with torch.no_grad(): outputs = model(input_tensor) |
|
inference_time = time.time() - start_time |
|
print(f"(Worker) Inference completed in {inference_time:.3f} seconds") |
|
probs = torch.sigmoid(outputs)[0].cpu().numpy() |
|
except Exception as e: |
|
print(f"(Worker) Error during PyTorch inference: {e}"); import traceback; print(traceback.format_exc()); return f"Error during inference: {e}", None |
|
|
|
predictions = get_tags(probs, labels_data, gen_threshold, char_threshold) |
|
output_tags = [] |
|
if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " ")) |
|
if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " ")) |
|
for category in ["artist", "character", "copyright", "general", "meta"]: |
|
tags = [tag.replace("_", " ") for tag, prob in predictions.get(category, []) |
|
if not (category == "meta" and any(p in tag.lower() for p in ['id', 'commentary','mismatch']))] |
|
output_tags.extend(tags) |
|
output_text = ", ".join(output_tags) |
|
|
|
if output_mode == "Tags Only": return output_text, None |
|
else: viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold); return output_text, viz_image |
|
|
|
|
|
css = """ |
|
.gradio-container { font-family: 'IBM Plex Sans', sans-serif; } |
|
footer { display: none !important; } |
|
.gr-prose { max-width: 100% !important; } |
|
""" |
|
js = """ |
|
async function paste_image(blob, gen_thresh, char_thresh, out_mode) { |
|
const data = await fetch(blob) |
|
const image_data = await data.blob() |
|
const file = new File([image_data], "pasted_image.png",{ type: image_data.type }) |
|
const dt = new DataTransfer() |
|
dt.items.add(file) |
|
const element = document.querySelector('#input-image input[type="file"]') |
|
element.files = dt.files |
|
// Trigger the change event manually |
|
const event = new Event('change', { bubbles: true }) |
|
element.dispatchEvent(event) |
|
// Wait a bit for Gradio to process the change, then trigger predict if needed |
|
// await new Promise(resolve => setTimeout(resolve, 100)); // Optional delay |
|
// You might need to manually trigger the prediction or rely on Gradio's auto-triggering |
|
return [file, gen_thresh, char_thresh, out_mode]; // Return input for Gradio function |
|
} |
|
|
|
async function paste_update(evt){ |
|
if (!evt.clipboardData || !evt.clipboardData.items) return; |
|
var url = evt.clipboardData.getData('text'); |
|
if (url) { |
|
// Basic check for image URL (you might want a more robust check) |
|
if (/\.(jpg|jpeg|png|webp|bmp)$/i.test(url)) { |
|
// Create a button or link to load the URL |
|
const url_container = document.getElementById('url-input-container'); |
|
url_container.innerHTML = `<p>Detected URL: <button id="load-url-btn" class="gr-button gr-button-sm gr-button-secondary">${url}</button></p>`; |
|
|
|
document.getElementById('load-url-btn').onclick = async () => { |
|
// Simulate file upload from URL - Gradio's Image component handles URLs directly |
|
const element = document.querySelector('#input-image input[type="file"]'); |
|
// Can't directly set URL to file input, so we pass it to Gradio fn |
|
// Or maybe update the image display src directly if possible? |
|
|
|
// Let Gradio handle the URL - user needs to click predict |
|
// We can pre-fill the image component if Gradio supports it via JS, |
|
// but it's simpler to just let the user click predict after pasting URL. |
|
alert("URL detected. Please ensure the image input is cleared and then press 'Predict' or re-upload the image."); |
|
// Clear current image preview if possible? |
|
|
|
// A workaround: display the URL and let the user manually trigger prediction |
|
// Or, try to use Gradio's JS API if available to update the Image component value |
|
// For now, just inform the user. |
|
}; |
|
return; // Don't process as image paste if URL is found |
|
} |
|
} |
|
|
|
var items = evt.clipboardData.items; |
|
for (var i = 0; i < items.length; i++) { |
|
if (items[i].type.indexOf("image") === 0) { |
|
var blob = items[i].getAsFile(); |
|
var reader = new FileReader(); |
|
reader.onload = function(event){ |
|
// Update the Gradio Image component source directly |
|
const imgElement = document.querySelector('#input-image img'); // Find the img tag inside the component |
|
if (imgElement) { |
|
imgElement.src = event.target.result; |
|
// We still need to pass the blob to the Gradio function |
|
// Use Gradio's JS API or hidden components if possible |
|
// For now, let's use a simple alert and rely on manual trigger |
|
alert("Image pasted. The preview should update. Please press 'Predict'."); |
|
// Trigger paste_image function - requires Gradio JS interaction |
|
// This part is tricky without official Gradio JS API for updates |
|
} |
|
}; |
|
reader.readAsDataURL(blob); |
|
// Prevent default paste handling |
|
evt.preventDefault(); |
|
break; |
|
} |
|
} |
|
} |
|
|
|
document.addEventListener('paste', paste_update); |
|
""" |
|
|
|
with gr.Blocks(css=css, js=js) as demo: |
|
gr.Markdown("# WD EVA02 LoRA PyTorch Tagger") |
|
gr.Markdown("Upload an image or paste an image URL to predict tags using the fine-tuned WD EVA02 Tagger model (PyTorch/Safetensors).") |
|
gr.Markdown(f"Model Repository: [{REPO_ID}](https://huggingface.co/{REPO_ID})") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image") |
|
gr.HTML("<div id='url-input-container'></div>") |
|
gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General Tag Threshold") |
|
char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold") |
|
output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode") |
|
predict_button = gr.Button("Predict", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
output_tags = gr.Textbox(label="Predicted Tags", lines=10) |
|
output_visualization = gr.Image(type="pil", label="Prediction Visualization") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", 0.55, 0.5, "Tags + Visualization"], |
|
["https://pbs.twimg.com/media/GjlX0gibcAA4EJ4.jpg", 0.5, 0.5, "Tags Only"], |
|
["https://pbs.twimg.com/media/Gj4nQbjbEAATeoH.jpg", 0.55, 0.5, "Tags + Visualization"], |
|
["https://pbs.twimg.com/media/GkbtX0GaoAMlUZt.jpg", 0.45, 0.45, "Tags + Visualization"] |
|
], |
|
inputs=[image_input, gen_threshold, char_threshold, output_mode], |
|
outputs=[output_tags, output_visualization], |
|
fn=predict, |
|
cache_examples=False |
|
) |
|
|
|
predict_button.click( |
|
fn=predict, |
|
inputs=[image_input, gen_threshold, char_threshold, output_mode], |
|
outputs=[output_tags, output_visualization] |
|
) |
|
|
|
if __name__ == "__main__": |
|
if not os.environ.get("HF_TOKEN"): |
|
print("Warning: HF_TOKEN environment variable not set.") |
|
demo.launch(share=True) |