File size: 1,558 Bytes
39113b9
30d5af0
 
 
69aa3f2
39113b9
82eb2a3
d3bd556
30d5af0
 
feb267b
f2e596b
30d5af0
 
39113b9
 
 
 
 
 
9f13edb
30d5af0
 
21edb75
69aa3f2
ad91d57
30d5af0
 
 
 
 
ad91d57
 
 
625973c
ad91d57
 
 
21edb75
2b19e7c
345caa6
30d5af0
39113b9
30d5af0
 
c652364
8e4b98b
 
c652364
5cbcd45
b1e58f3
30d5af0
39113b9
30d5af0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
import numpy as np
import clip
import torch
from PIL import Image
import base64
from io import BytesIO
from decimal import Decimal

# Load the CLIP model
model, preprocess = clip.load("ViT-L/14@336px")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()

# Define a function to find similarity
def find_similarity(base64_image, text_input):
    # Decode the base64 image to bytes
    image_bytes = base64.b64decode(base64_image)

    # Convert the bytes to a PIL image
    image = Image.open(BytesIO(image_bytes))

    # Preprocess the image
    image = preprocess(image).unsqueeze(0).to(device)

    # Prepare input text
    text_tokens = clip.tokenize([text_input]).to(device)

    # Encode image and text features


    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text_tokens)

    # Normalize features and calculate similarity
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()

    return similarity

# Create a Gradio interface
iface = gr.Interface(
    fn=find_similarity,
    inputs=[
        gr.Textbox(label="Base64 Image", lines=8),
        gr.Textbox(label="Text Input")
    ],
    outputs="number",
    live=True,
    title="CLIP Model Image-Text Cosine Similarity",
    description="Upload a base64 image and enter text to find their cosine similarity.",
)

iface.launch()