Spaces:
Runtime error
Runtime error
import gradio as gr | |
import random | |
from datasets import load_dataset | |
import csv | |
from datetime import datetime | |
import os | |
import pandas as pd | |
import json | |
from huggingface_hub import CommitScheduler, HfApi, snapshot_download | |
import shutil | |
import uuid | |
import git | |
from pathlib import Path | |
from io import BytesIO | |
import PIL | |
api = HfApi(token=os.environ["HF_TOKEN"]) | |
RESULTS_BACKUP_REPO = "taesiri/PhotoEditBattleResults" | |
# Load the experimental dataset | |
dataset = load_dataset("taesiri/IERv2-BattleResults_exp", split="train") | |
dataset_post_ids = list( | |
set( | |
load_dataset( | |
"taesiri/IERv2-BattleResults_exp", columns=["post_id"], split="train" | |
) | |
.to_pandas() | |
.post_id.tolist() | |
) | |
) | |
photoexp = pd.read_csv("./photoexp_filtered.csv") | |
valid_post_ids = set(photoexp.post_id.tolist()) | |
dataset = dataset.filter( | |
lambda xs: [x in valid_post_ids for x in xs["post_id"]], | |
batched=True, | |
batch_size=256, | |
) | |
print(f"Dataset size after filtering: {len(dataset)}") | |
# Download existing data from hub | |
def sync_with_hub(): | |
""" | |
Synchronize local data with the hub by cloning the dataset repo | |
""" | |
print("Starting sync with hub...") | |
data_dir = Path("./data") | |
local_csv_path = data_dir / "evaluation_results_exp.csv" | |
# Read existing local data if it exists | |
local_data = None | |
if local_csv_path.exists(): | |
local_data = pd.read_csv(local_csv_path) | |
print(f"Found local data with {len(local_data)} entries") | |
# Clone/pull latest data from hub | |
token = os.environ["HF_TOKEN"] | |
username = "taesiri" | |
repo_url = ( | |
f"https://{username}:{token}@huggingface.co/datasets/{RESULTS_BACKUP_REPO}" | |
) | |
hub_data_dir = Path("hub_data") | |
if hub_data_dir.exists(): | |
print("Pulling latest changes...") | |
repo = git.Repo(hub_data_dir) | |
origin = repo.remotes.origin | |
if "https://" in origin.url: | |
origin.set_url(repo_url) | |
origin.pull() | |
else: | |
print("Cloning repository...") | |
git.Repo.clone_from(repo_url, hub_data_dir) | |
# Merge hub data with local data | |
hub_data_source = hub_data_dir / "data" | |
if hub_data_source.exists(): | |
data_dir.mkdir(exist_ok=True) | |
hub_csv_path = hub_data_source / "evaluation_results_exp.csv" | |
if hub_csv_path.exists(): | |
hub_data = pd.read_csv(hub_csv_path) | |
print(f"Found hub data with {len(hub_data)} entries") | |
if local_data is not None: | |
# Merge data, keeping all entries and removing exact duplicates | |
merged_data = pd.concat([local_data, hub_data]).drop_duplicates() | |
print(f"Merged data has {len(merged_data)} entries") | |
# Save merged data | |
merged_data.to_csv(local_csv_path, index=False) | |
else: | |
# If no local data exists, just copy hub data | |
shutil.copy2(hub_csv_path, local_csv_path) | |
# Copy any other files from hub | |
for item in hub_data_source.glob("*"): | |
if item.is_file() and item.name != "evaluation_results_exp.csv": | |
shutil.copy2(item, data_dir / item.name) | |
elif item.is_dir(): | |
dest = data_dir / item.name | |
if not dest.exists(): | |
shutil.copytree(item, dest) | |
# Clean up cloned repo | |
if hub_data_dir.exists(): | |
shutil.rmtree(hub_data_dir) | |
print("Finished syncing with hub!") | |
scheduler = CommitScheduler( | |
repo_id=RESULTS_BACKUP_REPO, | |
repo_type="dataset", | |
folder_path="./data", | |
path_in_repo="data", | |
every=1, | |
) | |
def save_evaluation(post_id, model_a, model_b, verdict): | |
"""Save evaluation results to CSV. Multiple evaluations per image/model are allowed.""" | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
# Create data directory if it doesn't exist | |
os.makedirs("data", exist_ok=True) | |
filename = "data/evaluation_results_exp.csv" | |
# Create file with headers if it doesn't exist | |
if not os.path.exists(filename): | |
with open(filename, "w", newline="") as f: | |
writer = csv.writer(f) | |
writer.writerow(["timestamp", "post_id", "model_a", "model_b", "verdict"]) | |
# Append the new evaluation - multiple entries per image/model are allowed | |
with open(filename, "a", newline="") as f: | |
writer = csv.writer(f) | |
writer.writerow([timestamp, post_id, model_a, model_b, verdict]) | |
print( | |
f"Saved evaluation: {post_id} - Model A: {model_a} - Model B: {model_b} - Verdict: {verdict}" | |
) | |
def get_random_sample(): | |
"""Get a random sample by first selecting a post_id then picking random edits for that post.""" | |
# First randomly select a post_id from valid posts | |
random_post_id = random.choice(list(valid_post_ids)) | |
# Filter dataset for this post_id using batched processing | |
post_edits = dataset.filter( | |
lambda xs: [x == random_post_id for x in xs["post_id"]], | |
batched=True, | |
batch_size=256, | |
) | |
# Get matching photoexp entries for this post_id | |
matching_photoexp_entries = photoexp[photoexp.post_id == random_post_id] | |
# Randomly select one edit from the dataset | |
idx = random.randint(0, len(post_edits) - 1) | |
sample = post_edits[idx] | |
# Randomly select one entry from the matching photoexp entries | |
if not matching_photoexp_entries.empty: | |
random_photoexp_entry = matching_photoexp_entries.sample(n=1).iloc[0] | |
additional_edited_image = random_photoexp_entry["edited_image"] | |
# Add REDDIT_ prefix when using comment_id instead of model | |
model_b = random_photoexp_entry.get("model") | |
if model_b is None: | |
model_b = f"REDDIT_{random_photoexp_entry['comment_id']}" | |
else: | |
additional_edited_image = None | |
model_b = None | |
# Randomly assign images to A and B | |
if random.choice([True, False]): | |
image_a = sample["edited_image"] | |
model_a = sample["model"] | |
image_b = additional_edited_image | |
else: | |
image_a = additional_edited_image | |
model_a = model_b | |
image_b = sample["edited_image"] | |
model_b = sample["model"] | |
print(f"Selected post_id: {random_post_id}") | |
print(f"Selected edit from model: {sample['model']}") | |
return { | |
"post_id": sample["post_id"], | |
"instruction": '## Edit Request: "' + sample["instruction"] + '"', | |
"simplified_instruction": '## Edit Request: "' | |
+ sample["simplified_instruction"] | |
+ '"', | |
"source_image": sample["source_image"], | |
"image_a": image_a, | |
"image_b": image_b, | |
"model_a": model_a, | |
"model_b": model_b, | |
} | |
def evaluate(verdict, state): | |
"""Handle evaluation button clicks""" | |
if state is None: | |
return ( | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
False, | |
False, | |
False, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
) | |
# Save the evaluation | |
save_evaluation(state["post_id"], state["model_a"], state["model_b"], verdict) | |
# Get next sample | |
next_sample = get_random_sample() | |
# Reset button styles | |
a_better_reset = gr.update(variant="secondary") | |
b_better_reset = gr.update(variant="secondary") | |
neither_reset = gr.update(variant="secondary") | |
tie_reset = gr.update(variant="secondary") | |
return ( | |
next_sample["source_image"], | |
next_sample["image_a"], | |
next_sample["image_b"], | |
next_sample["instruction"], | |
next_sample["simplified_instruction"], | |
f"Model A: {next_sample['model_a']} | Model B: {next_sample['model_b']}", | |
next_sample, | |
None, # selected_verdict | |
False, | |
False, | |
False, | |
False, # reset all button states | |
a_better_reset, # reset A is better button style | |
b_better_reset, # reset B is better button style | |
neither_reset, # reset neither is good button style | |
tie_reset, # reset tie button style | |
next_sample["post_id"], | |
next_sample["simplified_instruction"], | |
) | |
def select_verdict(verdict, state): | |
"""Handle first step selection""" | |
if state is None: | |
return None, False, False, False, False # Ensure it returns 5 values | |
return ( | |
verdict, | |
verdict == "A is better", | |
verdict == "B is better", | |
verdict == "Neither is good", | |
verdict == "Tie", | |
) | |
def initialize(): | |
"""Initialize the interface with first sample""" | |
sample = get_random_sample() | |
return ( | |
sample["source_image"], | |
sample["image_a"], | |
sample["image_b"], | |
sample["instruction"], | |
sample["simplified_instruction"], | |
f"Model A: {sample['model_a']} | Model B: {sample['model_b']}", | |
sample, | |
None, # selected_verdict | |
False, # a_better_selected | |
False, # b_better_selected | |
False, # neither_selected | |
False, # tie_selected | |
sample["post_id"], | |
sample["simplified_instruction"], | |
) | |
def update_button_styles(verdict): | |
"""Update button styles based on selection""" | |
# Update button labels to use emojis | |
a_better_style = gr.update( | |
value="👈 A is better" if verdict == "A is better" else "👈 A is better" | |
) | |
b_better_style = gr.update( | |
value="👉 B is better" if verdict == "B is better" else "👉 B is better" | |
) | |
neither_style = gr.update( | |
value="👎 Both are bad" if verdict == "Neither is good" else "👎 Both are bad" | |
) | |
tie_style = gr.update(value="🤝 Tie" if verdict == "Tie" else "🤝 Tie") | |
return a_better_style, b_better_style, neither_style, tie_style | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
# Add instruction panel at the top | |
gr.HTML( | |
""" | |
<div style="padding: 0.8rem; margin-bottom: 0.8rem; border-radius: 0.5rem; color: white; text-align: center;"> | |
<div style="font-size: 1.2rem; margin-bottom: 0.5rem;">Read the user instruction, look at the source image, then evaluate which edit (A or B) best satisfies the request better.</div> | |
<div style="font-size: 1rem;"> | |
<strong>🤝 Tie</strong> | | |
<strong>👈 A is better</strong> | | |
<strong>👉 B is better</strong> | |
</div> | |
<div style="color: #ff4444; font-size: 0.9rem; margin-top: 0.5rem;"> | |
Please ignore any watermark on the image. Your rating should not be affected by any watermark on the image. | |
</div> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
simplified_instruction = gr.Textbox( | |
label="Simplified Instruction", show_label=True, visible=False | |
) | |
instruction = gr.Markdown(label="Original Instruction", show_label=True) | |
with gr.Row(): | |
with gr.Column(): | |
source_image = gr.Image(label="Source Image", show_label=True, height=500) | |
gr.HTML("<h2 style='text-align: center;'>Source Image</h2>") | |
tie_btn = gr.Button("🤝 Tie", variant="secondary") | |
with gr.Column(): | |
image_a = gr.Image(label="Image A", show_label=True, height=500) | |
gr.HTML("<h2 style='text-align: center;'>Image A</h2>") | |
a_better_btn = gr.Button("👈 A is better", variant="secondary") | |
with gr.Column(): | |
image_b = gr.Image(label="Image B", show_label=True, height=500) | |
gr.HTML("<h2 style='text-align: center;'>Image B</h2>") | |
b_better_btn = gr.Button("👉 B is better", variant="secondary") | |
# Add confirmation button in new row | |
with gr.Row(): | |
confirm_btn = gr.Button("Confirm Selection", variant="primary", visible=False) | |
with gr.Row(): | |
neither_btn = gr.Button("👎 Both are bad", variant="secondary", visible=False) | |
with gr.Accordion("DEBUG", open=False): | |
with gr.Column(): | |
post_id_display = gr.Textbox( | |
label="Post ID", show_label=True, interactive=False | |
) | |
model_info = gr.Textbox(label="Model Information", show_label=True) | |
simplified_instruction_debug = gr.Textbox( | |
label="Simplified Instruction", show_label=True, interactive=False | |
) | |
state = gr.State() | |
selected_verdict = gr.State() | |
# Add states for button selection | |
a_better_selected = gr.Checkbox(visible=False) | |
b_better_selected = gr.Checkbox(visible=False) | |
neither_selected = gr.Checkbox(visible=False) | |
tie_selected = gr.Checkbox(visible=False) | |
def update_confirm_visibility(a_better, b_better, neither, tie): | |
# Update button text based on selection | |
if a_better: | |
return gr.update(visible=True, value="Confirm A is better") | |
elif b_better: | |
return gr.update(visible=True, value="Confirm B is better") | |
elif neither: | |
return gr.update(visible=True, value="Confirm Neither is good") | |
elif tie: | |
return gr.update(visible=True, value="Confirm Tie") | |
return gr.update(visible=False) | |
# Initialize the interface | |
demo.load( | |
initialize, | |
outputs=[ | |
source_image, | |
image_a, | |
image_b, | |
instruction, | |
simplified_instruction, | |
model_info, | |
state, | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
post_id_display, | |
simplified_instruction_debug, | |
], | |
) | |
# Handle first step button clicks | |
a_better_btn.click( | |
lambda state: select_verdict("A is better", state), | |
inputs=[state], | |
outputs=[ | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
], | |
).then( | |
update_button_styles, | |
inputs=[selected_verdict], | |
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn], | |
) | |
b_better_btn.click( | |
lambda state: select_verdict("B is better", state), | |
inputs=[state], | |
outputs=[ | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
], | |
).then( | |
update_button_styles, | |
inputs=[selected_verdict], | |
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn], | |
) | |
neither_btn.click( | |
lambda state: select_verdict("Neither is good", state), | |
inputs=[state], | |
outputs=[ | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
], | |
).then( | |
update_button_styles, | |
inputs=[selected_verdict], | |
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn], | |
) | |
tie_btn.click( | |
lambda state: select_verdict("Tie", state), | |
inputs=[state], | |
outputs=[ | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
], | |
).then( | |
update_button_styles, | |
inputs=[selected_verdict], | |
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn], | |
) | |
# Update confirm button visibility when selection changes | |
for checkbox in [ | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
]: | |
checkbox.change( | |
update_confirm_visibility, | |
inputs=[ | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
], | |
outputs=[confirm_btn], | |
) | |
# Handle confirmation button click | |
confirm_btn.click( | |
lambda verdict, state: evaluate(verdict, state), | |
inputs=[selected_verdict, state], | |
outputs=[ | |
source_image, | |
image_a, | |
image_b, | |
instruction, | |
simplified_instruction, | |
model_info, | |
state, | |
selected_verdict, | |
a_better_selected, | |
b_better_selected, | |
neither_selected, | |
tie_selected, | |
a_better_btn, | |
b_better_btn, | |
neither_btn, | |
tie_btn, | |
post_id_display, | |
simplified_instruction_debug, | |
], | |
) | |
if __name__ == "__main__": | |
# Sync with hub before launching | |
sync_with_hub() | |
demo.launch() | |