cella110n commited on
Commit
d9eba7a
·
verified ·
1 Parent(s): 7997743

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -86
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- import spaces
3
  import onnxruntime as ort
4
  import numpy as np
5
  from PIL import Image, ImageDraw, ImageFont
@@ -12,6 +12,8 @@ import matplotlib
12
  from huggingface_hub import hf_hub_download
13
  from dataclasses import dataclass
14
  from typing import List, Dict, Optional, Tuple
 
 
15
 
16
  # MatplotlibのバックエンドをAggに設定 (GUIなし環境用)
17
  matplotlib.use('Agg')
@@ -293,7 +295,8 @@ TAG_MAPPING_FILENAME = "cl_eva02_tagger_v1_250426/tag_mapping.json"
293
  CACHE_DIR = "./model_cache"
294
 
295
  # グローバル変数(モデルとラベルをキャッシュ)
296
- onnx_session = None
 
297
  labels_data = None
298
  tag_to_category_map = None
299
 
@@ -319,111 +322,98 @@ def download_model_files():
319
 
320
 
321
  def initialize_model():
322
- """モデルとラベルデータを初期化(キャッシュ)"""
323
- global onnx_session, labels_data, tag_to_category_map
324
- if onnx_session is None:
 
 
325
  model_path, tag_mapping_path = download_model_files()
326
- print("Loading model and labels...")
327
-
328
- # --- Added Logging ---
329
- print("--- Environment Check ---")
330
- try:
331
- import torch
332
- print(f"PyTorch version: {torch.__version__}")
333
- if torch.cuda.is_available():
334
- print(f"PyTorch CUDA available: True")
335
- print(f"PyTorch CUDA version: {torch.version.cuda}")
336
- print(f"Detected GPU: {torch.cuda.get_device_name(0)}")
337
- if torch.backends.cudnn.is_available():
338
- print(f"PyTorch cuDNN available: True")
339
- print(f"PyTorch cuDNN version: {torch.backends.cudnn.version()}")
340
- else:
341
- print("PyTorch cuDNN available: False")
342
- else:
343
- print("PyTorch CUDA available: False")
344
- except ImportError:
345
- print("PyTorch not found.")
346
- except Exception as e:
347
- print(f"Error during PyTorch check: {e}")
348
-
349
- try:
350
- print(f"ONNX Runtime build info: {ort.get_buildinfo()}")
351
- except Exception as e:
352
- print(f"Error getting ONNX Runtime build info: {e}")
353
- print("-------------------------")
354
- # --- End Added Logging ---
355
-
356
- # ONNXセッションの初期化 (GPU優先)
357
- available_providers = ort.get_available_providers()
358
- print(f"Available ONNX Runtime providers: {available_providers}")
359
- providers = []
360
- if 'CUDAExecutionProvider' in available_providers:
361
- providers.append('CUDAExecutionProvider')
362
- # elif 'DmlExecutionProvider' in available_providers: # DirectML (Windows)
363
- # providers.append('DmlExecutionProvider')
364
- providers.append('CPUExecutionProvider') # Always include CPU as fallback
365
-
366
- try:
367
- onnx_session = ort.InferenceSession(model_path, providers=providers)
368
- print(f"Using ONNX Runtime provider: {onnx_session.get_providers()[0]}")
369
- except Exception as e:
370
- print(f"Error initializing ONNX session with providers {providers}: {e}")
371
- print("Falling back to CPUExecutionProvider only.")
372
- onnx_session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
373
-
374
  labels_data, _, tag_to_category_map = load_tag_mapping(tag_mapping_path)
375
- print("Model and labels loaded.")
 
376
 
377
  @spaces.GPU()
378
  def predict(image_input, gen_threshold, char_threshold, output_mode):
379
- print("--- predict function started ---") # Add log here
380
- """Gradioインターフェース用の予測関数"""
381
- initialize_model() # モデルがロードされていなければロード
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  if image_input is None:
384
  return "Please upload an image.", None
385
 
386
- print(f"Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
387
 
388
  # PIL Imageオブジェクトであることを確認
389
  if not isinstance(image_input, Image.Image):
390
- try:
391
- # URLの場合
392
- if isinstance(image_input, str) and image_input.startswith("http"):
393
- response = requests.get(image_input)
394
- response.raise_for_status()
395
- image = Image.open(io.BytesIO(response.content))
396
- # ファイルパスの場合 (Gradioでは通常発生しないが念のため)
397
- elif isinstance(image_input, str) and os.path.exists(image_input):
398
- image = Image.open(image_input)
399
- # Numpy配列の場合 (Gradio Imageコンポーネントからの入力)
400
- elif isinstance(image_input, np.ndarray):
401
- image = Image.fromarray(image_input)
402
- else:
403
- raise ValueError("Unsupported image input type")
404
- except Exception as e:
405
- print(f"Error loading image: {e}")
406
- return f"Error loading image: {e}", None
407
  else:
408
  image = image_input
409
 
410
-
411
  # 前処理
412
  original_pil_image, input_data = preprocess_image(image)
413
 
414
  # データ型をモデルの期待に合わせる (通常はfloat32)
415
- input_name = onnx_session.get_inputs()[0].name
416
- expected_type = onnx_session.get_inputs()[0].type
417
  if expected_type == 'tensor(float16)':
418
  input_data = input_data.astype(np.float16)
419
  else:
420
  input_data = input_data.astype(np.float32) # Default to float32
421
 
422
- # 推論
423
  start_time = time.time()
424
- outputs = onnx_session.run(None, {input_name: input_data})[0]
425
  inference_time = time.time() - start_time
426
- print(f"Inference completed in {inference_time:.3f} seconds")
427
 
428
  # シグモイド関数で確率に変換
429
  probs = 1 / (1 + np.exp(-outputs[0])) # Apply sigmoid to the first batch item
@@ -437,12 +427,12 @@ def predict(image_input, gen_threshold, char_threshold, output_mode):
437
  if predictions["rating"]:
438
  output_tags.append(predictions["rating"][0][0].replace("_", " "))
439
  if predictions["quality"]:
440
- output_tags.append(predictions["quality"][0][0].replace("_", " "))
441
 
442
  # 残りのカテゴリをアルファベット順に追加(オプション)
443
  for category in ["artist", "character", "copyright", "general", "meta"]:
444
  tags = [tag.replace("_", " ") for tag, prob in predictions[category]
445
- if not (category == "meta" and any(p in tag.lower() for p in ['id', 'commentary','mismatch']))] # メタタグフィルタリング
446
  output_tags.extend(tags)
447
 
448
  output_text = ", ".join(output_tags)
@@ -454,7 +444,6 @@ def predict(image_input, gen_threshold, char_threshold, output_mode):
454
  return output_text, viz_image
455
 
456
  # --- Gradio Interface Definition ---
457
- import time
458
 
459
  # CSS for styling
460
  css = """
@@ -594,6 +583,5 @@ if __name__ == "__main__":
594
  # 環境変数HF_TOKENがない場合に警告(プライベートリポジトリ用)
595
  if not os.environ.get("HF_TOKEN"):
596
  print("Warning: HF_TOKEN environment variable not set. Downloads from private repositories may fail.")
597
- # Initialize model on startup to avoid delay on first prediction
598
- initialize_model() # Removed startup initialization
599
  demo.launch(share=True)
 
1
  import gradio as gr
2
+ # import spaces # Removed
3
  import onnxruntime as ort
4
  import numpy as np
5
  from PIL import Image, ImageDraw, ImageFont
 
12
  from huggingface_hub import hf_hub_download
13
  from dataclasses import dataclass
14
  from typing import List, Dict, Optional, Tuple
15
+ import time
16
+ import spaces
17
 
18
  # MatplotlibのバックエンドをAggに設定 (GUIなし環境用)
19
  matplotlib.use('Agg')
 
295
  CACHE_DIR = "./model_cache"
296
 
297
  # グローバル変数(モデルとラベルをキャッシュ)
298
+ # onnx_session = None # Removed global session
299
+ model_path_global = None # Store model path globally
300
  labels_data = None
301
  tag_to_category_map = None
302
 
 
322
 
323
 
324
  def initialize_model():
325
+ """モデルファイルとラベルデータを準備(キャッシュ)"""
326
+ global model_path_global, labels_data, tag_to_category_map
327
+ # Only initialize once
328
+ if labels_data is None:
329
+ print("Downloading model files...") # Moved print here
330
  model_path, tag_mapping_path = download_model_files()
331
+ model_path_global = model_path # Store the path
332
+ print("Loading labels...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  labels_data, _, tag_to_category_map = load_tag_mapping(tag_mapping_path)
334
+ print("Labels loaded.")
335
+ # --- Removed ONNX Session Initialization ---
336
 
337
  @spaces.GPU()
338
  def predict(image_input, gen_threshold, char_threshold, output_mode):
339
+ print("--- predict function started (GPU worker) ---")
340
+ """Gradioインターフェース用の予測関数 (GPUワーカー内)"""
341
+ initialize_model() # Ensure files/labels are ready
342
+
343
+ # --- Create ONNX session inside the GPU function ---
344
+ print("Creating ONNX session for prediction...")
345
+ global model_path_global # Access the global model path
346
+ if model_path_global is None:
347
+ # Attempt initialization again if model path is missing (e.g., after restart)
348
+ initialize_model()
349
+ if model_path_global is None:
350
+ return "Error: Model path could not be initialized.", None
351
+
352
+ available_providers = ort.get_available_providers()
353
+ print(f"(Worker) Available ONNX Runtime providers: {available_providers}")
354
+ providers = []
355
+ if 'CUDAExecutionProvider' in available_providers:
356
+ providers.append('CUDAExecutionProvider')
357
+ providers.append('CPUExecutionProvider') # Always include CPU as fallback
358
+
359
+ try:
360
+ # Create session with GPU preference inside the worker
361
+ session = ort.InferenceSession(model_path_global, providers=providers)
362
+ print(f"(Worker) Using ONNX Runtime provider: {session.get_providers()[0]}")
363
+ except Exception as e:
364
+ print(f"(Worker) Error initializing ONNX session with providers {providers}: {e}")
365
+ # Fallback explicitly to CPU if GPU fails inside worker
366
+ try:
367
+ print("(Worker) Falling back to CPUExecutionProvider only.")
368
+ session = ort.InferenceSession(model_path_global, providers=['CPUExecutionProvider'])
369
+ except Exception as e_cpu:
370
+ print(f"(Worker) Error initializing ONNX session even with CPU: {e_cpu}")
371
+ return f"Error initializing ONNX session: {e_cpu}", None
372
+ # --- Session created ---
373
 
374
  if image_input is None:
375
  return "Please upload an image.", None
376
 
377
+ print(f"(Worker) Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
378
 
379
  # PIL Imageオブジェクトであることを確認
380
  if not isinstance(image_input, Image.Image):
381
+ try:
382
+ # URLの場合
383
+ if isinstance(image_input, str) and image_input.startswith("http"):
384
+ response = requests.get(image_input)
385
+ response.raise_for_status()
386
+ image = Image.open(io.BytesIO(response.content))
387
+ # ファイルパスの場合 (Gradioでは通常発生しないが念のため)
388
+ elif isinstance(image_input, str) and os.path.exists(image_input):
389
+ image = Image.open(image_input)
390
+ # Numpy配列の場合 (Gradio Imageコンポーネントからの入力)
391
+ elif isinstance(image_input, np.ndarray):
392
+ image = Image.fromarray(image_input)
393
+ else:
394
+ raise ValueError("Unsupported image input type")
395
+ except Exception as e:
396
+ print(f"(Worker) Error loading image: {e}")
397
+ return f"Error loading image: {e}", None
398
  else:
399
  image = image_input
400
 
 
401
  # 前処理
402
  original_pil_image, input_data = preprocess_image(image)
403
 
404
  # データ型をモデルの期待に合わせる (通常はfloat32)
405
+ input_name = session.get_inputs()[0].name
406
+ expected_type = session.get_inputs()[0].type
407
  if expected_type == 'tensor(float16)':
408
  input_data = input_data.astype(np.float16)
409
  else:
410
  input_data = input_data.astype(np.float32) # Default to float32
411
 
412
+ # 推論 (作成したセッションを使用)
413
  start_time = time.time()
414
+ outputs = session.run(None, {input_name: input_data})[0]
415
  inference_time = time.time() - start_time
416
+ print(f"(Worker) Inference completed in {inference_time:.3f} seconds")
417
 
418
  # シグモイド関数で確率に変換
419
  probs = 1 / (1 + np.exp(-outputs[0])) # Apply sigmoid to the first batch item
 
427
  if predictions["rating"]:
428
  output_tags.append(predictions["rating"][0][0].replace("_", " "))
429
  if predictions["quality"]:
430
+ output_tags.append(predictions["quality"][0][0].replace("_", " "))
431
 
432
  # 残りのカテゴリをアルファベット順に追加(オプション)
433
  for category in ["artist", "character", "copyright", "general", "meta"]:
434
  tags = [tag.replace("_", " ") for tag, prob in predictions[category]
435
+ if not (category == "meta" and any(p in tag.lower() for p in ['id', 'commentary','mismatch']))] # メタタグフィルタリング
436
  output_tags.extend(tags)
437
 
438
  output_text = ", ".join(output_tags)
 
444
  return output_text, viz_image
445
 
446
  # --- Gradio Interface Definition ---
 
447
 
448
  # CSS for styling
449
  css = """
 
583
  # 環境変数HF_TOKENがない場合に警告(プライベートリポジトリ用)
584
  if not os.environ.get("HF_TOKEN"):
585
  print("Warning: HF_TOKEN environment variable not set. Downloads from private repositories may fail.")
586
+ # initialize_model() # Removed startup initialization (model loaded in predict)
 
587
  demo.launch(share=True)