File size: 3,067 Bytes
a8221b0 7d41d39 255ecb4 a8221b0 96baeb1 a8221b0 afe6a3f a8221b0 7d41d39 a8221b0 7d41d39 a8221b0 7d41d39 a8221b0 7d41d39 a8221b0 7d41d39 a8221b0 96baeb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor
import requests
import asyncio
import httpx
import io
from PIL import Image
import PIL
from toolz import pluck
from piffle.image import IIIFImageClient
HF_MODEL_PATH = (
"ImageIN/levit-192_finetuned_on_unlabelled_IA_with_snorkel_labels"
)
classif_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_PATH)
feature_extractor = AutoFeatureExtractor.from_pretrained(HF_MODEL_PATH)
classif_pipeline = pipeline(
"image-classification", model=classif_model, feature_extractor=feature_extractor
)
def load_manifest(inputs):
with requests.get(inputs) as r:
return r.json()
def get_image_urls_from_manifest(data):
image_urls = []
for sequences in data['sequences']:
for canvases in sequences['canvases']:
image_urls.extend(image['resource']['@id'] for image in canvases['images'])
return image_urls
def resize_iiif_urls(im_url, size='224'):
parts = im_url.split("/")
parts[6] = f"{size}, {size}"
return "/".join(parts)
async def get_image(client, url):
try:
resp = await client.get(url, timeout=30)
return Image.open(io.BytesIO(resp.content))
except (PIL.UnidentifiedImageError, httpx.ReadTimeout, httpx.ConnectError):
return None
async def get_images(urls):
async with httpx.AsyncClient() as client:
tasks = [asyncio.ensure_future(get_image(client, url)) for url in urls]
images = await asyncio.gather(*tasks)
assert len(images) == len(urls)
image_url_tuples = []
for url, image in zip(urls, images):
if image is not None:
image_url_tuples.append((url, image))
return image_url_tuples
# return [image for image in images if image is not None]
def predict(inputs):
data = load_manifest(inputs)
urls = get_image_urls_from_manifest(data)
resized_urls = [resize_iiif_urls(url) for url in urls]
images_urls = asyncio.run(get_images(resized_urls))
predicted_images = []
images = list(pluck(1, images_urls))
urls = list(pluck(0, images_urls))
predictions = classif_pipeline(images, top_k=1)
for url, pred in zip(urls,predictions):
top_pred = pred[0]
if top_pred['label'] == 'illustrated':
image_url = IIIFImageClient.init_from_url(url)
image_url = image_url.canonicalize()
predicted_images.append((image_url.__str__(), f"Confidence: {top_pred['score']}, page: {10}"))
return predicted_images
# for image in images:
# top_pred = classif_pipeline(image, top_k=1)[0]
# if top_pred['label'] == 'illustrated':
# predicted_images.append((image, top_pred['score']))
# return predicted_images
gallery = gr.Gallery()
gallery.style(grid=3)
demo = gr.Interface(
fn=predict,
inputs=gr.Text(),
outputs=gallery,
title="ImageIN",
description="Identify illustrations in pages of historical books!",
)
demo.launch(debug=True)
|