File size: 3,003 Bytes
d5d3aa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1795a1a
d5d3aa6
 
 
 
 
 
 
 
 
1795a1a
d5d3aa6
 
 
 
 
 
 
889642a
d5d3aa6
 
 
 
 
889642a
1795a1a
d5d3aa6
889642a
d5d3aa6
 
 
1795a1a
d5d3aa6
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
from PIL import Image
import torch
import os
import tempfile
from gtts import gTTS

app = FastAPI()

# CORS Configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize models
try:
    processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
    git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
    git_model.eval()
    USE_GIT = True
except Exception as e:
    print(f"[INFO] Falling back to ViT: {e}")
    captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
    USE_GIT = False

def generate_caption(image_path: str) -> str:
    try:
        if USE_GIT:
            image = Image.open(image_path).convert("RGB")
            inputs = processor(images=image, return_tensors="pt")
            outputs = git_model.generate(**inputs, max_length=50)
            caption = processor.batch_decode(outputs, skip_special_tokens=True)[0]
        else:
            result = captioner(image_path)
            caption = result[0]['generated_text']
        return caption
    except Exception as e:
        raise Exception(f"Caption generation failed: {str(e)}")

@app.post("/imagecaption/")
async def caption_image(file: UploadFile = File(...)):
    # Validate file type
    valid_types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp']
    if file.content_type not in valid_types:
        raise HTTPException(
            status_code=400,
            detail="Please upload a valid image (JPEG, PNG, GIF, or WEBP)"
        )

    try:
        # Save temp file
        with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp:
            shutil.copyfileobj(file.file, temp)
            temp_path = temp.name

        # Generate caption
        caption = generate_caption(temp_path)
        
        # Generate audio
        audio_path = os.path.join(tempfile.gettempdir(), f"caption_{os.path.basename(temp_path)}.mp3")
        tts = gTTS(text=caption)
        tts.save(audio_path)

        return {
            "answer": caption,
            "audio": f"/files/{os.path.basename(audio_path)}"
        }

    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=str(e)
        )
    finally:
        if 'temp_path' in locals() and os.path.exists(temp_path):
            os.unlink(temp_path)

@app.get("/files/{filename}")
async def get_file(filename: str):
    file_path = os.path.join(tempfile.gettempdir(), filename)
    if os.path.exists(file_path):
        return FileResponse(file_path)
    raise HTTPException(status_code=404, detail="File not found")