File size: 2,692 Bytes
2ba0248
 
80a9af6
2ba0248
 
 
 
 
 
 
15825cc
2ba0248
 
e61c1cd
 
67f296e
df7325b
67f296e
2ba0248
e61c1cd
2ba0248
9b516eb
129d16e
 
 
9b516eb
129d16e
9b516eb
129d16e
2ba0248
 
c8ce5f0
2ba0248
 
5e2a1a2
9b516eb
2ba0248
 
 
 
 
64ed02c
 
 
2ba0248
 
 
 
 
e61c1cd
2ba0248
e61c1cd
2ba0248
 
35a385f
 
 
 
 
 
398493c
35a385f
e61c1cd
 
 
2ba0248
 
e61c1cd
2ba0248
 
 
 
 
 
 
 
 
398493c
35a385f
2ba0248
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from sentence_transformers import SentenceTransformer
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
import torch
from torchvision import io
from typing import Dict
from datetime import datetime
import numpy as np
import base64
import os, stat, io

# Load the model in half-precision on the available device(s)
model = AutoModelForVision2Seq.from_pretrained(
    "./SmolVLM-500M-Instruct",
    torch_dtype=torch.float32,
    _attn_implementation="eager",
    device_map="cpu"
)
processor = AutoProcessor.from_pretrained("./SmolVLM-500M-Instruct")

def array_to_image(image_array):
    if image_array is None:
        raise ValueError("No image provided. Please upload an image before submitting.")
    # Convert numpy array to PIL Image
    image = Image.fromarray(np.uint8(image_array)).convert("RGB")
    
    return image

def generate_embeddings(text):
    model = SentenceTransformer('./all-MiniLM-L6-v2')
    embeddings = model.encode(text)
    return embeddings

def describe_image(image_array):
    image = array_to_image(image_array)
    
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
                {"type": "text", "text": "Make a very detailed description of the image."},
            ],
        }
    ]

    prompt = processor.apply_chat_template(messages, add_generation_prompt=True)

    inputs = processor(text=prompt, images=[image], return_tensors="pt")

    # Inference: Generation of the output
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs, 
            max_new_tokens=500,
            num_beams=1,  # Disable beam search 
            do_sample=False,  # Disable sampling
            #temperature=1.0  # Set temperature to 1.0
        )
    output_ids = [
        generated_ids[len(input_ids) :]
        for input_ids, generated_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )   
    # Extract the detailed description from the response
    return output_text, generate_embeddings(output_text)

# Create a Gradio interface
iface = gr.Interface(
    fn=describe_image,
    inputs=gr.Image(),
    outputs=[gr.Textbox(label="Description"), gr.JSON(label="Embeddings")],
    title="Image Description with SmolVLM-500M-Instruct and Textual embeddings with all-MiniLM-L6-v2",
    description="Upload an image to get a detailed description using the SmolVLM-500M-Instruct model."
)

# Launch the app
#iface.launch(share=True)
iface.launch()