File size: 4,432 Bytes
d19fddf
 
 
 
 
 
 
 
 
 
 
20dc8e8
 
 
 
 
 
2d8e3da
d19fddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20dc8e8
 
 
 
 
 
 
 
 
 
 
 
d19fddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7099625
d19fddf
 
 
 
 
 
 
72b6d2f
d19fddf
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
import pickle
import numpy as np
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
from sklearn.metrics.pairwise import cosine_similarity
import csv
from PIL import Image

model_path_rclip = "kaveh/rclip"
embeddings_file_rclip = './image_embeddings_rclip.pkl' 

model_path_pubmedclip = "flaviagiammarino/pubmed-clip-vit-base-patch32"
embeddings_file_pubmedclip = './image_embeddings_pubmedclip.pkl' 

csv_path = "./captions.txt"

def load_image_ids(csv_file):
    ids = []
    captions = []
    with open(csv_file, 'r') as f:
        reader = csv.reader(f, delimiter='\t')
        for row in reader:
            ids.append(row[0])
            captions.append(row[1])
    return ids, captions

def load_embeddings(embeddings_file):
    with open(embeddings_file, 'rb') as f:
        image_embeddings = pickle.load(f)
    return image_embeddings


def find_similar_images(query_embedding, image_embeddings, k=2):
    similarities = cosine_similarity(query_embedding.reshape(1, -1), image_embeddings)
    closest_indices = np.argsort(similarities[0])[::-1][:k]
    scores = sorted(similarities[0])[::-1][:k]
    return closest_indices, scores


def main(query, model_id="rclip", k=2):
    if model_id=="rclip":
        # Load RCLIP model
        model = VisionTextDualEncoderModel.from_pretrained(model_path_rclip)
        processor = VisionTextDualEncoderProcessor.from_pretrained(model_path_rclip)
        # Load image embeddings 
        image_embeddings = load_embeddings(embeddings_file_rclip)
    elif mode_id=="pubmedclip":
        model = CLIPModel.from_pretrained(model_path_pubmedclip)
        processor = CLIPProcessor.from_pretrained(model_path_pubmedclip)
        # Load image embeddings 
        image_embeddings = load_embeddings(embeddings_file_pubmedclip)   


    # Embed the query
    inputs = processor(text=query, images=None, return_tensors="pt", padding=True)
    with torch.no_grad():
        query_embedding = model.get_text_features(**inputs)[0].numpy()
    
    # Get image names
    ids, captions = load_image_ids(csv_path)
    
    # Find similar images
    similar_image_indices, scores = find_similar_images(query_embedding, image_embeddings, k=int(k))

    # Return the results
    similar_image_names = [f"./images/{ids[index]}.jpg" for index in similar_image_indices]
    similar_image_captions = [captions[index] for index in similar_image_indices]
    similar_images = [Image.open(i) for i in similar_image_names]

    return similar_images, pd.DataFrame([[t+1 for t in range(k)], similar_image_names, similar_image_captions, scores], index=["#", "path", "caption", "score"]).T


# Define the Gradio interface
examples = [
            ["Chest X-ray photos",5],
            ["Orthopantogram (OPG)",5],
            ["Brain Scan",5], 
            ["tomography",5]
]

title="RCLIP Image Retrieval"
description = "CLIP model fine-tuned on the ROCO dataset"

with gr.Blocks(title=title) as demo:
    with gr.Row():
        with gr.Column(scale=5):
            gr.Markdown("# "+title)
            gr.Markdown(description)
        gr.HTML(value="<img src=\"https://newresults.co.uk/wp-content/uploads/2022/02/teesside-university-logo.png\" alt=\"teesside logo\" width=\"120\" height=\"70\">", show_label=False,scale=1)
        #Image.open("./data/teesside university logo.png"), height=70, show_label=False, container=False)
    with gr.Column(variant="compact"):
        with gr.Row(variant="compact"):
            query = gr.Textbox(value="Chest X-Ray Photos", label="Enter your query", show_label=False, placeholder= "Enter your query" , scale=5)
            btn = gr.Button("Search query", variant="primary", scale=1)
        
        n_s = gr.Slider(2, 10, label='Number of Top Results', value=5, step=1.0, show_label=True)
        
    
    with gr.Column(variant="compact"):
        gr.Markdown("## Results")
        gallery = gr.Gallery(label="found images", show_label=True, elem_id="gallery", columns=[2], rows=[4], object_fit="contain", height="400px", preview=True)
        gr.Markdown("Information of the found images")
        df = gr.DataFrame()
    btn.click(main, [query, n_s], [gallery, df])
    
    with gr.Column(variant="compact"):
        gr.Markdown("## Examples")
        gr.Examples(examples, [query, n_s])
    
    
demo.launch(debug='True')