Spaces:
Running
Running
import os.path | |
import datetime | |
import io | |
import PIL | |
import requests | |
from datasets import load_dataset, concatenate_datasets, Image | |
from data.lang2eng_map import lang2eng_mapping | |
from data.words_map import words_mapping | |
import gradio as gr | |
import bcrypt | |
from config.settings import HF_API_TOKEN | |
from huggingface_hub import snapshot_download | |
# from .blur import blur_faces, detect_faces | |
from retinaface import RetinaFace | |
from gradio_modal import Modal | |
import numpy as np | |
import cv2 | |
import time | |
import re | |
import os | |
import glob | |
def update_image(image_url): | |
try: | |
headers = {"User-Agent": "Mozilla/5.0"} | |
response = requests.get(image_url, headers=headers, timeout=10) | |
response.raise_for_status() | |
content_type = response.headers.get("Content-Type", "") | |
if "image" not in content_type: | |
gr.Error(f"⚠️ URL does not point to a valid image.", duration=5) | |
return "Error: URL does not point to a valid image." | |
img = PIL.Image.open(io.BytesIO(response.content)) | |
img = img.convert("RGB") | |
return img, Modal(visible=False) | |
except Exception as e: | |
# print(f"Error: {str(e)}") | |
if image_url is None or image_url == "": | |
return gr.Image(label="Image", elem_id="image_inp"), Modal(visible=False) | |
else: | |
return gr.Image(label="Image", value=None, elem_id="image_inp"), Modal(visible=True) | |
def update_timestamp(): | |
return gr.Textbox(datetime.datetime.now().timestamp(), label="Timestamp", visible=False) # FIXME visible=False) | |
def clear_data(): | |
return (None, None, None, None, None, gr.update(value=None), | |
gr.update(value=[]), gr.update(value=[]), gr.update(value=[]), | |
gr.update(value=[]), gr.update(value=[])) | |
def exit(): | |
return (None, None, None, gr.Dataset(samples=[]), gr.Markdown("**Loading your data, please wait ...**"), | |
gr.update(value=None), gr.update(value=None), [None, None, "", ""], gr.update(value=None), | |
gr.update(value=None), gr.update(value=None), | |
gr.update(value=None), gr.update(value=None), gr.update(value=None), | |
gr.update(value=None), gr.update(value=None)) | |
def validate_inputs(image, ori_img): # is_blurred | |
# Perform your validation logic here | |
if image is None: | |
return gr.Button("Submit", variant="primary", interactive=False), None, None, # False | |
# Define maximum dimensions | |
MAX_WIDTH = 1024 | |
MAX_HEIGHT = 1024 | |
# Get current dimensions | |
height, width = image.shape[:2] | |
# # Check if resizing is needed | |
# NOTE: for now, let's keep the full image resolution | |
# if width > MAX_WIDTH or height > MAX_HEIGHT: | |
# # Calculate scaling factor | |
# scale = min(MAX_WIDTH/width, MAX_HEIGHT/height) | |
# # Calculate new dimensions | |
# new_width = int(width * scale) | |
# new_height = int(height * scale) | |
# # Resize image while maintaining aspect ratio | |
# result_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) | |
# else: | |
# result_image = image | |
result_image = image | |
if ori_img is None: | |
# If the original image is None, set it to the resized image | |
ori_img = gr.State(result_image.copy()) | |
return gr.Button("Submit", variant="primary", interactive=True), result_image, ori_img # is_blurred | |
def add_prefix(example, column_name, prefix): | |
example[column_name] = (f"{prefix}/" + example[column_name]) | |
return example | |
def update_user_data(username, password, country, language_choice, HF_DATASET_NAME, local_ds_directory_path): | |
datasets_list = [] | |
# Try loading local dataset | |
try: | |
snapshot_download( | |
repo_id=HF_DATASET_NAME, | |
repo_type="dataset", | |
local_dir=local_ds_directory_path, # Your target local directory | |
allow_patterns=f"{country}/{language_choice}/{username}/*", # f"**/{username}/*" | |
token=HF_API_TOKEN | |
) | |
except Exception as e: | |
print(f"Snapshot download error: {e}") | |
# import pdb; pdb.set_trace() | |
if has_user_json(username, country, language_choice, local_ds_directory_path): | |
try: | |
# ds_local = load_dataset(local_ds_directory_path, data_files=f'logged_in_users/**/{username}/**/*.json') # This does not filter by country and language | |
ds_local = load_dataset(local_ds_directory_path, data_files=f'logged_in_users/{country}/{language_choice}/{username}/**/*.json') | |
ds_local = ds_local.remove_columns("image_file") | |
ds_local = ds_local.rename_column("image", "image_file") | |
ds_local = ds_local.map(add_prefix, fn_kwargs={"column_name": "image_file", "prefix": local_ds_directory_path}) | |
ds_local = ds_local.cast_column("image_file", Image()) | |
datasets_list.append(list(ds_local.values())[0]) | |
except Exception as e: | |
print(f"Local dataset load error: {e}") | |
# # Try loading hub dataset | |
# try: | |
# ds_hub = load_dataset(HF_DATASET_NAME, data_files=f'**/{username}/**/*.json', token=HF_API_TOKEN) | |
# ds_hub = ds_hub.cast_column("image_file", Image()) | |
# datasets_list.append(list(ds_hub.values())[0]) | |
# except Exception as e: | |
# print(f"Hub dataset load error: {e}") | |
# Handle all empty | |
if not datasets_list: | |
return gr.Dataset(samples=[]), gr.Markdown("<p style='color: red;'>No data available for this user. Please upload an image.</p>") | |
dataset = concatenate_datasets(datasets_list) | |
# TODO: we should link username with password and language and country, otherwise there will be an error when loading with different language and clicking on the example | |
if username and password: | |
user_dataset = dataset.filter(lambda x: x['username'] == username and is_password_correct(x['password'], password)) | |
user_dataset = user_dataset.sort('timestamp', reverse=True) | |
# Show only unique entries (most recent) | |
user_ids = set() | |
samples = [] | |
for d in user_dataset: | |
if d['id'] in user_ids: | |
continue | |
user_ids.add(d['id']) | |
if d['excluded']: | |
continue | |
# Get additional concepts by category or empty dict if not present | |
# additional_concepts_by_category = { | |
# "category1": d.get("category_1_concepts", []), | |
# "category2": d.get("category_2_concepts", []), | |
# "category3": d.get("category_3_concepts", []), | |
# "category4": d.get("category_4_concepts", []), | |
# "category5": d.get("category_5_concepts", []) | |
# } | |
additional_concepts_by_category = [ | |
d.get("category_1_concepts", [""]), | |
d.get("category_2_concepts", [""]), | |
d.get("category_3_concepts", [""]), | |
d.get("category_4_concepts", [""]), | |
d.get("category_5_concepts", [""]) | |
] | |
samples.append( | |
[ | |
d['image_file'], d['image_url'], d['caption'] or "", d['country'], | |
d['language'], d['category'], d['concept'], additional_concepts_by_category, d['id']] # d['is_blurred'] | |
) | |
return gr.Dataset(samples=samples), None | |
else: | |
# TODO: should we show the entire dataset instead? What about "other data" tab? | |
return gr.Dataset(samples=[]), None | |
def update_language(local_storage, metadata_dict, concepts_dict): | |
country, language, email, password, = local_storage | |
# my_translator = GoogleTranslator(source='english', target=metadata_dict[country][language]) | |
categories = concepts_dict[country][lang2eng_mapping.get(language, language)] | |
if language in words_mapping: | |
categories_keys_translated = [words_mapping[language].get(cat, cat) for cat in categories.keys()] | |
else: | |
categories_keys_translated = list(categories.keys()) | |
# Get the 5 categories in alphabetical order | |
categories_list = sorted(list(categories.keys()))[:5] | |
# Create translated labels for the 5 categories | |
translated_categories = [] | |
for cat in categories_list: | |
if language in words_mapping: | |
translated_cat = words_mapping[language].get(cat, cat) | |
else: | |
translated_cat = cat | |
translated_categories.append(translated_cat) | |
fn = metadata_dict[country][language]["Task"] | |
if os.path.exists(fn): | |
with open(fn, "r", encoding="utf-8") as f: | |
TASK_TEXT = f.read() | |
else: | |
fn = metadata_dict["USA"]["English"]["Task"] | |
with open(fn, "r", encoding="utf-8") as f: | |
TASK_TEXT = f.read() | |
fn = metadata_dict[country][language]["Instructions"] | |
if os.path.exists(fn): | |
with open(metadata_dict[country][language]["Instructions"], "r", encoding="utf-8") as f: | |
INST_TEXT = f.read() | |
else: | |
fn = metadata_dict["USA"]["English"]["Instructions"] | |
with open(fn, "r", encoding="utf-8") as f: | |
INST_TEXT = f.read() | |
return ( | |
gr.update(label=metadata_dict[country][language]["Country"], value=country), | |
gr.update(label=metadata_dict[country][language]["Language"], value=language), | |
gr.update(label=metadata_dict[country][language]["Email"], value=email), | |
gr.update(label=metadata_dict[country][language]["Password"], value=password), | |
gr.update(choices=categories_keys_translated, interactive=True, label=metadata_dict[country][language]["Category"], allow_custom_value=False, elem_id="category_btn"), | |
gr.update(choices=[], interactive=True, label=metadata_dict[country][language]["Concept"], allow_custom_value=True, elem_id="concept_btn"), | |
gr.update(label=metadata_dict[country][language]["Image"]), | |
gr.update(label=metadata_dict[country][language]["Image_URL"]), | |
gr.update(label=metadata_dict[country][language]["Description"]), | |
gr.Markdown(TASK_TEXT), | |
gr.Markdown(INST_TEXT), | |
gr.update(value=metadata_dict[country][language]["Instructs_btn"]), | |
gr.update(value=metadata_dict[country][language]["Clear_btn"]), | |
gr.update(value=metadata_dict[country][language]["Submit_btn"]), | |
gr.Markdown(metadata_dict[country][language]["Saving_text"]), | |
gr.Markdown(metadata_dict[country][language]["Saved_text"]), | |
gr.update(label=metadata_dict[country][language]["Timestamp"]), | |
gr.update(value=metadata_dict[country][language]["Exit_btn"]), | |
gr.Markdown(metadata_dict[country][language]["Browse_text"]), | |
gr.Markdown(metadata_dict[country][language]["Loading_msg"]), | |
# gr.update(choices=categories_keys_translated, interactive=True, label=metadata_dict[country][language].get("Add_Category","Additional Categories (Optional)"), allow_custom_value=False, elem_id="additional_category_btn"), | |
# gr.update(choices=[], interactive=True, label=metadata_dict[country][language].get("Add_Concept","Additional Concepts (Optional)"), allow_custom_value=True, elem_id="additional_concept_btn"), | |
gr.update(value=metadata_dict[country][language].get("Hide_all_btn","👤 Hide All Faces")), | |
gr.update(value=metadata_dict[country][language].get("Hide_btn","👤 Hide Specific Faces")), | |
gr.update(value=metadata_dict[country][language].get("Unhide_btn","👀 Unhide Faces")), | |
gr.update(value=metadata_dict[country][language].get("Exclude_btn","Exclude Selected Example")), | |
gr.update(label=translated_categories[0], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[0]])), | |
gr.update(label=translated_categories[1], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[1]])), | |
gr.update(label=translated_categories[2], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[2]])), | |
gr.update(label=translated_categories[3], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[3]])), | |
gr.update(label=translated_categories[4], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[4]])), | |
) | |
def update_intro_language(selected_country, selected_language, intro_markdown, metadata): | |
if selected_language is None: | |
return intro_markdown | |
fn = metadata[selected_country][selected_language]["Intro"] | |
if not os.path.exists(fn): | |
return intro_markdown | |
with open(metadata[selected_country][selected_language]["Intro"], "r", encoding="utf-8") as f: | |
INTRO_TEXT = f.read() | |
return gr.Markdown(INTRO_TEXT) | |
def handle_click_example(user_examples, concepts_dict): | |
print("handle_click_example") | |
print(user_examples) | |
ex = [item for item in user_examples] | |
# print(ex) | |
image_inp = ex[0] | |
image_url_inp = ex[1] | |
long_caption_inp = ex[2] | |
country_btn = ex[3] | |
language_btn = ex[4] | |
category_btn = ex[5] | |
concept_btn = ex[6] | |
additional_concepts_by_category = ex[7] | |
exampleid_btn = ex[8] | |
additional_concepts_by_category = [[] if (len(cat_concept)==1 and cat_concept[0]=='') else cat_concept for cat_concept in additional_concepts_by_category] | |
# import pdb; pdb.set_trace() | |
# # excluded_btn = ex[10] # TODO: add functionality that if True "exclude" button changes to "excluded" | |
# # is_blurred = ex[11] | |
# # Get predefined categories in the correct order | |
# predefined_categories = sorted(list(concepts_dict[country_btn][lang2eng_mapping.get(language_btn, language_btn)].keys()))[:5] | |
# # Create dropdown values for each category | |
# dropdown_values = [] | |
# for category in predefined_categories: | |
# if additional_concepts_by_category and category in additional_concepts_by_category: | |
# dropdown_values.append(additional_concepts_by_category[category]) | |
# else: | |
# dropdown_values.append(None) | |
### TODO: fix additional concepts not saving if categories in other language than English | |
# # Get the English version of the language | |
# eng_lang = lang2eng_mapping.get(language_btn, language_btn) | |
# # Get predefined categories in the correct order | |
# predefined_categories = sorted(list(concepts_dict[country_btn][eng_lang].keys()))[:5] | |
# # Create dropdown values for each category | |
# dropdown_values = [] | |
# for category in predefined_categories: | |
# if additional_concepts_by_category and category in additional_concepts_by_category: | |
# dropdown_values.append(additional_concepts_by_category[category]) | |
# else: | |
# dropdown_values.append(None) | |
# Need to return values for each category dropdown | |
return [image_inp, image_url_inp, long_caption_inp, exampleid_btn, category_btn, concept_btn] + additional_concepts_by_category + [True] | |
def is_password_correct(hashed_password, entered_password): | |
is_valid = bcrypt.checkpw(entered_password.encode(), hashed_password.encode()) | |
# print("password_check: ", entered_password," ", hashed_password," ", is_valid) | |
return is_valid | |
## Face blurring functions | |
def detect_faces(image): | |
""" | |
Detect faces in an image using RetinaFace. | |
Args: | |
image (numpy.ndarray): Input image in BGR | |
""" | |
# Start timer | |
start_time = time.time() | |
# Detect faces using RetinaFace | |
detection_start = time.time() | |
faces = RetinaFace.detect_faces(image, threshold=0.8) | |
detection_time = time.time() - detection_start | |
return faces, detection_time | |
# Hide Faces Button | |
def select_faces_to_hide(image, blur_faces_ids): | |
if image is None: | |
return None, Modal(visible=False), Modal(visible=False), None , "", None, gr.update(value=[]) | |
else: | |
# Detect faces | |
# import pdb; pdb.set_trace() | |
face_images = image.copy() | |
faces, detection_time = detect_faces(face_images) | |
print(f"Detection time: {detection_time:.2f} seconds") | |
# pdb.set_trace() | |
# Draw detections with IDs | |
for face_id, face_data in enumerate(faces.values(), start=1): | |
# Get face coordinates | |
facial_area = face_data['facial_area'] | |
x1, y1, x2, y2 = facial_area | |
# Draw rectangle around face | |
cv2.rectangle(face_images, (x1, y1), (x2, y2), (0, 0, 255), 2) | |
# Add ID text | |
cv2.putText(face_images, f"ID: {face_id}", (x1, y1 - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2) | |
# Update face count | |
face_count = len(faces) | |
blur_faces_ids = gr.update(choices=[f"Face ID: {i}" for i in range(1, face_count + 1)]) | |
current_faces_info = gr.State(faces) | |
if face_count == 0: | |
return image, Modal(visible=False), Modal(visible=True), None, "", None, gr.update(value=[]) | |
else: | |
return image, Modal(visible=True), Modal(visible=False), face_images, str(face_count), current_faces_info, blur_faces_ids # | |
def blur_selected_faces(image, blur_faces_ids, faces_info, face_img, faces_count): # is_blurred | |
if not blur_faces_ids: | |
return image, Modal(visible=True), face_img, faces_count, blur_faces_ids # is_blurred | |
faces = faces_info.value | |
parsed_faces_ids = blur_faces_ids | |
parsed_faces_ids = [f"face_{val.split(':')[-1].strip()}" for val in parsed_faces_ids] | |
# Base blur amount and bounds | |
MIN_BLUR = 31 # Minimum blur amount (must be odd) | |
MAX_BLUR = 131 # Maximum blur amount (must be odd) | |
blurring_start = time.time() | |
# Process each face | |
face_count = 0 | |
if faces and isinstance(faces, dict): | |
# blur by id | |
for face_key in parsed_faces_ids: | |
face_count += 1 | |
try: | |
face_data = faces[face_key] | |
except KeyError: | |
gr.Warning(f"⚠️ Face ID {face_key.split('_')[-1]} not found in detected faces.", duration=5) | |
return image, Modal(visible=True), face_img, faces_count, blur_faces_ids # is_blurred | |
# Get bounding box coordinates | |
x1, y1, x2, y2 = face_data['facial_area'] | |
# Calculate face region size | |
face_width = x2 - x1 | |
face_height = y2 - y1 | |
face_size = max(face_width, face_height) | |
# Calculate adaptive blur amount based on face size | |
# Scale blur amount between MIN_BLUR and MAX_BLUR based on face size | |
# Using image width as reference for scaling | |
img_width = image.shape[1] | |
blur_amount = int(MIN_BLUR + (MAX_BLUR - MIN_BLUR) * (face_size / img_width)) | |
# Ensure blur amount is odd | |
blur_amount = blur_amount if blur_amount % 2 == 1 else blur_amount + 1 | |
# Ensure within bounds | |
blur_amount = max(MIN_BLUR, min(MAX_BLUR, blur_amount)) | |
# Ensure the coordinates are within the image boundaries | |
ih, iw = image.shape[:2] | |
x1, y1 = max(0, x1), max(0, y1) | |
x2, y2 = min(iw, x2), min(ih, y2) | |
# Extract face region | |
face_region = image[y1:y2, x1:x2] | |
# Apply blur | |
blurred_face = cv2.GaussianBlur(face_region, (blur_amount, blur_amount), 0) | |
# Replace face region with blurred version | |
image[y1:y2, x1:x2] = blurred_face | |
blurring_time = time.time() - blurring_start | |
# Print timing information | |
print(f"Face blurring performance metrics:") | |
print(f"Face blurring time: {blurring_time:.4f} seconds") | |
if face_count == 0: | |
return image, Modal(visible=True), face_img, faces_count, blur_faces_ids | |
else: | |
return image, Modal(visible=False), None, None, gr.update(value=[]) | |
def blur_all_faces(image): | |
if image is None: | |
return None, Modal(visible=False) | |
else: | |
# Base blur amount and bounds | |
MIN_BLUR = 31 # Minimum blur amount (must be odd) | |
MAX_BLUR = 131 # Maximum blur amount (must be odd) | |
# Start timer | |
start_time = time.time() | |
# Detect faces using RetinaFace | |
detection_start = time.time() | |
faces = RetinaFace.detect_faces(image) | |
detection_time = time.time() - detection_start | |
# Create a copy of the image | |
output_image = image.copy() | |
face_count = 0 | |
blurring_start = time.time() | |
# Process each face | |
if faces and isinstance(faces, dict): | |
for face_key in faces: | |
face_count += 1 | |
face_data = faces[face_key] | |
# Get bounding box coordinates | |
x1, y1, x2, y2 = face_data['facial_area'] | |
# Calculate face region size | |
face_width = x2 - x1 | |
face_height = y2 - y1 | |
face_size = max(face_width, face_height) | |
# Calculate adaptive blur amount based on face size | |
# Scale blur amount between MIN_BLUR and MAX_BLUR based on face size | |
# Using image width as reference for scaling | |
img_width = image.shape[1] | |
blur_amount = int(MIN_BLUR + (MAX_BLUR - MIN_BLUR) * (face_size / img_width)) | |
# Ensure blur amount is odd | |
blur_amount = blur_amount if blur_amount % 2 == 1 else blur_amount + 1 | |
# Ensure within bounds | |
blur_amount = max(MIN_BLUR, min(MAX_BLUR, blur_amount)) | |
# Ensure the coordinates are within the image boundaries | |
ih, iw = image.shape[:2] | |
x1, y1 = max(0, x1), max(0, y1) | |
x2, y2 = min(iw, x2), min(ih, y2) | |
# Extract face region | |
face_region = output_image[y1:y2, x1:x2] | |
# Apply blur | |
blurred_face = cv2.GaussianBlur(face_region, (blur_amount, blur_amount), 0) | |
# Replace face region with blurred version | |
output_image[y1:y2, x1:x2] = blurred_face | |
blurring_time = time.time() - blurring_start | |
total_time = time.time() - start_time | |
# Print timing information | |
print(f"Face blurring performance metrics:") | |
print(f"Total faces detected: {face_count}") | |
print(f"Face detection time: {detection_time:.4f} seconds") | |
print(f"Face blurring time: {blurring_time:.4f} seconds") | |
print(f"Total processing time: {total_time:.4f} seconds") | |
print(f"Average time per face: {(total_time/max(1, face_count)):.4f} seconds") | |
if face_count == 0: | |
return image, Modal(visible=True) | |
else: | |
return output_image, Modal(visible=False) | |
def unhide_faces(img, ori_img): # is_blurred | |
if img is None: | |
return None | |
elif np.array_equal(img, ori_img.value): | |
return img # is_blurred | |
else: | |
return ori_img.value | |
def check_exclude_fn(image): | |
if image is None: | |
gr.Warning("⚠️ No image to exclude.") | |
return gr.update(visible=False) | |
else: | |
return gr.update(visible=True) | |
def has_user_json(username, country,language_choice, local_ds_directory_path): | |
"""Check if JSON files exist for username pattern.""" | |
return bool(glob.glob(os.path.join(local_ds_directory_path, "logged_in_users", country, language_choice, username, "**", "*.json"), recursive=True)) |