Mohammed Tawileh jinhai-2012 Kevin Hu liuhua commited on
Commit
564d9fe
·
1 Parent(s): bc9f8af

fix: improve embedding model validation logic for dataset operations (#3235)

Browse files

What problem does this PR solve?
When creating or updating datasets with custom embedding models (e.g.,
Ollama), the validation logic was too restrictive and prevented valid
models from being used. The previous implementation would reject valid
custom models if they weren't in the predefined list, even when they
existed in TenantLLMService.

Changes:
- Simplify and improve the embedding model validation flow in
create/update endpoints
- Check TenantLLMService for custom models before rejecting
- Make validation logic more consistent between create and update
operations

### What problem does this PR solve?

This fix allows users to successfully create and update datasets with
custom embedding models while maintaining proper validation checks.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Jin Hai <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>
Co-authored-by: liuhua <[email protected]>

Files changed (1) hide show
  1. api/apps/sdk/dataset.py +11 -22
api/apps/sdk/dataset.py CHANGED
@@ -159,21 +159,15 @@ def create(tenant_id):
159
  embd_model = LLMService.query(
160
  llm_name=req["embedding_model"], model_type="embedding"
161
  )
 
 
 
 
 
162
  if not embd_model:
163
  return get_error_data_result(
164
  f"`embedding_model` {req.get('embedding_model')} doesn't exist"
165
  )
166
- if embd_model:
167
- if req[
168
- "embedding_model"
169
- ] not in valid_embedding_models and not TenantLLMService.query(
170
- tenant_id=tenant_id,
171
- model_type="embedding",
172
- llm_name=req.get("embedding_model"),
173
- ):
174
- return get_error_data_result(
175
- f"`embedding_model` {req.get('embedding_model')} doesn't exist"
176
- )
177
  key_mapping = {
178
  "chunk_num": "chunk_count",
179
  "doc_num": "document_count",
@@ -403,21 +397,16 @@ def update(tenant_id, dataset_id):
403
  embd_model = LLMService.query(
404
  llm_name=req["embedding_model"], model_type="embedding"
405
  )
 
 
 
 
 
 
406
  if not embd_model:
407
  return get_error_data_result(
408
  f"`embedding_model` {req.get('embedding_model')} doesn't exist"
409
  )
410
- if embd_model:
411
- if req[
412
- "embedding_model"
413
- ] not in valid_embedding_models and not TenantLLMService.query(
414
- tenant_id=tenant_id,
415
- model_type="embedding",
416
- llm_name=req.get("embedding_model"),
417
- ):
418
- return get_error_data_result(
419
- f"`embedding_model` {req.get('embedding_model')} doesn't exist"
420
- )
421
  req["embd_id"] = req.pop("embedding_model")
422
  if "name" in req:
423
  req["name"] = req["name"].strip()
 
159
  embd_model = LLMService.query(
160
  llm_name=req["embedding_model"], model_type="embedding"
161
  )
162
+ if embd_model:
163
+ if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding",llm_name=req.get("embedding_model"),):
164
+ return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
165
+ if not embd_model:
166
+ embd_model=TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model"))
167
  if not embd_model:
168
  return get_error_data_result(
169
  f"`embedding_model` {req.get('embedding_model')} doesn't exist"
170
  )
 
 
 
 
 
 
 
 
 
 
 
171
  key_mapping = {
172
  "chunk_num": "chunk_count",
173
  "doc_num": "document_count",
 
397
  embd_model = LLMService.query(
398
  llm_name=req["embedding_model"], model_type="embedding"
399
  )
400
+ if embd_model:
401
+ if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding",llm_name=req.get("embedding_model"),):
402
+ return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
403
+ if not embd_model:
404
+ embd_model=TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model"))
405
+
406
  if not embd_model:
407
  return get_error_data_result(
408
  f"`embedding_model` {req.get('embedding_model')} doesn't exist"
409
  )
 
 
 
 
 
 
 
 
 
 
 
410
  req["embd_id"] = req.pop("embedding_model")
411
  if "name" in req:
412
  req["name"] = req["name"].strip()