Spaces:
Build error
Build error
File size: 2,678 Bytes
84c806e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
import requests
import streamlit as st
from PIL import Image
import jax
import jax.numpy as jnp
import numpy as np
from utils import load_model
def split_image(im):
im = np.array(im)
M = im.shape[0] // 3
N = im.shape[1] // 3
tiles = [
im[x:x + M, y:y + N]
for x in range(0, im.shape[0], M)
for y in range(0, im.shape[1], N)
]
return tiles
# def split_image(X):
# num_rows = X.shape[0] // 224
# num_cols = X.shape[1] // 224
# Xc = X[0:num_rows * 224, 0:num_cols * 224, :]
# patches = []
# for j in range(num_rows):
# for i in range(num_cols):
# patches.append(Xc[j * 224:(j + 1) * 224, i * 224:(i + 1) * 224, :])
# return patches
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Most Relevant Part of Image")
st.markdown("""
Given a piece of text, the CLIP model finds the part of an image that best explains the text.
To try it out, you can
1) Upload an image
2) Explain a part of the image in text
Which will yield the most relevant image tile from a 3x3 grid of the image
""")
query1 = st.text_input(
"Enter a URL to an image...",
value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg")
query2 = st.file_uploader("or upload an image...",
type=["jpg", "jpeg", "png"])
captions = st.text_input(
"Enter query to find most relevant part of image ",
value="이건 서울의 경복궁 사진이다.",
)
if st.button("질문 (Query)"):
if not any([query1, query2]):
st.error("Please upload an image or paste an image URL.")
else:
image_data = (query2 if query2 is not None else requests.get(
query1, stream=True).raw)
image = Image.open(image_data)
st.image(image)
images = split_image(image)
inputs = processor(text=captions,
images=images,
return_tensors="jax",
padding=True)
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"],
axes=[0, 2, 3, 1])
outputs = model(**inputs)
probs = jax.nn.softmax(outputs.logits_per_image, axis=0)
for idx, prob in sorted(enumerate(probs),
key=lambda x: x[1],
reverse=True):
st.text(f"Score: {prob[0]:.3f}")
st.image(images[idx])
|