File size: 2,264 Bytes
a8221b0 255ecb4 a8221b0 96baeb1 a8221b0 afe6a3f a8221b0 afe6a3f a8221b0 96baeb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gradio as gr
from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor
import requests
import asyncio
import httpx
import io
from PIL import Image
import PIL
HF_MODEL_PATH = (
"ImageIN/levit-192_finetuned_on_unlabelled_IA_with_snorkel_labels"
)
classif_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_PATH)
feature_extractor = AutoFeatureExtractor.from_pretrained(HF_MODEL_PATH)
classif_pipeline = pipeline(
"image-classification", model=classif_model, feature_extractor=feature_extractor
)
OUTPUT_SENTENCE = "This image is {result}."
def load_manifest(inputs):
with requests.get(inputs) as r:
return r.json()
def get_image_urls_from_manifest(data):
image_urls = []
for sequences in data['sequences']:
for canvases in sequences['canvases']:
image_urls.extend(image['resource']['@id'] for image in canvases['images'])
return image_urls
def resize_iiif_urls(im_url, size='224'):
parts = im_url.split("/")
parts[6] = f"{size}, {size}"
return "/".join(parts)
async def get_image(client, url):
try:
resp = await client.get(url, timeout=30)
return Image.open(io.BytesIO(resp.content))
except (PIL.UnidentifiedImageError, httpx.ReadTimeout, httpx.ConnectError):
return None
async def get_images(urls):
async with httpx.AsyncClient() as client:
tasks = [asyncio.ensure_future(get_image(client, url)) for url in urls]
images = await asyncio.gather(*tasks)
return [image for image in images if image is not None]
def predict(inputs):
data = load_manifest(inputs)
urls = get_image_urls_from_manifest(data)
resized_urls = [resize_iiif_urls(url) for url in urls]
images = asyncio.run(get_images(resized_urls))
predicted_images = []
for image in images:
top_pred = classif_pipeline(image, top_k=1)[0]
if top_pred['label'] == 'illustrated':
predicted_images.append((image, top_pred['score']))
return predicted_images
demo = gr.Interface(
fn=predict,
inputs=gr.Text(),
outputs=gr.Gallery(),
title="ImageIN",
description="Identify illustrations in pages of historical books!",
)
demo.launch(debug=True)
|