|
import gradio as gr |
|
from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor |
|
import requests |
|
import asyncio |
|
import httpx |
|
import io |
|
from PIL import Image |
|
import PIL |
|
|
|
|
|
HF_MODEL_PATH = ( |
|
"ImageIN/levit-192_finetuned_on_unlabelled_IA_with_snorkel_labels" |
|
) |
|
|
|
classif_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_PATH) |
|
feature_extractor = AutoFeatureExtractor.from_pretrained(HF_MODEL_PATH) |
|
|
|
classif_pipeline = pipeline( |
|
"image-classification", model=classif_model, feature_extractor=feature_extractor |
|
) |
|
|
|
OUTPUT_SENTENCE = "This image is {result}." |
|
|
|
|
|
def load_manifest(inputs): |
|
with requests.get(inputs) as r: |
|
return r.json() |
|
|
|
|
|
def get_image_urls_from_manifest(data): |
|
image_urls = [] |
|
for sequences in data['sequences']: |
|
for canvases in sequences['canvases']: |
|
image_urls.extend(image['resource']['@id'] for image in canvases['images']) |
|
return image_urls |
|
|
|
|
|
def resize_iiif_urls(im_url, size='224'): |
|
parts = im_url.split("/") |
|
parts[6] = f"{size}, {size}" |
|
return "/".join(parts) |
|
|
|
|
|
async def get_image(client, url): |
|
try: |
|
resp = await client.get(url, timeout=30) |
|
return Image.open(io.BytesIO(resp.content)) |
|
except (PIL.UnidentifiedImageError, httpx.ReadTimeout, httpx.ConnectError): |
|
return None |
|
|
|
|
|
async def get_images(urls): |
|
async with httpx.AsyncClient() as client: |
|
tasks = [asyncio.ensure_future(get_image(client, url)) for url in urls] |
|
images = await asyncio.gather(*tasks) |
|
return [image for image in images if image is not None] |
|
|
|
|
|
def predict(inputs): |
|
data = load_manifest(inputs) |
|
urls = get_image_urls_from_manifest(data) |
|
resized_urls = [resize_iiif_urls(url) for url in urls] |
|
images = asyncio.run(get_images(resized_urls)) |
|
predicted_images = [] |
|
for image in images: |
|
top_pred = classif_pipeline(image, top_k=1)[0] |
|
if top_pred['label'] == 'illustrated': |
|
predicted_images.append((image, top_pred['score'])) |
|
return predicted_images |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Text(), |
|
outputs=gr.Gallery(), |
|
title="ImageIN", |
|
description="Identify illustrations in pages of historical books!", |
|
) |
|
demo.launch(debug=True) |
|
|