|
import gradio as gr |
|
from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor |
|
import requests |
|
import asyncio |
|
import httpx |
|
import io |
|
from PIL import Image |
|
import PIL |
|
from functools import lru_cache |
|
from toolz import pluck |
|
from piffle.image import IIIFImageClient |
|
|
|
HF_MODEL_PATH = ( |
|
"ImageIN/levit-192_finetuned_on_unlabelled_IA_with_snorkel_labels" |
|
) |
|
|
|
classif_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_PATH) |
|
feature_extractor = AutoFeatureExtractor.from_pretrained(HF_MODEL_PATH) |
|
|
|
classif_pipeline = pipeline( |
|
"image-classification", model=classif_model, feature_extractor=feature_extractor |
|
) |
|
|
|
|
|
def load_manifest(inputs): |
|
with requests.get(inputs) as r: |
|
return r.json() |
|
|
|
|
|
def get_image_urls_from_manifest(data): |
|
image_urls = [] |
|
for sequences in data['sequences']: |
|
for canvases in sequences['canvases']: |
|
image_urls.extend(image['resource']['@id'] for image in canvases['images']) |
|
return image_urls |
|
|
|
|
|
def resize_iiif_urls(image_url, size='224'): |
|
|
|
|
|
|
|
image_url = IIIFImageClient.init_from_url(image_url) |
|
image_url = image_url.size(width=size, height=size) |
|
return image_url.__str__() |
|
|
|
|
|
async def get_image(client, url): |
|
try: |
|
resp = await client.get(url, timeout=30) |
|
return Image.open(io.BytesIO(resp.content)) |
|
except (PIL.UnidentifiedImageError, httpx.ReadTimeout, httpx.ConnectError): |
|
return None |
|
|
|
|
|
async def get_images(urls): |
|
async with httpx.AsyncClient() as client: |
|
tasks = [asyncio.ensure_future(get_image(client, url)) for url in urls] |
|
images = await asyncio.gather(*tasks) |
|
assert len(images) == len(urls) |
|
return [(url, image) for url, image in zip(urls, images) if image is not None] |
|
|
|
|
|
|
|
def predict(inputs): |
|
return _predict(str(inputs)) |
|
|
|
|
|
@lru_cache(maxsize=100) |
|
def _predict(inputs): |
|
data = load_manifest(inputs) |
|
urls = get_image_urls_from_manifest(data) |
|
resized_urls = [resize_iiif_urls(url) for url in urls] |
|
images_urls = asyncio.run(get_images(resized_urls)) |
|
predicted_images = [] |
|
images = list(pluck(1, images_urls)) |
|
urls = list(pluck(0, images_urls)) |
|
predictions = classif_pipeline(images, top_k=1, num_workers=2) |
|
for url, pred in zip(urls, predictions): |
|
top_pred = pred[0] |
|
if top_pred['label'] == 'illustrated': |
|
image_url = IIIFImageClient.init_from_url(url) |
|
image_url = image_url.size(width=500) |
|
image_url = image_url.size(width=500, height='') |
|
predicted_images.append((str(image_url), f"Confidence: {top_pred['score']}, \n image url: {image_url}")) |
|
|
|
return predicted_images |
|
|
|
|
|
gallery = gr.Gallery() |
|
gallery.style(grid=3) |
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Text(label="IIIF manifest url"), |
|
outputs=gallery, |
|
title="ImageIN", |
|
description="Identify illustrations in pages of historical books!", |
|
examples=['https://iiif.lib.harvard.edu/manifests/drs:427603172', |
|
"https://iiif.lib.harvard.edu/manifests/drs:427342693","https://iiif.archivelab.org/iiif/holylandbiblebok0000cunn/manifest.json"] |
|
).queue() |
|
demo.launch() |
|
|