Spaces:
Build error
Build error
| import ripple | |
| import streamlit as stl | |
| from tqdm.auto import tqdm | |
| # streamlit app | |
| stl.set_page_config( | |
| page_title="Ripple", | |
| ) | |
| stl.title("ripple search") | |
| stl.write( | |
| "An app that uses text input to search for described images, using embeddings of selected image datasets. Uses contrastive learning models(CLIP) and the sentence-transformers" | |
| ) | |
| stl.link_button( | |
| label="Full library code", | |
| url="https://github.com/kelechi-c/ripple_net", | |
| ) | |
| dataset = stl.selectbox( | |
| "choose huggingface dataset(bigger datasets take more time to embed..)", | |
| options=[ | |
| "huggan/few-shot-art-painting", | |
| "huggan/wikiart", | |
| "zh-plus/tiny-imagenet", | |
| "huggan/flowers-102-categories", | |
| "lambdalabs/naruto-blip-captions", | |
| "detection-datasets/fashionpedia", | |
| "fantasyfish/laion-art", | |
| "Chris1/cityscapes" | |
| ], | |
| ) | |
| # initalized global variables | |
| embedded_data = None | |
| embedder = None | |
| finder = None | |
| search_term = None | |
| ret_images = [] | |
| scores = [] | |
| #@stl.cache_data | |
| def embed_data(dataset): | |
| embedder = ripple.ImageEmbedder( | |
| dataset, retrieval_type="text-image", dataset_type="huggingface" | |
| ) | |
| embedded_data = embedder.create_embeddings(device="cpu") | |
| return embedded_data | |
| def init_search(embedded_data): | |
| text_search = ripple.TextSearch(embedded_data, embedder.embed_model) | |
| stl.success("Initialized text search class") | |
| return text_search | |
| def get_images_from_description(description): | |
| scores, ret_images = finder.get_nearest_examples(description, k_images=4) | |
| return scores, ret_images | |
| if dataset and stl.button("embed image dataset"): | |
| with stl.spinner("Initializing and creating image embeddings from dataset"): | |
| embedded_data = embed_data(dataset) | |
| stl.success("Successfully embedded and created image index") | |
| if embedded_data: | |
| finder = init_search(embedded_data) | |
| search_term = stl.text_input("Text description/search for image") | |
| if search_term: | |
| with stl.spinner("retrieving images with description.."): | |
| scores, ret_images = get_images_from_description(search_term) | |
| stl.success(f"sucessfully retrieved {len(ret_images)} images") | |
| try: | |
| for count, score, image in tqdm(zip(range(len(ret_images)), scores, ret_images)): | |
| stl.image(image["image"][count]) | |
| stl.write(score) | |
| except Exception as e: | |
| st.error(e) | |