import numpy as np
from PIL import Image
from araclip import AraClip
import gradio as gr
model = AraClip.from_pretrained("Arabic-Clip/araclip")

def search(labels, image):
    # process labels
    labels = labels.split(",")
    labels =  [item.strip() for item in labels if item != ""]
    
    # embed data 
    image_features = model.embed(image=image)
    text_features = np.stack([model.embed(text=label) for label in labels])
    # search for most similar data
    similarities = text_features @ image_features
    best_match = labels[np.argmax(similarities)]
    return best_match

    


demo = gr.Interface(search,
                    [gr.Textbox(label="labels",info="separate labels with ',' "),gr.Image(type="pil")],
                    [gr.Textbox(label="most probable label")],
                    examples=[["حصان, كلب, قطة", "cat.png"]], 
                    theme="ocean",
                    title="AraClip"
                   )

demo.launch(debug=True)