Spaces:
Runtime error
Runtime error
import random | |
from datasets import load_dataset | |
import pandas as pd | |
import os | |
from pathlib import Path | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
# Load the experimental dataset | |
dataset = load_dataset("taesiri/IERv2-BattleResults_exp", split="train") | |
dataset_post_ids = list( | |
set( | |
load_dataset( | |
"taesiri/IERv2-BattleResults_exp", columns=["post_id"], split="train" | |
) | |
.to_pandas() | |
.post_id.tolist() | |
) | |
) | |
# Load and filter photoexp dataset | |
photoexp = pd.read_csv("./photoexp_filtered.csv") | |
valid_post_ids = set(photoexp.post_id.tolist()) | |
# Filter dataset to include only valid_post_ids | |
dataset = dataset.filter( | |
lambda xs: [x in valid_post_ids for x in xs["post_id"]], | |
batched=True, | |
batch_size=256, | |
) | |
def download_and_save_image(url, save_path): | |
"""Download image from URL and save it to disk""" | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
img = Image.open(BytesIO(response.content)) | |
img.save(save_path) | |
return True | |
except Exception as e: | |
print(f"Error downloading image {url}: {e}") | |
return False | |
def get_random_sample(): | |
"""Get a random sample by first selecting a post_id then picking random edits for that post.""" | |
# First randomly select a post_id from valid posts | |
random_post_id = random.choice(list(valid_post_ids)) | |
# Filter dataset for this post_id | |
post_edits = dataset.filter( | |
lambda xs: [x == random_post_id for x in xs["post_id"]], | |
batched=True, | |
batch_size=256, | |
) | |
# Get matching photoexp entries for this post_id | |
matching_photoexp_entries = photoexp[photoexp.post_id == random_post_id] | |
# Randomly select one edit from the dataset | |
idx = random.randint(0, len(post_edits) - 1) | |
sample = post_edits[idx] | |
# Randomly select one entry from the matching photoexp entries | |
if not matching_photoexp_entries.empty: | |
random_photoexp_entry = matching_photoexp_entries.sample(n=1).iloc[0] | |
additional_edited_image = random_photoexp_entry["edited_image"] | |
model_b = random_photoexp_entry.get("model") | |
if model_b is None: | |
model_b = f"REDDIT_{random_photoexp_entry['comment_id']}" | |
else: | |
return None | |
return { | |
"post_id": sample["post_id"], | |
"instruction": sample["instruction"], | |
"simplified_instruction": sample["simplified_instruction"], | |
"source_image": sample["source_image"], | |
"edit1_image": sample["edited_image"], | |
"edit1_model": sample["model"], | |
"edit2_image": additional_edited_image, | |
"edit2_model": model_b, | |
} | |
def save_sample(sample, output_dir): | |
"""Save a sample to disk with all its components""" | |
if sample is None: | |
return False | |
# Create directory structure | |
sample_dir = Path(output_dir) / str(sample["post_id"]) | |
sample_dir.mkdir(parents=True, exist_ok=True) | |
# Save instruction and metadata | |
with open(sample_dir / "metadata.txt", "w") as f: | |
f.write(f"Post ID: {sample['post_id']}\n") | |
f.write(f"Original Instruction: {sample['instruction']}\n") | |
f.write(f"Simplified Instruction: {sample['simplified_instruction']}\n") | |
f.write(f"Edit 1 Model: {sample['edit1_model']}\n") | |
f.write(f"Edit 2 Model: {sample['edit2_model']}\n") | |
# Save images | |
success = True | |
success &= download_and_save_image( | |
sample["source_image"], sample_dir / "source.jpg" | |
) | |
success &= download_and_save_image(sample["edit1_image"], sample_dir / "edit1.jpg") | |
success &= download_and_save_image(sample["edit2_image"], sample_dir / "edit2.jpg") | |
return success | |
def main(): | |
output_dir = Path("extracted_samples") | |
output_dir.mkdir(exist_ok=True) | |
num_samples = 100 # Number of samples to extract | |
successful_samples = 0 | |
print(f"Extracting {num_samples} samples...") | |
while successful_samples < num_samples: | |
sample = get_random_sample() | |
if sample and save_sample(sample, output_dir): | |
successful_samples += 1 | |
print(f"Successfully saved sample {successful_samples}/{num_samples}") | |
else: | |
print("Failed to save sample, trying next...") | |
print(f"Successfully extracted {successful_samples} samples to {output_dir}") | |
if __name__ == "__main__": | |
main() | |