File size: 2,420 Bytes
cbb13b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr 
import gradio.components as grc
import onnxruntime
import numpy as np
from torchvision.transforms import Normalize, Compose, Resize, ToTensor

batch_size = 1

def convert_to_rgb(image):
    return image.convert("RGB")


def get_transform(image_size=384):
    return Compose([
        convert_to_rgb,
        Resize((image_size, image_size)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


def load_tag_list(tag_list_file):
        with open(tag_list_file, 'r', encoding="utf-8") as f:
            tag_list = f.read().splitlines()
        tag_list = np.array(tag_list)
        return tag_list

def load_word_vocabulary(word_vocabulary_file):
    with open(word_vocabulary_file, 'r', encoding="utf-8") as f:
        word_vocabulary = f.read().splitlines()
        words = [word.split(',') for word in word_vocabulary]
    word2idx = {}
    for i in range(len(words)):
        for j in range(len(words[i])):
            word2idx[words[i][j]] = i
    return word2idx

from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="Inf009/ram-tagger", repo_type="model", local_dir="resources", filename="ram_swin_large_14m_b1.onnx", local_dir_use_symlinks=False)
ort_session = onnxruntime.InferenceSession("resources/ram_swin_large_14m_b1.onnx", providers=["CUDAExecutionProvider"])
transform = get_transform()
tag_list = load_tag_list("resources/ram_tag_list.txt")
word_index = load_word_vocabulary("resources/word_vocabulary_english.txt")


def inference_by_image_pil(image):
    image_arrays = transform(image).unsqueeze(0).numpy()
    # compute ONNX Runtime output prediction
    ort_inputs = {ort_session.get_inputs()[0].name: image_arrays}
    ort_outs = ort_session.run(None, ort_inputs)
    index = np.argwhere(ort_outs[0][0] == 1)
    token = tag_list[index].squeeze(axis=1).tolist()
    token = rerank_tags(token)
    return ",".join(token)

def rerank_tags(tags):
    indexed_tags = [[] for _ in range(max(word_index.values()) + 1)]
    for tag in tags:
        indexed_tags[word_index[tag]].append(tag)
    reranked_tags = []
    for indexed_tag in indexed_tags:
        reranked_tags += indexed_tag
    return reranked_tags


app = gr.Interface(fn=inference_by_image_pil, inputs=grc.Image(type='pil'), 
                   outputs=grc.Text(), title="RAM Tagger", 
                   description="A tagger for images.")
app.launch()