Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor | |
import requests | |
import asyncio | |
import httpx | |
import io | |
from PIL import Image | |
import PIL | |
from functools import lru_cache | |
from toolz import pluck | |
from piffle.image import IIIFImageClient | |
HF_MODEL_PATH = ( | |
"ImageIN/levit-192_finetuned_on_unlabelled_IA_with_snorkel_labels" | |
) | |
classif_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_PATH) | |
feature_extractor = AutoFeatureExtractor.from_pretrained(HF_MODEL_PATH) | |
classif_pipeline = pipeline( | |
"image-classification", model=classif_model, feature_extractor=feature_extractor | |
) | |
def load_manifest(inputs): | |
with requests.get(inputs) as r: | |
return r.json() | |
def get_image_urls_from_manifest(data): | |
image_urls = [] | |
for sequences in data['sequences']: | |
for canvases in sequences['canvases']: | |
image_urls.extend(image['resource']['@id'] for image in canvases['images']) | |
return image_urls | |
def resize_iiif_urls(image_url, size='224'): | |
# parts = im_url.split("/") | |
# parts[6] = f"{size}, {size}" | |
# return "/".join(parts) | |
image_url = IIIFImageClient.init_from_url(image_url) | |
image_url = image_url.size(width=size, height=size) | |
return image_url.__str__() | |
async def get_image(client, url): | |
try: | |
resp = await client.get(url, timeout=30) | |
return Image.open(io.BytesIO(resp.content)) | |
except (PIL.UnidentifiedImageError, httpx.ReadTimeout, httpx.ConnectError): | |
return None | |
async def get_images(urls): | |
async with httpx.AsyncClient() as client: | |
tasks = [asyncio.ensure_future(get_image(client, url)) for url in urls] | |
images = await asyncio.gather(*tasks) | |
assert len(images) == len(urls) | |
return [(url, image) for url, image in zip(urls, images) if image is not None] | |
# return [image for image in images if image is not None] | |
def predict(inputs): | |
return _predict(str(inputs)) | |
def _predict(inputs): | |
data = load_manifest(inputs) | |
urls = get_image_urls_from_manifest(data) | |
resized_urls = [resize_iiif_urls(url) for url in urls] | |
images_urls = asyncio.run(get_images(resized_urls)) | |
predicted_images = [] | |
images = list(pluck(1, images_urls)) | |
urls = list(pluck(0, images_urls)) | |
predictions = classif_pipeline(images, top_k=1, num_workers=2) | |
for url, pred in zip(urls, predictions): | |
top_pred = pred[0] | |
if top_pred['label'] == 'illustrated': | |
image_url = IIIFImageClient.init_from_url(url) | |
image_url = image_url.size(width=500) | |
image_url = image_url.size(width=500, height='') | |
predicted_images.append((str(image_url), f"Confidence: {top_pred['score']}, \n image url: {image_url}")) | |
return predicted_images | |
gallery = gr.Gallery() | |
gallery.style(grid=3) | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.Text(label="IIIF manifest url"), | |
outputs=gallery, | |
title="ImageIN", | |
description="Identify illustrations in pages of historical books!", | |
examples=['https://iiif.lib.harvard.edu/manifests/drs:427603172', | |
"https://iiif.lib.harvard.edu/manifests/drs:427342693","https://iiif.archivelab.org/iiif/holylandbiblebok0000cunn/manifest.json"] | |
).queue() | |
demo.launch() | |