bla commited on
Commit
40415b8
·
verified ·
1 Parent(s): cf4cb00

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +108 -59
main.py CHANGED
@@ -1,47 +1,47 @@
 
 
 
1
  import base64
 
2
  import logging
 
3
  import time
 
4
  from io import BytesIO
5
 
6
  import torch
7
- from fastapi import Body, FastAPI, File, HTTPException, Query, UploadFile
 
8
  from PIL import Image
9
  from pydantic import BaseModel
10
  from qwen_vl_utils import process_vision_info
11
- from transformers import (
12
- AutoProcessor,
13
- Qwen2_5_VLForConditionalGeneration,
14
- Qwen2VLForConditionalGeneration,
15
- )
16
 
17
  app = FastAPI()
18
 
 
 
19
 
20
- # Define request model
21
  class PredictRequest(BaseModel):
22
  image_base64: list[str]
23
  prompt: str
24
 
25
-
26
- # checkpoint = "Qwen/Qwen2-VL-2B-Instruct"
27
- # min_pixels = 256 * 28 * 28
28
- # max_pixels = 1280 * 28 * 28
29
- # processor = AutoProcessor.from_pretrained(
30
- # checkpoint, min_pixels=min_pixels, max_pixels=max_pixels
31
- # )
32
- # model = Qwen2VLForConditionalGeneration.from_pretrained(
33
- # checkpoint,
34
- # torch_dtype=torch.bfloat16,
35
- # device_map="auto",
36
- # # attn_implementation="flash_attention_2",
37
- # )
38
- # checkpoint = "Qwen/Qwen2.5-VL-7B-Instruct"
39
  checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct-AWQ"
40
  min_pixels = 256 * 28 * 28
41
  max_pixels = 1280 * 28 * 28
 
 
42
  processor = AutoProcessor.from_pretrained(
43
  checkpoint, min_pixels=min_pixels, max_pixels=max_pixels
44
  )
 
 
45
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
46
  checkpoint,
47
  torch_dtype="auto",
@@ -49,12 +49,10 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
49
  # attn_implementation="flash_attention_2",
50
  )
51
 
52
-
53
  @app.get("/")
54
  def read_root():
55
  return {"message": "API is live. Use the /predict endpoint."}
56
 
57
-
58
  def encode_image(image_data: BytesIO, max_size=(800, 800), quality=85):
59
  """
60
  Converts an image from file data to a Base64-encoded string with optimized size.
@@ -69,7 +67,6 @@ def encode_image(image_data: BytesIO, max_size=(800, 800), quality=85):
69
  except Exception as e:
70
  raise HTTPException(status_code=500, detail=f"Error encoding image: {e}")
71
 
72
-
73
  @app.post("/encode-image/")
74
  async def upload_and_encode_image(file: UploadFile = File(...)):
75
  """
@@ -82,46 +79,43 @@ async def upload_and_encode_image(file: UploadFile = File(...)):
82
  except Exception as e:
83
  raise HTTPException(status_code=400, detail=f"Invalid file: {e}")
84
 
85
-
86
  @app.post("/predict")
87
- def predict(data: PredictRequest):
88
  """
89
- Generates a description for an image using the Qwen-2-VL model.
90
-
91
- Args:
92
- data (PredictRequest): The request containing encoded images and a prompt.
93
-
94
- Returns:
95
- dict: The generated description of the image(s).
 
 
 
 
 
 
96
  """
97
-
98
  logging.warning("Calling /predict endpoint...")
99
 
100
- # Ensure image_base64 is a list (even if a single image is provided)
101
- image_list = (
102
- data.image_base64
103
- if isinstance(data.image_base64, list)
104
- else [data.image_base64]
105
- )
106
 
107
- # Create the input message structure with multiple images
108
  messages = [
109
  {
110
  "role": "user",
111
  "content": [
112
  {"type": "image", "image": f"data:image;base64,{image}"}
113
  for image in image_list
114
- ]
115
- + [{"type": "text", "text": data.prompt}],
116
  }
117
  ]
118
 
119
- logging.info("Processing inputs...", len(image_list))
120
 
121
- # Prepare inputs for the model
122
- text = processor.apply_chat_template(
123
- messages, tokenize=False, add_generation_prompt=True
124
- )
125
  image_inputs, video_inputs = process_vision_info(messages)
126
  inputs = processor(
127
  text=[text],
@@ -132,21 +126,76 @@ def predict(data: PredictRequest):
132
  ).to(model.device)
133
 
134
  logging.warning("Starting generation...")
135
-
136
  start_time = time.time()
137
 
138
- # Generate the output
139
  generated_ids = model.generate(**inputs, max_new_tokens=2056)
140
  generated_ids_trimmed = [
141
- out_ids[len(in_ids) :]
142
- for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
143
  ]
144
  output_text = processor.batch_decode(
145
- generated_ids_trimmed,
146
- skip_special_tokens=True,
147
- clean_up_tokenization_spaces=False,
148
  )
149
-
150
- logging.warning(f"Generation completed in {time.time() - start_time:.2f}s.")
151
-
152
- return {"response": output_text[0] if output_text else "No description generated."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -- coding: utf-8 --
3
+
4
  import base64
5
+ import json
6
  import logging
7
+ import os
8
  import time
9
+ import uuid
10
  from io import BytesIO
11
 
12
  import torch
13
+ from fastapi import FastAPI, HTTPException, UploadFile, File
14
+ from fastapi.staticfiles import StaticFiles
15
  from PIL import Image
16
  from pydantic import BaseModel
17
  from qwen_vl_utils import process_vision_info
18
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
19
+
20
+ # Create the temporary folder if it doesn't exist.
21
+ TEMP_DIR = "temp"
22
+ os.makedirs(TEMP_DIR, exist_ok=True)
23
 
24
  app = FastAPI()
25
 
26
+ # Mount the temporary folder so annotated images can be served at /temp/<filename>
27
+ app.mount("/temp", StaticFiles(directory=TEMP_DIR), name="temp")
28
 
29
+ # Define the request model
30
  class PredictRequest(BaseModel):
31
  image_base64: list[str]
32
  prompt: str
33
 
34
+ # Use the desired checkpoint: Qwen/Qwen2.5-VL-3B-Instruct-AWQ
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct-AWQ"
36
  min_pixels = 256 * 28 * 28
37
  max_pixels = 1280 * 28 * 28
38
+
39
+ # Load the processor with the image resolution settings
40
  processor = AutoProcessor.from_pretrained(
41
  checkpoint, min_pixels=min_pixels, max_pixels=max_pixels
42
  )
43
+
44
+ # Load the Qwen2.5-VL model.
45
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
46
  checkpoint,
47
  torch_dtype="auto",
 
49
  # attn_implementation="flash_attention_2",
50
  )
51
 
 
52
  @app.get("/")
53
  def read_root():
54
  return {"message": "API is live. Use the /predict endpoint."}
55
 
 
56
  def encode_image(image_data: BytesIO, max_size=(800, 800), quality=85):
57
  """
58
  Converts an image from file data to a Base64-encoded string with optimized size.
 
67
  except Exception as e:
68
  raise HTTPException(status_code=500, detail=f"Error encoding image: {e}")
69
 
 
70
  @app.post("/encode-image/")
71
  async def upload_and_encode_image(file: UploadFile = File(...)):
72
  """
 
79
  except Exception as e:
80
  raise HTTPException(status_code=400, detail=f"Invalid file: {e}")
81
 
 
82
  @app.post("/predict")
83
+ def predict(data: PredictRequest, annotate: bool = False):
84
  """
85
+ Generates a description (e.g. bounding boxes with labels) for image(s) using Qwen2.5-VL-3B-Instruct-AWQ.
86
+ If 'annotate' is True (as a query parameter), the first image is annotated with the predicted bounding boxes,
87
+ stored in a temporary folder, and its URL is returned.
88
+
89
+ Request:
90
+ - image_base64: List of base64-encoded images.
91
+ - prompt: A prompt string.
92
+
93
+ Response (JSON):
94
+ {
95
+ "response": <text generated by Qwen2.5-VL>,
96
+ "annotated_image_url": "/temp/<filename>" # only if annotate=True
97
+ }
98
  """
 
99
  logging.warning("Calling /predict endpoint...")
100
 
101
+ # Ensure image_base64 is a list.
102
+ image_list = data.image_base64 if isinstance(data.image_base64, list) else [data.image_base64]
 
 
 
 
103
 
104
+ # Create input messages: include all images and then the prompt.
105
  messages = [
106
  {
107
  "role": "user",
108
  "content": [
109
  {"type": "image", "image": f"data:image;base64,{image}"}
110
  for image in image_list
111
+ ] + [{"type": "text", "text": data.prompt}],
 
112
  }
113
  ]
114
 
115
+ logging.info("Processing inputs... Number of images: %d", len(image_list))
116
 
117
+ # Prepare inputs for the model using the processor's chat interface.
118
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
119
  image_inputs, video_inputs = process_vision_info(messages)
120
  inputs = processor(
121
  text=[text],
 
126
  ).to(model.device)
127
 
128
  logging.warning("Starting generation...")
 
129
  start_time = time.time()
130
 
131
+ # Generate output using the model.
132
  generated_ids = model.generate(**inputs, max_new_tokens=2056)
133
  generated_ids_trimmed = [
134
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 
135
  ]
136
  output_text = processor.batch_decode(
137
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
 
 
138
  )
139
+ generation_time = time.time() - start_time
140
+ logging.warning("Generation completed in %.2fs.", generation_time)
141
+
142
+ # The generated output text is expected to be JSON (e.g., list of detections).
143
+ result_text = output_text[0] if output_text else "No description generated."
144
+ response_data = {"response": result_text}
145
+
146
+ if annotate:
147
+ # Decode the first image for annotation.
148
+ try:
149
+ img_str = image_list[0]
150
+ # If the image string contains a data URI prefix, remove it.
151
+ if img_str.startswith("data:image"):
152
+ img_str = img_str.split(",")[1]
153
+ img_data = base64.b64decode(img_str)
154
+ image = Image.open(BytesIO(img_data))
155
+ except Exception as e:
156
+ raise HTTPException(status_code=500, detail=f"Error decoding image for annotation: {e}")
157
+
158
+ # Determine image dimensions (width, height)
159
+ input_wh = image.size
160
+ resolution_wh = input_wh # Assuming no resolution change
161
+
162
+ # Parse the detection result from the model output.
163
+ try:
164
+ detection_result = json.loads(result_text)
165
+ except Exception as e:
166
+ raise HTTPException(status_code=500, detail=f"Error parsing detection result: {e}")
167
+
168
+ # Use the supervision library to create detections and annotate the image.
169
+ try:
170
+ import supervision as sv
171
+ detections = sv.Detections.from_vlm(
172
+ vlm=sv.VLM.QWEN_2_5_VL,
173
+ result=detection_result,
174
+ input_wh=input_wh,
175
+ resolution_wh=resolution_wh
176
+ )
177
+ except Exception as e:
178
+ raise HTTPException(status_code=500, detail=f"Error creating detections: {e}")
179
+
180
+ try:
181
+ box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
182
+ label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
183
+
184
+ annotated_image = image.copy()
185
+ annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
186
+ annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
187
+ except Exception as e:
188
+ raise HTTPException(status_code=500, detail=f"Error annotating image: {e}")
189
+
190
+ # Save the annotated image in the temporary folder.
191
+ try:
192
+ filename = f"{uuid.uuid4()}.jpg"
193
+ filepath = os.path.join(TEMP_DIR, filename)
194
+ annotated_image.save(filepath, format="JPEG")
195
+ except Exception as e:
196
+ raise HTTPException(status_code=500, detail=f"Error saving annotated image: {e}")
197
+
198
+ # Add the annotated image URL to the response.
199
+ response_data["annotated_image_url"] = f"/temp/{filename}"
200
+
201
+ return response_data