taesiri's picture
backup
a167ff0
raw
history blame
16.9 kB
import gradio as gr
import random
from datasets import load_dataset
import csv
from datetime import datetime
import os
import pandas as pd
import json
from huggingface_hub import CommitScheduler, HfApi, snapshot_download
import shutil
import uuid
import git
from pathlib import Path
from io import BytesIO
import PIL
api = HfApi(token=os.environ["HF_TOKEN"])
RESULTS_BACKUP_REPO = "taesiri/PhotoEditBattleResults"
# Load the experimental dataset
dataset = load_dataset("taesiri/IERv2-BattleResults_exp", split="train")
dataset_post_ids = list(
set(
load_dataset(
"taesiri/IERv2-BattleResults_exp", columns=["post_id"], split="train"
)
.to_pandas()
.post_id.tolist()
)
)
photoexp = pd.read_csv("./photoexp_filtered.csv")
valid_post_ids = set(photoexp.post_id.tolist())
dataset = dataset.filter(
lambda xs: [x in valid_post_ids for x in xs["post_id"]],
batched=True,
batch_size=256,
)
print(f"Dataset size after filtering: {len(dataset)}")
# Download existing data from hub
def sync_with_hub():
"""
Synchronize local data with the hub by cloning the dataset repo
"""
print("Starting sync with hub...")
data_dir = Path("./data")
local_csv_path = data_dir / "evaluation_results_exp.csv"
# Read existing local data if it exists
local_data = None
if local_csv_path.exists():
local_data = pd.read_csv(local_csv_path)
print(f"Found local data with {len(local_data)} entries")
# Clone/pull latest data from hub
token = os.environ["HF_TOKEN"]
username = "taesiri"
repo_url = (
f"https://{username}:{token}@huggingface.co/datasets/{RESULTS_BACKUP_REPO}"
)
hub_data_dir = Path("hub_data")
if hub_data_dir.exists():
print("Pulling latest changes...")
repo = git.Repo(hub_data_dir)
origin = repo.remotes.origin
if "https://" in origin.url:
origin.set_url(repo_url)
origin.pull()
else:
print("Cloning repository...")
git.Repo.clone_from(repo_url, hub_data_dir)
# Merge hub data with local data
hub_data_source = hub_data_dir / "data"
if hub_data_source.exists():
data_dir.mkdir(exist_ok=True)
hub_csv_path = hub_data_source / "evaluation_results_exp.csv"
if hub_csv_path.exists():
hub_data = pd.read_csv(hub_csv_path)
print(f"Found hub data with {len(hub_data)} entries")
if local_data is not None:
# Merge data, keeping all entries and removing exact duplicates
merged_data = pd.concat([local_data, hub_data]).drop_duplicates()
print(f"Merged data has {len(merged_data)} entries")
# Save merged data
merged_data.to_csv(local_csv_path, index=False)
else:
# If no local data exists, just copy hub data
shutil.copy2(hub_csv_path, local_csv_path)
# Copy any other files from hub
for item in hub_data_source.glob("*"):
if item.is_file() and item.name != "evaluation_results_exp.csv":
shutil.copy2(item, data_dir / item.name)
elif item.is_dir():
dest = data_dir / item.name
if not dest.exists():
shutil.copytree(item, dest)
# Clean up cloned repo
if hub_data_dir.exists():
shutil.rmtree(hub_data_dir)
print("Finished syncing with hub!")
scheduler = CommitScheduler(
repo_id=RESULTS_BACKUP_REPO,
repo_type="dataset",
folder_path="./data",
path_in_repo="data",
every=1,
)
def save_evaluation(post_id, model_a, model_b, verdict):
"""Save evaluation results to CSV. Multiple evaluations per image/model are allowed."""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Create data directory if it doesn't exist
os.makedirs("data", exist_ok=True)
filename = "data/evaluation_results_exp.csv"
# Create file with headers if it doesn't exist
if not os.path.exists(filename):
with open(filename, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["timestamp", "post_id", "model_a", "model_b", "verdict"])
# Append the new evaluation - multiple entries per image/model are allowed
with open(filename, "a", newline="") as f:
writer = csv.writer(f)
writer.writerow([timestamp, post_id, model_a, model_b, verdict])
print(
f"Saved evaluation: {post_id} - Model A: {model_a} - Model B: {model_b} - Verdict: {verdict}"
)
def get_random_sample():
"""Get a random sample by first selecting a post_id then picking random edits for that post."""
# First randomly select a post_id from valid posts
random_post_id = random.choice(list(valid_post_ids))
# Filter dataset for this post_id using batched processing
post_edits = dataset.filter(
lambda xs: [x == random_post_id for x in xs["post_id"]],
batched=True,
batch_size=256,
)
# Get matching photoexp entries for this post_id
matching_photoexp_entries = photoexp[photoexp.post_id == random_post_id]
# Randomly select one edit from the dataset
idx = random.randint(0, len(post_edits) - 1)
sample = post_edits[idx]
# Randomly select one entry from the matching photoexp entries
if not matching_photoexp_entries.empty:
random_photoexp_entry = matching_photoexp_entries.sample(n=1).iloc[0]
additional_edited_image = random_photoexp_entry["edited_image"]
# Add REDDIT_ prefix when using comment_id instead of model
model_b = random_photoexp_entry.get("model")
if model_b is None:
model_b = f"REDDIT_{random_photoexp_entry['comment_id']}"
else:
additional_edited_image = None
model_b = None
# Randomly assign images to A and B
if random.choice([True, False]):
image_a = sample["edited_image"]
model_a = sample["model"]
image_b = additional_edited_image
else:
image_a = additional_edited_image
model_a = model_b
image_b = sample["edited_image"]
model_b = sample["model"]
print(f"Selected post_id: {random_post_id}")
print(f"Selected edit from model: {sample['model']}")
return {
"post_id": sample["post_id"],
"instruction": '## Edit Request: "' + sample["instruction"] + '"',
"simplified_instruction": '## Edit Request: "'
+ sample["simplified_instruction"]
+ '"',
"source_image": sample["source_image"],
"image_a": image_a,
"image_b": image_b,
"model_a": model_a,
"model_b": model_b,
}
def evaluate(verdict, state):
"""Handle evaluation button clicks"""
if state is None:
return (
None,
None,
None,
None,
None,
None,
None,
False,
False,
False,
None,
None,
None,
None,
None,
None,
None,
)
# Save the evaluation
save_evaluation(state["post_id"], state["model_a"], state["model_b"], verdict)
# Get next sample
next_sample = get_random_sample()
# Reset button styles
a_better_reset = gr.update(variant="secondary")
b_better_reset = gr.update(variant="secondary")
neither_reset = gr.update(variant="secondary")
tie_reset = gr.update(variant="secondary")
return (
next_sample["source_image"],
next_sample["image_a"],
next_sample["image_b"],
next_sample["instruction"],
next_sample["simplified_instruction"],
f"Model A: {next_sample['model_a']} | Model B: {next_sample['model_b']}",
next_sample,
None, # selected_verdict
False,
False,
False,
False, # reset all button states
a_better_reset, # reset A is better button style
b_better_reset, # reset B is better button style
neither_reset, # reset neither is good button style
tie_reset, # reset tie button style
next_sample["post_id"],
next_sample["simplified_instruction"],
)
def select_verdict(verdict, state):
"""Handle first step selection"""
if state is None:
return None, False, False, False, False # Ensure it returns 5 values
return (
verdict,
verdict == "A is better",
verdict == "B is better",
verdict == "Neither is good",
verdict == "Tie",
)
def initialize():
"""Initialize the interface with first sample"""
sample = get_random_sample()
return (
sample["source_image"],
sample["image_a"],
sample["image_b"],
sample["instruction"],
sample["simplified_instruction"],
f"Model A: {sample['model_a']} | Model B: {sample['model_b']}",
sample,
None, # selected_verdict
False, # a_better_selected
False, # b_better_selected
False, # neither_selected
False, # tie_selected
sample["post_id"],
sample["simplified_instruction"],
)
def update_button_styles(verdict):
"""Update button styles based on selection"""
# Update button labels to use emojis
a_better_style = gr.update(
value="👈 A is better" if verdict == "A is better" else "👈 A is better"
)
b_better_style = gr.update(
value="👉 B is better" if verdict == "B is better" else "👉 B is better"
)
neither_style = gr.update(
value="👎 Both are bad" if verdict == "Neither is good" else "👎 Both are bad"
)
tie_style = gr.update(value="🤝 Tie" if verdict == "Tie" else "🤝 Tie")
return a_better_style, b_better_style, neither_style, tie_style
# Create Gradio interface
with gr.Blocks() as demo:
# Add instruction panel at the top
gr.HTML(
"""
<div style="padding: 0.8rem; margin-bottom: 0.8rem; border-radius: 0.5rem; color: white; text-align: center;">
<div style="font-size: 1.2rem; margin-bottom: 0.5rem;">Read the user instruction, look at the source image, then evaluate which edit (A or B) best satisfies the request better.</div>
<div style="font-size: 1rem;">
<strong>🤝 Tie</strong> &nbsp;&nbsp;|&nbsp;&nbsp;
<strong>👈 A is better</strong> &nbsp;&nbsp;|&nbsp;&nbsp;
<strong>👉 B is better</strong>
</div>
<div style="color: #ff4444; font-size: 0.9rem; margin-top: 0.5rem;">
Please ignore any watermark on the image. Your rating should not be affected by any watermark on the image.
</div>
</div>
"""
)
with gr.Row():
simplified_instruction = gr.Textbox(
label="Simplified Instruction", show_label=True, visible=False
)
instruction = gr.Markdown(label="Original Instruction", show_label=True)
with gr.Row():
with gr.Column():
source_image = gr.Image(label="Source Image", show_label=True, height=500)
gr.HTML("<h2 style='text-align: center;'>Source Image</h2>")
tie_btn = gr.Button("🤝 Tie", variant="secondary")
with gr.Column():
image_a = gr.Image(label="Image A", show_label=True, height=500)
gr.HTML("<h2 style='text-align: center;'>Image A</h2>")
a_better_btn = gr.Button("👈 A is better", variant="secondary")
with gr.Column():
image_b = gr.Image(label="Image B", show_label=True, height=500)
gr.HTML("<h2 style='text-align: center;'>Image B</h2>")
b_better_btn = gr.Button("👉 B is better", variant="secondary")
# Add confirmation button in new row
with gr.Row():
confirm_btn = gr.Button("Confirm Selection", variant="primary", visible=False)
with gr.Row():
neither_btn = gr.Button("👎 Both are bad", variant="secondary", visible=False)
with gr.Accordion("DEBUG", open=False):
with gr.Column():
post_id_display = gr.Textbox(
label="Post ID", show_label=True, interactive=False
)
model_info = gr.Textbox(label="Model Information", show_label=True)
simplified_instruction_debug = gr.Textbox(
label="Simplified Instruction", show_label=True, interactive=False
)
state = gr.State()
selected_verdict = gr.State()
# Add states for button selection
a_better_selected = gr.Checkbox(visible=False)
b_better_selected = gr.Checkbox(visible=False)
neither_selected = gr.Checkbox(visible=False)
tie_selected = gr.Checkbox(visible=False)
def update_confirm_visibility(a_better, b_better, neither, tie):
# Update button text based on selection
if a_better:
return gr.update(visible=True, value="Confirm A is better")
elif b_better:
return gr.update(visible=True, value="Confirm B is better")
elif neither:
return gr.update(visible=True, value="Confirm Neither is good")
elif tie:
return gr.update(visible=True, value="Confirm Tie")
return gr.update(visible=False)
# Initialize the interface
demo.load(
initialize,
outputs=[
source_image,
image_a,
image_b,
instruction,
simplified_instruction,
model_info,
state,
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
post_id_display,
simplified_instruction_debug,
],
)
# Handle first step button clicks
a_better_btn.click(
lambda state: select_verdict("A is better", state),
inputs=[state],
outputs=[
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
],
).then(
update_button_styles,
inputs=[selected_verdict],
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn],
)
b_better_btn.click(
lambda state: select_verdict("B is better", state),
inputs=[state],
outputs=[
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
],
).then(
update_button_styles,
inputs=[selected_verdict],
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn],
)
neither_btn.click(
lambda state: select_verdict("Neither is good", state),
inputs=[state],
outputs=[
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
],
).then(
update_button_styles,
inputs=[selected_verdict],
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn],
)
tie_btn.click(
lambda state: select_verdict("Tie", state),
inputs=[state],
outputs=[
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
],
).then(
update_button_styles,
inputs=[selected_verdict],
outputs=[a_better_btn, b_better_btn, neither_btn, tie_btn],
)
# Update confirm button visibility when selection changes
for checkbox in [
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
]:
checkbox.change(
update_confirm_visibility,
inputs=[
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
],
outputs=[confirm_btn],
)
# Handle confirmation button click
confirm_btn.click(
lambda verdict, state: evaluate(verdict, state),
inputs=[selected_verdict, state],
outputs=[
source_image,
image_a,
image_b,
instruction,
simplified_instruction,
model_info,
state,
selected_verdict,
a_better_selected,
b_better_selected,
neither_selected,
tie_selected,
a_better_btn,
b_better_btn,
neither_btn,
tie_btn,
post_id_display,
simplified_instruction_debug,
],
)
if __name__ == "__main__":
# Sync with hub before launching
sync_with_hub()
demo.launch()