gemma-3n-alkdf / app.py
broadfield-dev's picture
Update app.py
e957893 verified
raw
history blame
3.48 kB
import gradio as gr
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
from PIL import Image
import requests
import torch
import io
import os
from huggingface_hub import login
hf_token = os.environ.get("HF_TOKEN")
login(token=hf_token)
# Initialize the model and processor
model_id = "google/gemma-3n-e4b-it"
try:
model = Gemma3nForConditionalGeneration.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.bfloat16
).eval()
processor = AutoProcessor.from_pretrained(model_id)
except Exception as e:
raise Exception(f"Failed to load model or processor: {str(e)}")
def process_inputs(image_input, image_url, text_prompt):
"""
Process image (from file or URL) and text prompt to generate a response using the Gemma model.
Args:
image_input: Uploaded image file
image_url: URL of an image
text_prompt: Text input from the user
Returns:
Generated text response from the model
"""
try:
# Handle image input: prioritize uploaded image, then URL, then None
image = None
if image_input is not None:
image = Image.open(image_input).convert("RGB")
elif image_url:
response = requests.get(image_url, stream=True)
response.raise_for_status()
image = Image.open(io.BytesIO(response.content)).convert("RGB")
# Prepare messages for the model
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": []
}
]
# Add image to content if provided
if image is not None:
messages[1]["content"].append({"type": "image", "image": image})
# Add text prompt if provided
if text_prompt:
messages[1]["content"].append({"type": "text", "text": text_prompt})
else:
return "Please provide a text prompt."
# Process inputs using the processor
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
# Generate response
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=500, do_sample=False)
generation = generation[0][input_len:]
# Decode and return the response
decoded = processor.decode(generation, skip_special_tokens=True)
return decoded
except Exception as e:
return f"Error: {str(e)}"
# Define the Gradio interface
iface = gr.Interface(
fn=process_inputs,
inputs=[
gr.Image(type="file", label="Upload Image (optional)"),
gr.Textbox(label="Image URL (optional)", placeholder="Enter image URL"),
gr.Textbox(label="Text Prompt", placeholder="Enter your prompt here")
],
outputs=gr.Textbox(label="Model Response"),
title="Gemma-3 Multimodal App (Authenticated)",
description="Upload an image or provide an image URL, and enter a text prompt to interact with the Gemma-3 model. Ensure you have authenticated with a valid Hugging Face access token.",
allow_flagging="never"
)
# Launch the app
iface.launch()