Upload app.py
Browse files
app.py
CHANGED
@@ -33,6 +33,7 @@ class LabelData:
|
|
33 |
copyright: list[np.int64]
|
34 |
meta: list[np.int64]
|
35 |
quality: list[np.int64]
|
|
|
36 |
|
37 |
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
|
38 |
if image.mode not in ["RGB", "RGBA"]:
|
@@ -71,7 +72,7 @@ def load_tag_mapping(mapping_path):
|
|
71 |
raise ValueError("Unsupported tag mapping format: Expected a dictionary.")
|
72 |
|
73 |
names = [None] * (max(idx_to_tag.keys()) + 1)
|
74 |
-
rating, general, artist, character, copyright, meta, quality = [], [], [], [], [], [], []
|
75 |
for idx, tag in idx_to_tag.items():
|
76 |
if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
|
77 |
names[idx] = tag
|
@@ -84,9 +85,10 @@ def load_tag_mapping(mapping_path):
|
|
84 |
elif category == 'Copyright': copyright.append(idx_int)
|
85 |
elif category == 'Meta': meta.append(idx_int)
|
86 |
elif category == 'Quality': quality.append(idx_int)
|
|
|
87 |
|
88 |
return LabelData(names=names, rating=np.array(rating, dtype=np.int64), general=np.array(general, dtype=np.int64), artist=np.array(artist, dtype=np.int64),
|
89 |
-
character=np.array(character, dtype=np.int64), copyright=np.array(copyright, dtype=np.int64), meta=np.array(meta, dtype=np.int64), quality=np.array(quality, dtype=np.int64)), idx_to_tag, tag_to_category
|
90 |
|
91 |
def preprocess_image(image: Image.Image, target_size=(448, 448)):
|
92 |
# Adapted from onnx_predict.py's version
|
@@ -112,7 +114,8 @@ def get_tags(probs, labels: LabelData, gen_threshold, char_threshold):
|
|
112 |
"copyright": [],
|
113 |
"artist": [],
|
114 |
"meta": [],
|
115 |
-
"quality": []
|
|
|
116 |
}
|
117 |
# Rating (select max)
|
118 |
if len(labels.rating) > 0:
|
@@ -160,7 +163,9 @@ def get_tags(probs, labels: LabelData, gen_threshold, char_threshold):
|
|
160 |
"character": (labels.character, char_threshold),
|
161 |
"copyright": (labels.copyright, char_threshold),
|
162 |
"artist": (labels.artist, char_threshold),
|
163 |
-
"meta": (labels.meta, gen_threshold)
|
|
|
|
|
164 |
}
|
165 |
for category, (indices, threshold) in category_map.items():
|
166 |
if len(indices) > 0:
|
@@ -205,7 +210,7 @@ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: floa
|
|
205 |
all_tags, all_probs, all_colors = [], [], []
|
206 |
color_map = {
|
207 |
'rating': 'red', 'character': 'blue', 'copyright': 'purple',
|
208 |
-
'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow'
|
209 |
}
|
210 |
|
211 |
# Aggregate tags from predictions dictionary
|
@@ -213,7 +218,7 @@ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: floa
|
|
213 |
('rating', 'R', color_map['rating']), ('quality', 'Q', color_map['quality']),
|
214 |
('character', 'C', color_map['character']), ('copyright', '©', color_map['copyright']),
|
215 |
('artist', 'A', color_map['artist']), ('general', 'G', color_map['general']),
|
216 |
-
('meta', 'M', color_map['meta'])
|
217 |
]:
|
218 |
sorted_tags = sorted(predictions.get(cat, []), key=lambda x: x[1], reverse=True)
|
219 |
for tag, prob in sorted_tags:
|
@@ -279,9 +284,10 @@ MODEL_OPTIONS = {
|
|
279 |
"cl_eva02_tagger_v1_250517": "cl_eva02_tagger_v1_250517/model.onnx",
|
280 |
"cl_eva02_tagger_v1_250518": "cl_eva02_tagger_v1_250518/model.onnx",
|
281 |
"cl_eva02_tagger_v1_250520": "cl_eva02_tagger_v1_250520/model.onnx",
|
282 |
-
"cl_eva02_tagger_v1_250522": "cl_eva02_tagger_v1_250522/model.onnx"
|
|
|
283 |
}
|
284 |
-
DEFAULT_MODEL = "
|
285 |
CACHE_DIR = "./model_cache"
|
286 |
|
287 |
# --- Global variables for paths (initialized at startup) ---
|
@@ -461,7 +467,7 @@ def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, outpu
|
|
461 |
if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
|
462 |
if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
|
463 |
# Add other categories, respecting order and filtering meta if needed
|
464 |
-
for category in ["artist", "character", "copyright", "general", "meta"]:
|
465 |
tags_in_category = predictions.get(category, [])
|
466 |
for tag, prob in tags_in_category:
|
467 |
# Basic meta tag filtering for text output
|
@@ -511,7 +517,7 @@ with gr.Blocks(css=css) as demo:
|
|
511 |
label="Model Version",
|
512 |
interactive=True
|
513 |
)
|
514 |
-
gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General/Meta Tag Threshold")
|
515 |
char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
|
516 |
output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
|
517 |
predict_button = gr.Button("Predict", variant="primary")
|
|
|
33 |
copyright: list[np.int64]
|
34 |
meta: list[np.int64]
|
35 |
quality: list[np.int64]
|
36 |
+
model: list[np.int64]
|
37 |
|
38 |
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
|
39 |
if image.mode not in ["RGB", "RGBA"]:
|
|
|
72 |
raise ValueError("Unsupported tag mapping format: Expected a dictionary.")
|
73 |
|
74 |
names = [None] * (max(idx_to_tag.keys()) + 1)
|
75 |
+
rating, general, artist, character, copyright, meta, quality, model_name = [], [], [], [], [], [], [], []
|
76 |
for idx, tag in idx_to_tag.items():
|
77 |
if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
|
78 |
names[idx] = tag
|
|
|
85 |
elif category == 'Copyright': copyright.append(idx_int)
|
86 |
elif category == 'Meta': meta.append(idx_int)
|
87 |
elif category == 'Quality': quality.append(idx_int)
|
88 |
+
elif category == 'Model': model_name.append(idx_int)
|
89 |
|
90 |
return LabelData(names=names, rating=np.array(rating, dtype=np.int64), general=np.array(general, dtype=np.int64), artist=np.array(artist, dtype=np.int64),
|
91 |
+
character=np.array(character, dtype=np.int64), copyright=np.array(copyright, dtype=np.int64), meta=np.array(meta, dtype=np.int64), quality=np.array(quality, dtype=np.int64), model=np.array(model_name, dtype=np.int64)), idx_to_tag, tag_to_category
|
92 |
|
93 |
def preprocess_image(image: Image.Image, target_size=(448, 448)):
|
94 |
# Adapted from onnx_predict.py's version
|
|
|
114 |
"copyright": [],
|
115 |
"artist": [],
|
116 |
"meta": [],
|
117 |
+
"quality": [],
|
118 |
+
"model": []
|
119 |
}
|
120 |
# Rating (select max)
|
121 |
if len(labels.rating) > 0:
|
|
|
163 |
"character": (labels.character, char_threshold),
|
164 |
"copyright": (labels.copyright, char_threshold),
|
165 |
"artist": (labels.artist, char_threshold),
|
166 |
+
"meta": (labels.meta, gen_threshold),
|
167 |
+
"quality": (labels.quality, gen_threshold),
|
168 |
+
"model": (labels.model, gen_threshold)
|
169 |
}
|
170 |
for category, (indices, threshold) in category_map.items():
|
171 |
if len(indices) > 0:
|
|
|
210 |
all_tags, all_probs, all_colors = [], [], []
|
211 |
color_map = {
|
212 |
'rating': 'red', 'character': 'blue', 'copyright': 'purple',
|
213 |
+
'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow', 'model': 'cyan'
|
214 |
}
|
215 |
|
216 |
# Aggregate tags from predictions dictionary
|
|
|
218 |
('rating', 'R', color_map['rating']), ('quality', 'Q', color_map['quality']),
|
219 |
('character', 'C', color_map['character']), ('copyright', '©', color_map['copyright']),
|
220 |
('artist', 'A', color_map['artist']), ('general', 'G', color_map['general']),
|
221 |
+
('meta', 'M', color_map['meta']), ('model', 'M', color_map['model'])
|
222 |
]:
|
223 |
sorted_tags = sorted(predictions.get(cat, []), key=lambda x: x[1], reverse=True)
|
224 |
for tag, prob in sorted_tags:
|
|
|
284 |
"cl_eva02_tagger_v1_250517": "cl_eva02_tagger_v1_250517/model.onnx",
|
285 |
"cl_eva02_tagger_v1_250518": "cl_eva02_tagger_v1_250518/model.onnx",
|
286 |
"cl_eva02_tagger_v1_250520": "cl_eva02_tagger_v1_250520/model.onnx",
|
287 |
+
"cl_eva02_tagger_v1_250522": "cl_eva02_tagger_v1_250522/model.onnx",
|
288 |
+
"cl_eva02_tagger_v1_250523": "cl_eva02_tagger_v1_250523/model.onnx"
|
289 |
}
|
290 |
+
DEFAULT_MODEL = "cl_eva02_tagger_v1_250523"
|
291 |
CACHE_DIR = "./model_cache"
|
292 |
|
293 |
# --- Global variables for paths (initialized at startup) ---
|
|
|
467 |
if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
|
468 |
if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
|
469 |
# Add other categories, respecting order and filtering meta if needed
|
470 |
+
for category in ["artist", "character", "copyright", "general", "meta", "model"]:
|
471 |
tags_in_category = predictions.get(category, [])
|
472 |
for tag, prob in tags_in_category:
|
473 |
# Basic meta tag filtering for text output
|
|
|
517 |
label="Model Version",
|
518 |
interactive=True
|
519 |
)
|
520 |
+
gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General/Meta/Model Tag Threshold")
|
521 |
char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
|
522 |
output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
|
523 |
predict_button = gr.Button("Predict", variant="primary")
|