File size: 3,320 Bytes
a8221b0
 
 
 
 
 
 
 
8a46519
7d41d39
 
255ecb4
a8221b0
 
 
 
 
 
 
 
 
 
 
0df3f4e
a8221b0
 
 
 
 
 
 
 
 
 
 
 
 
8a46519
 
 
 
 
0df3f4e
8a46519
 
a8221b0
 
 
 
 
afe6a3f
a8221b0
 
 
 
 
 
 
7d41d39
8a46519
7d41d39
a8221b0
 
 
0df3f4e
 
 
da79104
0df3f4e
a8221b0
 
 
7d41d39
a8221b0
7d41d39
 
e76352d
0df3f4e
7d41d39
a8221b0
7d41d39
8a46519
 
 
 
7d41d39
 
0df3f4e
7d41d39
 
a8221b0
 
 
8a46519
7d41d39
a8221b0
 
0bd8e1d
 
cbbdb8f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor
import requests
import asyncio
import httpx
import io
from PIL import Image
import PIL
from functools import lru_cache
from toolz import pluck
from piffle.image import IIIFImageClient

HF_MODEL_PATH = (
    "ImageIN/levit-192_finetuned_on_unlabelled_IA_with_snorkel_labels"
)

classif_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_PATH)
feature_extractor = AutoFeatureExtractor.from_pretrained(HF_MODEL_PATH)

classif_pipeline = pipeline(
    "image-classification", model=classif_model, feature_extractor=feature_extractor
)


def load_manifest(inputs):
    with requests.get(inputs) as r:
        return r.json()


def get_image_urls_from_manifest(data):
    image_urls = []
    for sequences in data['sequences']:
        for canvases in sequences['canvases']:
            image_urls.extend(image['resource']['@id'] for image in canvases['images'])
    return image_urls


def resize_iiif_urls(image_url, size='224'):
    # parts = im_url.split("/")
    # parts[6] = f"{size}, {size}"
    # return "/".join(parts)
    image_url = IIIFImageClient.init_from_url(image_url)
    image_url = image_url.size(width=size, height=size)
    return image_url.__str__()


async def get_image(client, url):
    try:
        resp = await client.get(url, timeout=30)
        return Image.open(io.BytesIO(resp.content))
    except (PIL.UnidentifiedImageError, httpx.ReadTimeout, httpx.ConnectError):
        return None


async def get_images(urls):
    async with httpx.AsyncClient() as client:
        tasks = [asyncio.ensure_future(get_image(client, url)) for url in urls]
        images = await asyncio.gather(*tasks)
        assert len(images) == len(urls)
        return [(url, image) for url, image in zip(urls, images) if image is not None]
        # return [image for image in images if image is not None]


def predict(inputs):
    return _predict(str(inputs))


@lru_cache(maxsize=100)
def _predict(inputs):
    data = load_manifest(inputs)
    urls = get_image_urls_from_manifest(data)
    resized_urls = [resize_iiif_urls(url) for url in urls]
    images_urls = asyncio.run(get_images(resized_urls))
    predicted_images = []
    images = list(pluck(1, images_urls))
    urls = list(pluck(0, images_urls))
    predictions = classif_pipeline(images, top_k=1, num_workers=2)
    for url, pred in zip(urls, predictions):
        top_pred = pred[0]
        if top_pred['label'] == 'illustrated':
            image_url = IIIFImageClient.init_from_url(url)
            image_url = image_url.size(width=500)
            image_url = image_url.size(width=500, height='')
            predicted_images.append((str(image_url), f"Confidence: {top_pred['score']}, \n image url: {image_url}"))

    return predicted_images


gallery = gr.Gallery()
gallery.style(grid=3)

demo = gr.Interface(
    fn=predict,
    inputs=gr.Text(label="IIIF manifest url"),
    outputs=gallery,
    title="ImageIN",
    description="Identify illustrations in pages of historical books!",
    examples=['https://iiif.lib.harvard.edu/manifests/drs:427603172',
              "https://iiif.lib.harvard.edu/manifests/drs:427342693","https://iiif.archivelab.org/iiif/holylandbiblebok0000cunn/manifest.json"]
).queue()
demo.launch()