Spaces:
Running
Running
# /// script | |
# requires-python = ">=3.12" | |
# dependencies = [ | |
# "chromadb==1.0.4", | |
# "datasets==3.5.0", | |
# "marimo", | |
# "matplotlib==3.10.1", | |
# "numpy==2.2.4", | |
# "open-clip-torch==2.32.0", | |
# "pillow==11.1.0", | |
# ] | |
# /// | |
import marimo | |
__generated_with = "0.12.8" | |
app = marimo.App(width="medium") | |
def _(): | |
import marimo as mo | |
return (mo,) | |
def _(mo): | |
mo.md( | |
r""" | |
# Multimodal Retrieval | |
Chroma supports multimodal collections, i.e. collections which contain, and can be queried by, multiple modalities of data. | |
This notebook shows an example of how to create and query a collection with both text and images, using Chroma's built-in features. | |
""" | |
) | |
return | |
def _(mo): | |
mo.md( | |
r""" | |
## Dataset | |
We us a small subset of the [coco object detection dataset](https://huggingface.co/datasets/detection-datasets/coco), hosted on HuggingFace. | |
We download a small fraction of all the images in the dataset locally, and use it to create a multimodal collection. | |
""" | |
) | |
return | |
def _(): | |
import os | |
from datasets import load_dataset | |
from matplotlib import pyplot as plt | |
return load_dataset, os | |
def _(load_dataset, mo): | |
with mo.status.spinner(title="Loading dataset"): | |
dataset = load_dataset( | |
path="detection-datasets/coco", | |
name="default", | |
split="train", | |
streaming=True, | |
) | |
N_IMAGES = 20 | |
return N_IMAGES, dataset | |
def _(N_IMAGES, dataset, mo, os): | |
# Write the images to a folder | |
IMAGE_FOLDER = "images" | |
os.makedirs(IMAGE_FOLDER, exist_ok=True) | |
i = 0 | |
all_images = [] | |
with mo.status.spinner(title="Loading images"): | |
for row in dataset.take(N_IMAGES): | |
image = row["image"] | |
all_images.append(image) | |
image.save(f"images/{i}.jpg") | |
i += 1 | |
return IMAGE_FOLDER, all_images | |
def _(mo): | |
img_width = mo.ui.slider( | |
label="Image width", start=100, stop=300, step=10, debounce=True | |
) | |
img_width | |
return (img_width,) | |
def _(all_images, img_width, mo): | |
import io | |
def as_image(src): | |
img_byte_arr = io.BytesIO() | |
src.save(img_byte_arr, format=src.format or "PNG") | |
img_byte_arr.seek(0) | |
return mo.image(img_byte_arr, width=img_width.value) | |
mo.hstack( | |
[as_image(_img) for _img in all_images[10:]], | |
wrap=True, | |
) | |
return | |
def _(mo): | |
mo.md( | |
r""" | |
## Ingesting multimodal data | |
Chroma supports multimodal collections by referencing external URIs for data types other than text. | |
All you have to do is specify a data loader when creating the collection, and then provide the URI for each entry. | |
For this example, we are only adding images, though you can also add text. | |
""" | |
) | |
return | |
def _(mo): | |
mo.md( | |
r""" | |
### Creating a multi-modal collection | |
First we create the default Chroma client. | |
""" | |
) | |
return | |
def _(): | |
import chromadb | |
client = chromadb.Client() | |
return (client,) | |
def _(mo): | |
mo.md( | |
r""" | |
Next we specify an embedding function and a data loader. | |
The built-in `OpenCLIPEmbeddingFunction` works with both text and image data. The `ImageLoader` is a simple data loader that loads images from a local directory. | |
""" | |
) | |
return | |
def _(): | |
from chromadb.utils.data_loaders import ImageLoader | |
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction | |
embedding_function = OpenCLIPEmbeddingFunction() | |
image_loader = ImageLoader() | |
return embedding_function, image_loader | |
def _(mo): | |
mo.md(r"""We create a collection with the embedding function and data loader.""") | |
return | |
def _(IMAGE_FOLDER, client, embedding_function, image_loader, os): | |
collection = client.create_collection( | |
name="multimodal_collection", | |
embedding_function=embedding_function, | |
data_loader=image_loader, | |
get_or_create=True, | |
) | |
# Get the uris to the images | |
image_uris = sorted( | |
[ | |
os.path.join(IMAGE_FOLDER, image_name) | |
for image_name in os.listdir(IMAGE_FOLDER) | |
] | |
) | |
ids = [str(i) for i in range(len(image_uris))] | |
collection.add(ids=ids, uris=image_uris) | |
return (collection,) | |
def _(mo): | |
mo.md( | |
r""" | |
### Adding multi-modal data | |
We add image data to the collection using the image URIs. The data loader and embedding functions we specified earlier will ingest data from the provided URIs automatically. | |
""" | |
) | |
return | |
def _(mo): | |
mo.md( | |
r""" | |
## Querying a multi-modal collection | |
We can query the collection using text as normal, since the `OpenCLIPEmbeddingFunction` works with both text and images. | |
""" | |
) | |
return | |
def _(mo): | |
query = mo.ui.text_area(label="Query with text", full_width=True).form( | |
bordered=False | |
) | |
mo.vstack([query, mo.md("Try: *animal* or *vehicle*")]) | |
return (query,) | |
def _(collection, mo, query): | |
mo.stop(not query.value) | |
_retrieved = collection.query( | |
query_texts=[query.value], include=["data"], n_results=3 | |
) | |
[mo.image(img, height=200) for img in _retrieved["data"][0]] | |
return | |
def _(mo): | |
mo.md( | |
r""" | |
/// admonition | One more thing! | |
We can also query by images directly, by using the `query_images` field in the `collection.query` method. | |
/// | |
""" | |
) | |
return | |
def _(collection, mo, selected_image): | |
mo.stop(not selected_image.value) | |
import numpy as np | |
from PIL import Image | |
query_image = np.array(Image.open(selected_image.path())) | |
selected = mo.as_html(mo.image(query_image)) | |
_retrieved = collection.query( | |
query_images=[query_image], include=["data"], n_results=5 | |
) | |
results = [mo.image(_img) for _img in _retrieved["data"][0][1:]] | |
return results, selected | |
def _(IMAGE_FOLDER, mo): | |
selected_image = mo.ui.file_browser(IMAGE_FOLDER, multiple=False) | |
selected_image | |
return (selected_image,) | |
def _(mo, results, selected): | |
mo.hstack( | |
[ | |
mo.vstack([mo.md("## Selected"), selected]), | |
mo.vstack([mo.md("## Similar"), *results]), | |
], | |
widths="equal", | |
gap=4, | |
) | |
return | |
def _(mo): | |
mo.md(r"""This example was adapted from [multimodal_retrieval.ipynb](https://github.com/chroma-core/chroma/blob/main/examples/multimodal/multimodal_retrieval.ipynb), using `marimo convert`.""") | |
return | |
if __name__ == "__main__": | |
app.run() |