from fastapi import FastAPI, UploadFile, Form from fastapi.responses import RedirectResponse, FileResponse, JSONResponse import os import shutil from PIL import Image from transformers import ViltProcessor, ViltForQuestionAnswering, pipeline from gtts import gTTS import easyocr import torch import tempfile import gradio as gr app = FastAPI() # Load VQA Model vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") # Load image captioning model captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") # Load EasyOCR reader reader = easyocr.Reader(['en', 'fr', 'ar']) def classify_question(question: str): question_lower = question.lower() if any(word in question_lower for word in ["text", "say", "written", "read"]): return "ocr" elif any(word in question_lower for word in ["caption", "describe", "what is in the image"]): return "caption" else: return "vqa" def answer_question_from_image(image, question): if image is None or not question.strip(): return "Please upload an image and ask a question.", None mode = classify_question(question) if mode == "ocr": try: result = reader.readtext(image) text = " ".join([entry[1] for entry in result]) answer = text.strip() or "No readable text found." except Exception as e: answer = f"OCR Error: {e}" elif mode == "caption": try: answer = captioner(image)[0]['generated_text'] except Exception as e: answer = f"Captioning error: {e}" else: try: inputs = vqa_processor(image, question, return_tensors="pt") with torch.no_grad(): outputs = vqa_model(**inputs) predicted_id = outputs.logits.argmax(-1).item() answer = vqa_model.config.id2label[predicted_id] except Exception as e: answer = f"VQA error: {e}" try: tts = gTTS(text=answer) with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp: tts.save(tmp.name) audio_path = tmp.name except Exception as e: return f"Answer: {answer}\n\n⚠️ Audio generation error: {e}", None return answer, audio_path def process_image_question(image: Image.Image, question: str): answer, audio_path = answer_question_from_image(image, question) return answer, audio_path gui = gr.Interface( fn=process_image_question, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Textbox(lines=2, placeholder="Ask a question about the image...", label="Question") ], outputs=[ gr.Textbox(label="Answer", lines=5), gr.Audio(label="Answer (Audio)", type="filepath") ], title="🧠 Image QA with Voice", description="Upload an image and ask a question. Works for OCR, captioning, and VQA." ) app = gr.mount_gradio_app(app, gui, path="/") @app.get("/") def home(): return RedirectResponse(url="/")