from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from transformers import AutoProcessor, AutoModelForCausalLM, pipeline from PIL import Image import torch import os import tempfile from gtts import gTTS app = FastAPI() # CORS Configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize models try: processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco") git_model.eval() USE_GIT = True except Exception as e: print(f"[INFO] Falling back to ViT: {e}") captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") USE_GIT = False def generate_caption(image_path: str) -> str: try: if USE_GIT: image = Image.open(image_path).convert("RGB") inputs = processor(images=image, return_tensors="pt") outputs = git_model.generate(**inputs, max_length=50) caption = processor.batch_decode(outputs, skip_special_tokens=True)[0] else: result = captioner(image_path) caption = result[0]['generated_text'] return caption except Exception as e: raise Exception(f"Caption generation failed: {str(e)}") @app.post("/imagecaption/") async def caption_image(file: UploadFile = File(...)): # Validate file type valid_types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp'] if file.content_type not in valid_types: raise HTTPException( status_code=400, detail="Please upload a valid image (JPEG, PNG, GIF, or WEBP)" ) try: # Save temp file with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp: shutil.copyfileobj(file.file, temp) temp_path = temp.name # Generate caption caption = generate_caption(temp_path) # Generate audio audio_path = os.path.join(tempfile.gettempdir(), f"caption_{os.path.basename(temp_path)}.mp3") tts = gTTS(text=caption) tts.save(audio_path) return { "answer": caption, "audio": f"/files/{os.path.basename(audio_path)}" } except HTTPException: raise except Exception as e: raise HTTPException( status_code=500, detail=str(e) ) finally: if 'temp_path' in locals() and os.path.exists(temp_path): os.unlink(temp_path) @app.get("/files/{filename}") async def get_file(filename: str): file_path = os.path.join(tempfile.gettempdir(), filename) if os.path.exists(file_path): return FileResponse(file_path) raise HTTPException(status_code=404, detail="File not found")