File size: 2,264 Bytes
a8221b0
 
 
 
 
 
 
 
 
255ecb4
a8221b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96baeb1
a8221b0
 
 
 
 
 
 
afe6a3f
a8221b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afe6a3f
a8221b0
 
 
 
 
 
 
 
 
96baeb1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor
import requests
import asyncio
import httpx
import io
from PIL import Image
import PIL


HF_MODEL_PATH = (
    "ImageIN/levit-192_finetuned_on_unlabelled_IA_with_snorkel_labels"
)

classif_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_PATH)
feature_extractor = AutoFeatureExtractor.from_pretrained(HF_MODEL_PATH)

classif_pipeline = pipeline(
    "image-classification", model=classif_model, feature_extractor=feature_extractor
)

OUTPUT_SENTENCE = "This image is {result}."


def load_manifest(inputs):
    with requests.get(inputs) as r:
        return r.json()


def get_image_urls_from_manifest(data):
    image_urls = []
    for sequences in data['sequences']:
        for canvases in sequences['canvases']:
            image_urls.extend(image['resource']['@id'] for image in canvases['images'])
    return image_urls


def resize_iiif_urls(im_url, size='224'):
    parts = im_url.split("/")
    parts[6] = f"{size}, {size}"
    return "/".join(parts)


async def get_image(client, url):
    try:
        resp = await client.get(url, timeout=30)
        return Image.open(io.BytesIO(resp.content))
    except (PIL.UnidentifiedImageError, httpx.ReadTimeout, httpx.ConnectError):
        return None


async def get_images(urls):
    async with httpx.AsyncClient() as client:
        tasks = [asyncio.ensure_future(get_image(client, url)) for url in urls]
        images = await asyncio.gather(*tasks)
        return [image for image in images if image is not None]


def predict(inputs):
    data = load_manifest(inputs)
    urls = get_image_urls_from_manifest(data)
    resized_urls = [resize_iiif_urls(url) for url in urls]
    images = asyncio.run(get_images(resized_urls))
    predicted_images = []
    for image in images:
        top_pred = classif_pipeline(image, top_k=1)[0]
        if top_pred['label'] == 'illustrated':
            predicted_images.append((image, top_pred['score']))
        return predicted_images


demo = gr.Interface(
    fn=predict,
    inputs=gr.Text(),
    outputs=gr.Gallery(),
    title="ImageIN",
    description="Identify illustrations in pages of historical books!",
)
demo.launch(debug=True)