Commit 
							
							·
						
						0a05763
	
1
								Parent(s):
							
							b6c8684
								
Add more web test cases (#3702)
Browse files### What problem does this PR solve?
Test cases about dataset
### Type of change
- [x] Other (please describe): test cases
---------
Signed-off-by: jinhai <[email protected]>
    	
        api/apps/kb_app.py
    CHANGED
    
    | @@ -29,6 +29,7 @@ from api.db.db_models import File | |
| 29 | 
             
            from api.utils.api_utils import get_json_result
         | 
| 30 | 
             
            from api import settings
         | 
| 31 | 
             
            from rag.nlp import search
         | 
|  | |
| 32 |  | 
| 33 |  | 
| 34 | 
             
            @manager.route('/create', methods=['post'])
         | 
| @@ -36,10 +37,19 @@ from rag.nlp import search | |
| 36 | 
             
            @validate_request("name")
         | 
| 37 | 
             
            def create():
         | 
| 38 | 
             
                req = request.json
         | 
| 39 | 
            -
                 | 
| 40 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 41 | 
             
                    KnowledgebaseService.query,
         | 
| 42 | 
            -
                    name= | 
| 43 | 
             
                    tenant_id=current_user.id,
         | 
| 44 | 
             
                    status=StatusEnum.VALID.value)
         | 
| 45 | 
             
                try:
         | 
| @@ -73,7 +83,8 @@ def update(): | |
| 73 | 
             
                    if not KnowledgebaseService.query(
         | 
| 74 | 
             
                            created_by=current_user.id, id=req["kb_id"]):
         | 
| 75 | 
             
                        return get_json_result(
         | 
| 76 | 
            -
                            data=False, message='Only owner of knowledgebase authorized for this operation.', | 
|  | |
| 77 |  | 
| 78 | 
             
                    e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
         | 
| 79 | 
             
                    if not e:
         | 
| @@ -81,7 +92,8 @@ def update(): | |
| 81 | 
             
                            message="Can't find this knowledgebase!")
         | 
| 82 |  | 
| 83 | 
             
                    if req["name"].lower() != kb.name.lower() \
         | 
| 84 | 
            -
                            and len( | 
|  | |
| 85 | 
             
                        return get_data_error_result(
         | 
| 86 | 
             
                            message="Duplicated knowledgebase name.")
         | 
| 87 |  | 
| @@ -152,10 +164,11 @@ def rm(): | |
| 152 | 
             
                    )
         | 
| 153 | 
             
                try:
         | 
| 154 | 
             
                    kbs = KnowledgebaseService.query(
         | 
| 155 | 
            -
             | 
| 156 | 
             
                    if not kbs:
         | 
| 157 | 
             
                        return get_json_result(
         | 
| 158 | 
            -
                            data=False, message='Only owner of knowledgebase authorized for this operation.', | 
|  | |
| 159 |  | 
| 160 | 
             
                    for doc in DocumentService.query(kb_id=req["kb_id"]):
         | 
| 161 | 
             
                        if not DocumentService.remove_document(doc, kbs[0].tenant_id):
         | 
|  | |
| 29 | 
             
            from api.utils.api_utils import get_json_result
         | 
| 30 | 
             
            from api import settings
         | 
| 31 | 
             
            from rag.nlp import search
         | 
| 32 | 
            +
            from api.constants import DATASET_NAME_LIMIT
         | 
| 33 |  | 
| 34 |  | 
| 35 | 
             
            @manager.route('/create', methods=['post'])
         | 
|  | |
| 37 | 
             
            @validate_request("name")
         | 
| 38 | 
             
            def create():
         | 
| 39 | 
             
                req = request.json
         | 
| 40 | 
            +
                dataset_name = req["name"]
         | 
| 41 | 
            +
                if not isinstance(dataset_name, str):
         | 
| 42 | 
            +
                    return get_data_error_result(message="Dataset name must be string.")
         | 
| 43 | 
            +
                if dataset_name == "":
         | 
| 44 | 
            +
                    return get_data_error_result(message="Dataset name can't be empty.")
         | 
| 45 | 
            +
                if len(dataset_name) >= DATASET_NAME_LIMIT:
         | 
| 46 | 
            +
                    return get_data_error_result(
         | 
| 47 | 
            +
                        message=f"Dataset name length is {len(dataset_name)} which is large than {DATASET_NAME_LIMIT}")
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                dataset_name = dataset_name.strip()
         | 
| 50 | 
            +
                dataset_name = duplicate_name(
         | 
| 51 | 
             
                    KnowledgebaseService.query,
         | 
| 52 | 
            +
                    name=dataset_name,
         | 
| 53 | 
             
                    tenant_id=current_user.id,
         | 
| 54 | 
             
                    status=StatusEnum.VALID.value)
         | 
| 55 | 
             
                try:
         | 
|  | |
| 83 | 
             
                    if not KnowledgebaseService.query(
         | 
| 84 | 
             
                            created_by=current_user.id, id=req["kb_id"]):
         | 
| 85 | 
             
                        return get_json_result(
         | 
| 86 | 
            +
                            data=False, message='Only owner of knowledgebase authorized for this operation.',
         | 
| 87 | 
            +
                            code=settings.RetCode.OPERATING_ERROR)
         | 
| 88 |  | 
| 89 | 
             
                    e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
         | 
| 90 | 
             
                    if not e:
         | 
|  | |
| 92 | 
             
                            message="Can't find this knowledgebase!")
         | 
| 93 |  | 
| 94 | 
             
                    if req["name"].lower() != kb.name.lower() \
         | 
| 95 | 
            +
                            and len(
         | 
| 96 | 
            +
                        KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1:
         | 
| 97 | 
             
                        return get_data_error_result(
         | 
| 98 | 
             
                            message="Duplicated knowledgebase name.")
         | 
| 99 |  | 
|  | |
| 164 | 
             
                    )
         | 
| 165 | 
             
                try:
         | 
| 166 | 
             
                    kbs = KnowledgebaseService.query(
         | 
| 167 | 
            +
                        created_by=current_user.id, id=req["kb_id"])
         | 
| 168 | 
             
                    if not kbs:
         | 
| 169 | 
             
                        return get_json_result(
         | 
| 170 | 
            +
                            data=False, message='Only owner of knowledgebase authorized for this operation.',
         | 
| 171 | 
            +
                            code=settings.RetCode.OPERATING_ERROR)
         | 
| 172 |  | 
| 173 | 
             
                    for doc in DocumentService.query(kb_id=req["kb_id"]):
         | 
| 174 | 
             
                        if not DocumentService.remove_document(doc, kbs[0].tenant_id):
         | 
    	
        api/constants.py
    CHANGED
    
    | @@ -23,3 +23,5 @@ API_VERSION = "v1" | |
| 23 | 
             
            RAG_FLOW_SERVICE_NAME = "ragflow"
         | 
| 24 | 
             
            REQUEST_WAIT_SEC = 2
         | 
| 25 | 
             
            REQUEST_MAX_WAIT_SEC = 300
         | 
|  | |
|  | 
|  | |
| 23 | 
             
            RAG_FLOW_SERVICE_NAME = "ragflow"
         | 
| 24 | 
             
            REQUEST_WAIT_SEC = 2
         | 
| 25 | 
             
            REQUEST_MAX_WAIT_SEC = 300
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            DATASET_NAME_LIMIT = 128
         | 
    	
        rag/utils/infinity_conn.py
    CHANGED
    
    | @@ -310,7 +310,9 @@ class InfinityConnection(DocStoreConnection): | |
| 310 | 
             
                        table_name = f"{indexName}_{knowledgebaseId}"
         | 
| 311 | 
             
                        table_instance = db_instance.get_table(table_name)
         | 
| 312 | 
             
                        kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
         | 
| 313 | 
            -
                         | 
|  | |
|  | |
| 314 | 
             
                    self.connPool.release_conn(inf_conn)
         | 
| 315 | 
             
                    res = concat_dataframes(df_list, ["id"])
         | 
| 316 | 
             
                    res_fields = self.getFields(res, res.columns)
         | 
|  | |
| 310 | 
             
                        table_name = f"{indexName}_{knowledgebaseId}"
         | 
| 311 | 
             
                        table_instance = db_instance.get_table(table_name)
         | 
| 312 | 
             
                        kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
         | 
| 313 | 
            +
                        if len(kb_res) != 0 and kb_res.shape[0] > 0:
         | 
| 314 | 
            +
                            df_list.append(kb_res)
         | 
| 315 | 
            +
             | 
| 316 | 
             
                    self.connPool.release_conn(inf_conn)
         | 
| 317 | 
             
                    res = concat_dataframes(df_list, ["id"])
         | 
| 318 | 
             
                    res_fields = self.getFields(res, res.columns)
         | 
    	
        sdk/python/test/test_frontend_api/common.py
    CHANGED
    
    | @@ -3,6 +3,8 @@ import requests | |
| 3 |  | 
| 4 | 
             
            HOST_ADDRESS = os.getenv('HOST_ADDRESS', 'http://127.0.0.1:9380')
         | 
| 5 |  | 
|  | |
|  | |
| 6 | 
             
            def create_dataset(auth, dataset_name):
         | 
| 7 | 
             
                authorization = {"Authorization": auth}
         | 
| 8 | 
             
                url = f"{HOST_ADDRESS}/v1/kb/create"
         | 
| @@ -24,3 +26,9 @@ def rm_dataset(auth, dataset_id): | |
| 24 | 
             
                json = {"kb_id": dataset_id}
         | 
| 25 | 
             
                res = requests.post(url=url, headers=authorization, json=json)
         | 
| 26 | 
             
                return res.json()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 3 |  | 
| 4 | 
             
            HOST_ADDRESS = os.getenv('HOST_ADDRESS', 'http://127.0.0.1:9380')
         | 
| 5 |  | 
| 6 | 
            +
            DATASET_NAME_LIMIT = 128
         | 
| 7 | 
            +
             | 
| 8 | 
             
            def create_dataset(auth, dataset_name):
         | 
| 9 | 
             
                authorization = {"Authorization": auth}
         | 
| 10 | 
             
                url = f"{HOST_ADDRESS}/v1/kb/create"
         | 
|  | |
| 26 | 
             
                json = {"kb_id": dataset_id}
         | 
| 27 | 
             
                res = requests.post(url=url, headers=authorization, json=json)
         | 
| 28 | 
             
                return res.json()
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            def update_dataset(auth, json_req):
         | 
| 31 | 
            +
                authorization = {"Authorization": auth}
         | 
| 32 | 
            +
                url = f"{HOST_ADDRESS}/v1/kb/update"
         | 
| 33 | 
            +
                res = requests.post(url=url, headers=authorization, json=json_req)
         | 
| 34 | 
            +
                return res.json()
         | 
    	
        sdk/python/test/test_frontend_api/test_dataset.py
    CHANGED
    
    | @@ -1,6 +1,8 @@ | |
| 1 | 
            -
            from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset
         | 
| 2 | 
            -
            import  | 
| 3 | 
            -
             | 
|  | |
|  | |
| 4 |  | 
| 5 | 
             
            def test_dataset(get_auth):
         | 
| 6 | 
             
                # create dataset
         | 
| @@ -56,8 +58,76 @@ def test_dataset_1k_dataset(get_auth): | |
| 56 | 
             
                    assert res.get("code") == 0, f"{res.get('message')}"
         | 
| 57 | 
             
                print(f"{len(dataset_list)} datasets are deleted")
         | 
| 58 |  | 
| 59 | 
            -
             | 
| 60 | 
            -
            # create  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 61 | 
             
            # update dataset with different parameters
         | 
| 62 | 
            -
            # create duplicated name dataset
         | 
| 63 | 
            -
            #
         | 
|  | |
| 1 | 
            +
            from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            import pytest
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
            import string
         | 
| 6 |  | 
| 7 | 
             
            def test_dataset(get_auth):
         | 
| 8 | 
             
                # create dataset
         | 
|  | |
| 58 | 
             
                    assert res.get("code") == 0, f"{res.get('message')}"
         | 
| 59 | 
             
                print(f"{len(dataset_list)} datasets are deleted")
         | 
| 60 |  | 
| 61 | 
            +
            def test_duplicated_name_dataset(get_auth):
         | 
| 62 | 
            +
                # create dataset
         | 
| 63 | 
            +
                for i in range(20):
         | 
| 64 | 
            +
                    res = create_dataset(get_auth, "test_create_dataset")
         | 
| 65 | 
            +
                    assert res.get("code") == 0, f"{res.get('message')}"
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                # list dataset
         | 
| 68 | 
            +
                res = list_dataset(get_auth, 1)
         | 
| 69 | 
            +
                data = res.get("data")
         | 
| 70 | 
            +
                dataset_list = []
         | 
| 71 | 
            +
                pattern = r'^test_create_dataset.*'
         | 
| 72 | 
            +
                for item in data:
         | 
| 73 | 
            +
                    dataset_name = item.get("name")
         | 
| 74 | 
            +
                    dataset_id = item.get("id")
         | 
| 75 | 
            +
                    dataset_list.append(dataset_id)
         | 
| 76 | 
            +
                    match = re.match(pattern, dataset_name)
         | 
| 77 | 
            +
                    assert match != None
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                for dataset_id in dataset_list:
         | 
| 80 | 
            +
                    res = rm_dataset(get_auth, dataset_id)
         | 
| 81 | 
            +
                    assert res.get("code") == 0, f"{res.get('message')}"
         | 
| 82 | 
            +
                print(f"{len(dataset_list)} datasets are deleted")
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            def test_invalid_name_dataset(get_auth):
         | 
| 85 | 
            +
                # create dataset
         | 
| 86 | 
            +
                # with pytest.raises(Exception) as e:
         | 
| 87 | 
            +
                res = create_dataset(get_auth, 0)
         | 
| 88 | 
            +
                assert res['code'] == 102
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                res = create_dataset(get_auth, "")
         | 
| 91 | 
            +
                assert res['code'] == 102
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                long_string = ""
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                while len(long_string) <= DATASET_NAME_LIMIT:
         | 
| 96 | 
            +
                    long_string += random.choice(string.ascii_letters + string.digits)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                res = create_dataset(get_auth, long_string)
         | 
| 99 | 
            +
                assert res['code'] == 102
         | 
| 100 | 
            +
                print(res)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            def test_update_different_params_dataset(get_auth):
         | 
| 103 | 
            +
                # create dataset
         | 
| 104 | 
            +
                res = create_dataset(get_auth, "test_create_dataset")
         | 
| 105 | 
            +
                assert res.get("code") == 0, f"{res.get('message')}"
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                # list dataset
         | 
| 108 | 
            +
                page_number = 1
         | 
| 109 | 
            +
                dataset_list = []
         | 
| 110 | 
            +
                while True:
         | 
| 111 | 
            +
                    res = list_dataset(get_auth, page_number)
         | 
| 112 | 
            +
                    data = res.get("data")
         | 
| 113 | 
            +
                    for item in data:
         | 
| 114 | 
            +
                        dataset_id = item.get("id")
         | 
| 115 | 
            +
                        dataset_list.append(dataset_id)
         | 
| 116 | 
            +
                    if len(dataset_list) < page_number * 150:
         | 
| 117 | 
            +
                        break
         | 
| 118 | 
            +
                    page_number += 1
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                print(f"found {len(dataset_list)} datasets")
         | 
| 121 | 
            +
                dataset_id = dataset_list[0]
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                json_req = {"kb_id": dataset_id, "name": "test_update_dataset", "description": "test", "permission": "me", "parser_id": "presentation"}
         | 
| 124 | 
            +
                res = update_dataset(get_auth, json_req)
         | 
| 125 | 
            +
                assert res.get("code") == 0, f"{res.get('message')}"
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                # delete dataset
         | 
| 128 | 
            +
                for dataset_id in dataset_list:
         | 
| 129 | 
            +
                    res = rm_dataset(get_auth, dataset_id)
         | 
| 130 | 
            +
                    assert res.get("code") == 0, f"{res.get('message')}"
         | 
| 131 | 
            +
                print(f"{len(dataset_list)} datasets are deleted")
         | 
| 132 | 
            +
             | 
| 133 | 
             
            # update dataset with different parameters
         | 
|  | |
|  | 
    	
        printEnvironment.sh → show_env.sh
    RENAMED
    
    | @@ -15,7 +15,7 @@ get_distro_info() { | |
| 15 | 
             
                echo "$distro_id $distro_version (Kernel version: $kernel_version)"
         | 
| 16 | 
             
            }
         | 
| 17 |  | 
| 18 | 
            -
            # get Git  | 
| 19 | 
             
            git_repo_name=''
         | 
| 20 | 
             
            if git rev-parse --is-inside-work-tree > /dev/null 2>&1; then
         | 
| 21 | 
             
                git_repo_name=$(basename "$(git rev-parse --show-toplevel)")
         | 
| @@ -48,8 +48,8 @@ else | |
| 48 | 
             
                python_version="Python not installed"
         | 
| 49 | 
             
            fi
         | 
| 50 |  | 
| 51 | 
            -
            # Print all  | 
| 52 | 
            -
            echo "Current  | 
| 53 |  | 
| 54 | 
             
            # get Commit ID
         | 
| 55 | 
             
            git_version=$(git log -1 --pretty=format:'%h')
         | 
|  | |
| 15 | 
             
                echo "$distro_id $distro_version (Kernel version: $kernel_version)"
         | 
| 16 | 
             
            }
         | 
| 17 |  | 
| 18 | 
            +
            # get Git repository name
         | 
| 19 | 
             
            git_repo_name=''
         | 
| 20 | 
             
            if git rev-parse --is-inside-work-tree > /dev/null 2>&1; then
         | 
| 21 | 
             
                git_repo_name=$(basename "$(git rev-parse --show-toplevel)")
         | 
|  | |
| 48 | 
             
                python_version="Python not installed"
         | 
| 49 | 
             
            fi
         | 
| 50 |  | 
| 51 | 
            +
            # Print all information
         | 
| 52 | 
            +
            echo "Current Repository: $git_repo_name"
         | 
| 53 |  | 
| 54 | 
             
            # get Commit ID
         | 
| 55 | 
             
            git_version=$(git log -1 --pretty=format:'%h')
         |