internal-v0 / logic /handlers.py
carlosh93's picture
fixing multiuser issue
3bc3503
import os.path
import datetime
import io
import PIL
import requests
from datasets import load_dataset, concatenate_datasets, Image
from data.lang2eng_map import lang2eng_mapping
from data.words_map import words_mapping
import gradio as gr
import bcrypt
from config.settings import HF_API_TOKEN
from huggingface_hub import snapshot_download
# from .blur import blur_faces, detect_faces
from retinaface import RetinaFace
from gradio_modal import Modal
import numpy as np
import cv2
import time
import re
import os
import glob
def update_image(image_url):
try:
headers = {"User-Agent": "Mozilla/5.0"}
response = requests.get(image_url, headers=headers, timeout=10)
response.raise_for_status()
content_type = response.headers.get("Content-Type", "")
if "image" not in content_type:
gr.Error(f"⚠️ URL does not point to a valid image.", duration=5)
return "Error: URL does not point to a valid image."
img = PIL.Image.open(io.BytesIO(response.content))
img = img.convert("RGB")
return img, Modal(visible=False)
except Exception as e:
# print(f"Error: {str(e)}")
if image_url is None or image_url == "":
return gr.Image(label="Image", elem_id="image_inp"), Modal(visible=False)
else:
return gr.Image(label="Image", value=None, elem_id="image_inp"), Modal(visible=True)
def update_timestamp():
return gr.Textbox(datetime.datetime.now().timestamp(), label="Timestamp", visible=False) # FIXME visible=False)
def clear_data():
return (None, None, None, None, None, gr.update(value=None),
gr.update(value=[]), gr.update(value=[]), gr.update(value=[]),
gr.update(value=[]), gr.update(value=[]))
def exit():
return (None, None, None, gr.Dataset(samples=[]), gr.Markdown("**Loading your data, please wait ...**"),
gr.update(value=None), gr.update(value=None), [None, None, "", ""], gr.update(value=None),
gr.update(value=None), gr.update(value=None),
gr.update(value=None), gr.update(value=None), gr.update(value=None),
gr.update(value=None), gr.update(value=None))
def validate_inputs(image, ori_img): # is_blurred
# Perform your validation logic here
if image is None:
return gr.Button("Submit", variant="primary", interactive=False), None, None, # False
# Define maximum dimensions
MAX_WIDTH = 1024
MAX_HEIGHT = 1024
# Get current dimensions
height, width = image.shape[:2]
# # Check if resizing is needed
# NOTE: for now, let's keep the full image resolution
# if width > MAX_WIDTH or height > MAX_HEIGHT:
# # Calculate scaling factor
# scale = min(MAX_WIDTH/width, MAX_HEIGHT/height)
# # Calculate new dimensions
# new_width = int(width * scale)
# new_height = int(height * scale)
# # Resize image while maintaining aspect ratio
# result_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
# else:
# result_image = image
result_image = image
if ori_img is None:
# If the original image is None, set it to the resized image
ori_img = gr.State(result_image.copy())
return gr.Button("Submit", variant="primary", interactive=True), result_image, ori_img # is_blurred
def add_prefix(example, column_name, prefix):
example[column_name] = (f"{prefix}/" + example[column_name])
return example
def update_user_data(username, password, country, language_choice, HF_DATASET_NAME, local_ds_directory_path):
datasets_list = []
# Try loading local dataset
try:
snapshot_download(
repo_id=HF_DATASET_NAME,
repo_type="dataset",
local_dir=local_ds_directory_path, # Your target local directory
allow_patterns=f"{country}/{language_choice}/{username}/*", # f"**/{username}/*"
token=HF_API_TOKEN
)
except Exception as e:
print(f"Snapshot download error: {e}")
# import pdb; pdb.set_trace()
if has_user_json(username, country, language_choice, local_ds_directory_path):
try:
# ds_local = load_dataset(local_ds_directory_path, data_files=f'logged_in_users/**/{username}/**/*.json') # This does not filter by country and language
ds_local = load_dataset(local_ds_directory_path, data_files=f'logged_in_users/{country}/{language_choice}/{username}/**/*.json')
ds_local = ds_local.remove_columns("image_file")
ds_local = ds_local.rename_column("image", "image_file")
ds_local = ds_local.map(add_prefix, fn_kwargs={"column_name": "image_file", "prefix": local_ds_directory_path})
ds_local = ds_local.cast_column("image_file", Image())
datasets_list.append(list(ds_local.values())[0])
except Exception as e:
print(f"Local dataset load error: {e}")
# # Try loading hub dataset
# try:
# ds_hub = load_dataset(HF_DATASET_NAME, data_files=f'**/{username}/**/*.json', token=HF_API_TOKEN)
# ds_hub = ds_hub.cast_column("image_file", Image())
# datasets_list.append(list(ds_hub.values())[0])
# except Exception as e:
# print(f"Hub dataset load error: {e}")
# Handle all empty
if not datasets_list:
return gr.Dataset(samples=[]), gr.Markdown("<p style='color: red;'>No data available for this user. Please upload an image.</p>")
dataset = concatenate_datasets(datasets_list)
# TODO: we should link username with password and language and country, otherwise there will be an error when loading with different language and clicking on the example
if username and password:
user_dataset = dataset.filter(lambda x: x['username'] == username and is_password_correct(x['password'], password))
user_dataset = user_dataset.sort('timestamp', reverse=True)
# Show only unique entries (most recent)
user_ids = set()
samples = []
for d in user_dataset:
if d['id'] in user_ids:
continue
user_ids.add(d['id'])
if d['excluded']:
continue
# Get additional concepts by category or empty dict if not present
# additional_concepts_by_category = {
# "category1": d.get("category_1_concepts", []),
# "category2": d.get("category_2_concepts", []),
# "category3": d.get("category_3_concepts", []),
# "category4": d.get("category_4_concepts", []),
# "category5": d.get("category_5_concepts", [])
# }
additional_concepts_by_category = [
d.get("category_1_concepts", [""]),
d.get("category_2_concepts", [""]),
d.get("category_3_concepts", [""]),
d.get("category_4_concepts", [""]),
d.get("category_5_concepts", [""])
]
samples.append(
[
d['image_file'], d['image_url'], d['caption'] or "", d['country'],
d['language'], d['category'], d['concept'], additional_concepts_by_category, d['id']] # d['is_blurred']
)
return gr.Dataset(samples=samples), None
else:
# TODO: should we show the entire dataset instead? What about "other data" tab?
return gr.Dataset(samples=[]), None
def update_language(local_storage, metadata_dict, concepts_dict):
country, language, email, password, = local_storage
# my_translator = GoogleTranslator(source='english', target=metadata_dict[country][language])
categories = concepts_dict[country][lang2eng_mapping.get(language, language)]
if language in words_mapping:
categories_keys_translated = [words_mapping[language].get(cat, cat) for cat in categories.keys()]
else:
categories_keys_translated = list(categories.keys())
# Get the 5 categories in alphabetical order
categories_list = sorted(list(categories.keys()))[:5]
# Create translated labels for the 5 categories
translated_categories = []
for cat in categories_list:
if language in words_mapping:
translated_cat = words_mapping[language].get(cat, cat)
else:
translated_cat = cat
translated_categories.append(translated_cat)
fn = metadata_dict[country][language]["Task"]
if os.path.exists(fn):
with open(fn, "r", encoding="utf-8") as f:
TASK_TEXT = f.read()
else:
fn = metadata_dict["USA"]["English"]["Task"]
with open(fn, "r", encoding="utf-8") as f:
TASK_TEXT = f.read()
fn = metadata_dict[country][language]["Instructions"]
if os.path.exists(fn):
with open(metadata_dict[country][language]["Instructions"], "r", encoding="utf-8") as f:
INST_TEXT = f.read()
else:
fn = metadata_dict["USA"]["English"]["Instructions"]
with open(fn, "r", encoding="utf-8") as f:
INST_TEXT = f.read()
return (
gr.update(label=metadata_dict[country][language]["Country"], value=country),
gr.update(label=metadata_dict[country][language]["Language"], value=language),
gr.update(label=metadata_dict[country][language]["Email"], value=email),
gr.update(label=metadata_dict[country][language]["Password"], value=password),
gr.update(choices=categories_keys_translated, interactive=True, label=metadata_dict[country][language]["Category"], allow_custom_value=False, elem_id="category_btn"),
gr.update(choices=[], interactive=True, label=metadata_dict[country][language]["Concept"], allow_custom_value=True, elem_id="concept_btn"),
gr.update(label=metadata_dict[country][language]["Image"]),
gr.update(label=metadata_dict[country][language]["Image_URL"]),
gr.update(label=metadata_dict[country][language]["Description"]),
gr.Markdown(TASK_TEXT),
gr.Markdown(INST_TEXT),
gr.update(value=metadata_dict[country][language]["Instructs_btn"]),
gr.update(value=metadata_dict[country][language]["Clear_btn"]),
gr.update(value=metadata_dict[country][language]["Submit_btn"]),
gr.Markdown(metadata_dict[country][language]["Saving_text"]),
gr.Markdown(metadata_dict[country][language]["Saved_text"]),
gr.update(label=metadata_dict[country][language]["Timestamp"]),
gr.update(value=metadata_dict[country][language]["Exit_btn"]),
gr.Markdown(metadata_dict[country][language]["Browse_text"]),
gr.Markdown(metadata_dict[country][language]["Loading_msg"]),
# gr.update(choices=categories_keys_translated, interactive=True, label=metadata_dict[country][language].get("Add_Category","Additional Categories (Optional)"), allow_custom_value=False, elem_id="additional_category_btn"),
# gr.update(choices=[], interactive=True, label=metadata_dict[country][language].get("Add_Concept","Additional Concepts (Optional)"), allow_custom_value=True, elem_id="additional_concept_btn"),
gr.update(value=metadata_dict[country][language].get("Hide_all_btn","👤 Hide All Faces")),
gr.update(value=metadata_dict[country][language].get("Hide_btn","👤 Hide Specific Faces")),
gr.update(value=metadata_dict[country][language].get("Unhide_btn","👀 Unhide Faces")),
gr.update(value=metadata_dict[country][language].get("Exclude_btn","Exclude Selected Example")),
gr.update(label=translated_categories[0], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[0]])),
gr.update(label=translated_categories[1], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[1]])),
gr.update(label=translated_categories[2], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[2]])),
gr.update(label=translated_categories[3], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[3]])),
gr.update(label=translated_categories[4], choices=sorted(concepts_dict[country][lang2eng_mapping.get(language, language)][categories_list[4]])),
)
def update_intro_language(selected_country, selected_language, intro_markdown, metadata):
if selected_language is None:
return intro_markdown
fn = metadata[selected_country][selected_language]["Intro"]
if not os.path.exists(fn):
return intro_markdown
with open(metadata[selected_country][selected_language]["Intro"], "r", encoding="utf-8") as f:
INTRO_TEXT = f.read()
return gr.Markdown(INTRO_TEXT)
def handle_click_example(user_examples, concepts_dict):
print("handle_click_example")
print(user_examples)
ex = [item for item in user_examples]
# print(ex)
image_inp = ex[0]
image_url_inp = ex[1]
long_caption_inp = ex[2]
country_btn = ex[3]
language_btn = ex[4]
category_btn = ex[5]
concept_btn = ex[6]
additional_concepts_by_category = ex[7]
exampleid_btn = ex[8]
additional_concepts_by_category = [[] if (len(cat_concept)==1 and cat_concept[0]=='') else cat_concept for cat_concept in additional_concepts_by_category]
# import pdb; pdb.set_trace()
# # excluded_btn = ex[10] # TODO: add functionality that if True "exclude" button changes to "excluded"
# # is_blurred = ex[11]
# # Get predefined categories in the correct order
# predefined_categories = sorted(list(concepts_dict[country_btn][lang2eng_mapping.get(language_btn, language_btn)].keys()))[:5]
# # Create dropdown values for each category
# dropdown_values = []
# for category in predefined_categories:
# if additional_concepts_by_category and category in additional_concepts_by_category:
# dropdown_values.append(additional_concepts_by_category[category])
# else:
# dropdown_values.append(None)
### TODO: fix additional concepts not saving if categories in other language than English
# # Get the English version of the language
# eng_lang = lang2eng_mapping.get(language_btn, language_btn)
# # Get predefined categories in the correct order
# predefined_categories = sorted(list(concepts_dict[country_btn][eng_lang].keys()))[:5]
# # Create dropdown values for each category
# dropdown_values = []
# for category in predefined_categories:
# if additional_concepts_by_category and category in additional_concepts_by_category:
# dropdown_values.append(additional_concepts_by_category[category])
# else:
# dropdown_values.append(None)
# Need to return values for each category dropdown
return [image_inp, image_url_inp, long_caption_inp, exampleid_btn, category_btn, concept_btn] + additional_concepts_by_category + [True]
def is_password_correct(hashed_password, entered_password):
is_valid = bcrypt.checkpw(entered_password.encode(), hashed_password.encode())
# print("password_check: ", entered_password," ", hashed_password," ", is_valid)
return is_valid
## Face blurring functions
def detect_faces(image):
"""
Detect faces in an image using RetinaFace.
Args:
image (numpy.ndarray): Input image in BGR
"""
# Start timer
start_time = time.time()
# Detect faces using RetinaFace
detection_start = time.time()
faces = RetinaFace.detect_faces(image, threshold=0.8)
detection_time = time.time() - detection_start
return faces, detection_time
# Hide Faces Button
def select_faces_to_hide(image, blur_faces_ids):
if image is None:
return None, Modal(visible=False), Modal(visible=False), None , "", None, gr.update(value=[])
else:
# Detect faces
# import pdb; pdb.set_trace()
face_images = image.copy()
faces, detection_time = detect_faces(face_images)
print(f"Detection time: {detection_time:.2f} seconds")
# pdb.set_trace()
# Draw detections with IDs
for face_id, face_data in enumerate(faces.values(), start=1):
# Get face coordinates
facial_area = face_data['facial_area']
x1, y1, x2, y2 = facial_area
# Draw rectangle around face
cv2.rectangle(face_images, (x1, y1), (x2, y2), (0, 0, 255), 2)
# Add ID text
cv2.putText(face_images, f"ID: {face_id}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
# Update face count
face_count = len(faces)
blur_faces_ids = gr.update(choices=[f"Face ID: {i}" for i in range(1, face_count + 1)])
current_faces_info = gr.State(faces)
if face_count == 0:
return image, Modal(visible=False), Modal(visible=True), None, "", None, gr.update(value=[])
else:
return image, Modal(visible=True), Modal(visible=False), face_images, str(face_count), current_faces_info, blur_faces_ids #
def blur_selected_faces(image, blur_faces_ids, faces_info, face_img, faces_count): # is_blurred
if not blur_faces_ids:
return image, Modal(visible=True), face_img, faces_count, blur_faces_ids # is_blurred
faces = faces_info.value
parsed_faces_ids = blur_faces_ids
parsed_faces_ids = [f"face_{val.split(':')[-1].strip()}" for val in parsed_faces_ids]
# Base blur amount and bounds
MIN_BLUR = 31 # Minimum blur amount (must be odd)
MAX_BLUR = 131 # Maximum blur amount (must be odd)
blurring_start = time.time()
# Process each face
face_count = 0
if faces and isinstance(faces, dict):
# blur by id
for face_key in parsed_faces_ids:
face_count += 1
try:
face_data = faces[face_key]
except KeyError:
gr.Warning(f"⚠️ Face ID {face_key.split('_')[-1]} not found in detected faces.", duration=5)
return image, Modal(visible=True), face_img, faces_count, blur_faces_ids # is_blurred
# Get bounding box coordinates
x1, y1, x2, y2 = face_data['facial_area']
# Calculate face region size
face_width = x2 - x1
face_height = y2 - y1
face_size = max(face_width, face_height)
# Calculate adaptive blur amount based on face size
# Scale blur amount between MIN_BLUR and MAX_BLUR based on face size
# Using image width as reference for scaling
img_width = image.shape[1]
blur_amount = int(MIN_BLUR + (MAX_BLUR - MIN_BLUR) * (face_size / img_width))
# Ensure blur amount is odd
blur_amount = blur_amount if blur_amount % 2 == 1 else blur_amount + 1
# Ensure within bounds
blur_amount = max(MIN_BLUR, min(MAX_BLUR, blur_amount))
# Ensure the coordinates are within the image boundaries
ih, iw = image.shape[:2]
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(iw, x2), min(ih, y2)
# Extract face region
face_region = image[y1:y2, x1:x2]
# Apply blur
blurred_face = cv2.GaussianBlur(face_region, (blur_amount, blur_amount), 0)
# Replace face region with blurred version
image[y1:y2, x1:x2] = blurred_face
blurring_time = time.time() - blurring_start
# Print timing information
print(f"Face blurring performance metrics:")
print(f"Face blurring time: {blurring_time:.4f} seconds")
if face_count == 0:
return image, Modal(visible=True), face_img, faces_count, blur_faces_ids
else:
return image, Modal(visible=False), None, None, gr.update(value=[])
def blur_all_faces(image):
if image is None:
return None, Modal(visible=False)
else:
# Base blur amount and bounds
MIN_BLUR = 31 # Minimum blur amount (must be odd)
MAX_BLUR = 131 # Maximum blur amount (must be odd)
# Start timer
start_time = time.time()
# Detect faces using RetinaFace
detection_start = time.time()
faces = RetinaFace.detect_faces(image)
detection_time = time.time() - detection_start
# Create a copy of the image
output_image = image.copy()
face_count = 0
blurring_start = time.time()
# Process each face
if faces and isinstance(faces, dict):
for face_key in faces:
face_count += 1
face_data = faces[face_key]
# Get bounding box coordinates
x1, y1, x2, y2 = face_data['facial_area']
# Calculate face region size
face_width = x2 - x1
face_height = y2 - y1
face_size = max(face_width, face_height)
# Calculate adaptive blur amount based on face size
# Scale blur amount between MIN_BLUR and MAX_BLUR based on face size
# Using image width as reference for scaling
img_width = image.shape[1]
blur_amount = int(MIN_BLUR + (MAX_BLUR - MIN_BLUR) * (face_size / img_width))
# Ensure blur amount is odd
blur_amount = blur_amount if blur_amount % 2 == 1 else blur_amount + 1
# Ensure within bounds
blur_amount = max(MIN_BLUR, min(MAX_BLUR, blur_amount))
# Ensure the coordinates are within the image boundaries
ih, iw = image.shape[:2]
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(iw, x2), min(ih, y2)
# Extract face region
face_region = output_image[y1:y2, x1:x2]
# Apply blur
blurred_face = cv2.GaussianBlur(face_region, (blur_amount, blur_amount), 0)
# Replace face region with blurred version
output_image[y1:y2, x1:x2] = blurred_face
blurring_time = time.time() - blurring_start
total_time = time.time() - start_time
# Print timing information
print(f"Face blurring performance metrics:")
print(f"Total faces detected: {face_count}")
print(f"Face detection time: {detection_time:.4f} seconds")
print(f"Face blurring time: {blurring_time:.4f} seconds")
print(f"Total processing time: {total_time:.4f} seconds")
print(f"Average time per face: {(total_time/max(1, face_count)):.4f} seconds")
if face_count == 0:
return image, Modal(visible=True)
else:
return output_image, Modal(visible=False)
def unhide_faces(img, ori_img): # is_blurred
if img is None:
return None
elif np.array_equal(img, ori_img.value):
return img # is_blurred
else:
return ori_img.value
def check_exclude_fn(image):
if image is None:
gr.Warning("⚠️ No image to exclude.")
return gr.update(visible=False)
else:
return gr.update(visible=True)
def has_user_json(username, country,language_choice, local_ds_directory_path):
"""Check if JSON files exist for username pattern."""
return bool(glob.glob(os.path.join(local_ds_directory_path, "logged_in_users", country, language_choice, username, "**", "*.json"), recursive=True))