danilohssantana commited on
Commit
d7e7825
·
1 Parent(s): a2b6d64

trying new model

Browse files
Files changed (1) hide show
  1. main.py +68 -15
main.py CHANGED
@@ -5,17 +5,17 @@ import torch
5
  from fastapi import FastAPI, Query
6
  from PIL import Image
7
  from qwen_vl_utils import process_vision_info
8
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
9
 
10
  app = FastAPI()
11
 
12
- checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
13
  min_pixels = 256 * 28 * 28
14
  max_pixels = 1280 * 28 * 28
15
  processor = AutoProcessor.from_pretrained(
16
  checkpoint, min_pixels=min_pixels, max_pixels=max_pixels
17
  )
18
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
19
  checkpoint,
20
  torch_dtype=torch.bfloat16,
21
  device_map="auto",
@@ -58,25 +58,34 @@ def encode_image(image_path, max_size=(800, 800), quality=85):
58
  print(f"❌ Error encoding image {image_path}: {e}")
59
  return None
60
 
61
-
62
  @app.get("/predict")
63
- def predict(image_url: str = Query(...), prompt: str = Query(...)):
 
 
 
 
 
 
 
 
 
 
64
 
65
  image = encode_image(image_url)
66
 
 
 
67
  messages = [
68
- {
69
- "role": "system",
70
- "content": "You are a helpful assistant with vision abilities.",
71
- },
72
  {
73
  "role": "user",
74
  "content": [
75
  {"type": "image", "image": f"data:image;base64,{image}"},
76
  {"type": "text", "text": prompt},
77
  ],
78
- },
79
  ]
 
 
80
  text = processor.apply_chat_template(
81
  messages, tokenize=False, add_generation_prompt=True
82
  )
@@ -87,16 +96,60 @@ def predict(image_url: str = Query(...), prompt: str = Query(...)):
87
  videos=video_inputs,
88
  padding=True,
89
  return_tensors="pt",
90
- ).to(model.device)
91
- with torch.no_grad():
92
- generated_ids = model.generate(**inputs, max_new_tokens=128)
 
93
  generated_ids_trimmed = [
94
  out_ids[len(in_ids) :]
95
  for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
96
  ]
97
- output_texts = processor.batch_decode(
98
  generated_ids_trimmed,
99
  skip_special_tokens=True,
100
  clean_up_tokenization_spaces=False,
101
  )
102
- return {"response": output_texts[0]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from fastapi import FastAPI, Query
6
  from PIL import Image
7
  from qwen_vl_utils import process_vision_info
8
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration
9
 
10
  app = FastAPI()
11
 
12
+ checkpoint = "Qwen/Qwen2-VL-3B-Instruct"
13
  min_pixels = 256 * 28 * 28
14
  max_pixels = 1280 * 28 * 28
15
  processor = AutoProcessor.from_pretrained(
16
  checkpoint, min_pixels=min_pixels, max_pixels=max_pixels
17
  )
18
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
19
  checkpoint,
20
  torch_dtype=torch.bfloat16,
21
  device_map="auto",
 
58
  print(f"❌ Error encoding image {image_path}: {e}")
59
  return None
60
 
 
61
  @app.get("/predict")
62
+ def describe_image_with_qwen2_vl(image_url: str = Query(...), prompt: str = Query(...)):
63
+ """
64
+ Generates a description for an image using the Qwen-2-VL model.
65
+
66
+ Args:
67
+ image_url (str): The URL of the image to describe.
68
+ prompt (str): The text prompt to guide the model's response.
69
+
70
+ Returns:
71
+ str: The generated description of the image.
72
+ """
73
 
74
  image = encode_image(image_url)
75
 
76
+
77
+ # Create the input message structure
78
  messages = [
 
 
 
 
79
  {
80
  "role": "user",
81
  "content": [
82
  {"type": "image", "image": f"data:image;base64,{image}"},
83
  {"type": "text", "text": prompt},
84
  ],
85
+ }
86
  ]
87
+
88
+ # Prepare inputs for the model
89
  text = processor.apply_chat_template(
90
  messages, tokenize=False, add_generation_prompt=True
91
  )
 
96
  videos=video_inputs,
97
  padding=True,
98
  return_tensors="pt",
99
+ ).to("cuda:0")
100
+
101
+ # Generate the output
102
+ generated_ids = model.generate(**inputs, max_new_tokens=2056)
103
  generated_ids_trimmed = [
104
  out_ids[len(in_ids) :]
105
  for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
106
  ]
107
+ output_text = processor.batch_decode(
108
  generated_ids_trimmed,
109
  skip_special_tokens=True,
110
  clean_up_tokenization_spaces=False,
111
  )
112
+
113
+ return {"response": output_text[0] if output_text else "No description generated."}
114
+
115
+ # @app.get("/predict")
116
+ # def predict(image_url: str = Query(...), prompt: str = Query(...)):
117
+
118
+ # image = encode_image(image_url)
119
+
120
+ # messages = [
121
+ # {
122
+ # "role": "system",
123
+ # "content": "You are a helpful assistant with vision abilities.",
124
+ # },
125
+ # {
126
+ # "role": "user",
127
+ # "content": [
128
+ # {"type": "image", "image": f"data:image;base64,{image}"},
129
+ # {"type": "text", "text": prompt},
130
+ # ],
131
+ # },
132
+ # ]
133
+ # text = processor.apply_chat_template(
134
+ # messages, tokenize=False, add_generation_prompt=True
135
+ # )
136
+ # image_inputs, video_inputs = process_vision_info(messages)
137
+ # inputs = processor(
138
+ # text=[text],
139
+ # images=image_inputs,
140
+ # videos=video_inputs,
141
+ # padding=True,
142
+ # return_tensors="pt",
143
+ # ).to(model.device)
144
+ # with torch.no_grad():
145
+ # generated_ids = model.generate(**inputs, max_new_tokens=128)
146
+ # generated_ids_trimmed = [
147
+ # out_ids[len(in_ids) :]
148
+ # for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
149
+ # ]
150
+ # output_texts = processor.batch_decode(
151
+ # generated_ids_trimmed,
152
+ # skip_special_tokens=True,
153
+ # clean_up_tokenization_spaces=False,
154
+ # )
155
+ # return {"response": output_texts[0]}