davanstrien's picture
davanstrien HF Staff
catch httpx.ConnectError
afe6a3f
raw
history blame
2.26 kB
import gradio as gr
from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor
import requests
import asyncio
import httpx
import io
from PIL import Image
import PIL
HF_MODEL_PATH = (
"ImageIN/levit-192_finetuned_on_unlabelled_IA_with_snorkel_labels"
)
classif_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_PATH)
feature_extractor = AutoFeatureExtractor.from_pretrained(HF_MODEL_PATH)
classif_pipeline = pipeline(
"image-classification", model=classif_model, feature_extractor=feature_extractor
)
OUTPUT_SENTENCE = "This image is {result}."
def load_manifest(inputs):
with requests.get(inputs) as r:
return r.json()
def get_image_urls_from_manifest(data):
image_urls = []
for sequences in data['sequences']:
for canvases in sequences['canvases']:
image_urls.extend(image['resource']['@id'] for image in canvases['images'])
return image_urls
def resize_iiif_urls(im_url, size='224'):
parts = im_url.split("/")
parts[6] = f"{size}, {size}"
return "/".join(parts)
async def get_image(client, url):
try:
resp = await client.get(url, timeout=30)
return Image.open(io.BytesIO(resp.content))
except (PIL.UnidentifiedImageError, httpx.ReadTimeout, httpx.ConnectError):
return None
async def get_images(urls):
async with httpx.AsyncClient() as client:
tasks = [asyncio.ensure_future(get_image(client, url)) for url in urls]
images = await asyncio.gather(*tasks)
return [image for image in images if image is not None]
def predict(inputs):
data = load_manifest(inputs)
urls = get_image_urls_from_manifest(data)
resized_urls = [resize_iiif_urls(url) for url in urls]
images = asyncio.run(get_images(resized_urls))
predicted_images = []
for image in images:
top_pred = classif_pipeline(image, top_k=1)[0]
if top_pred['label'] == 'illustrated':
predicted_images.append((image, top_pred['score']))
return predicted_images
demo = gr.Interface(
fn=predict,
inputs=gr.Text(),
outputs=gr.Gallery(),
title="ImageIN",
description="Identify illustrations in pages of historical books!",
)
demo.launch(debug=True)