Spaces:
Running
Running
import os | |
import json | |
import time | |
from huggingface_hub import HfApi, create_repo, CommitScheduler | |
import bcrypt | |
import shutil | |
import uuid | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
def load_concepts(path="data/concepts.json"): | |
with open(path, encoding='utf-8') as f: | |
data = json.load(f) | |
sorted_data = dict() | |
for country in sorted(data): | |
sorted_data[country] = dict() | |
for lang in sorted(data[country]): | |
sorted_data[country][lang] = data[country][lang] | |
return sorted_data | |
def load_metadata(path="data/metadata.json"): | |
with open(path, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
sorted_data = dict() | |
for country in sorted(data): | |
sorted_data[country] = dict() | |
for lang in sorted(data[country]): | |
sorted_data[country][lang] = data[country][lang] | |
return sorted_data | |
class CustomHFDatasetSaver: | |
def __init__(self, api_token, dataset_name, private=False): | |
self.api_token = api_token | |
self.dataset_name = dataset_name | |
self.private = private | |
self.api = HfApi() | |
def setup(self, data_outputs, local_ds_folder): | |
# create repo is not exist | |
self.dataset_name = create_repo( | |
repo_id=self.dataset_name, | |
token=self.api_token, | |
private=self.private, | |
repo_type="dataset", | |
exist_ok=True, | |
).repo_id | |
# Create the local data folder if not exist | |
self.local_ds_folder = local_ds_folder | |
os.makedirs(self.local_ds_folder, exist_ok=True) | |
self.data_outputs = data_outputs # list of components to read values from | |
# create scheduler to commit the data to the hub every x minutes | |
self.scheduler = CommitScheduler( | |
repo_id=self.dataset_name, | |
repo_type="dataset", | |
folder_path=self.local_ds_folder, | |
every=1, | |
token=self.api_token, | |
) | |
def validate_data(self, values_dic): | |
""" | |
Validates the data before saving to ensure no required fields are empty. | |
Returns (bool, str) tuple where first value indicates if validation passed | |
and second value contains error message if validation failed. | |
""" | |
# Remove 'image' from required fields since we handle it separately | |
required_fields = ['country', 'language', 'category', 'concept', 'caption'] | |
# Check if image is provided (either uploaded or via URL) | |
image = values_dic.get('image') | |
image_url = values_dic.get('image_url') | |
# Check if image exists and is not None | |
has_image = image is not None and (isinstance(image, dict) or (hasattr(image, 'shape') and image.shape[0] > 0)) | |
has_url = image_url is not None and image_url.strip() != "" | |
if not has_image and not has_url: | |
return False, "Either an image or image URL must be provided" | |
# Check required fields | |
for field in required_fields: | |
value = values_dic.get(field) | |
if value is None or (isinstance(value, str) and value.strip() == ""): | |
return False, f"Required field '{field}' cannot be empty" | |
# Check if image file exists if image path is provided | |
if has_image and isinstance(image, dict): | |
if not os.path.exists(image.get('path', '')): | |
return False, "Image file not found" | |
return True, "" | |
#TODO: add a function to check if the user is logged in | |
def is_logged_in(self): | |
pass | |
#TODO: check if the user is logged in (add a decorator to the save function) | |
def save(self, *values): | |
# 'values' are the outputs from your data collection components, | |
# you can map these to field names as needed | |
values_dic = dict(zip(self.data_outputs, values)) | |
# print(f"Values received: {values_dic}") | |
# Validate data before proceeding | |
is_valid, error_msg = self.validate_data(values_dic) | |
if not is_valid: | |
raise gr.Error(error_msg) | |
# raise ValueError(error_msg) | |
values_dic['password'] = self.hash_password(values_dic['password']) | |
# # Process main category and concept | |
# main_category = values_dic.get('category', '') | |
# main_concept = values_dic.get('concept', '') | |
# # Process category-specific concept dropdowns | |
# additional_concepts_by_category = {} | |
# # Extract predefined categories and their corresponding dropdowns from values_dic | |
# predefined_categories = sorted(list(values_dic.get('concepts_dict', {}) | |
# .get(values_dic.get('country', 'USA'), {}) | |
# .get(values_dic.get('language', 'English'), {}).keys()))[:5] | |
# # Process each category dropdown | |
# for i, category in enumerate(predefined_categories): | |
# dropdown_key = f'category{i+1}_concepts' | |
# if dropdown_key in values_dic and values_dic[dropdown_key]: | |
# # Only add non-empty concept selections | |
# if values_dic[dropdown_key]: | |
# additional_concepts_by_category[category] = values_dic[dropdown_key] | |
### TODO: fix saving additional concepts if not displayed in English | |
# # Process category-specific concept dropdowns | |
# additional_concepts_by_category = {} | |
# # Extract the country and language | |
# country = values_dic.get('country', 'USA') | |
# language = values_dic.get('language', 'English') | |
# concepts_dict = values_dic.get('concepts_dict', {}) | |
# lang2eng_mapping = values_dic.get('country_lang_map', {}) | |
# # Get the English version of the language for dictionary lookup | |
# eng_lang = lang2eng_mapping.get(language, language) | |
# # Get the predefined categories in English | |
# predefined_categories = sorted(list(concepts_dict.get(country, {}).get(eng_lang, {}).keys()))[:5] | |
# # Process each category dropdown | |
# for i, category in enumerate(predefined_categories): | |
# dropdown_key = f'category_{i+1}_concepts' | |
# if dropdown_key in values_dic and values_dic[dropdown_key]: | |
# # Only add non-empty concept selections | |
# additional_concepts_by_category[category] = values_dic[dropdown_key] | |
current_timestamp = int(time.time() * 1000) | |
# Create a unique ID for the sample is not provided | |
if not values_dic.get("id"): | |
# Missing ID | |
country, language, category, concept = values_dic.get("country"), values_dic.get("language"), values_dic.get("category"), values_dic.get("concept") | |
values_dic["id"] = f'{country}_{language}_{category}_{concept}_{current_timestamp}' | |
#prepare the main directory of the sample | |
if values_dic.get("username"): | |
sample_dir = os.path.join("logged_in_users", values_dic["country"], values_dic["language"], values_dic["username"], str(current_timestamp)) | |
else: | |
sample_dir = os.path.join("anonymous_users", values_dic["country"], values_dic["language"], str(uuid.uuid4())) | |
os.makedirs(os.path.join(self.local_ds_folder, sample_dir), exist_ok=True) | |
# Destination path | |
dest_image_path = os.path.join(sample_dir, "image.png") | |
# Source path (to be used for copying the file in the with lock block) | |
# This is the path of the image file that was uploaded by the user | |
# I want to save the values_dic['image'] in the dest_image_path | |
# Convert numpy array to PIL Image and save it | |
# === | |
# uploaded_image_path = os.path.join(self.local_ds_folder, dest_image_path) | |
# img = Image.fromarray(values_dic['image']) | |
# img.save(uploaded_image_path) | |
full_dest_path = os.path.join(self.local_ds_folder, dest_image_path) | |
# Handle different image types | |
image_data = values_dic['image'] | |
if isinstance(image_data, dict) and 'path' in image_data: | |
# New upload case - copy from the uploaded path | |
uploaded_image_path = image_data['path'] | |
with self.scheduler.lock: | |
shutil.copy(uploaded_image_path, full_dest_path) | |
elif isinstance(image_data, np.ndarray): # not values_dic.get('excluded', False) and | |
# Exclude case with numpy array - save the array as an image | |
with self.scheduler.lock: | |
# Convert numpy array to PIL image and save | |
img = Image.fromarray(image_data) | |
img.save(full_dest_path) | |
elif isinstance(image_data, Image.Image): | |
# PIL image case | |
with self.scheduler.lock: | |
image_data.save(full_dest_path) | |
values_dic['image'] = dest_image_path | |
image_file_path_on_hub = f"https://huggingface.co/datasets/{self.dataset_name}/resolve/main/{dest_image_path}" | |
# print(f"Saving sample: {values}") | |
# Build the metadata dictionary. | |
data_dict = { | |
# in case using windows | |
"image": values_dic['image'].replace("\\", "/"), | |
"image_file": image_file_path_on_hub.replace("\\", "/"), | |
# "image": values_dic['image'], | |
# "image_file": image_file_path_on_hub, | |
"image_url": values_dic['image_url'] or "", | |
"caption": values_dic['caption'] or "", | |
"country": values_dic['country'] or "", | |
"language": values_dic['language'] or "", | |
"category": values_dic['category'] or "", | |
"concept": values_dic['concept'] or "", | |
"category_1_concepts": [""] if values_dic.get('category_1_concepts', [""])==[] else values_dic.get('category_1_concepts', [""]), | |
"category_2_concepts": [""] if values_dic.get('category_2_concepts', [""])==[] else values_dic.get('category_2_concepts', [""]), | |
"category_3_concepts": [""] if values_dic.get('category_3_concepts', [""])==[] else values_dic.get('category_3_concepts', [""]), | |
"category_4_concepts": [""] if values_dic.get('category_4_concepts', [""])==[] else values_dic.get('category_4_concepts', [""]), | |
"category_5_concepts": [""] if values_dic.get('category_5_concepts', [""])==[] else values_dic.get('category_5_concepts', [""]), | |
"timestamp": current_timestamp, | |
"username": values_dic['username'] or "", | |
"password": values_dic['password'] or "", | |
"id": values_dic['id'], | |
"excluded": False if values_dic.get('excluded') is None else bool(values_dic.get('excluded')), | |
# "is_blurred": str(values_dic.get('is_blurred')) | |
} | |
print(f"Data dictionary: {data_dict}") | |
# Define a unique filename for the JSON metadata file (stored in self.folder). | |
json_filename = f"sample_{current_timestamp}.json" | |
json_file_path = os.path.join(self.local_ds_folder, sample_dir, json_filename) | |
with self.scheduler.lock: | |
# Save the metadata to the sample file in the local dataset folder | |
with open(json_file_path, "w", encoding="utf-8") as f: | |
json.dump(data_dict, f, indent=2) | |
print("Data saved successfully") | |
def hash_password(self, raw_password): | |
""" | |
Hashes a raw password using bcrypt and returns the hashed password. | |
raw_password (str): The plain text password to be hashed. | |
str: The hashed password as a string. | |
""" | |
hashed_password = bcrypt.hashpw(raw_password.encode(), bcrypt.gensalt()).decode() | |
return hashed_password | |