Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from PIL import Image | |
import torch | |
from fastapi import FastAPI | |
from fastapi.responses import RedirectResponse | |
# Initialize FastAPI | |
app = FastAPI() | |
# Load models - Using microsoft/git-large-coco | |
try: | |
# Load the better model | |
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") | |
git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco") | |
print("Successfully loaded microsoft/git-large-coco model") | |
USE_GIT = True | |
except Exception as e: | |
print(f"Failed to load GIT model: {e}. Falling back to smaller model") | |
captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
USE_GIT = False | |
def generate_caption(image_path): | |
"""Generate caption using the best available model""" | |
try: | |
if USE_GIT: | |
image = Image.open(image_path) | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = git_model.generate(**inputs, max_length=50) | |
return processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
else: | |
result = captioner(image_path) | |
return result[0]['generated_text'] | |
except Exception as e: | |
print(f"Caption generation error: {e}") | |
return "Could not generate caption" | |
def process_image(file_path: str): | |
"""Handle image processing for Gradio interface""" | |
if not file_path: | |
return "Please upload an image first" | |
try: | |
caption = generate_caption(file_path) | |
return f"๐ท Image Caption:\n{caption}" | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
# Gradio Interface | |
with gr.Blocks(title="Image Captioning Service", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# ๐ผ๏ธ Image Captioning Service") | |
gr.Markdown("Upload an image to get automatic captioning") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(label="Upload Image", type="filepath") | |
analyze_btn = gr.Button("Generate Caption", variant="primary") | |
with gr.Column(): | |
output = gr.Textbox(label="Caption Result", lines=5) | |
analyze_btn.click( | |
fn=process_image, | |
inputs=[image_input], | |
outputs=[output] | |
) | |
# Mount Gradio app to FastAPI | |
app = gr.mount_gradio_app(app, demo, path="/") | |
def redirect_to_interface(): | |
return RedirectResponse(url="/") | |