import gradio as gr from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor import requests import asyncio import httpx import io from PIL import Image import PIL from functools import lru_cache from toolz import pluck from piffle.image import IIIFImageClient HF_MODEL_PATH = ( "ImageIN/levit-192_finetuned_on_unlabelled_IA_with_snorkel_labels" ) classif_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_PATH) feature_extractor = AutoFeatureExtractor.from_pretrained(HF_MODEL_PATH) classif_pipeline = pipeline( "image-classification", model=classif_model, feature_extractor=feature_extractor ) def load_manifest(inputs): with requests.get(inputs) as r: return r.json() def get_image_urls_from_manifest(data): image_urls = [] for sequences in data['sequences']: for canvases in sequences['canvases']: image_urls.extend(image['resource']['@id'] for image in canvases['images']) return image_urls def resize_iiif_urls(image_url, size='224'): # parts = im_url.split("/") # parts[6] = f"{size}, {size}" # return "/".join(parts) image_url = IIIFImageClient.init_from_url(image_url) image_url = image_url.size(width=size, height=size) return image_url.__str__() async def get_image(client, url): try: resp = await client.get(url, timeout=30) return Image.open(io.BytesIO(resp.content)) except (PIL.UnidentifiedImageError, httpx.ReadTimeout, httpx.ConnectError): return None async def get_images(urls): async with httpx.AsyncClient() as client: tasks = [asyncio.ensure_future(get_image(client, url)) for url in urls] images = await asyncio.gather(*tasks) assert len(images) == len(urls) return [(url, image) for url, image in zip(urls, images) if image is not None] # return [image for image in images if image is not None] def predict(inputs): return _predict(str(inputs)) @lru_cache(maxsize=100) def _predict(inputs): data = load_manifest(inputs) urls = get_image_urls_from_manifest(data) resized_urls = [resize_iiif_urls(url) for url in urls] images_urls = asyncio.run(get_images(resized_urls)) predicted_images = [] images = list(pluck(1, images_urls)) urls = list(pluck(0, images_urls)) predictions = classif_pipeline(images, top_k=1, num_workers=2) for url, pred in zip(urls, predictions): top_pred = pred[0] if top_pred['label'] == 'illustrated': image_url = IIIFImageClient.init_from_url(url) image_url = image_url.size(width=500) image_url = image_url.size(width=500, height='') predicted_images.append((str(image_url), f"Confidence: {top_pred['score']}, \n image url: {image_url}")) return predicted_images gallery = gr.Gallery() gallery.style(grid=3) demo = gr.Interface( fn=predict, inputs=gr.Text(label="IIIF manifest url"), outputs=gallery, title="ImageIN", description="Identify illustrations in pages of historical books!", examples=['https://iiif.lib.harvard.edu/manifests/drs:427603172', "https://iiif.lib.harvard.edu/manifests/drs:427342693","https://iiif.archivelab.org/iiif/holylandbiblebok0000cunn/manifest.json"] ).queue() demo.launch()