fix: improve embedding model validation logic for dataset operations (#3235)
Browse filesWhat problem does this PR solve?
When creating or updating datasets with custom embedding models (e.g.,
Ollama), the validation logic was too restrictive and prevented valid
models from being used. The previous implementation would reject valid
custom models if they weren't in the predefined list, even when they
existed in TenantLLMService.
Changes:
- Simplify and improve the embedding model validation flow in
create/update endpoints
- Check TenantLLMService for custom models before rejecting
- Make validation logic more consistent between create and update
operations
### What problem does this PR solve?
This fix allows users to successfully create and update datasets with
custom embedding models while maintaining proper validation checks.
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
---------
Co-authored-by: Jin Hai <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>
Co-authored-by: liuhua <[email protected]>
- api/apps/sdk/dataset.py +11 -22
@@ -159,21 +159,15 @@ def create(tenant_id):
|
|
159 |
embd_model = LLMService.query(
|
160 |
llm_name=req["embedding_model"], model_type="embedding"
|
161 |
)
|
|
|
|
|
|
|
|
|
|
|
162 |
if not embd_model:
|
163 |
return get_error_data_result(
|
164 |
f"`embedding_model` {req.get('embedding_model')} doesn't exist"
|
165 |
)
|
166 |
-
if embd_model:
|
167 |
-
if req[
|
168 |
-
"embedding_model"
|
169 |
-
] not in valid_embedding_models and not TenantLLMService.query(
|
170 |
-
tenant_id=tenant_id,
|
171 |
-
model_type="embedding",
|
172 |
-
llm_name=req.get("embedding_model"),
|
173 |
-
):
|
174 |
-
return get_error_data_result(
|
175 |
-
f"`embedding_model` {req.get('embedding_model')} doesn't exist"
|
176 |
-
)
|
177 |
key_mapping = {
|
178 |
"chunk_num": "chunk_count",
|
179 |
"doc_num": "document_count",
|
@@ -403,21 +397,16 @@ def update(tenant_id, dataset_id):
|
|
403 |
embd_model = LLMService.query(
|
404 |
llm_name=req["embedding_model"], model_type="embedding"
|
405 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
if not embd_model:
|
407 |
return get_error_data_result(
|
408 |
f"`embedding_model` {req.get('embedding_model')} doesn't exist"
|
409 |
)
|
410 |
-
if embd_model:
|
411 |
-
if req[
|
412 |
-
"embedding_model"
|
413 |
-
] not in valid_embedding_models and not TenantLLMService.query(
|
414 |
-
tenant_id=tenant_id,
|
415 |
-
model_type="embedding",
|
416 |
-
llm_name=req.get("embedding_model"),
|
417 |
-
):
|
418 |
-
return get_error_data_result(
|
419 |
-
f"`embedding_model` {req.get('embedding_model')} doesn't exist"
|
420 |
-
)
|
421 |
req["embd_id"] = req.pop("embedding_model")
|
422 |
if "name" in req:
|
423 |
req["name"] = req["name"].strip()
|
|
|
159 |
embd_model = LLMService.query(
|
160 |
llm_name=req["embedding_model"], model_type="embedding"
|
161 |
)
|
162 |
+
if embd_model:
|
163 |
+
if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding",llm_name=req.get("embedding_model"),):
|
164 |
+
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
|
165 |
+
if not embd_model:
|
166 |
+
embd_model=TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model"))
|
167 |
if not embd_model:
|
168 |
return get_error_data_result(
|
169 |
f"`embedding_model` {req.get('embedding_model')} doesn't exist"
|
170 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
key_mapping = {
|
172 |
"chunk_num": "chunk_count",
|
173 |
"doc_num": "document_count",
|
|
|
397 |
embd_model = LLMService.query(
|
398 |
llm_name=req["embedding_model"], model_type="embedding"
|
399 |
)
|
400 |
+
if embd_model:
|
401 |
+
if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding",llm_name=req.get("embedding_model"),):
|
402 |
+
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
|
403 |
+
if not embd_model:
|
404 |
+
embd_model=TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model"))
|
405 |
+
|
406 |
if not embd_model:
|
407 |
return get_error_data_result(
|
408 |
f"`embedding_model` {req.get('embedding_model')} doesn't exist"
|
409 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
req["embd_id"] = req.pop("embedding_model")
|
411 |
if "name" in req:
|
412 |
req["name"] = req["name"].strip()
|