danilohssantana commited on
Commit
a2b6d64
·
1 Parent(s): 3a8bfcd

fixing image loading

Browse files
Files changed (3) hide show
  1. main.py +66 -12
  2. model.py +7 -2
  3. test.py +13 -0
main.py CHANGED
@@ -1,17 +1,19 @@
 
 
 
 
1
  from fastapi import FastAPI, Query
2
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
3
  from qwen_vl_utils import process_vision_info
4
- import torch
5
 
6
  app = FastAPI()
7
 
8
  checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
9
- min_pixels = 256*28*28
10
- max_pixels = 1280*28*28
11
  processor = AutoProcessor.from_pretrained(
12
- checkpoint,
13
- min_pixels=min_pixels,
14
- max_pixels=max_pixels
15
  )
16
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
17
  checkpoint,
@@ -20,17 +22,64 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
20
  # attn_implementation="flash_attention_2",
21
  )
22
 
 
23
  @app.get("/")
24
  def read_root():
25
  return {"message": "API is live. Use the /predict endpoint."}
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  @app.get("/predict")
28
  def predict(image_url: str = Query(...), prompt: str = Query(...)):
 
 
 
29
  messages = [
30
- {"role": "system", "content": "You are a helpful assistant with vision abilities."},
31
- {"role": "user", "content": [{"type": "image", "image": image_url}, {"type": "text", "text": prompt}]},
 
 
 
 
 
 
 
 
 
32
  ]
33
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
34
  image_inputs, video_inputs = process_vision_info(messages)
35
  inputs = processor(
36
  text=[text],
@@ -41,8 +90,13 @@ def predict(image_url: str = Query(...), prompt: str = Query(...)):
41
  ).to(model.device)
42
  with torch.no_grad():
43
  generated_ids = model.generate(**inputs, max_new_tokens=128)
44
- generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
 
 
 
45
  output_texts = processor.batch_decode(
46
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
 
 
47
  )
48
  return {"response": output_texts[0]}
 
1
+ import base64
2
+ from io import BytesIO
3
+
4
+ import torch
5
  from fastapi import FastAPI, Query
6
+ from PIL import Image
7
  from qwen_vl_utils import process_vision_info
8
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
9
 
10
  app = FastAPI()
11
 
12
  checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
13
+ min_pixels = 256 * 28 * 28
14
+ max_pixels = 1280 * 28 * 28
15
  processor = AutoProcessor.from_pretrained(
16
+ checkpoint, min_pixels=min_pixels, max_pixels=max_pixels
 
 
17
  )
18
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
19
  checkpoint,
 
22
  # attn_implementation="flash_attention_2",
23
  )
24
 
25
+
26
  @app.get("/")
27
  def read_root():
28
  return {"message": "API is live. Use the /predict endpoint."}
29
 
30
+
31
+ def encode_image(image_path, max_size=(800, 800), quality=85):
32
+ """
33
+ Converts an image from a local file path to a Base64-encoded string with optimized size.
34
+
35
+ Args:
36
+ image_path (str): The path to the image file.
37
+ max_size (tuple): The maximum width and height of the resized image.
38
+ quality (int): The compression quality (1-100, higher means better quality but bigger size).
39
+
40
+ Returns:
41
+ str: Base64-encoded representation of the optimized image.
42
+ """
43
+ try:
44
+ with Image.open(image_path) as img:
45
+ # Convert to RGB (avoid issues with PNG transparency)
46
+ img = img.convert("RGB")
47
+
48
+ # Resize while maintaining aspect ratio
49
+ img.thumbnail(max_size, Image.LANCZOS)
50
+
51
+ # Save to buffer with compression
52
+ buffer = BytesIO()
53
+ img.save(
54
+ buffer, format="JPEG", quality=quality
55
+ ) # Save as JPEG to reduce size
56
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
57
+ except Exception as e:
58
+ print(f"❌ Error encoding image {image_path}: {e}")
59
+ return None
60
+
61
+
62
  @app.get("/predict")
63
  def predict(image_url: str = Query(...), prompt: str = Query(...)):
64
+
65
+ image = encode_image(image_url)
66
+
67
  messages = [
68
+ {
69
+ "role": "system",
70
+ "content": "You are a helpful assistant with vision abilities.",
71
+ },
72
+ {
73
+ "role": "user",
74
+ "content": [
75
+ {"type": "image", "image": f"data:image;base64,{image}"},
76
+ {"type": "text", "text": prompt},
77
+ ],
78
+ },
79
  ]
80
+ text = processor.apply_chat_template(
81
+ messages, tokenize=False, add_generation_prompt=True
82
+ )
83
  image_inputs, video_inputs = process_vision_info(messages)
84
  inputs = processor(
85
  text=[text],
 
90
  ).to(model.device)
91
  with torch.no_grad():
92
  generated_ids = model.generate(**inputs, max_new_tokens=128)
93
+ generated_ids_trimmed = [
94
+ out_ids[len(in_ids) :]
95
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
96
+ ]
97
  output_texts = processor.batch_decode(
98
+ generated_ids_trimmed,
99
+ skip_special_tokens=True,
100
+ clean_up_tokenization_spaces=False,
101
  )
102
  return {"response": output_texts[0]}
model.py CHANGED
@@ -1,10 +1,15 @@
1
  import requests
2
 
3
- url = "https://<uname>-<spacename>.hf.space/predict"
 
 
 
 
 
4
 
5
  # Define the parameters
6
  params = {
7
- "image_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
8
  "prompt": "describe",
9
  }
10
 
 
1
  import requests
2
 
3
+ # curl -G "https://<uname>-<spacename>.hf.space/predict" \
4
+ # --data-urlencode "image_url=https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" \
5
+ # --data-urlencode "prompt=Describe this image."
6
+
7
+
8
+ url = "https://danilohssantana-qwen2-5-vl-api.hf.space/predict"
9
 
10
  # Define the parameters
11
  params = {
12
+ "image_url": "https://cdn.britannica.com/35/238335-050-2CB2EB8A/Lionel-Messi-Argentina-Netherlands-World-Cup-Qatar-2022.jpg",
13
  "prompt": "describe",
14
  }
15
 
test.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+
3
+ import requests
4
+ from PIL import Image
5
+
6
+ image_url = "https://cdn.britannica.com/35/238335-050-2CB2EB8A/Lionel-Messi-Argentina-Netherlands-World-Cup-Qatar-2022.jpg"
7
+ response = requests.get(image_url, stream=True)
8
+
9
+ if response.status_code == 200:
10
+ image = Image.open(BytesIO(response.content))
11
+ image.show()
12
+ else:
13
+ print(f"Failed to download image. Status code: {response.status_code}")