File size: 4,526 Bytes
daae077 4195413 daae077 4195413 daae077 4195413 daae077 4195413 daae077 4195413 daae077 4195413 daae077 4195413 f047147 4195413 f047147 4195413 f047147 4195413 a8ca956 4195413 daae077 0ed6129 4195413 daae077 4195413 daae077 4195413 daae077 4195413 daae077 4195413 daae077 4195413 daae077 4195413 daae077 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
from pymongo import MongoClient
from PIL import Image
import base64
import os
import io
import boto3
import json
# AWS Bedrock client setup
bedrock_runtime = boto3.client('bedrock-runtime',
aws_access_key_id=os.environ.get('AWS_ACCESS_KEY'),
aws_secret_access_key=os.environ.get('AWS_SECRET_KEY'),
region_name="us-east-1")
# Function to construct the request body for Bedrock
def construct_bedrock_body(base64_string, text):
if text:
return json.dumps({
"inputImage": base64_string,
"embeddingConfig": {"outputEmbeddingLength": 1024},
"inputText": text
})
return json.dumps({
"inputImage": base64_string,
"embeddingConfig": {"outputEmbeddingLength": 1024},
})
# Function to get the embedding from Bedrock model
def get_embedding_from_titan_multimodal(body):
response = bedrock_runtime.invoke_model(
body=body,
modelId="amazon.titan-embed-image-v1",
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
return response_body["embedding"]
# MongoDB setup
uri = os.environ.get('MONGODB_ATLAS_URI')
client = MongoClient(uri)
db_name = 'celebrity_1000_embeddings'
collection_name = 'celeb_images'
celeb_images = client[db_name][collection_name]
# Function to generate image description using Claude 3 Sonnet
def generate_image_description_with_claude(image_base64):
claude_body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 1000,
"system": "Please respond as a celebrity reporter.",
"messages": [{
"role": "user",
"content": [
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": image_base64}},
{"type": "text", "text": "Who is in this image?"}
]
}]
})
claude_response = bedrock_runtime.invoke_model(
body=claude_body,
modelId="anthropic.claude-3-sonnet-20240229-v1:0",
accept="application/json",
contentType="application/json",
)
response_body = json.loads(claude_response.get("body").read())
# Assuming the response contains a field 'content' with the description
return response_body["content"][0].get("text", "No description available")
# Main function to start image search
def start_image_search(image, text):
if not image:
raise gr.Error("Please upload an image first, make sure to press the 'Submit' button after selecting the image.")
buffered = io.BytesIO()
image = image.resize((800, 600))
image.save(buffered, format="JPEG", quality=85)
img_byte = buffered.getvalue()
img_base64 = base64.b64encode(img_byte)
img_base64_str = img_base64.decode('utf-8')
body = construct_bedrock_body(img_base64_str, text)
embedding = get_embedding_from_titan_multimodal(body)
doc = list(celeb_images.aggregate([
{
"$vectorSearch": {
"index": "vector_index",
"path": "embeddings",
"queryVector": embedding,
"numCandidates": 15,
"limit": 3
}
}, {"$project": {"image": 1}}
]))
images_with_descriptions = []
for image_doc in doc:
pil_image = Image.open(io.BytesIO(base64.b64decode(image_doc['image'])))
img_byte = io.BytesIO()
pil_image.save(img_byte, format='JPEG')
img_base64 = base64.b64encode(img_byte.getvalue()).decode('utf-8')
description = generate_image_description_with_claude(img_base64)
images_with_descriptions.append((pil_image, description))
return images_with_descriptions
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("""
# MongoDB's Vector Celeb Image Matcher
Upload an image and find the most similar celeb image from the database, along with an AI-generated description.
💪 Make a great pose to impact the search! 🤯
""")
gr.Interface(fn=start_image_search,
inputs=[gr.Image(type="pil", label="Upload an image"), gr.Textbox(label="Enter an adjustment to the image")],
outputs=gr.Gallery(label="Located images with AI-generated descriptions", show_label=True, elem_id="gallery",
columns=[3], rows=[1], object_fit="contain", height="auto")
)
demo.launch() |