File size: 4,417 Bytes
067283f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import pytest
from comfy_execution.validation import validate_node_input


def test_exact_match():
    """Test cases where types match exactly"""
    assert validate_node_input("STRING", "STRING")
    assert validate_node_input("STRING,INT", "STRING,INT")
    assert validate_node_input("INT,STRING", "STRING,INT")  # Order shouldn't matter


def test_strict_mode():
    """Test strict mode validation"""
    # Should pass - received type is subset of input type
    assert validate_node_input("STRING", "STRING,INT", strict=True)
    assert validate_node_input("INT", "STRING,INT", strict=True)
    assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True)

    # Should fail - received type is not subset of input type
    assert not validate_node_input("STRING,INT", "STRING", strict=True)
    assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True)
    assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True)


def test_non_strict_mode():
    """Test non-strict mode validation (default behavior)"""
    # Should pass - types have overlap
    assert validate_node_input("STRING,BOOLEAN", "STRING,INT")
    assert validate_node_input("STRING,INT", "INT,BOOLEAN")
    assert validate_node_input("STRING", "STRING,INT")

    # Should fail - no overlap in types
    assert not validate_node_input("BOOLEAN", "STRING,INT")
    assert not validate_node_input("FLOAT", "STRING,INT")
    assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT")


def test_whitespace_handling():
    """Test that whitespace is handled correctly"""
    assert validate_node_input("STRING, INT", "STRING,INT")
    assert validate_node_input("STRING,INT", "STRING, INT")
    assert validate_node_input(" STRING , INT ", "STRING,INT")
    assert validate_node_input("STRING,INT", " STRING , INT ")


def test_empty_strings():
    """Test behavior with empty strings"""
    assert validate_node_input("", "")
    assert not validate_node_input("STRING", "")
    assert not validate_node_input("", "STRING")


def test_single_vs_multiple():
    """Test single type against multiple types"""
    assert validate_node_input("STRING", "STRING,INT,BOOLEAN")
    assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False)
    assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True)


def test_non_string():
    """Test non-string types"""
    obj1 = object()
    obj2 = object()
    assert validate_node_input(obj1, obj1)
    assert not validate_node_input(obj1, obj2)


class NotEqualsOverrideTest(str):
    """Test class for ``__ne__`` override."""

    def __ne__(self, value: object) -> bool:
        if self == "*" or value == "*":
            return False
        if self == "LONGER_THAN_2":
            return not len(value) > 2
        raise TypeError("This is a class for unit tests only.")


def test_ne_override():
    """Test ``__ne__`` any override"""
    any = NotEqualsOverrideTest("*")
    invalid_type = "INVALID_TYPE"
    obj = object()
    assert validate_node_input(any, any)
    assert validate_node_input(any, invalid_type)
    assert validate_node_input(any, obj)
    assert validate_node_input(any, {})
    assert validate_node_input(any, [])
    assert validate_node_input(any, [1, 2, 3])


def test_ne_custom_override():
    """Test ``__ne__`` custom override"""
    special = NotEqualsOverrideTest("LONGER_THAN_2")

    assert validate_node_input(special, special)
    assert validate_node_input(special, "*")
    assert validate_node_input(special, "INVALID_TYPE")
    assert validate_node_input(special, [1, 2, 3])

    # Should fail
    assert not validate_node_input(special, [1, 2])
    assert not validate_node_input(special, "TY")


@pytest.mark.parametrize(

    "received,input_type,strict,expected",

    [

        ("STRING", "STRING", False, True),

        ("STRING,INT", "STRING,INT", False, True),

        ("STRING", "STRING,INT", True, True),

        ("STRING,INT", "STRING", True, False),

        ("BOOLEAN", "STRING,INT", False, False),

        ("STRING,BOOLEAN", "STRING,INT", False, True),

    ],

)
def test_parametrized_cases(received, input_type, strict, expected):
    """Parametrized test cases for various scenarios"""
    assert validate_node_input(received, input_type, strict) == expected