Spaces:
Running
Running
Commit
·
bcf356c
1
Parent(s):
7c13927
fixing endpoint parameters
Browse files
main.py
CHANGED
@@ -8,8 +8,16 @@ from fastapi import FastAPI, File, UploadFile, HTTPException
|
|
8 |
from qwen_vl_utils import process_vision_info
|
9 |
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration
|
10 |
|
|
|
|
|
|
|
11 |
app = FastAPI()
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
checkpoint = "Qwen/Qwen2-VL-2B-Instruct"
|
14 |
min_pixels = 256 * 28 * 28
|
15 |
max_pixels = 1280 * 28 * 28
|
@@ -86,7 +94,7 @@ async def upload_and_encode_image(file: UploadFile = File(...)):
|
|
86 |
raise HTTPException(status_code=400, detail=f"Invalid file: {e}")
|
87 |
|
88 |
@app.post("/predict")
|
89 |
-
def predict(data:
|
90 |
"""
|
91 |
Generates a description for an image using the Qwen-2-VL model.
|
92 |
|
|
|
8 |
from qwen_vl_utils import process_vision_info
|
9 |
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration
|
10 |
|
11 |
+
from fastapi import FastAPI, Body
|
12 |
+
from pydantic import BaseModel
|
13 |
+
|
14 |
app = FastAPI()
|
15 |
|
16 |
+
# Define request model
|
17 |
+
class PredictRequest(BaseModel):
|
18 |
+
image_base64: str
|
19 |
+
prompt: str
|
20 |
+
|
21 |
checkpoint = "Qwen/Qwen2-VL-2B-Instruct"
|
22 |
min_pixels = 256 * 28 * 28
|
23 |
max_pixels = 1280 * 28 * 28
|
|
|
94 |
raise HTTPException(status_code=400, detail=f"Invalid file: {e}")
|
95 |
|
96 |
@app.post("/predict")
|
97 |
+
def predict(data: PredictRequest):
|
98 |
"""
|
99 |
Generates a description for an image using the Qwen-2-VL model.
|
100 |
|
model.py
CHANGED
@@ -28,7 +28,8 @@ predict_payload = {
|
|
28 |
"prompt": "describe the image",
|
29 |
}
|
30 |
|
31 |
-
predict_response = requests.post(f"{BASE_URL}/predict",
|
|
|
32 |
|
33 |
# Step 4: Print the response
|
34 |
if predict_response.status_code == 200:
|
|
|
28 |
"prompt": "describe the image",
|
29 |
}
|
30 |
|
31 |
+
predict_response = requests.post(f"{BASE_URL}/predict", json=predict_payload)
|
32 |
+
|
33 |
|
34 |
# Step 4: Print the response
|
35 |
if predict_response.status_code == 200:
|