Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import ( | |
PaliGemmaProcessor, | |
PaliGemmaForConditionalGeneration, | |
) | |
from transformers.image_utils import load_image | |
import torch | |
import os | |
import spaces # Import the spaces module | |
import requests | |
from io import BytesIO | |
from PIL import Image | |
def load_model(): | |
"""Load PaliGemma2 model and processor with Hugging Face token.""" | |
token = os.getenv("HUGGINGFACEHUB_API_TOKEN") # Retrieve token from environment variable | |
if not token: | |
raise ValueError( | |
"Hugging Face API token not found. Please set it in the environment variables." | |
) | |
# Load the processor and model using the correct identifier | |
model_id = "google/paligemma2-28b-pt-896" | |
processor = PaliGemmaProcessor.from_pretrained(model_id, use_auth_token=token) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = PaliGemmaForConditionalGeneration.from_pretrained( | |
model_id, torch_dtype=torch.bfloat16, use_auth_token=token | |
).to(device).eval() | |
return processor, model | |
# Decorate the function that uses the GPU | |
def process_image_and_text(image_pil, text_input): | |
"""Extract text from image using PaliGemma2.""" | |
processor, model = load_model() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the image using load_image | |
# Convert PIL image to bytes | |
buffered = BytesIO() | |
image_pil.save(buffered, format="JPEG") | |
image_bytes = buffered.getvalue() | |
image = load_image(image_bytes) | |
# Use the provided text input | |
model_inputs = processor(text=text_input, images=image, return_tensors="pt").to( | |
device, dtype=torch.bfloat16 | |
) | |
input_len = model_inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) | |
generation = generation[0][input_len:] | |
decoded = processor.decode(generation, skip_special_tokens=True) | |
return decoded | |
if __name__ == "__main__": | |
iface = gr.Interface( | |
fn=process_image_and_text, | |
inputs=[ | |
gr.Image(type="pil", label="Upload an image"), | |
gr.Textbox(label="Enter Text Prompt"), | |
], | |
outputs=gr.Textbox(label="Generated Text"), | |
title="PaliGemma2 Image and Text to Text", | |
description="Upload an image and enter a text prompt. The model will generate text based on both.", | |
) | |
iface.launch() |