File size: 2,381 Bytes
c91d9f3
c580f5e
 
 
 
b9c7982
c91d9f3
 
 
b9c7982
 
 
c91d9f3
 
 
 
 
 
 
 
 
 
 
 
 
dc3c0b2
c580f5e
 
 
b9c7982
 
c91d9f3
 
 
 
 
b9c7982
c91d9f3
 
c580f5e
b9c7982
 
04d8090
 
b9c7982
 
 
c580f5e
33262af
b9c7982
c91d9f3
b9c7982
 
 
 
c91d9f3
b9c7982
c91d9f3
 
 
 
33262af
 
b9c7982
33262af
 
b9c7982
 
33262af
c91d9f3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
import torch
import os
import spaces  # Import the spaces module
import requests
from io import BytesIO
from PIL import Image


def load_model():
    """Load PaliGemma2 model and processor with Hugging Face token."""

    token = os.getenv("HUGGINGFACEHUB_API_TOKEN")  # Retrieve token from environment variable

    if not token:
        raise ValueError(
            "Hugging Face API token not found. Please set it in the environment variables."
        )

    # Load the processor and model using the correct identifier
    model_id = "google/paligemma2-28b-pt-448"
    processor = PaliGemmaProcessor.from_pretrained(model_id, use_auth_token=token)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = PaliGemmaForConditionalGeneration.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, use_auth_token=token
    ).to(device).eval()

    return processor, model


@spaces.GPU  # Decorate the function that uses the GPU
def process_image_and_text(image_pil, text_input):
    """Extract text from image using PaliGemma2."""
    processor, model = load_model()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load the image using load_image
    # We can pass the PIL image directly to load_image
    image = load_image(image_pil)

    # Use the provided text input
    model_inputs = processor(text=text_input, images=image, return_tensors="pt").to(
        device, dtype=torch.bfloat16
    )
    input_len = model_inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
        generation = generation[0][input_len:]
        decoded = processor.decode(generation, skip_special_tokens=True)

    return decoded


if __name__ == "__main__":
    iface = gr.Interface(
        fn=process_image_and_text,
        inputs=[
            gr.Image(type="pil", label="Upload an image"),
            gr.Textbox(label="Enter Text Prompt"),
        ],
        outputs=gr.Textbox(label="Generated Text"),
        title="PaliGemma2 Image and Text to Text",
        description="Upload an image and enter a text prompt. The model will generate text based on both.",
    )
    iface.launch()