cella110n commited on
Commit
87e05c0
·
verified ·
1 Parent(s): d9eba7a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -135
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
- # import spaces # Removed
3
- import onnxruntime as ort
4
  import numpy as np
5
  from PIL import Image, ImageDraw, ImageFont
6
  import json
@@ -13,7 +12,10 @@ from huggingface_hub import hf_hub_download
13
  from dataclasses import dataclass
14
  from typing import List, Dict, Optional, Tuple
15
  import time
16
- import spaces
 
 
 
17
 
18
  # MatplotlibのバックエンドをAggに設定 (GUIなし環境用)
19
  matplotlib.use('Agg')
@@ -289,169 +291,165 @@ def visualize_predictions(image: Image.Image, predictions, threshold=0.45):
289
 
290
  # 定数
291
  REPO_ID = "cella110n/cl_tagger"
292
- # MODEL_FILENAME = "cl_eva02_tagger_v1_250426/model_optimized.onnx"
293
- MODEL_FILENAME = "cl_eva02_tagger_v1_250426/model.onnx" # Use non-optimized if needed
294
- TAG_MAPPING_FILENAME = "cl_eva02_tagger_v1_250426/tag_mapping.json"
295
  CACHE_DIR = "./model_cache"
296
 
297
- # グローバル変数(モデルとラベルをキャッシュ)
298
- # onnx_session = None # Removed global session
299
- model_path_global = None # Store model path globally
300
  labels_data = None
301
  tag_to_category_map = None
302
 
303
  def download_model_files():
304
- """Hugging Face Hubからモデルとタグマッピングをダウンロード"""
 
 
 
 
 
 
305
  print("Downloading model files...")
306
- # 環境変数からHFトークンを取得 (プライベートリポジトリ用)
307
  hf_token = os.environ.get("HF_TOKEN")
308
  try:
309
- model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, cache_dir=CACHE_DIR, token=hf_token)
310
- tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token)
311
- print(f"Model downloaded to: {model_path}")
312
- print(f"Tag mapping downloaded to: {tag_mapping_path}")
313
- return model_path, tag_mapping_path
 
 
 
 
 
314
  except Exception as e:
315
  print(f"Error downloading files: {e}")
316
- # トークンがない場合のエラーメッセージを改善
317
  if "401 Client Error" in str(e) or "Repository not found" in str(e):
318
- raise gr.Error(f"Could not download files from {REPO_ID}. "
319
- f"If this is a private repository, make sure to set the HF_TOKEN secret in your Space settings.")
320
  else:
321
  raise gr.Error(f"Error downloading files: {e}")
322
 
323
-
324
- def initialize_model():
325
- """モデルファイルとラベルデータを準備(キャッシュ)"""
326
- global model_path_global, labels_data, tag_to_category_map
327
- # Only initialize once
328
  if labels_data is None:
329
- print("Downloading model files...") # Moved print here
330
- model_path, tag_mapping_path = download_model_files()
331
- model_path_global = model_path # Store the path
332
- print("Loading labels...")
333
- labels_data, _, tag_to_category_map = load_tag_mapping(tag_mapping_path)
334
- print("Labels loaded.")
335
- # --- Removed ONNX Session Initialization ---
 
 
 
 
 
336
 
337
  @spaces.GPU()
338
  def predict(image_input, gen_threshold, char_threshold, output_mode):
339
  print("--- predict function started (GPU worker) ---")
340
- """Gradioインターフェース用の予測関数 (GPUワーカー内)"""
341
- initialize_model() # Ensure files/labels are ready
342
-
343
- # --- Create ONNX session inside the GPU function ---
344
- print("Creating ONNX session for prediction...")
345
- global model_path_global # Access the global model path
346
- if model_path_global is None:
347
- # Attempt initialization again if model path is missing (e.g., after restart)
348
- initialize_model()
349
- if model_path_global is None:
350
- return "Error: Model path could not be initialized.", None
351
-
352
- available_providers = ort.get_available_providers()
353
- print(f"(Worker) Available ONNX Runtime providers: {available_providers}")
354
- providers = []
355
- if 'CUDAExecutionProvider' in available_providers:
356
- providers.append('CUDAExecutionProvider')
357
- providers.append('CPUExecutionProvider') # Always include CPU as fallback
358
-
359
  try:
360
- # Create session with GPU preference inside the worker
361
- session = ort.InferenceSession(model_path_global, providers=providers)
362
- print(f"(Worker) Using ONNX Runtime provider: {session.get_providers()[0]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  except Exception as e:
364
- print(f"(Worker) Error initializing ONNX session with providers {providers}: {e}")
365
- # Fallback explicitly to CPU if GPU fails inside worker
366
- try:
367
- print("(Worker) Falling back to CPUExecutionProvider only.")
368
- session = ort.InferenceSession(model_path_global, providers=['CPUExecutionProvider'])
369
- except Exception as e_cpu:
370
- print(f"(Worker) Error initializing ONNX session even with CPU: {e_cpu}")
371
- return f"Error initializing ONNX session: {e_cpu}", None
372
- # --- Session created ---
373
 
374
  if image_input is None:
375
  return "Please upload an image.", None
376
-
377
  print(f"(Worker) Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
378
-
379
- # PIL Imageオブジェクトであることを確認
380
  if not isinstance(image_input, Image.Image):
381
  try:
382
- # URLの場合
383
  if isinstance(image_input, str) and image_input.startswith("http"):
384
- response = requests.get(image_input)
385
- response.raise_for_status()
386
  image = Image.open(io.BytesIO(response.content))
387
- # ファイルパスの場合 (Gradioでは通常発生しないが念のため)
388
  elif isinstance(image_input, str) and os.path.exists(image_input):
389
  image = Image.open(image_input)
390
- # Numpy配列の場合 (Gradio Imageコンポーネントからの入力)
391
  elif isinstance(image_input, np.ndarray):
392
  image = Image.fromarray(image_input)
393
- else:
394
- raise ValueError("Unsupported image input type")
395
  except Exception as e:
396
- print(f"(Worker) Error loading image: {e}")
397
- return f"Error loading image: {e}", None
398
- else:
399
- image = image_input
400
-
401
- # 前処理
402
- original_pil_image, input_data = preprocess_image(image)
403
 
404
- # データ型をモデルの期待に合わせる (通常はfloat32)
405
- input_name = session.get_inputs()[0].name
406
- expected_type = session.get_inputs()[0].type
407
- if expected_type == 'tensor(float16)':
408
- input_data = input_data.astype(np.float16)
409
- else:
410
- input_data = input_data.astype(np.float32) # Default to float32
411
-
412
- # 推論 (作成したセッションを使用)
413
- start_time = time.time()
414
- outputs = session.run(None, {input_name: input_data})[0]
415
- inference_time = time.time() - start_time
416
- print(f"(Worker) Inference completed in {inference_time:.3f} seconds")
417
-
418
- # シグモイド関数で確率に変換
419
- probs = 1 / (1 + np.exp(-outputs[0])) # Apply sigmoid to the first batch item
420
 
421
- # タグ取得
422
  predictions = get_tags(probs, labels_data, gen_threshold, char_threshold)
423
-
424
- # タグを整形
425
  output_tags = []
426
- # RatingとQualityを最初に追加
427
- if predictions["rating"]:
428
- output_tags.append(predictions["rating"][0][0].replace("_", " "))
429
- if predictions["quality"]:
430
- output_tags.append(predictions["quality"][0][0].replace("_", " "))
431
-
432
- # 残りのカテゴリをアルファベット順に追加(オプション)
433
  for category in ["artist", "character", "copyright", "general", "meta"]:
434
- tags = [tag.replace("_", " ") for tag, prob in predictions[category]
435
- if not (category == "meta" and any(p in tag.lower() for p in ['id', 'commentary','mismatch']))] # メタタグフィルタリング
436
  output_tags.extend(tags)
437
-
438
  output_text = ", ".join(output_tags)
439
 
440
- if output_mode == "Tags Only":
441
- return output_text, None
442
- else: # Visualization
443
- viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold)
444
- return output_text, viz_image
445
 
446
  # --- Gradio Interface Definition ---
447
-
448
- # CSS for styling
449
  css = """
450
  .gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
451
  footer { display: none !important; }
452
  .gr-prose { max-width: 100% !important; }
453
  """
454
- # Custom JS for image pasting and URL handling
455
  js = """
456
  async function paste_image(blob, gen_thresh, char_thresh, out_mode) {
457
  const data = await fetch(blob)
@@ -530,17 +528,14 @@ document.addEventListener('paste', paste_update);
530
  """
531
 
532
  with gr.Blocks(css=css, js=js) as demo:
533
- gr.Markdown("# WD EVA02 LoRA ONNX Tagger")
534
- gr.Markdown("Upload an image or paste an image URL to predict tags using the fine-tuned WD EVA02 Tagger model (ONNX format).")
535
  gr.Markdown(f"Model Repository: [{REPO_ID}](https://huggingface.co/{REPO_ID})")
536
 
537
  with gr.Row():
538
  with gr.Column(scale=1):
539
- # Use elem_id for JS targeting
540
  image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
541
- # Container for URL paste message
542
  gr.HTML("<div id='url-input-container'></div>")
543
-
544
  gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General Tag Threshold")
545
  char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
546
  output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
@@ -550,7 +545,6 @@ with gr.Blocks(css=css, js=js) as demo:
550
  output_tags = gr.Textbox(label="Predicted Tags", lines=10)
551
  output_visualization = gr.Image(type="pil", label="Prediction Visualization")
552
 
553
- # Examples
554
  gr.Examples(
555
  examples=[
556
  ["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", 0.55, 0.5, "Tags + Visualization"],
@@ -561,7 +555,7 @@ with gr.Blocks(css=css, js=js) as demo:
561
  inputs=[image_input, gen_threshold, char_threshold, output_mode],
562
  outputs=[output_tags, output_visualization],
563
  fn=predict,
564
- cache_examples=False # Slows down startup if True and large examples
565
  )
566
 
567
  predict_button.click(
@@ -570,18 +564,7 @@ with gr.Blocks(css=css, js=js) as demo:
570
  outputs=[output_tags, output_visualization]
571
  )
572
 
573
- # Add listener for image input changes (e.g., from pasting)
574
- # This might trigger prediction automatically or require the button click
575
- # image_input.change(
576
- # fn=predict,
577
- # inputs=[image_input, gen_threshold, char_threshold, output_mode],
578
- # outputs=[output_tags, output_visualization]
579
- # )
580
-
581
-
582
  if __name__ == "__main__":
583
- # 環境変数HF_TOKENがない場合に警告(プライベートリポジトリ用)
584
  if not os.environ.get("HF_TOKEN"):
585
- print("Warning: HF_TOKEN environment variable not set. Downloads from private repositories may fail.")
586
- # initialize_model() # Removed startup initialization (model loaded in predict)
587
  demo.launch(share=True)
 
1
  import gradio as gr
2
+ import spaces
 
3
  import numpy as np
4
  from PIL import Image, ImageDraw, ImageFont
5
  import json
 
12
  from dataclasses import dataclass
13
  from typing import List, Dict, Optional, Tuple
14
  import time
15
+
16
+ import torch
17
+ import timm
18
+ from safetensors.torch import load_file as safe_load_file
19
 
20
  # MatplotlibのバックエンドをAggに設定 (GUIなし環境用)
21
  matplotlib.use('Agg')
 
291
 
292
  # 定数
293
  REPO_ID = "cella110n/cl_tagger"
294
+ SAFETENSORS_FILENAME = "lora_model_0426/checkpoint_epoch_4.safetensors"
295
+ METADATA_FILENAME = "lora_model_0426/checkpoint_epoch_4_metadata.json"
296
+ TAG_MAPPING_FILENAME = "lora_model_0426/tag_mapping.json"
297
  CACHE_DIR = "./model_cache"
298
 
299
+ safetensors_path_global = None
300
+ metadata_path_global = None
301
+ tag_mapping_path_global = None
302
  labels_data = None
303
  tag_to_category_map = None
304
 
305
  def download_model_files():
306
+ """Hugging Face Hubからモデル、メタデータ、タグマッピングをダウンロード"""
307
+ global safetensors_path_global, metadata_path_global, tag_mapping_path_global
308
+ # Check if files seem to be downloaded already
309
+ if safetensors_path_global and tag_mapping_path_global and os.path.exists(safetensors_path_global) and os.path.exists(tag_mapping_path_global):
310
+ print("Files seem already downloaded.")
311
+ return
312
+
313
  print("Downloading model files...")
 
314
  hf_token = os.environ.get("HF_TOKEN")
315
  try:
316
+ safetensors_path_global = hf_hub_download(repo_id=REPO_ID, filename=SAFETENSORS_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True) # Force download to ensure latest
317
+ tag_mapping_path_global = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True)
318
+ print(f"Safetensors downloaded to: {safetensors_path_global}")
319
+ print(f"Tag mapping downloaded to: {tag_mapping_path_global}")
320
+ try:
321
+ metadata_path_global = hf_hub_download(repo_id=REPO_ID, filename=METADATA_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=True)
322
+ print(f"Metadata downloaded to: {metadata_path_global}")
323
+ except Exception:
324
+ print(f"Metadata file ({METADATA_FILENAME}) not found or download failed. Proceeding without it.")
325
+ metadata_path_global = None
326
  except Exception as e:
327
  print(f"Error downloading files: {e}")
 
328
  if "401 Client Error" in str(e) or "Repository not found" in str(e):
329
+ raise gr.Error(f"Could not download files from {REPO_ID}. Check HF_TOKEN secret.")
 
330
  else:
331
  raise gr.Error(f"Error downloading files: {e}")
332
 
333
+ def initialize_labels_and_paths():
334
+ """ラベルデータとファイルパスを準備(キャッシュ)"""
335
+ global labels_data, tag_to_category_map, tag_mapping_path_global
 
 
336
  if labels_data is None:
337
+ download_model_files() # Ensure files are downloaded
338
+ print("Loading labels from tag_mapping.json...")
339
+ if tag_mapping_path_global and os.path.exists(tag_mapping_path_global):
340
+ try:
341
+ labels_data, _, tag_to_category_map = load_tag_mapping(tag_mapping_path_global)
342
+ print(f"Labels loaded successfully. Number of labels: {len(labels_data.names)}")
343
+ except Exception as e:
344
+ print(f"Error loading tag mapping from {tag_mapping_path_global}: {e}")
345
+ raise gr.Error(f"Error loading tag mapping file: {e}")
346
+ else:
347
+ print(f"Tag mapping file not found after download attempt: {tag_mapping_path_global}")
348
+ raise gr.Error("Tag mapping file could not be downloaded or found.")
349
 
350
  @spaces.GPU()
351
  def predict(image_input, gen_threshold, char_threshold, output_mode):
352
  print("--- predict function started (GPU worker) ---")
353
+ initialize_labels_and_paths()
354
+ print("Loading PyTorch model...")
355
+ global safetensors_path_global, labels_data
356
+ if safetensors_path_global is None or labels_data is None:
357
+ initialize_labels_and_paths()
358
+ if safetensors_path_global is None or labels_data is None:
359
+ return "Error: Model/Labels paths could not be initialized.", None
 
 
 
 
 
 
 
 
 
 
 
 
360
  try:
361
+ print(f"Creating base model: eva02_large_patch14_448.mim_m38m_ft_in1k")
362
+ num_classes = len(labels_data.names)
363
+ # Validate num_classes (should be > 0)
364
+ if num_classes <= 0:
365
+ raise ValueError(f"Invalid number of classes loaded from tag mapping: {num_classes}")
366
+ print(f"Setting num_classes: {num_classes}")
367
+ model = timm.create_model(
368
+ 'eva02_large_patch14_448.mim_m38m_ft_in1k',
369
+ pretrained=True,
370
+ num_classes=num_classes
371
+ )
372
+ print(f"Loading state dict from: {safetensors_path_global}")
373
+ if not os.path.exists(safetensors_path_global):
374
+ raise FileNotFoundError(f"Safetensors file not found at: {safetensors_path_global}")
375
+ state_dict = safe_load_file(safetensors_path_global)
376
+ adapted_state_dict = {}
377
+ for k, v in state_dict.items():
378
+ # Adjust key names if needed based on how lora.py saved the merged model
379
+ # Example: If saved with 'base_model.' prefix
380
+ # if k.startswith('base_model.'):
381
+ # adapted_state_dict[k[len('base_model.'):]] = v
382
+ # else:
383
+ adapted_state_dict[k] = v # Assuming direct key match for now
384
+
385
+ missing_keys, unexpected_keys = model.load_state_dict(adapted_state_dict, strict=False)
386
+ print(f"State dict loaded. Missing keys: {missing_keys}")
387
+ print(f"State dict loaded. Unexpected keys: {unexpected_keys}")
388
+ # Handle critical missing keys (like the head) if necessary
389
+ if any(k.startswith('head.') for k in missing_keys):
390
+ print("Warning: Classification head weights might be missing or mismatched!")
391
+ # if unexpected_keys:
392
+ # print("Warning: Unexpected keys found in state_dict.")
393
+
394
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
395
+ print(f"Moving model to device: {device}")
396
+ model.to(device)
397
+ model.eval()
398
+ print("Model loaded and moved to device.")
399
  except Exception as e:
400
+ print(f"(Worker) Error loading PyTorch model: {e}")
401
+ import traceback
402
+ print(traceback.format_exc())
403
+ return f"Error loading PyTorch model: {e}", None
 
 
 
 
 
404
 
405
  if image_input is None:
406
  return "Please upload an image.", None
 
407
  print(f"(Worker) Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
 
 
408
  if not isinstance(image_input, Image.Image):
409
  try:
 
410
  if isinstance(image_input, str) and image_input.startswith("http"):
411
+ response = requests.get(image_input); response.raise_for_status()
 
412
  image = Image.open(io.BytesIO(response.content))
 
413
  elif isinstance(image_input, str) and os.path.exists(image_input):
414
  image = Image.open(image_input)
 
415
  elif isinstance(image_input, np.ndarray):
416
  image = Image.fromarray(image_input)
417
+ else: raise ValueError("Unsupported image input type")
 
418
  except Exception as e:
419
+ print(f"(Worker) Error loading image: {e}"); return f"Error loading image: {e}", None
420
+ else: image = image_input
 
 
 
 
 
421
 
422
+ original_pil_image, input_tensor = preprocess_image(image)
423
+ input_tensor = input_tensor.to(device)
424
+ try:
425
+ print("(Worker) Running inference...")
426
+ start_time = time.time()
427
+ with torch.no_grad(): outputs = model(input_tensor)
428
+ inference_time = time.time() - start_time
429
+ print(f"(Worker) Inference completed in {inference_time:.3f} seconds")
430
+ probs = torch.sigmoid(outputs)[0].cpu().numpy()
431
+ except Exception as e:
432
+ print(f"(Worker) Error during PyTorch inference: {e}"); import traceback; print(traceback.format_exc()); return f"Error during inference: {e}", None
 
 
 
 
 
433
 
 
434
  predictions = get_tags(probs, labels_data, gen_threshold, char_threshold)
 
 
435
  output_tags = []
436
+ if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
437
+ if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
 
 
 
 
 
438
  for category in ["artist", "character", "copyright", "general", "meta"]:
439
+ tags = [tag.replace("_", " ") for tag, prob in predictions.get(category, [])
440
+ if not (category == "meta" and any(p in tag.lower() for p in ['id', 'commentary','mismatch']))]
441
  output_tags.extend(tags)
 
442
  output_text = ", ".join(output_tags)
443
 
444
+ if output_mode == "Tags Only": return output_text, None
445
+ else: viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold); return output_text, viz_image
 
 
 
446
 
447
  # --- Gradio Interface Definition ---
 
 
448
  css = """
449
  .gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
450
  footer { display: none !important; }
451
  .gr-prose { max-width: 100% !important; }
452
  """
 
453
  js = """
454
  async function paste_image(blob, gen_thresh, char_thresh, out_mode) {
455
  const data = await fetch(blob)
 
528
  """
529
 
530
  with gr.Blocks(css=css, js=js) as demo:
531
+ gr.Markdown("# WD EVA02 LoRA PyTorch Tagger")
532
+ gr.Markdown("Upload an image or paste an image URL to predict tags using the fine-tuned WD EVA02 Tagger model (PyTorch/Safetensors).")
533
  gr.Markdown(f"Model Repository: [{REPO_ID}](https://huggingface.co/{REPO_ID})")
534
 
535
  with gr.Row():
536
  with gr.Column(scale=1):
 
537
  image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
 
538
  gr.HTML("<div id='url-input-container'></div>")
 
539
  gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General Tag Threshold")
540
  char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
541
  output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
 
545
  output_tags = gr.Textbox(label="Predicted Tags", lines=10)
546
  output_visualization = gr.Image(type="pil", label="Prediction Visualization")
547
 
 
548
  gr.Examples(
549
  examples=[
550
  ["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", 0.55, 0.5, "Tags + Visualization"],
 
555
  inputs=[image_input, gen_threshold, char_threshold, output_mode],
556
  outputs=[output_tags, output_visualization],
557
  fn=predict,
558
+ cache_examples=False
559
  )
560
 
561
  predict_button.click(
 
564
  outputs=[output_tags, output_visualization]
565
  )
566
 
 
 
 
 
 
 
 
 
 
567
  if __name__ == "__main__":
 
568
  if not os.environ.get("HF_TOKEN"):
569
+ print("Warning: HF_TOKEN environment variable not set.")
 
570
  demo.launch(share=True)