GhAyoub commited on
Commit
b717804
·
1 Parent(s): ea6bbd5

[OCR API] Optimization pass.

Browse files
Files changed (2) hide show
  1. main.py +39 -10
  2. prompts.txt +0 -0
main.py CHANGED
@@ -2,47 +2,76 @@ from fastapi import FastAPI, Query
2
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
3
  from qwen_vl_utils import process_vision_info
4
  import torch
 
 
 
5
 
6
  app = FastAPI()
7
 
 
8
  checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
9
- min_pixels = 256*28*28
10
- max_pixels = 1280*28*28
 
 
11
  processor = AutoProcessor.from_pretrained(
12
  checkpoint,
13
- min_pixels=min_pixels,
14
- max_pixels=max_pixels
15
  )
 
16
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
17
  checkpoint,
18
- torch_dtype=torch.bfloat16,
19
- device_map="auto",
20
- # attn_implementation="flash_attention_2",
21
  )
22
 
 
 
 
 
 
 
 
 
 
23
  @app.get("/")
24
  def read_root():
25
  return {"message": "API is live. Use the /predict endpoint."}
26
 
 
27
  @app.get("/predict")
28
  def predict(image_url: str = Query(...), prompt: str = Query(...)):
29
  messages = [
30
  {"role": "system", "content": "You are a helpful assistant with vision abilities."},
31
  {"role": "user", "content": [{"type": "image", "image": image_url}, {"type": "text", "text": prompt}]},
32
  ]
 
33
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
34
- image_inputs, video_inputs = process_vision_info(messages)
 
 
 
 
 
35
  inputs = processor(
36
  text=[text],
37
  images=image_inputs,
38
  videos=video_inputs,
39
  padding=True,
 
 
40
  return_tensors="pt",
41
- ).to(model.device)
 
 
42
  with torch.no_grad():
43
- generated_ids = model.generate(**inputs, max_new_tokens=128)
 
44
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
45
  output_texts = processor.batch_decode(
46
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
47
  )
 
48
  return {"response": output_texts[0]}
 
2
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
3
  from qwen_vl_utils import process_vision_info
4
  import torch
5
+ import requests
6
+ from PIL import Image
7
+ from io import BytesIO
8
 
9
  app = FastAPI()
10
 
11
+ # Load model and processor
12
  checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
13
+
14
+ # Check for Metal GPU support on macOS
15
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
16
+
17
  processor = AutoProcessor.from_pretrained(
18
  checkpoint,
19
+ min_pixels=256 * 28 * 28,
20
+ max_pixels=1280 * 28 * 28
21
  )
22
+
23
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
24
  checkpoint,
25
+ torch_dtype=torch.float16 if device == "mps" else torch.bfloat16, # Use float16 on Apple Metal
26
+ device_map={"": 0} if device == "mps" else "cpu",
27
+ attn_implementation="flash_attention_2", # If it supports Mac
28
  )
29
 
30
+
31
+ # Function to load and resize images (reduces processing time)
32
+ def load_and_resize_image(image_url):
33
+ response = requests.get(image_url)
34
+ image = Image.open(BytesIO(response.content)).convert("RGB")
35
+ image = image.resize((512, 512)) # Resize to 512x512 to speed up processing
36
+ return image
37
+
38
+
39
  @app.get("/")
40
  def read_root():
41
  return {"message": "API is live. Use the /predict endpoint."}
42
 
43
+
44
  @app.get("/predict")
45
  def predict(image_url: str = Query(...), prompt: str = Query(...)):
46
  messages = [
47
  {"role": "system", "content": "You are a helpful assistant with vision abilities."},
48
  {"role": "user", "content": [{"type": "image", "image": image_url}, {"type": "text", "text": prompt}]},
49
  ]
50
+
51
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
52
+
53
+ # Process image
54
+ image_inputs = [load_and_resize_image(image_url)]
55
+ video_inputs = None
56
+
57
+ # Process inputs
58
  inputs = processor(
59
  text=[text],
60
  images=image_inputs,
61
  videos=video_inputs,
62
  padding=True,
63
+ truncation=True, # Ensures token limit
64
+ max_length=512, # Prevents excessive memory usage
65
  return_tensors="pt",
66
+ ).to(device)
67
+
68
+ # Generate response
69
  with torch.no_grad():
70
+ generated_ids = model.generate(**inputs, max_new_tokens=64) # Reduced for faster inference
71
+
72
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
73
  output_texts = processor.batch_decode(
74
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
75
  )
76
+
77
  return {"response": output_texts[0]}
prompts.txt ADDED
File without changes