koclip / image2text.py
jaketae's picture
feature: extract feature from user input image
5dce03a
raw
history blame
1.26 kB
import streamlit as st
import numpy as np
import jax.numpy as jnp
from PIL import Image
from utils import load_model
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Zero-shot Image Classification")
st.markdown(
"""
Some text goes in here.
"""
)
query = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if st.button("์งˆ๋ฌธ (Query)"):
if query is None:
st.error("Please upload an image query.")
else:
image = Image.open(query)
pixel_values = processor(
text=[""], images=image, return_tensors="jax", padding=True
).pixel_values
pixel_values = jnp.transpose(pixel_values, axes=[0, 2, 3, 1])
vec = np.asarray(model.get_image_features(pixel_values))
# ids, dists = index.knnQuery(vec, k=10)
# result_files = map(lambda id: files[id], ids)
# result_imgs, result_captions = [], []
# for file, dist in zip(result_files, dists):
# result_imgs.append(plt.imread(os.path.join(images_directory, file)))
# result_captions.append("{:s} (์œ ์‚ฌ๋„: {:.3f})".format(file, 1.0 - dist))