Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -22,6 +22,7 @@ dataset = stl.selectbox(
|
|
| 22 |
"huggan/few-shot-art-painting",
|
| 23 |
"huggan/wikiart",
|
| 24 |
"zh-plus/tiny-imagenet",
|
|
|
|
| 25 |
"lambdalabs/naruto-blip-captions",
|
| 26 |
"detection-datasets/fashionpedia",
|
| 27 |
"fantasyfish/laion-art",
|
|
@@ -37,30 +38,43 @@ search_term = None
|
|
| 37 |
ret_images = []
|
| 38 |
scores = []
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
embedder = ripple.ImageEmbedder(
|
| 44 |
dataset, retrieval_type="text-image", dataset_type="huggingface"
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
stl.success("Sucessfully embedded and dcreated image index")
|
| 49 |
|
| 50 |
-
|
|
|
|
| 51 |
text_search = ripple.TextSearch(embedded_data, embedder.embed_model)
|
| 52 |
stl.success("Initialized text search class")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
|
|
|
|
|
|
|
|
|
| 54 |
search_term = stl.text_input("Text description/search for image")
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
|
| 61 |
try:
|
| 62 |
for count, score, image in tqdm(zip(range(len(ret_images)), scores, ret_images)):
|
| 63 |
stl.image(image["image"][count])
|
| 64 |
stl.write(score)
|
| 65 |
-
|
|
|
|
| 66 |
st.error(e)
|
|
|
|
| 22 |
"huggan/few-shot-art-painting",
|
| 23 |
"huggan/wikiart",
|
| 24 |
"zh-plus/tiny-imagenet",
|
| 25 |
+
"huggan/flowers-102-categories",
|
| 26 |
"lambdalabs/naruto-blip-captions",
|
| 27 |
"detection-datasets/fashionpedia",
|
| 28 |
"fantasyfish/laion-art",
|
|
|
|
| 38 |
ret_images = []
|
| 39 |
scores = []
|
| 40 |
|
| 41 |
+
#@stl.cache_data
|
| 42 |
+
def embed_data(dataset):
|
| 43 |
+
embedder = ripple.ImageEmbedder(
|
|
|
|
| 44 |
dataset, retrieval_type="text-image", dataset_type="huggingface"
|
| 45 |
+
)
|
| 46 |
+
embedded_data = embedder.create_embeddings(device="cpu")
|
| 47 |
+
return embedded_data
|
|
|
|
| 48 |
|
| 49 |
+
@stl.cache_resource
|
| 50 |
+
def init_search(embedded_data):
|
| 51 |
text_search = ripple.TextSearch(embedded_data, embedder.embed_model)
|
| 52 |
stl.success("Initialized text search class")
|
| 53 |
+
return text_search
|
| 54 |
+
|
| 55 |
+
def get_images_from_description(description):
|
| 56 |
+
scores, ret_images = description, k_images=4)
|
| 57 |
+
return scores, ret_images
|
| 58 |
+
|
| 59 |
+
if dataset and stl.button("embed image dataset"):
|
| 60 |
+
with stl.spinner("Initializing and creating image embeddings from dataset"):
|
| 61 |
+
embedded_data = embed_data(dataset)
|
| 62 |
+
stl.success("Successfully embedded and created image index")
|
| 63 |
|
| 64 |
+
if embedded_data:
|
| 65 |
+
finder = init_search(embedded_data)
|
| 66 |
+
|
| 67 |
search_term = stl.text_input("Text description/search for image")
|
| 68 |
|
| 69 |
+
if search_term:
|
| 70 |
+
with stl.spinner("retrieving images with description.."):
|
| 71 |
+
scores, ret_images = get_images_from_description(search_term)
|
| 72 |
+
stl.success(f"sucessfully retrieved {len(ret_images)} images")
|
| 73 |
|
| 74 |
try:
|
| 75 |
for count, score, image in tqdm(zip(range(len(ret_images)), scores, ret_images)):
|
| 76 |
stl.image(image["image"][count])
|
| 77 |
stl.write(score)
|
| 78 |
+
|
| 79 |
+
except Exception as e:
|
| 80 |
st.error(e)
|