Pash1986's picture
Upload 2 files
daae077 verified
raw
history blame
3.05 kB
import gradio as gr
from pymongo import MongoClient
from PIL import Image
import base64
import os
import io
import boto3
import json
bedrock_runtime = boto3.client('bedrock-runtime',
aws_access_key_id=os.environ.get('AWS_ACCESS_KEY'),
aws_secret_access_key=os.environ.get('AWS_SECRET_KEY'),
region_name="us-east-1"
)
def construct_bedrock_body(base64_string, text):
if text:
return json.dumps(
{
"inputImage": base64_string,
"embeddingConfig": {"outputEmbeddingLength": 1024},
"inputText": text
}
)
return json.dumps(
{
"inputImage": base64_string,
"embeddingConfig": {"outputEmbeddingLength": 1024},
}
)
def get_embedding_from_titan_multimodal(body):
response = bedrock_runtime.invoke_model(
body=body,
modelId="amazon.titan-embed-image-v1",
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
return response_body["embedding"]
uri = os.environ.get('MONGODB_ATLAS_URI')
client = MongoClient(uri)
db_name = 'celebrity_1000_embeddings'
collection_name = 'celeb_images'
celeb_images = client[db_name][collection_name]
def start_image_search(image, text):
if not image:
## Alert the user to upload an image
raise gr.Error("Please upload an image first, make sure to press the 'Submit' button after selecting the image.")
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_byte = buffered.getvalue()
# Encode this byte array to Base64
img_base64 = base64.b64encode(img_byte)
# Convert Base64 bytes to string for JSON serialization
img_base64_str = img_base64.decode('utf-8')
body = construct_bedrock_body(img_base64_str, text)
embedding = get_embedding_from_titan_multimodal(body)
doc = list(celeb_images.aggregate([{
"$vectorSearch": {
"index": "vector_index",
"path" : "embeddings",
"queryVector": embedding,
"numCandidates" : 15,
"limit" : 3
}}, {"$project": {"image":1}}]))
images = []
for image in doc:
images.append(Image.open(io.BytesIO(base64.b64decode(image['image']))))
return images
with gr.Blocks() as demo:
gr.Markdown(
"""
# MongoDB's Vector Celeb Image matcher
Upload an image and find the most similar celeb image from the database.
💪 Make a great pose to impact the search! 🤯
""")
### Image gradio input
gr.Interface(
fn=start_image_search,
inputs=[gr.Image(type="pil", label="Upload an image"),gr.Textbox(label="Enter an adjusment to the image")],
## outputs=gr.Image(type="pil")
outputs=gr.Gallery(
label="Generated images", show_label=True, elem_id="gallery"
, columns=[3], rows=[1], object_fit="contain", height="auto")
)
demo.launch()