davanstrien's picture
davanstrien HF Staff
enable queue
cbbdb8f
raw
history blame
3.32 kB
import gradio as gr
from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor
import requests
import asyncio
import httpx
import io
from PIL import Image
import PIL
from functools import lru_cache
from toolz import pluck
from piffle.image import IIIFImageClient
HF_MODEL_PATH = (
"ImageIN/levit-192_finetuned_on_unlabelled_IA_with_snorkel_labels"
)
classif_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_PATH)
feature_extractor = AutoFeatureExtractor.from_pretrained(HF_MODEL_PATH)
classif_pipeline = pipeline(
"image-classification", model=classif_model, feature_extractor=feature_extractor
)
def load_manifest(inputs):
with requests.get(inputs) as r:
return r.json()
def get_image_urls_from_manifest(data):
image_urls = []
for sequences in data['sequences']:
for canvases in sequences['canvases']:
image_urls.extend(image['resource']['@id'] for image in canvases['images'])
return image_urls
def resize_iiif_urls(image_url, size='224'):
# parts = im_url.split("/")
# parts[6] = f"{size}, {size}"
# return "/".join(parts)
image_url = IIIFImageClient.init_from_url(image_url)
image_url = image_url.size(width=size, height=size)
return image_url.__str__()
async def get_image(client, url):
try:
resp = await client.get(url, timeout=30)
return Image.open(io.BytesIO(resp.content))
except (PIL.UnidentifiedImageError, httpx.ReadTimeout, httpx.ConnectError):
return None
async def get_images(urls):
async with httpx.AsyncClient() as client:
tasks = [asyncio.ensure_future(get_image(client, url)) for url in urls]
images = await asyncio.gather(*tasks)
assert len(images) == len(urls)
return [(url, image) for url, image in zip(urls, images) if image is not None]
# return [image for image in images if image is not None]
def predict(inputs):
return _predict(str(inputs))
@lru_cache(maxsize=100)
def _predict(inputs):
data = load_manifest(inputs)
urls = get_image_urls_from_manifest(data)
resized_urls = [resize_iiif_urls(url) for url in urls]
images_urls = asyncio.run(get_images(resized_urls))
predicted_images = []
images = list(pluck(1, images_urls))
urls = list(pluck(0, images_urls))
predictions = classif_pipeline(images, top_k=1, num_workers=2)
for url, pred in zip(urls, predictions):
top_pred = pred[0]
if top_pred['label'] == 'illustrated':
image_url = IIIFImageClient.init_from_url(url)
image_url = image_url.size(width=500)
image_url = image_url.size(width=500, height='')
predicted_images.append((str(image_url), f"Confidence: {top_pred['score']}, \n image url: {image_url}"))
return predicted_images
gallery = gr.Gallery()
gallery.style(grid=3)
demo = gr.Interface(
fn=predict,
inputs=gr.Text(label="IIIF manifest url"),
outputs=gallery,
title="ImageIN",
description="Identify illustrations in pages of historical books!",
examples=['https://iiif.lib.harvard.edu/manifests/drs:427603172',
"https://iiif.lib.harvard.edu/manifests/drs:427342693","https://iiif.archivelab.org/iiif/holylandbiblebok0000cunn/manifest.json"]
).queue()
demo.launch()