import gradio as gr from transformers import AutoProcessor, AutoModelForCausalLM from PIL import Image import torch from fastapi import FastAPI from fastapi.responses import RedirectResponse # Initialize FastAPI app = FastAPI() # Load models - Using microsoft/git-large-coco try: # Load the better model processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco") print("Successfully loaded microsoft/git-large-coco model") USE_GIT = True except Exception as e: print(f"Failed to load GIT model: {e}. Falling back to smaller model") captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") USE_GIT = False def generate_caption(image_path): """Generate caption using the best available model""" try: if USE_GIT: image = Image.open(image_path) inputs = processor(images=image, return_tensors="pt") outputs = git_model.generate(**inputs, max_length=50) return processor.batch_decode(outputs, skip_special_tokens=True)[0] else: result = captioner(image_path) return result[0]['generated_text'] except Exception as e: print(f"Caption generation error: {e}") return "Could not generate caption" def process_image(file_path: str): """Handle image processing for Gradio interface""" if not file_path: return "Please upload an image first" try: caption = generate_caption(file_path) return f"📷 Image Caption:\n{caption}" except Exception as e: return f"Error processing image: {str(e)}" # Gradio Interface with gr.Blocks(title="Image Captioning Service", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🖼️ Image Captioning Service") gr.Markdown("Upload an image to get automatic captioning") with gr.Row(): with gr.Column(): image_input = gr.Image(label="Upload Image", type="filepath") analyze_btn = gr.Button("Generate Caption", variant="primary") with gr.Column(): output = gr.Textbox(label="Caption Result", lines=5) analyze_btn.click( fn=process_image, inputs=[image_input], outputs=[output] ) # Mount Gradio app to FastAPI app = gr.mount_gradio_app(app, demo, path="/") @app.get("/") def redirect_to_interface(): return RedirectResponse(url="/")