import gradio as gr import torch from PIL import Image from model import GitBaseCocoModel def generate_captions( image, max_len, num_captions, ): """ Generates captions for the given image. ----- Parameters: image: PIL.Image The image to generate captions for. max_len: int The maximum length of the caption. num_captions: int The number of captions to generate. ----- Returns: list[str] """ device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = "microsoft/git-base-coco" model = GitBaseCocoModel(device, checkpoint) caption = model.generate(image, max_len, num_captions) # Convert list to a single string separated by newlines. caption = "\n".join(caption) return caption title = "Git-Base-COCO Image Captioning" description = "A model for generating captions for images." interface = gr.Interface( fn=generate_captions, inputs=[ gr.inputs.Image(type="pil", label="Image"), gr.inputs.Slider(minimum=20, maximum=100, step=5, default=50, label="Maximum Caption Length"), gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Captions to Generate"), ], outputs=[ gr.outputs.Textbox(label="Caption"), ], title=title, description=description, ) if __name__ == "__main__": interface.launch( enable_queue=True, debug=True )