Pash1986's picture
Update app.py
a8ca956 verified
raw
history blame
4.53 kB
import gradio as gr
from pymongo import MongoClient
from PIL import Image
import base64
import os
import io
import boto3
import json
# AWS Bedrock client setup
bedrock_runtime = boto3.client('bedrock-runtime',
aws_access_key_id=os.environ.get('AWS_ACCESS_KEY'),
aws_secret_access_key=os.environ.get('AWS_SECRET_KEY'),
region_name="us-east-1")
# Function to construct the request body for Bedrock
def construct_bedrock_body(base64_string, text):
if text:
return json.dumps({
"inputImage": base64_string,
"embeddingConfig": {"outputEmbeddingLength": 1024},
"inputText": text
})
return json.dumps({
"inputImage": base64_string,
"embeddingConfig": {"outputEmbeddingLength": 1024},
})
# Function to get the embedding from Bedrock model
def get_embedding_from_titan_multimodal(body):
response = bedrock_runtime.invoke_model(
body=body,
modelId="amazon.titan-embed-image-v1",
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
return response_body["embedding"]
# MongoDB setup
uri = os.environ.get('MONGODB_ATLAS_URI')
client = MongoClient(uri)
db_name = 'celebrity_1000_embeddings'
collection_name = 'celeb_images'
celeb_images = client[db_name][collection_name]
# Function to generate image description using Claude 3 Sonnet
def generate_image_description_with_claude(image_base64):
claude_body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 1000,
"system": "Please respond as a celebrity reporter.",
"messages": [{
"role": "user",
"content": [
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": image_base64}},
{"type": "text", "text": "Who is in this image?"}
]
}]
})
claude_response = bedrock_runtime.invoke_model(
body=claude_body,
modelId="anthropic.claude-3-sonnet-20240229-v1:0",
accept="application/json",
contentType="application/json",
)
response_body = json.loads(claude_response.get("body").read())
# Assuming the response contains a field 'content' with the description
return response_body["content"][0].get("text", "No description available")
# Main function to start image search
def start_image_search(image, text):
if not image:
raise gr.Error("Please upload an image first, make sure to press the 'Submit' button after selecting the image.")
buffered = io.BytesIO()
image = image.resize((800, 600))
image.save(buffered, format="JPEG", quality=85)
img_byte = buffered.getvalue()
img_base64 = base64.b64encode(img_byte)
img_base64_str = img_base64.decode('utf-8')
body = construct_bedrock_body(img_base64_str, text)
embedding = get_embedding_from_titan_multimodal(body)
doc = list(celeb_images.aggregate([
{
"$vectorSearch": {
"index": "vector_index",
"path": "embeddings",
"queryVector": embedding,
"numCandidates": 15,
"limit": 3
}
}, {"$project": {"image": 1}}
]))
images_with_descriptions = []
for image_doc in doc:
pil_image = Image.open(io.BytesIO(base64.b64decode(image_doc['image'])))
img_byte = io.BytesIO()
pil_image.save(img_byte, format='JPEG')
img_base64 = base64.b64encode(img_byte.getvalue()).decode('utf-8')
description = generate_image_description_with_claude(img_base64)
images_with_descriptions.append((pil_image, description))
return images_with_descriptions
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("""
# MongoDB's Vector Celeb Image Matcher
Upload an image and find the most similar celeb image from the database, along with an AI-generated description.
💪 Make a great pose to impact the search! 🤯
""")
gr.Interface(fn=start_image_search,
inputs=[gr.Image(type="pil", label="Upload an image"), gr.Textbox(label="Enter an adjustment to the image")],
outputs=gr.Gallery(label="Located images with AI-generated descriptions", show_label=True, elem_id="gallery",
columns=[3], rows=[1], object_fit="contain", height="auto")
)
demo.launch()