|
import gradio as gr |
|
import os |
|
import random |
|
import csv |
|
from pathlib import Path |
|
from datetime import datetime, timedelta |
|
import tempfile |
|
from huggingface_hub import HfApi, hf_hub_download, login |
|
from huggingface_hub.utils import RepositoryNotFoundError, EntryNotFoundError |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
import atexit |
|
import threading |
|
import time |
|
import shutil |
|
|
|
|
|
DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "matsant01/user-study-collected-preferences") |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
RESULTS_FILENAME_IN_REPO = "preferences.csv" |
|
TEMP_DIR = tempfile.mkdtemp() |
|
LOCAL_RESULTS_FILE = Path(TEMP_DIR) / RESULTS_FILENAME_IN_REPO |
|
UPLOAD_INTERVAL_HOURS = 0.1 |
|
|
|
DATA_DIR = Path("data") |
|
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp"] |
|
|
|
|
|
hf_api = None |
|
scheduler = BackgroundScheduler(daemon=True) |
|
upload_lock = threading.Lock() |
|
new_preferences_recorded_since_last_upload = threading.Event() |
|
|
|
|
|
def initialize_hub_and_results(): |
|
global hf_api |
|
if HF_TOKEN: |
|
print("Logging into Hugging Face Hub...") |
|
try: |
|
login(token=HF_TOKEN) |
|
hf_api = HfApi() |
|
print(f"Attempting initial download of {RESULTS_FILENAME_IN_REPO} from {DATASET_REPO_ID}") |
|
hf_hub_download( |
|
repo_id=DATASET_REPO_ID, |
|
filename=RESULTS_FILENAME_IN_REPO, |
|
repo_type="dataset", |
|
token=HF_TOKEN, |
|
local_dir=TEMP_DIR, |
|
local_dir_use_symlinks=False |
|
) |
|
print(f"Successfully downloaded existing {RESULTS_FILENAME_IN_REPO} to {LOCAL_RESULTS_FILE}") |
|
except EntryNotFoundError: |
|
print(f"{RESULTS_FILENAME_IN_REPO} not found in repo. Will create locally.") |
|
except RepositoryNotFoundError: |
|
print(f"Error: Dataset repository {DATASET_REPO_ID} not found or token lacks permissions.") |
|
print("Results saving will be disabled.") |
|
hf_api = None |
|
except Exception as e: |
|
print(f"Error during initial download/login: {e}") |
|
print("Proceeding without initial download. File will be created locally.") |
|
else: |
|
print("Warning: HF_TOKEN secret not found. Results will not be saved to the Hub.") |
|
hf_api = None |
|
|
|
|
|
|
|
def find_image(folder_path: Path, base_name: str) -> Path | None: |
|
for ext in IMAGE_EXTENSIONS: |
|
file_path = folder_path / f"{base_name}{ext}" |
|
if file_path.exists(): |
|
return file_path |
|
return None |
|
|
|
def get_sample_ids() -> list[str]: |
|
sample_ids = [] |
|
if DATA_DIR.is_dir(): |
|
for item in DATA_DIR.iterdir(): |
|
if item.is_dir(): |
|
prompt_file = item / "prompt.txt" |
|
input_bg = find_image(item, "input_bg") |
|
input_fg = find_image(item, "input_fg") |
|
output_baseline = find_image(item, "baseline") |
|
output_tficon = find_image(item, "tf-icon") |
|
if prompt_file.exists() and input_bg and input_fg and output_baseline and output_tficon: |
|
sample_ids.append(item.name) |
|
return sample_ids |
|
|
|
def load_sample_data(sample_id: str) -> dict | None: |
|
sample_path = DATA_DIR / sample_id |
|
if not sample_path.is_dir(): |
|
return None |
|
|
|
prompt_file = sample_path / "prompt.txt" |
|
input_bg_path = find_image(sample_path, "input_bg") |
|
input_fg_path = find_image(sample_path, "input_fg") |
|
output_baseline_path = find_image(sample_path, "baseline") |
|
output_tficon_path = find_image(sample_path, "tf-icon") |
|
|
|
if not all([prompt_file.exists(), input_bg_path, input_fg_path, output_baseline_path, output_tficon_path]): |
|
print(f"Warning: Missing files in sample {sample_id}") |
|
return None |
|
|
|
try: |
|
prompt = prompt_file.read_text().strip() |
|
except Exception as e: |
|
print(f"Error reading prompt for {sample_id}: {e}") |
|
return None |
|
|
|
return { |
|
"id": sample_id, |
|
"prompt": prompt, |
|
"input_bg": str(input_bg_path), |
|
"input_fg": str(input_fg_path), |
|
"output_baseline": str(output_baseline_path), |
|
"output_tficon": str(output_tficon_path), |
|
} |
|
|
|
|
|
|
|
INITIAL_SAMPLE_IDS = get_sample_ids() |
|
|
|
def get_next_sample(available_ids: list[str]) -> tuple[dict | None, list[str]]: |
|
if not available_ids: |
|
return None, [] |
|
chosen_id = random.choice(available_ids) |
|
remaining_ids = [id for id in available_ids if id != chosen_id] |
|
sample_data = load_sample_data(chosen_id) |
|
return sample_data, remaining_ids |
|
|
|
def display_new_sample(state: dict, available_ids: list[str]): |
|
sample_data, remaining_ids = get_next_sample(available_ids) |
|
|
|
if not sample_data: |
|
return { |
|
prompt_display: gr.update(value="**Prompt:** No more samples available. Thank you!"), |
|
input_bg_display: gr.update(value=None, visible=False), |
|
input_fg_display: gr.update(value=None, visible=False), |
|
output_a_display: gr.update(value=None, visible=False), |
|
output_b_display: gr.update(value=None, visible=False), |
|
choice_button_a: gr.update(visible=False), |
|
choice_button_b: gr.update(visible=False), |
|
next_button: gr.update(visible=False), |
|
status_display: gr.update(value="**Status:** Completed!"), |
|
app_state: state, |
|
available_samples_state: remaining_ids |
|
} |
|
|
|
outputs = [ |
|
{"model_name": "baseline", "path": sample_data["output_baseline"]}, |
|
{"model_name": "tf-icon", "path": sample_data["output_tficon"]}, |
|
] |
|
random.shuffle(outputs) |
|
output_a = outputs[0] |
|
output_b = outputs[1] |
|
|
|
state = { |
|
"current_sample_id": sample_data["id"], |
|
"output_a_model_name": output_a["model_name"], |
|
"output_b_model_name": output_b["model_name"], |
|
} |
|
|
|
return { |
|
prompt_display: gr.update(value=f"**Prompt:** {sample_data['prompt']}"), |
|
input_bg_display: gr.update(value=sample_data["input_bg"], visible=True), |
|
input_fg_display: gr.update(value=sample_data["input_fg"], visible=True), |
|
output_a_display: gr.update(value=output_a["path"], visible=True), |
|
output_b_display: gr.update(value=output_b["path"], visible=True), |
|
choice_button_a: gr.update(visible=True, interactive=True), |
|
choice_button_b: gr.update(visible=True, interactive=True), |
|
next_button: gr.update(visible=False), |
|
status_display: gr.update(value="**Status:** Please choose the image you prefer."), |
|
app_state: state, |
|
available_samples_state: remaining_ids |
|
} |
|
|
|
def record_preference(choice: str, state: dict, request: gr.Request): |
|
if not request: |
|
print("Error: Request object is None. Cannot get session ID.") |
|
session_id = "unknown_session" |
|
else: |
|
try: |
|
session_id = request.client.host |
|
except AttributeError: |
|
print("Error: request.client is None or has no 'host' attribute.") |
|
session_id = "unknown_client" |
|
|
|
if not state or "current_sample_id" not in state: |
|
print("Warning: State missing, cannot record preference.") |
|
return { |
|
choice_button_a: gr.update(interactive=False), |
|
choice_button_b: gr.update(interactive=False), |
|
next_button: gr.update(visible=True, interactive=True), |
|
status_display: gr.update(value="**Status:** Error: Session state lost. Click Next Sample."), |
|
app_state: state |
|
} |
|
|
|
chosen_model_name = state["output_a_model_name"] if choice == "A" else state["output_b_model_name"] |
|
baseline_display = "A" if state["output_a_model_name"] == "baseline" else "B" |
|
tficon_display = "B" if state["output_a_model_name"] == "baseline" else "A" |
|
|
|
new_row = { |
|
"timestamp": datetime.now().isoformat(), |
|
"session_id": session_id, |
|
"sample_id": state["current_sample_id"], |
|
"baseline_displayed_as": baseline_display, |
|
"tficon_displayed_as": tficon_display, |
|
"chosen_display": choice, |
|
"chosen_model_name": chosen_model_name |
|
} |
|
header = list(new_row.keys()) |
|
|
|
try: |
|
with upload_lock: |
|
file_exists = LOCAL_RESULTS_FILE.exists() |
|
mode = 'a' if file_exists else 'w' |
|
with open(LOCAL_RESULTS_FILE, mode, newline='', encoding='utf-8') as f: |
|
writer = csv.DictWriter(f, fieldnames=header) |
|
if not file_exists or os.path.getsize(LOCAL_RESULTS_FILE) == 0: |
|
writer.writeheader() |
|
print(f"Created or wrote header to {LOCAL_RESULTS_FILE}") |
|
writer.writerow(new_row) |
|
print(f"Appended preference for {state['current_sample_id']} to local file.") |
|
new_preferences_recorded_since_last_upload.set() |
|
|
|
except Exception as e: |
|
print(f"Error writing local results file {LOCAL_RESULTS_FILE}: {e}") |
|
return { |
|
choice_button_a: gr.update(interactive=False), |
|
choice_button_b: gr.update(interactive=False), |
|
next_button: gr.update(visible=True, interactive=True), |
|
status_display: gr.update(value=f"**Status:** Error saving preference locally: {e}. Click Next."), |
|
app_state: state |
|
} |
|
|
|
return { |
|
choice_button_a: gr.update(interactive=False), |
|
choice_button_b: gr.update(interactive=False), |
|
next_button: gr.update(visible=True, interactive=True), |
|
status_display: gr.update(value=f"**Status:** Preference recorded (Chose {choice}). Click Next Sample."), |
|
app_state: state |
|
} |
|
|
|
def upload_preferences_to_hub(): |
|
print("Periodic upload check triggered.") |
|
if not hf_api: |
|
print("Upload check skipped: Hugging Face API not available.") |
|
return |
|
|
|
if not new_preferences_recorded_since_last_upload.is_set(): |
|
print("Upload check skipped: No new preferences recorded since last upload.") |
|
return |
|
|
|
with upload_lock: |
|
if not new_preferences_recorded_since_last_upload.is_set(): |
|
print("Upload check skipped (race condition avoided): No new preferences.") |
|
return |
|
|
|
if not LOCAL_RESULTS_FILE.exists() or os.path.getsize(LOCAL_RESULTS_FILE) == 0: |
|
print("Upload check skipped: Local results file is missing or empty.") |
|
new_preferences_recorded_since_last_upload.clear() |
|
return |
|
|
|
try: |
|
print(f"Attempting to upload {LOCAL_RESULTS_FILE} to {DATASET_REPO_ID}/{RESULTS_FILENAME_IN_REPO}") |
|
start_time = time.time() |
|
hf_api.upload_file( |
|
path_or_fileobj=str(LOCAL_RESULTS_FILE), |
|
path_in_repo=RESULTS_FILENAME_IN_REPO, |
|
repo_id=DATASET_REPO_ID, |
|
repo_type="dataset", |
|
commit_message=f"Periodic upload of preferences - {datetime.now().isoformat()}" |
|
) |
|
end_time = time.time() |
|
print(f"Successfully uploaded preferences. Took {end_time - start_time:.2f} seconds.") |
|
new_preferences_recorded_since_last_upload.clear() |
|
except Exception as e: |
|
print(f"Error uploading results file: {e}") |
|
|
|
def handle_choice_a(state: dict, request: gr.Request): |
|
return record_preference("A", state, request) |
|
|
|
def handle_choice_b(state: dict, request: gr.Request): |
|
return record_preference("B", state, request) |
|
|
|
with gr.Blocks(title="Image Composition User Study") as demo: |
|
gr.Markdown("# Image Composition User Study") |
|
gr.Markdown( |
|
"> Please look at the input images and the prompt below. " |
|
"Then, compare the two output images (Output A and Output B) and click the button below the one you prefer." |
|
) |
|
|
|
app_state = gr.State({}) |
|
available_samples_state = gr.State(INITIAL_SAMPLE_IDS) |
|
|
|
status_display = gr.Markdown("**Status:** Loading first sample...") |
|
|
|
gr.Markdown("## Inputs") |
|
with gr.Row(): |
|
prompt_display = gr.Markdown("**Prompt:** Loading...") |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("<div style='text-align: center;'>Input Background</div>") |
|
input_bg_display = gr.Image(type="filepath", height=250, width=250, interactive=False, show_label=False) |
|
with gr.Column(): |
|
gr.Markdown("<div style='text-align: center;'>Input Foreground</div>") |
|
input_fg_display = gr.Image(type="filepath", height=250, width=250, interactive=False, show_label=False) |
|
|
|
gr.Markdown("---") |
|
gr.Markdown("## Choose your preferred output") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
output_a_display = gr.Image(label="Output A", type="filepath", height=400, width=400, interactive=False) |
|
choice_button_a = gr.Button("Choose Output A", variant="primary") |
|
with gr.Column(): |
|
output_b_display = gr.Image(label="Output B", type="filepath", height=400, width=400, interactive=False) |
|
choice_button_b = gr.Button("Choose Output B", variant="primary") |
|
|
|
next_button = gr.Button("🔁 Next Sample 🔁", visible=False) |
|
|
|
demo.load( |
|
fn=display_new_sample, |
|
inputs=[app_state, available_samples_state], |
|
outputs=[ |
|
prompt_display, input_bg_display, input_fg_display, |
|
output_a_display, output_b_display, |
|
choice_button_a, choice_button_b, next_button, status_display, |
|
app_state, available_samples_state |
|
] |
|
) |
|
|
|
choice_button_a.click( |
|
fn=handle_choice_a, |
|
inputs=[app_state], |
|
outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state], |
|
api_name=False, |
|
) |
|
|
|
choice_button_b.click( |
|
fn=handle_choice_b, |
|
inputs=[app_state], |
|
outputs=[choice_button_a, choice_button_b, next_button, status_display, app_state], |
|
api_name=False, |
|
) |
|
|
|
next_button.click( |
|
fn=display_new_sample, |
|
inputs=[app_state, available_samples_state], |
|
outputs=[ |
|
prompt_display, input_bg_display, input_fg_display, |
|
output_a_display, output_b_display, |
|
choice_button_a, choice_button_b, next_button, status_display, |
|
app_state, available_samples_state |
|
], |
|
api_name=False, |
|
) |
|
|
|
def cleanup_temp_dir(): |
|
if Path(TEMP_DIR).exists(): |
|
print(f"Cleaning up temporary directory: {TEMP_DIR}") |
|
shutil.rmtree(TEMP_DIR, ignore_errors=True) |
|
|
|
def shutdown_hook(): |
|
print("Application shutting down. Performing final upload check...") |
|
upload_preferences_to_hub() |
|
if scheduler.running: |
|
print("Shutting down scheduler...") |
|
scheduler.shutdown(wait=False) |
|
cleanup_temp_dir() |
|
print("Shutdown complete.") |
|
|
|
atexit.register(shutdown_hook) |
|
|
|
if __name__ == "__main__": |
|
initialize_hub_and_results() |
|
|
|
if not INITIAL_SAMPLE_IDS: |
|
print("Error: No valid samples found in the 'data' directory.") |
|
print("Please ensure the 'data' directory exists and contains subdirectories") |
|
print("named like 'sample_id', each with 'prompt.txt', 'input_bg.*',") |
|
print("'input_fg.*', 'baseline.*', and 'tf-icon.*' files.") |
|
elif not DATASET_REPO_ID: |
|
print("Error: DATASET_REPO_ID environment variable is not set or is set to the default placeholder.") |
|
print("Please set the DATASET_REPO_ID environment variable or update the script.") |
|
elif hf_api: |
|
print(f"Starting periodic upload scheduler (every {UPLOAD_INTERVAL_HOURS} hours)...") |
|
scheduler.add_job(upload_preferences_to_hub, 'interval', hours=UPLOAD_INTERVAL_HOURS) |
|
scheduler.start() |
|
print(f"Found {len(INITIAL_SAMPLE_IDS)} samples.") |
|
print(f"Configured to save results periodically to Hugging Face Dataset: {DATASET_REPO_ID}") |
|
print("Starting Gradio app...") |
|
demo.launch(server_name="0.0.0.0") |
|
else: |
|
print("Warning: Running without Hugging Face Hub integration (HF_TOKEN or DATASET_REPO_ID missing/invalid).") |
|
print(f"Found {len(INITIAL_SAMPLE_IDS)} samples.") |
|
print("Starting Gradio app...") |
|
demo.launch(server_name="0.0.0.0") |