danilohssantana commited on
Commit
bcf356c
·
1 Parent(s): 7c13927

fixing endpoint parameters

Browse files
Files changed (2) hide show
  1. main.py +9 -1
  2. model.py +2 -1
main.py CHANGED
@@ -8,8 +8,16 @@ from fastapi import FastAPI, File, UploadFile, HTTPException
8
  from qwen_vl_utils import process_vision_info
9
  from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration
10
 
 
 
 
11
  app = FastAPI()
12
 
 
 
 
 
 
13
  checkpoint = "Qwen/Qwen2-VL-2B-Instruct"
14
  min_pixels = 256 * 28 * 28
15
  max_pixels = 1280 * 28 * 28
@@ -86,7 +94,7 @@ async def upload_and_encode_image(file: UploadFile = File(...)):
86
  raise HTTPException(status_code=400, detail=f"Invalid file: {e}")
87
 
88
  @app.post("/predict")
89
- def predict(data: any = Query(...)):
90
  """
91
  Generates a description for an image using the Qwen-2-VL model.
92
 
 
8
  from qwen_vl_utils import process_vision_info
9
  from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration
10
 
11
+ from fastapi import FastAPI, Body
12
+ from pydantic import BaseModel
13
+
14
  app = FastAPI()
15
 
16
+ # Define request model
17
+ class PredictRequest(BaseModel):
18
+ image_base64: str
19
+ prompt: str
20
+
21
  checkpoint = "Qwen/Qwen2-VL-2B-Instruct"
22
  min_pixels = 256 * 28 * 28
23
  max_pixels = 1280 * 28 * 28
 
94
  raise HTTPException(status_code=400, detail=f"Invalid file: {e}")
95
 
96
  @app.post("/predict")
97
+ def predict(data: PredictRequest):
98
  """
99
  Generates a description for an image using the Qwen-2-VL model.
100
 
model.py CHANGED
@@ -28,7 +28,8 @@ predict_payload = {
28
  "prompt": "describe the image",
29
  }
30
 
31
- predict_response = requests.post(f"{BASE_URL}/predict", params=predict_payload)
 
32
 
33
  # Step 4: Print the response
34
  if predict_response.status_code == 200:
 
28
  "prompt": "describe the image",
29
  }
30
 
31
+ predict_response = requests.post(f"{BASE_URL}/predict", json=predict_payload)
32
+
33
 
34
  # Step 4: Print the response
35
  if predict_response.status_code == 200: