breadlicker45 commited on
Commit
dcd8e07
·
verified ·
1 Parent(s): dc3c0b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -18
app.py CHANGED
@@ -24,37 +24,40 @@ def load_model():
24
 
25
  # Load the processor and model using the correct identifier
26
  model_id = "google/paligemma2-28b-pt-448"
27
- processor = PaliGemmaProcessor.from_pretrained(model_id, use_auth_token=token)
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  model = PaliGemmaForConditionalGeneration.from_pretrained(
30
- model_id, torch_dtype=torch.bfloat16, use_auth_token=token
31
  ).to(device).eval()
32
 
33
  return processor, model
34
 
35
 
36
- @spaces.GPU # Decorate the function that uses the GPU
37
  def process_image_and_text(image_pil, text_input):
38
  """Extract text from image using PaliGemma2."""
39
- processor, model = load_model()
40
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
41
 
42
- # Load the image using load_image
43
- # We can pass the PIL image directly to load_image
44
- image = load_image(image_pil)
45
 
46
- # Use the provided text input
47
- model_inputs = processor(text=text_input, images=image, return_tensors="pt").to(
48
- device, dtype=torch.bfloat16
49
- )
50
- input_len = model_inputs["input_ids"].shape[-1]
51
 
52
- with torch.inference_mode():
53
- generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
54
- generation = generation[0][input_len:]
55
- decoded = processor.decode(generation, skip_special_tokens=True)
56
 
57
- return decoded
 
 
 
58
 
59
 
60
  if __name__ == "__main__":
 
24
 
25
  # Load the processor and model using the correct identifier
26
  model_id = "google/paligemma2-28b-pt-448"
27
+ processor = PaliGemmaProcessor.from_pretrained(model_id, token=token)
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  model = PaliGemmaForConditionalGeneration.from_pretrained(
30
+ model_id, torch_dtype=torch.bfloat16, token=token
31
  ).to(device).eval()
32
 
33
  return processor, model
34
 
35
 
36
+ @spaces.GPU(duration=120) # Increased timeout to 120 seconds
37
  def process_image_and_text(image_pil, text_input):
38
  """Extract text from image using PaliGemma2."""
39
+ try:
40
+ processor, model = load_model()
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
 
43
+ # Load the image using load_image
44
+ image = load_image(image_pil)
 
45
 
46
+ # Use the provided text input
47
+ model_inputs = processor(text=text_input, images=image, return_tensors="pt").to(
48
+ device, dtype=torch.bfloat16
49
+ )
50
+ input_len = model_inputs["input_ids"].shape[-1]
51
 
52
+ with torch.inference_mode():
53
+ generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
54
+ generation = generation[0][input_len:]
55
+ decoded = processor.decode(generation, skip_special_tokens=True)
56
 
57
+ return decoded
58
+ except Exception as e:
59
+ print(f"Error during GPU task: {e}")
60
+ raise gr.Error(f"GPU task failed: {e}")
61
 
62
 
63
  if __name__ == "__main__":