ikraamkb commited on
Commit
4a81c80
·
verified ·
1 Parent(s): 974f8bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -92
app.py CHANGED
@@ -1,70 +1,3 @@
1
- """from fastapi import FastAPI, UploadFile, Form
2
- from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
3
- import os
4
- import shutil
5
- from PIL import Image
6
- from transformers import ViltProcessor, ViltForQuestionAnswering
7
- from gtts import gTTS
8
- import torch
9
- import tempfile
10
- import gradio as gr
11
-
12
-
13
-
14
- app = FastAPI()
15
-
16
- # Load VQA Model
17
- vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
18
- vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
19
-
20
-
21
- def answer_question_from_image(image, question):
22
- if image is None or not question.strip():
23
- return "Please upload an image and ask a question.", None
24
-
25
- # Process with model
26
- inputs = vqa_processor(image, question, return_tensors="pt")
27
- with torch.no_grad():
28
- outputs = vqa_model(**inputs)
29
- predicted_id = outputs.logits.argmax(-1).item()
30
- answer = vqa_model.config.id2label[predicted_id]
31
-
32
- # Generate TTS audio
33
- try:
34
- tts = gTTS(text=answer)
35
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
36
- tts.save(tmp.name)
37
- audio_path = tmp.name
38
- except Exception as e:
39
- return f"Answer: {answer}\n\n⚠️ Audio generation error: {e}", None
40
-
41
- return answer, audio_path
42
-
43
-
44
- def process_image_question(image: Image.Image, question: str):
45
- answer, audio_path = answer_question_from_image(image, question)
46
- return answer, audio_path
47
-
48
-
49
- gui = gr.Interface(
50
- fn=process_image_question,
51
- inputs=[
52
- gr.Image(type="pil", label="Upload Image"),
53
- gr.Textbox(lines=2, placeholder="Ask a question about the image...", label="Question")
54
- ],
55
- outputs=[
56
- gr.Textbox(label="Answer", lines=5),
57
- gr.Audio(label="Answer (Audio)", type="filepath")
58
- ],
59
- title="🧠 Image QA with Voice",
60
- description="Upload an image and ask a question. You'll get a text + spoken answer."
61
- )
62
-
63
- app = gr.mount_gradio_app(app, gui, path="/")
64
-
65
- @app.get("/")
66
- def home():
67
- return RedirectResponse(url="/") """
68
  from fastapi import FastAPI, UploadFile, Form
69
  from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
70
  import os
@@ -72,6 +5,7 @@ import shutil
72
  from PIL import Image
73
  from transformers import ViltProcessor, ViltForQuestionAnswering, pipeline
74
  from gtts import gTTS
 
75
  import torch
76
  import tempfile
77
  import gradio as gr
@@ -82,42 +16,56 @@ app = FastAPI()
82
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
83
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
84
 
85
- # Load GPT model for rewriting short answers
86
- gpt_rewriter = pipeline("text-generation", model="EleutherAI/gpt-neo-1.3B")
87
 
88
- def rewrite_answer(question: str, short_answer: str):
89
- prompt = f"Q: {question}\nA: {short_answer}\n\nRespond with a full sentence:"
90
- try:
91
- result = gpt_rewriter(prompt, max_length=50, do_sample=False)
92
- full_sentence = result[0]['generated_text'].split("Respond with a full sentence:")[-1].strip()
93
- return full_sentence
94
- except Exception as e:
95
- return short_answer # fallback
96
 
97
  def answer_question_from_image(image, question):
98
  if image is None or not question.strip():
99
  return "Please upload an image and ask a question.", None
100
 
101
- # Process with model
102
- inputs = vqa_processor(image, question, return_tensors="pt")
103
- with torch.no_grad():
104
- outputs = vqa_model(**inputs)
105
- predicted_id = outputs.logits.argmax(-1).item()
106
- short_answer = vqa_model.config.id2label[predicted_id]
107
-
108
- # Rewrite short answer using GPT
109
- full_answer = rewrite_answer(question, short_answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- # Generate TTS audio
112
  try:
113
- tts = gTTS(text=full_answer)
114
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
115
  tts.save(tmp.name)
116
  audio_path = tmp.name
117
  except Exception as e:
118
- return f"Answer: {full_answer}\n\n⚠️ Audio generation error: {e}", None
119
 
120
- return full_answer, audio_path
121
 
122
  def process_image_question(image: Image.Image, question: str):
123
  answer, audio_path = answer_question_from_image(image, question)
@@ -134,11 +82,11 @@ gui = gr.Interface(
134
  gr.Audio(label="Answer (Audio)", type="filepath")
135
  ],
136
  title="🧠 Image QA with Voice",
137
- description="Upload an image and ask a question. You'll get a detailed text + spoken answer."
138
  )
139
 
140
  app = gr.mount_gradio_app(app, gui, path="/")
141
 
142
  @app.get("/")
143
  def home():
144
- return RedirectResponse(url="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, UploadFile, Form
2
  from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
3
  import os
 
5
  from PIL import Image
6
  from transformers import ViltProcessor, ViltForQuestionAnswering, pipeline
7
  from gtts import gTTS
8
+ import pytesseract
9
  import torch
10
  import tempfile
11
  import gradio as gr
 
16
  vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
17
  vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
18
 
19
+ # Load image captioning model
20
+ captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
21
 
22
+ def classify_question(question: str):
23
+ question_lower = question.lower()
24
+ if any(word in question_lower for word in ["text", "say", "written", "read"]):
25
+ return "ocr"
26
+ elif any(word in question_lower for word in ["caption", "describe", "what is in the image"]):
27
+ return "caption"
28
+ else:
29
+ return "vqa"
30
 
31
  def answer_question_from_image(image, question):
32
  if image is None or not question.strip():
33
  return "Please upload an image and ask a question.", None
34
 
35
+ mode = classify_question(question)
36
+
37
+ if mode == "ocr":
38
+ try:
39
+ text = pytesseract.image_to_string(image)
40
+ answer = text.strip() or "No readable text found."
41
+ except Exception as e:
42
+ answer = f"OCR Error: {e}"
43
+
44
+ elif mode == "caption":
45
+ try:
46
+ answer = captioner(image)[0]['generated_text']
47
+ except Exception as e:
48
+ answer = f"Captioning error: {e}"
49
+
50
+ else:
51
+ try:
52
+ inputs = vqa_processor(image, question, return_tensors="pt")
53
+ with torch.no_grad():
54
+ outputs = vqa_model(**inputs)
55
+ predicted_id = outputs.logits.argmax(-1).item()
56
+ answer = vqa_model.config.id2label[predicted_id]
57
+ except Exception as e:
58
+ answer = f"VQA error: {e}"
59
 
 
60
  try:
61
+ tts = gTTS(text=answer)
62
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
63
  tts.save(tmp.name)
64
  audio_path = tmp.name
65
  except Exception as e:
66
+ return f"Answer: {answer}\n\n⚠️ Audio generation error: {e}", None
67
 
68
+ return answer, audio_path
69
 
70
  def process_image_question(image: Image.Image, question: str):
71
  answer, audio_path = answer_question_from_image(image, question)
 
82
  gr.Audio(label="Answer (Audio)", type="filepath")
83
  ],
84
  title="🧠 Image QA with Voice",
85
+ description="Upload an image and ask a question. Works for OCR, captioning, and VQA."
86
  )
87
 
88
  app = gr.mount_gradio_app(app, gui, path="/")
89
 
90
  @app.get("/")
91
  def home():
92
+ return RedirectResponse(url="/")