gradio_field_app / create_argilla_dataset.py
burtenshaw's picture
burtenshaw HF Staff
add create dataset script
547d83c
raw
history blame
2.61 kB
import random
import urllib.parse
from http import client
import argilla as rg
import requests
from jinja2 import Template
def generate_html(image_urls, app_url):
"""none"""
template_str = """
<html>
<body>
<iframe src="{{ full_url }}" width="400" height="400" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
</body>
</html>
"""
base_url = f"{app_url}/?"
encoded_urls = [
"img{}={}".format(i + 1, urllib.parse.quote(url, safe=""))
for i, url in enumerate(image_urls)
]
full_url = base_url + "&".join(encoded_urls)
template = Template(template_str)
return template.render(full_url=full_url)
def fetch_data(max_images, name="horse"):
base_url = "https://datasets-server.huggingface.co/rows"
params = {
"dataset": "gigant/horse2zebra",
"config": name,
"split": "train",
"offset": 0,
"length": max_images,
}
response = requests.get(base_url, params=params)
if response.status_code == 200:
data = response.json()
return [row["row"]["image"]["src"] for row in data["rows"]]
else:
print(f"Failed to fetch data. Status code: {response.status_code}")
return None
def log_records(dataset, app_url, max_images=10):
horse_urls = fetch_data(max_images=max_images, name="horse")
zebra_urls = fetch_data(max_images=max_images, name="zebra")
records = []
for horse_url, zebra_url in zip(horse_urls, zebra_urls):
markdown_str = generate_html(image_urls=[horse_url, zebra_url], app_url=app_url)
record = rg.Record(fields={"gradio_app": markdown_str})
records.append(record)
dataset.records.log(records)
def create_dataset(api_url="http://localhost:6900", api_key="owner.apikey"):
client = rg.Argilla(api_url=api_url, api_key=api_key)
# Create a dataset
settings = rg.Settings(
fields=[rg.TextField(name="gradio_app")],
questions=[
rg.LabelQuestion(
name="is_zebra",
description="Is this a zebra?",
labels=["true", "false"],
)
],
)
dataset = rg.Dataset(
name=f"horse2zebra_{random.randint(0,100)}", settings=settings, client=client
)
dataset.create()
return dataset
if __name__ == "__main__":
dataset = create_dataset(api_url="https://burtenshaw-gradio-field.hf.space")
log_records(
dataset, max_images=10, app_url="https://burtenshaw-gradio-field-app.hf.space"
)