import gradio as gr import gradio.components as grc import onnxruntime import numpy as np from torchvision.transforms import Normalize, Compose, Resize, ToTensor batch_size = 1 def convert_to_rgb(image): return image.convert("RGB") def get_transform(image_size=384): return Compose([ convert_to_rgb, Resize((image_size, image_size)), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def load_tag_list(tag_list_file): with open(tag_list_file, 'r', encoding="utf-8") as f: tag_list = f.read().splitlines() tag_list = np.array(tag_list) return tag_list def load_word_vocabulary(word_vocabulary_file): with open(word_vocabulary_file, 'r', encoding="utf-8") as f: word_vocabulary = f.read().splitlines() words = [word.split(',') for word in word_vocabulary] word2idx = {} for i in range(len(words)): for j in range(len(words[i])): word2idx[words[i][j]] = i return word2idx from huggingface_hub import hf_hub_download hf_hub_download(repo_id="Inf009/ram-tagger", repo_type="model", local_dir="resources", filename="ram_swin_large_14m_b1.onnx", local_dir_use_symlinks=False) ort_session = onnxruntime.InferenceSession("resources/ram_swin_large_14m_b1.onnx", providers=["CUDAExecutionProvider"]) transform = get_transform() tag_list = load_tag_list("resources/ram_tag_list.txt") word_index = load_word_vocabulary("resources/word_vocabulary_english.txt") def inference_by_image_pil(image): image_arrays = transform(image).unsqueeze(0).numpy() # compute ONNX Runtime output prediction ort_inputs = {ort_session.get_inputs()[0].name: image_arrays} ort_outs = ort_session.run(None, ort_inputs) index = np.argwhere(ort_outs[0][0] == 1) token = tag_list[index].squeeze(axis=1).tolist() token = rerank_tags(token) return ",".join(token) def rerank_tags(tags): indexed_tags = [[] for _ in range(max(word_index.values()) + 1)] for tag in tags: indexed_tags[word_index[tag]].append(tag) reranked_tags = [] for indexed_tag in indexed_tags: reranked_tags += indexed_tag return reranked_tags app = gr.Interface(fn=inference_by_image_pil, inputs=grc.Image(type='pil'), outputs=grc.Text(), title="RAM Tagger", description="A tagger for images.") app.launch()