File size: 3,045 Bytes
daae077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
from pymongo import MongoClient
from PIL import Image
import base64
import os
import io
import boto3
import json

bedrock_runtime = boto3.client('bedrock-runtime',
        aws_access_key_id=os.environ.get('AWS_ACCESS_KEY'),
        aws_secret_access_key=os.environ.get('AWS_SECRET_KEY'),
        region_name="us-east-1"
    )

def construct_bedrock_body(base64_string, text):
    if text:
        return json.dumps(
            {
                "inputImage": base64_string,
                "embeddingConfig": {"outputEmbeddingLength": 1024},
                "inputText": text
            }
        )
        
    return json.dumps(
        {
            "inputImage": base64_string,
            "embeddingConfig": {"outputEmbeddingLength": 1024},
        }
    )


def get_embedding_from_titan_multimodal(body):
    

    response = bedrock_runtime.invoke_model(
        body=body,
        modelId="amazon.titan-embed-image-v1",
        accept="application/json",
        contentType="application/json",
    )

    response_body = json.loads(response.get("body").read())
    return response_body["embedding"]

uri = os.environ.get('MONGODB_ATLAS_URI')
client = MongoClient(uri)
db_name = 'celebrity_1000_embeddings'
collection_name = 'celeb_images'

celeb_images = client[db_name][collection_name]

def start_image_search(image, text):
    if not image:
        ## Alert the user to upload an image
        raise gr.Error("Please upload an image first, make sure to press the 'Submit' button after selecting the image.")
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    img_byte = buffered.getvalue()
    # Encode this byte array to Base64
    img_base64 = base64.b64encode(img_byte)

    # Convert Base64 bytes to string for JSON serialization
    img_base64_str = img_base64.decode('utf-8')
    body = construct_bedrock_body(img_base64_str, text)
    embedding = get_embedding_from_titan_multimodal(body)

    doc =  list(celeb_images.aggregate([{
        "$vectorSearch": {
            "index": "vector_index",
            "path" : "embeddings",
            "queryVector": embedding,
            "numCandidates" : 15,
            "limit" : 3
        }}, {"$project": {"image":1}}]))
    
    images = []
    for image in doc:
        images.append(Image.open(io.BytesIO(base64.b64decode(image['image']))))
         
    return images
    
with gr.Blocks() as demo:
    gr.Markdown(
    """
    # MongoDB's Vector Celeb Image matcher  

    Upload an image and find the most similar celeb image from the database. 

    💪 Make a great pose to impact the search! 🤯
   
    """)

    ### Image gradio input
    gr.Interface(
        fn=start_image_search, 
        inputs=[gr.Image(type="pil", label="Upload an image"),gr.Textbox(label="Enter an adjusment to the image")],
       ## outputs=gr.Image(type="pil")
        outputs=gr.Gallery(
        label="Generated images", show_label=True, elem_id="gallery"
    , columns=[3], rows=[1], object_fit="contain", height="auto")
    )

demo.launch()