ImageCaptioning / app.py
krishnv's picture
Update app.py
31e8f8b verified
raw
history blame
1.38 kB
from PIL import Image
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, PreTrainedTokenizerFast
import gradio as gr
# Load the model and processor
model = VisionEncoderDecoderModel.from_pretrained("microsoft/git-base")
feature_extractor = ViTFeatureExtractor.from_pretrained("microsoft/git-base")
tokenizer = PreTrainedTokenizerFast.from_pretrained("microsoft/git-base")
# Define the captioning function
def caption_images(image):
# Preprocess the image
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
# Generate captions
encoder_outputs = model.generate(pixel_values.to('cpu'), num_beams=5)
generated_sentence = tokenizer.batch_decode(encoder_outputs, skip_special_tokens=True)
return generated_sentence[0].strip()
# Define Gradio interface components
inputs = [
gr.inputs.Image(type='pil', label='Original Image')
]
outputs = [
gr.outputs.Textbox(label='Caption')
]
# Define Gradio app properties
title = "Simple Image Captioning Application"
description = "Upload an image to see the caption generated"
example = ['messi.jpg'] # Replace with a valid path to an example image
# Create and launch the Gradio interface
gr.Interface(
fn=caption_images,
inputs=inputs,
outputs=outputs,
title=title,
description=description,
examples=example,
).launch(debug=True)