cella110n commited on
Commit
9f60fec
·
verified ·
1 Parent(s): 9d963ab

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -33,6 +33,7 @@ class LabelData:
33
  copyright: list[np.int64]
34
  meta: list[np.int64]
35
  quality: list[np.int64]
 
36
 
37
  def pil_ensure_rgb(image: Image.Image) -> Image.Image:
38
  if image.mode not in ["RGB", "RGBA"]:
@@ -71,7 +72,7 @@ def load_tag_mapping(mapping_path):
71
  raise ValueError("Unsupported tag mapping format: Expected a dictionary.")
72
 
73
  names = [None] * (max(idx_to_tag.keys()) + 1)
74
- rating, general, artist, character, copyright, meta, quality = [], [], [], [], [], [], []
75
  for idx, tag in idx_to_tag.items():
76
  if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
77
  names[idx] = tag
@@ -84,9 +85,10 @@ def load_tag_mapping(mapping_path):
84
  elif category == 'Copyright': copyright.append(idx_int)
85
  elif category == 'Meta': meta.append(idx_int)
86
  elif category == 'Quality': quality.append(idx_int)
 
87
 
88
  return LabelData(names=names, rating=np.array(rating, dtype=np.int64), general=np.array(general, dtype=np.int64), artist=np.array(artist, dtype=np.int64),
89
- character=np.array(character, dtype=np.int64), copyright=np.array(copyright, dtype=np.int64), meta=np.array(meta, dtype=np.int64), quality=np.array(quality, dtype=np.int64)), idx_to_tag, tag_to_category
90
 
91
  def preprocess_image(image: Image.Image, target_size=(448, 448)):
92
  # Adapted from onnx_predict.py's version
@@ -112,7 +114,8 @@ def get_tags(probs, labels: LabelData, gen_threshold, char_threshold):
112
  "copyright": [],
113
  "artist": [],
114
  "meta": [],
115
- "quality": []
 
116
  }
117
  # Rating (select max)
118
  if len(labels.rating) > 0:
@@ -160,7 +163,9 @@ def get_tags(probs, labels: LabelData, gen_threshold, char_threshold):
160
  "character": (labels.character, char_threshold),
161
  "copyright": (labels.copyright, char_threshold),
162
  "artist": (labels.artist, char_threshold),
163
- "meta": (labels.meta, gen_threshold) # Use gen_threshold for meta as per original code
 
 
164
  }
165
  for category, (indices, threshold) in category_map.items():
166
  if len(indices) > 0:
@@ -205,7 +210,7 @@ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: floa
205
  all_tags, all_probs, all_colors = [], [], []
206
  color_map = {
207
  'rating': 'red', 'character': 'blue', 'copyright': 'purple',
208
- 'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow'
209
  }
210
 
211
  # Aggregate tags from predictions dictionary
@@ -213,7 +218,7 @@ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: floa
213
  ('rating', 'R', color_map['rating']), ('quality', 'Q', color_map['quality']),
214
  ('character', 'C', color_map['character']), ('copyright', '©', color_map['copyright']),
215
  ('artist', 'A', color_map['artist']), ('general', 'G', color_map['general']),
216
- ('meta', 'M', color_map['meta'])
217
  ]:
218
  sorted_tags = sorted(predictions.get(cat, []), key=lambda x: x[1], reverse=True)
219
  for tag, prob in sorted_tags:
@@ -279,9 +284,10 @@ MODEL_OPTIONS = {
279
  "cl_eva02_tagger_v1_250517": "cl_eva02_tagger_v1_250517/model.onnx",
280
  "cl_eva02_tagger_v1_250518": "cl_eva02_tagger_v1_250518/model.onnx",
281
  "cl_eva02_tagger_v1_250520": "cl_eva02_tagger_v1_250520/model.onnx",
282
- "cl_eva02_tagger_v1_250522": "cl_eva02_tagger_v1_250522/model.onnx"
 
283
  }
284
- DEFAULT_MODEL = "cl_eva02_tagger_v1_250522"
285
  CACHE_DIR = "./model_cache"
286
 
287
  # --- Global variables for paths (initialized at startup) ---
@@ -461,7 +467,7 @@ def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, outpu
461
  if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
462
  if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
463
  # Add other categories, respecting order and filtering meta if needed
464
- for category in ["artist", "character", "copyright", "general", "meta"]:
465
  tags_in_category = predictions.get(category, [])
466
  for tag, prob in tags_in_category:
467
  # Basic meta tag filtering for text output
@@ -511,7 +517,7 @@ with gr.Blocks(css=css) as demo:
511
  label="Model Version",
512
  interactive=True
513
  )
514
- gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General/Meta Tag Threshold")
515
  char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
516
  output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
517
  predict_button = gr.Button("Predict", variant="primary")
 
33
  copyright: list[np.int64]
34
  meta: list[np.int64]
35
  quality: list[np.int64]
36
+ model: list[np.int64]
37
 
38
  def pil_ensure_rgb(image: Image.Image) -> Image.Image:
39
  if image.mode not in ["RGB", "RGBA"]:
 
72
  raise ValueError("Unsupported tag mapping format: Expected a dictionary.")
73
 
74
  names = [None] * (max(idx_to_tag.keys()) + 1)
75
+ rating, general, artist, character, copyright, meta, quality, model_name = [], [], [], [], [], [], [], []
76
  for idx, tag in idx_to_tag.items():
77
  if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
78
  names[idx] = tag
 
85
  elif category == 'Copyright': copyright.append(idx_int)
86
  elif category == 'Meta': meta.append(idx_int)
87
  elif category == 'Quality': quality.append(idx_int)
88
+ elif category == 'Model': model_name.append(idx_int)
89
 
90
  return LabelData(names=names, rating=np.array(rating, dtype=np.int64), general=np.array(general, dtype=np.int64), artist=np.array(artist, dtype=np.int64),
91
+ character=np.array(character, dtype=np.int64), copyright=np.array(copyright, dtype=np.int64), meta=np.array(meta, dtype=np.int64), quality=np.array(quality, dtype=np.int64), model=np.array(model_name, dtype=np.int64)), idx_to_tag, tag_to_category
92
 
93
  def preprocess_image(image: Image.Image, target_size=(448, 448)):
94
  # Adapted from onnx_predict.py's version
 
114
  "copyright": [],
115
  "artist": [],
116
  "meta": [],
117
+ "quality": [],
118
+ "model": []
119
  }
120
  # Rating (select max)
121
  if len(labels.rating) > 0:
 
163
  "character": (labels.character, char_threshold),
164
  "copyright": (labels.copyright, char_threshold),
165
  "artist": (labels.artist, char_threshold),
166
+ "meta": (labels.meta, gen_threshold),
167
+ "quality": (labels.quality, gen_threshold),
168
+ "model": (labels.model, gen_threshold)
169
  }
170
  for category, (indices, threshold) in category_map.items():
171
  if len(indices) > 0:
 
210
  all_tags, all_probs, all_colors = [], [], []
211
  color_map = {
212
  'rating': 'red', 'character': 'blue', 'copyright': 'purple',
213
+ 'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow', 'model': 'cyan'
214
  }
215
 
216
  # Aggregate tags from predictions dictionary
 
218
  ('rating', 'R', color_map['rating']), ('quality', 'Q', color_map['quality']),
219
  ('character', 'C', color_map['character']), ('copyright', '©', color_map['copyright']),
220
  ('artist', 'A', color_map['artist']), ('general', 'G', color_map['general']),
221
+ ('meta', 'M', color_map['meta']), ('model', 'M', color_map['model'])
222
  ]:
223
  sorted_tags = sorted(predictions.get(cat, []), key=lambda x: x[1], reverse=True)
224
  for tag, prob in sorted_tags:
 
284
  "cl_eva02_tagger_v1_250517": "cl_eva02_tagger_v1_250517/model.onnx",
285
  "cl_eva02_tagger_v1_250518": "cl_eva02_tagger_v1_250518/model.onnx",
286
  "cl_eva02_tagger_v1_250520": "cl_eva02_tagger_v1_250520/model.onnx",
287
+ "cl_eva02_tagger_v1_250522": "cl_eva02_tagger_v1_250522/model.onnx",
288
+ "cl_eva02_tagger_v1_250523": "cl_eva02_tagger_v1_250523/model.onnx"
289
  }
290
+ DEFAULT_MODEL = "cl_eva02_tagger_v1_250523"
291
  CACHE_DIR = "./model_cache"
292
 
293
  # --- Global variables for paths (initialized at startup) ---
 
467
  if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
468
  if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
469
  # Add other categories, respecting order and filtering meta if needed
470
+ for category in ["artist", "character", "copyright", "general", "meta", "model"]:
471
  tags_in_category = predictions.get(category, [])
472
  for tag, prob in tags_in_category:
473
  # Basic meta tag filtering for text output
 
517
  label="Model Version",
518
  interactive=True
519
  )
520
+ gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General/Meta/Model Tag Threshold")
521
  char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
522
  output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
523
  predict_button = gr.Button("Predict", variant="primary")