breadlicker45 commited on
Commit
b9c7982
·
verified ·
1 Parent(s): c580f5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -14
app.py CHANGED
@@ -3,10 +3,13 @@ from transformers import (
3
  PaliGemmaProcessor,
4
  PaliGemmaForConditionalGeneration,
5
  )
6
- from PIL import Image
7
  import torch
8
  import os
9
  import spaces # Import the spaces module
 
 
 
10
 
11
 
12
  def load_model():
@@ -24,39 +27,48 @@ def load_model():
24
  processor = PaliGemmaProcessor.from_pretrained(model_id, use_auth_token=token)
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
  model = PaliGemmaForConditionalGeneration.from_pretrained(
27
- model_id, use_auth_token=token, torch_dtype=torch.bfloat16
28
- ).to(device)
29
 
30
  return processor, model
31
 
32
 
33
  @spaces.GPU # Decorate the function that uses the GPU
34
- def process_image_and_text(image, text_input):
35
  """Extract text from image using PaliGemma2."""
36
  processor, model = load_model()
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
- # Preprocess the image and text
39
- inputs = processor(text=text_input, images=image, return_tensors="pt").to(
 
 
 
 
 
 
 
 
40
  device, dtype=torch.bfloat16
41
  )
 
42
 
43
- # Generate predictions
44
- with torch.no_grad():
45
- generated_ids = model.generate(**inputs, max_new_tokens=100)
46
- text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
47
 
48
- return text
49
 
50
 
51
  if __name__ == "__main__":
52
  iface = gr.Interface(
53
  fn=process_image_and_text,
54
  inputs=[
55
- gr.Image(type="pil", label="Upload an image containing text"),
56
  gr.Textbox(label="Enter Text Prompt"),
57
  ],
58
- outputs=gr.Textbox(label="Extracted/Generated Text"),
59
- title="Text Reading/Generation with PaliGemma2",
60
  description="Upload an image and enter a text prompt. The model will generate text based on both.",
61
  )
62
  iface.launch()
 
3
  PaliGemmaProcessor,
4
  PaliGemmaForConditionalGeneration,
5
  )
6
+ from transformers.image_utils import load_image
7
  import torch
8
  import os
9
  import spaces # Import the spaces module
10
+ import requests
11
+ from io import BytesIO
12
+ from PIL import Image
13
 
14
 
15
  def load_model():
 
27
  processor = PaliGemmaProcessor.from_pretrained(model_id, use_auth_token=token)
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  model = PaliGemmaForConditionalGeneration.from_pretrained(
30
+ model_id, torch_dtype=torch.bfloat16, use_auth_token=token
31
+ ).to(device).eval()
32
 
33
  return processor, model
34
 
35
 
36
  @spaces.GPU # Decorate the function that uses the GPU
37
+ def process_image_and_text(image_pil, text_input):
38
  """Extract text from image using PaliGemma2."""
39
  processor, model = load_model()
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
+
42
+ # Load the image using load_image
43
+ # Convert PIL image to bytes
44
+ buffered = BytesIO()
45
+ image_pil.save(buffered, format="JPEG")
46
+ image_bytes = buffered.getvalue()
47
+ image = load_image(image_bytes)
48
+
49
+ # Use the provided text input
50
+ model_inputs = processor(text=text_input, images=image, return_tensors="pt").to(
51
  device, dtype=torch.bfloat16
52
  )
53
+ input_len = model_inputs["input_ids"].shape[-1]
54
 
55
+ with torch.inference_mode():
56
+ generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
57
+ generation = generation[0][input_len:]
58
+ decoded = processor.decode(generation, skip_special_tokens=True)
59
 
60
+ return decoded
61
 
62
 
63
  if __name__ == "__main__":
64
  iface = gr.Interface(
65
  fn=process_image_and_text,
66
  inputs=[
67
+ gr.Image(type="pil", label="Upload an image"),
68
  gr.Textbox(label="Enter Text Prompt"),
69
  ],
70
+ outputs=gr.Textbox(label="Generated Text"),
71
+ title="PaliGemma2 Image and Text to Text",
72
  description="Upload an image and enter a text prompt. The model will generate text based on both.",
73
  )
74
  iface.launch()