Upload folder using huggingface_hub
Browse files- collections_operators.py +5 -1
- dataset.py +1 -0
- inference.py +21 -1
- metric.py +1 -0
- metrics.py +73 -0
- operators.py +1 -1
- schema.py +3 -0
- serializers.py +31 -0
- templates.py +4 -0
- tool_calling.py +119 -0
- type_utils.py +5 -0
- types.py +17 -1
- version.py +1 -1
collections_operators.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Any, Generator, List, Optional
|
2 |
|
3 |
from .dict_utils import dict_get, dict_set
|
4 |
from .operators import FieldOperator, StreamOperator
|
@@ -12,6 +12,10 @@ class Dictify(FieldOperator):
|
|
12 |
def process_value(self, tup: Any) -> Any:
|
13 |
return dict(zip(self.with_keys, tup))
|
14 |
|
|
|
|
|
|
|
|
|
15 |
|
16 |
class Wrap(FieldOperator):
|
17 |
inside: str
|
|
|
1 |
+
from typing import Any, Dict, Generator, List, Optional
|
2 |
|
3 |
from .dict_utils import dict_get, dict_set
|
4 |
from .operators import FieldOperator, StreamOperator
|
|
|
12 |
def process_value(self, tup: Any) -> Any:
|
13 |
return dict(zip(self.with_keys, tup))
|
14 |
|
15 |
+
class DictToTuplesList(FieldOperator):
|
16 |
+
|
17 |
+
def process_value(self, dic: Dict) -> Any:
|
18 |
+
return list(dic.items())
|
19 |
|
20 |
class Wrap(FieldOperator):
|
21 |
inside: str
|
dataset.py
CHANGED
@@ -68,6 +68,7 @@ from .system_prompts import __file__ as _
|
|
68 |
from .task import __file__ as _
|
69 |
from .templates import __file__ as _
|
70 |
from .text_utils import __file__ as _
|
|
|
71 |
from .type_utils import __file__ as _
|
72 |
from .types import __file__ as _
|
73 |
from .utils import __file__ as _
|
|
|
68 |
from .task import __file__ as _
|
69 |
from .templates import __file__ as _
|
70 |
from .text_utils import __file__ as _
|
71 |
+
from .tool_calling import __file__ as _
|
72 |
from .type_utils import __file__ as _
|
73 |
from .types import __file__ as _
|
74 |
from .utils import __file__ as _
|
inference.py
CHANGED
@@ -342,6 +342,14 @@ class InferenceEngine(Artifact):
|
|
342 |
}
|
343 |
]
|
344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
346 |
class LogProbInferenceEngine(abc.ABC, Artifact):
|
347 |
"""Abstract base class for inference with log probs."""
|
@@ -3164,12 +3172,14 @@ class LiteLLMInferenceEngine(
|
|
3164 |
# Introduce a slight delay to prevent burstiness
|
3165 |
await asyncio.sleep(0.01)
|
3166 |
messages = self.to_messages(instance)
|
|
|
3167 |
kwargs = self.to_dict([StandardAPIParamsMixin])
|
3168 |
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
3169 |
del kwargs["credentials"]
|
3170 |
try:
|
3171 |
response = await self._completion(
|
3172 |
messages=messages,
|
|
|
3173 |
max_retries=self.max_retries,
|
3174 |
drop_params=False,
|
3175 |
**self.credentials,
|
@@ -3181,8 +3191,17 @@ class LiteLLMInferenceEngine(
|
|
3181 |
) from e
|
3182 |
|
3183 |
usage = response.get("usage", {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3184 |
return TextGenerationInferenceOutput(
|
3185 |
-
prediction=
|
3186 |
input_tokens=usage.get("prompt_tokens"),
|
3187 |
output_tokens=usage.get("completion_tokens"),
|
3188 |
model_name=response.get("model", self.model),
|
@@ -3267,6 +3286,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3267 |
"watsonx-sdk": { # checked from ibm_watsonx_ai.APIClient().foundation_models.ChatModels
|
3268 |
"granite-20b-code-instruct": "ibm/granite-20b-code-instruct",
|
3269 |
"granite-3-2-8b-instruct": "ibm/granite-3-2-8b-instruct",
|
|
|
3270 |
"granite-3-2b-instruct": "ibm/granite-3-2b-instruct",
|
3271 |
"granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
|
3272 |
"granite-34b-code-instruct": "ibm/granite-34b-code-instruct",
|
|
|
342 |
}
|
343 |
]
|
344 |
|
345 |
+
def to_tools(self, instance):
|
346 |
+
task_data = instance.get("task_data")
|
347 |
+
if isinstance(task_data, str):
|
348 |
+
task_data = json.loads(task_data)
|
349 |
+
if "__tools__" in task_data:
|
350 |
+
return task_data["__tools__"]
|
351 |
+
return None
|
352 |
+
|
353 |
|
354 |
class LogProbInferenceEngine(abc.ABC, Artifact):
|
355 |
"""Abstract base class for inference with log probs."""
|
|
|
3172 |
# Introduce a slight delay to prevent burstiness
|
3173 |
await asyncio.sleep(0.01)
|
3174 |
messages = self.to_messages(instance)
|
3175 |
+
tools = self.to_tools(instance)
|
3176 |
kwargs = self.to_dict([StandardAPIParamsMixin])
|
3177 |
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
3178 |
del kwargs["credentials"]
|
3179 |
try:
|
3180 |
response = await self._completion(
|
3181 |
messages=messages,
|
3182 |
+
tools=tools,
|
3183 |
max_retries=self.max_retries,
|
3184 |
drop_params=False,
|
3185 |
**self.credentials,
|
|
|
3191 |
) from e
|
3192 |
|
3193 |
usage = response.get("usage", {})
|
3194 |
+
|
3195 |
+
if tools is None:
|
3196 |
+
prediction = response["choices"][0]["message"]["content"]
|
3197 |
+
else:
|
3198 |
+
try:
|
3199 |
+
func_call = response["choices"][0]["message"]["tool_calls"][0]["function"]
|
3200 |
+
prediction = f'{{"name": "{func_call.name}", "arguments": {func_call.arguments}}}'
|
3201 |
+
except:
|
3202 |
+
prediction = response["choices"][0]["message"]["content"] or ""
|
3203 |
return TextGenerationInferenceOutput(
|
3204 |
+
prediction=prediction,
|
3205 |
input_tokens=usage.get("prompt_tokens"),
|
3206 |
output_tokens=usage.get("completion_tokens"),
|
3207 |
model_name=response.get("model", self.model),
|
|
|
3286 |
"watsonx-sdk": { # checked from ibm_watsonx_ai.APIClient().foundation_models.ChatModels
|
3287 |
"granite-20b-code-instruct": "ibm/granite-20b-code-instruct",
|
3288 |
"granite-3-2-8b-instruct": "ibm/granite-3-2-8b-instruct",
|
3289 |
+
"granite-3-3-8b-instruct": "ibm/granite-3-3-8b-instruct",
|
3290 |
"granite-3-2b-instruct": "ibm/granite-3-2b-instruct",
|
3291 |
"granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
|
3292 |
"granite-34b-code-instruct": "ibm/granite-34b-code-instruct",
|
metric.py
CHANGED
@@ -65,6 +65,7 @@ from .system_prompts import __file__ as _
|
|
65 |
from .task import __file__ as _
|
66 |
from .templates import __file__ as _
|
67 |
from .text_utils import __file__ as _
|
|
|
68 |
from .type_utils import __file__ as _
|
69 |
from .types import __file__ as _
|
70 |
from .utils import __file__ as _
|
|
|
65 |
from .task import __file__ as _
|
66 |
from .templates import __file__ as _
|
67 |
from .text_utils import __file__ as _
|
68 |
+
from .tool_calling import __file__ as _
|
69 |
from .type_utils import __file__ as _
|
70 |
from .types import __file__ as _
|
71 |
from .utils import __file__ as _
|
metrics.py
CHANGED
@@ -63,7 +63,9 @@ from .operators import ArtifactFetcherMixin, Copy, Set
|
|
63 |
from .random_utils import get_seed
|
64 |
from .settings_utils import get_settings
|
65 |
from .stream import MultiStream, Stream
|
|
|
66 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
|
|
67 |
from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
|
68 |
|
69 |
logger = get_logger()
|
@@ -786,6 +788,77 @@ class F1Fast(MapReduceMetric[str, Tuple[int, int]]):
|
|
786 |
|
787 |
return result
|
788 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
789 |
|
790 |
class MetricWithConfidenceInterval(Metric):
|
791 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
|
|
63 |
from .random_utils import get_seed
|
64 |
from .settings_utils import get_settings
|
65 |
from .stream import MultiStream, Stream
|
66 |
+
from .tool_calling import convert_chat_api_format_to_tool
|
67 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
68 |
+
from .types import ToolCall
|
69 |
from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
|
70 |
|
71 |
logger = get_logger()
|
|
|
788 |
|
789 |
return result
|
790 |
|
791 |
+
class ToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
|
792 |
+
main_score = "exact_match"
|
793 |
+
reduction = MeanReduction()
|
794 |
+
prediction_type = ToolCall
|
795 |
+
|
796 |
+
def map(
|
797 |
+
self, prediction: ToolCall, references: List[ToolCall], task_data: Dict[str, Any]
|
798 |
+
) -> Dict[str, float]:
|
799 |
+
|
800 |
+
|
801 |
+
exact_match = float(
|
802 |
+
str(prediction) in [str(reference) for reference in references]
|
803 |
+
)
|
804 |
+
|
805 |
+
tool_choice = float(
|
806 |
+
str(prediction["name"]) in [str(reference["name"]) for reference in references]
|
807 |
+
)
|
808 |
+
|
809 |
+
parameter_choice = 0.0
|
810 |
+
for reference in references:
|
811 |
+
if len(prediction["arguments"]) > 0:
|
812 |
+
|
813 |
+
score = len(set(prediction["arguments"]).intersection(set(reference["arguments"]))) / len(set(prediction["arguments"]))
|
814 |
+
else:
|
815 |
+
score = 1.0
|
816 |
+
if score > parameter_choice:
|
817 |
+
parameter_choice = score
|
818 |
+
|
819 |
+
|
820 |
+
parameter_values = 0.0
|
821 |
+
for reference in references:
|
822 |
+
value_matches = 0
|
823 |
+
for key, val in prediction["arguments"].items():
|
824 |
+
try:
|
825 |
+
if val in reference["arguments"][key] or reference["arguments"][key] in val:
|
826 |
+
value_matches += 1
|
827 |
+
except:
|
828 |
+
pass
|
829 |
+
|
830 |
+
if len(prediction["arguments"]) > 0:
|
831 |
+
|
832 |
+
score = value_matches / len(prediction["arguments"])
|
833 |
+
else:
|
834 |
+
score = 1.0
|
835 |
+
if score > parameter_values:
|
836 |
+
parameter_values = score
|
837 |
+
|
838 |
+
for tool in task_data["__tools__"]:
|
839 |
+
tool = convert_chat_api_format_to_tool(tool)
|
840 |
+
tool_params_types = {}
|
841 |
+
for param in tool["parameters"]:
|
842 |
+
tool_params_types[param["name"]] = param["type"]
|
843 |
+
correct_parameters_types = 0
|
844 |
+
for key, value in prediction["arguments"].items():
|
845 |
+
typing_type = tool_params_types.get(key, Any)
|
846 |
+
if isoftype(value, typing_type):
|
847 |
+
correct_parameters_types += 1
|
848 |
+
if len(prediction["arguments"]) > 0:
|
849 |
+
parameters_types = correct_parameters_types / len(prediction["arguments"])
|
850 |
+
else:
|
851 |
+
parameters_types = 1.0
|
852 |
+
|
853 |
+
|
854 |
+
return {
|
855 |
+
self.main_score: exact_match,
|
856 |
+
"tool_choice": tool_choice,
|
857 |
+
"parameter_choice": parameter_choice,
|
858 |
+
"parameters_types": parameters_types,
|
859 |
+
"parameter_values": parameter_values
|
860 |
+
}
|
861 |
+
|
862 |
|
863 |
class MetricWithConfidenceInterval(Metric):
|
864 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
operators.py
CHANGED
@@ -930,7 +930,7 @@ class Cast(FieldOperator):
|
|
930 |
failure_default: Optional[Any] = "__UNDEFINED__"
|
931 |
|
932 |
def prepare(self):
|
933 |
-
self.types = {"int": int, "float": float, "str": str, "bool": bool}
|
934 |
|
935 |
def process_value(self, value):
|
936 |
try:
|
|
|
930 |
failure_default: Optional[Any] = "__UNDEFINED__"
|
931 |
|
932 |
def prepare(self):
|
933 |
+
self.types = {"int": int, "float": float, "str": str, "bool": bool, "tuple": tuple}
|
934 |
|
935 |
def process_value(self, value):
|
936 |
try:
|
schema.py
CHANGED
@@ -141,6 +141,9 @@ class FinalizeDataset(InstanceOperatorValidator):
|
|
141 |
}
|
142 |
if use_reference_fields:
|
143 |
task_data = {**task_data, **instance["reference_fields"]}
|
|
|
|
|
|
|
144 |
return task_data
|
145 |
|
146 |
def serialize_instance_fields(self, instance, task_data):
|
|
|
141 |
}
|
142 |
if use_reference_fields:
|
143 |
task_data = {**task_data, **instance["reference_fields"]}
|
144 |
+
|
145 |
+
if "__tools__" in instance:
|
146 |
+
task_data["__tools__"] = instance["__tools__"]
|
147 |
return task_data
|
148 |
|
149 |
def serialize_instance_fields(self, instance, task_data):
|
serializers.py
CHANGED
@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Union
|
|
7 |
from .dataclass import AbstractField, Field
|
8 |
from .operators import InstanceFieldOperator
|
9 |
from .settings_utils import get_constants
|
|
|
10 |
from .type_utils import isoftype, to_type_string
|
11 |
from .types import (
|
12 |
Dialog,
|
@@ -16,6 +17,8 @@ from .types import (
|
|
16 |
Number,
|
17 |
SQLDatabase,
|
18 |
Table,
|
|
|
|
|
19 |
Video,
|
20 |
)
|
21 |
|
@@ -161,15 +164,43 @@ class MultiDocumentSerializer(DocumentSerializer):
|
|
161 |
return "\n\n".join(documents)
|
162 |
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
class MultiTypeSerializer(Serializer):
|
165 |
serializers: List[SingleTypeSerializer] = Field(
|
166 |
default_factory=lambda: [
|
167 |
DocumentSerializer(),
|
|
|
168 |
DialogSerializer(),
|
169 |
MultiDocumentSerializer(),
|
170 |
ImageSerializer(),
|
171 |
VideoSerializer(),
|
172 |
TableSerializer(),
|
|
|
173 |
DialogSerializer(),
|
174 |
]
|
175 |
)
|
|
|
7 |
from .dataclass import AbstractField, Field
|
8 |
from .operators import InstanceFieldOperator
|
9 |
from .settings_utils import get_constants
|
10 |
+
from .tool_calling import convert_to_chat_api_format
|
11 |
from .type_utils import isoftype, to_type_string
|
12 |
from .types import (
|
13 |
Dialog,
|
|
|
17 |
Number,
|
18 |
SQLDatabase,
|
19 |
Table,
|
20 |
+
Tool,
|
21 |
+
ToolCall,
|
22 |
Video,
|
23 |
)
|
24 |
|
|
|
164 |
return "\n\n".join(documents)
|
165 |
|
166 |
|
167 |
+
|
168 |
+
class ToolsSerializer(SingleTypeSerializer):
|
169 |
+
|
170 |
+
serialized_type = List[Tool]
|
171 |
+
_requirements_list: List[str] = ["pydantic"]
|
172 |
+
|
173 |
+
def serialize(self, value: List[Tool], instance: Dict[str, Any]) -> str:
|
174 |
+
if "__tools__" not in instance:
|
175 |
+
instance["__tools__"] = []
|
176 |
+
tool = []
|
177 |
+
for tool in value:
|
178 |
+
chat_api_tool = convert_to_chat_api_format(tool=tool)
|
179 |
+
instance["__tools__"].append(
|
180 |
+
chat_api_tool
|
181 |
+
)
|
182 |
+
tool["parameters"] = chat_api_tool["function"]["parameters"]
|
183 |
+
return json.dumps(instance["__tools__"], indent=4)
|
184 |
+
|
185 |
+
class ToolCallSerializer(SingleTypeSerializer):
|
186 |
+
|
187 |
+
serialized_type = ToolCall
|
188 |
+
_requirements_list: List[str] = ["pydantic"]
|
189 |
+
|
190 |
+
def serialize(self, value: ToolCall, instance: Dict[str, Any]) -> str:
|
191 |
+
return json.dumps(value)
|
192 |
+
|
193 |
class MultiTypeSerializer(Serializer):
|
194 |
serializers: List[SingleTypeSerializer] = Field(
|
195 |
default_factory=lambda: [
|
196 |
DocumentSerializer(),
|
197 |
+
ToolCallSerializer(),
|
198 |
DialogSerializer(),
|
199 |
MultiDocumentSerializer(),
|
200 |
ImageSerializer(),
|
201 |
VideoSerializer(),
|
202 |
TableSerializer(),
|
203 |
+
ToolsSerializer(),
|
204 |
DialogSerializer(),
|
205 |
]
|
206 |
)
|
templates.py
CHANGED
@@ -19,6 +19,8 @@ from .serializers import (
|
|
19 |
Serializer,
|
20 |
SQLDatabaseAsSchemaSerializer,
|
21 |
TableSerializer,
|
|
|
|
|
22 |
VideoSerializer,
|
23 |
)
|
24 |
from .settings_utils import get_constants
|
@@ -63,6 +65,8 @@ class Template(InstanceOperator):
|
|
63 |
ImageSerializer(),
|
64 |
VideoSerializer(),
|
65 |
TableSerializer(),
|
|
|
|
|
66 |
DialogSerializer(),
|
67 |
ListSerializer(),
|
68 |
SQLDatabaseAsSchemaSerializer(),
|
|
|
19 |
Serializer,
|
20 |
SQLDatabaseAsSchemaSerializer,
|
21 |
TableSerializer,
|
22 |
+
ToolCallSerializer,
|
23 |
+
ToolsSerializer,
|
24 |
VideoSerializer,
|
25 |
)
|
26 |
from .settings_utils import get_constants
|
|
|
65 |
ImageSerializer(),
|
66 |
VideoSerializer(),
|
67 |
TableSerializer(),
|
68 |
+
ToolCallSerializer(),
|
69 |
+
ToolsSerializer(),
|
70 |
DialogSerializer(),
|
71 |
ListSerializer(),
|
72 |
SQLDatabaseAsSchemaSerializer(),
|
tool_calling.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Type
|
2 |
+
|
3 |
+
from .operators import FieldOperator
|
4 |
+
from .types import Parameter, Tool
|
5 |
+
|
6 |
+
|
7 |
+
def convert_to_chat_api_format(tool: Tool) -> Dict[str, Any]:
|
8 |
+
|
9 |
+
from pydantic import create_model
|
10 |
+
|
11 |
+
field_definitions = {}
|
12 |
+
for param in tool["parameters"]:
|
13 |
+
param_name = param["name"]
|
14 |
+
param_type = param.get("type", Any)
|
15 |
+
field_definitions[param_name] = (param_type, ...) # ... means required in Pydantic
|
16 |
+
|
17 |
+
model = create_model(f"{tool['name']}Params", **field_definitions)
|
18 |
+
|
19 |
+
schema = model.model_json_schema()
|
20 |
+
|
21 |
+
return {
|
22 |
+
"type": "function",
|
23 |
+
"function": {
|
24 |
+
"name": tool["name"],
|
25 |
+
"description": tool["description"],
|
26 |
+
"parameters": schema
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def convert_chat_api_format_to_tool(chat_api_tool: Dict[str, Any]) -> Tool:
|
32 |
+
"""Convert a Chat API formatted tool back to the original Tool structure.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
chat_api_tool: A dictionary representing a tool in Chat API format
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
A Tool dictionary with name, description, and parameters
|
39 |
+
"""
|
40 |
+
# Extract function information
|
41 |
+
function_info = chat_api_tool.get("function", {})
|
42 |
+
name = function_info.get("name", chat_api_tool.get("name", ""))
|
43 |
+
description = function_info.get("description", chat_api_tool.get("description", ""))
|
44 |
+
|
45 |
+
# Extract parameters from schema
|
46 |
+
parameters: List[Parameter] = []
|
47 |
+
schema = function_info.get("parameters", chat_api_tool.get("parameters", ""))
|
48 |
+
properties = schema.get("properties", {})
|
49 |
+
|
50 |
+
for param_name, param_schema in properties.items():
|
51 |
+
# Map JSON schema type to Python type
|
52 |
+
param_type = json_schema_to_python_type(param_schema)
|
53 |
+
|
54 |
+
parameter: Parameter = {
|
55 |
+
"name": param_name,
|
56 |
+
"type": param_type
|
57 |
+
}
|
58 |
+
parameters.append(parameter)
|
59 |
+
|
60 |
+
# Construct and return the Tool
|
61 |
+
tool: Tool = {
|
62 |
+
"name": name,
|
63 |
+
"description": description,
|
64 |
+
"parameters": parameters
|
65 |
+
}
|
66 |
+
|
67 |
+
return tool
|
68 |
+
|
69 |
+
def json_schema_to_python_type(schema: Dict[str, Any]) -> Type:
|
70 |
+
"""Convert JSON schema type to Python type."""
|
71 |
+
from typing import Any, Dict, List, Union
|
72 |
+
|
73 |
+
schema_type = schema.get("type")
|
74 |
+
|
75 |
+
# Handle simple types
|
76 |
+
simple_types = {
|
77 |
+
"string": str,
|
78 |
+
"integer": int,
|
79 |
+
"number": float,
|
80 |
+
"boolean": bool,
|
81 |
+
"null": type(None)
|
82 |
+
}
|
83 |
+
|
84 |
+
if schema_type in simple_types:
|
85 |
+
return simple_types[schema_type]
|
86 |
+
|
87 |
+
# Handle arrays
|
88 |
+
if schema_type == "array":
|
89 |
+
items = schema.get("items", {})
|
90 |
+
if not items:
|
91 |
+
return List[Any]
|
92 |
+
|
93 |
+
item_type = json_schema_to_python_type(items)
|
94 |
+
return List[item_type]
|
95 |
+
|
96 |
+
# Handle objects
|
97 |
+
if schema_type == "object":
|
98 |
+
return Dict[str, Any]
|
99 |
+
|
100 |
+
# Handle unions with anyOf/oneOf
|
101 |
+
if "anyOf" in schema or "oneOf" in schema:
|
102 |
+
union_schemas = schema.get("anyOf", []) or schema.get("oneOf", [])
|
103 |
+
union_types = [json_schema_to_python_type(s) for s in union_schemas]
|
104 |
+
# Use Union for Python 3.9+ or create Union using typing module
|
105 |
+
return Union[tuple(union_types)] if union_types else Any
|
106 |
+
|
107 |
+
# Handle references (simplified)
|
108 |
+
if "$ref" in schema:
|
109 |
+
# In a real implementation, you'd resolve references
|
110 |
+
return Any
|
111 |
+
|
112 |
+
# Default to Any for unrecognized schema types
|
113 |
+
return Any
|
114 |
+
|
115 |
+
|
116 |
+
class ToTool(FieldOperator):
|
117 |
+
|
118 |
+
def process_value(self, value: Dict[str, Any]) -> Tool:
|
119 |
+
return convert_chat_api_format_to_tool(value)
|
type_utils.py
CHANGED
@@ -69,6 +69,8 @@ def is_typed_dict(object):
|
|
69 |
|
70 |
def is_type(object):
|
71 |
"""Checks if the provided object is a type, including generics, Literal, TypedDict, and NewType."""
|
|
|
|
|
72 |
return (
|
73 |
isinstance(object, (type, *_generics_types))
|
74 |
or is_new_type(object)
|
@@ -487,6 +489,9 @@ def isoftype(object, typing_type):
|
|
487 |
if not is_type(typing_type):
|
488 |
raise UnsupportedTypeError(typing_type)
|
489 |
|
|
|
|
|
|
|
490 |
if is_new_type(typing_type):
|
491 |
typing_type = typing_type.__supertype__
|
492 |
|
|
|
69 |
|
70 |
def is_type(object):
|
71 |
"""Checks if the provided object is a type, including generics, Literal, TypedDict, and NewType."""
|
72 |
+
if object is typing.Type:
|
73 |
+
return True
|
74 |
return (
|
75 |
isinstance(object, (type, *_generics_types))
|
76 |
or is_new_type(object)
|
|
|
489 |
if not is_type(typing_type):
|
490 |
raise UnsupportedTypeError(typing_type)
|
491 |
|
492 |
+
if typing_type is typing.Type:
|
493 |
+
return is_type(object)
|
494 |
+
|
495 |
if is_new_type(typing_type):
|
496 |
typing_type = typing_type.__supertype__
|
497 |
|
types.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Any, Dict, List, Literal, NewType, Optional, TypedDict, Union
|
2 |
|
3 |
from .type_utils import register_type
|
4 |
|
@@ -51,6 +51,18 @@ class SQLDatabase(TypedDict):
|
|
51 |
dbms: Optional[str]
|
52 |
data: Optional[Dict[str, Dict]]
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
register_type(Text)
|
56 |
register_type(Number)
|
@@ -64,3 +76,7 @@ register_type(Document)
|
|
64 |
register_type(MultiDocument)
|
65 |
register_type(RagResponse)
|
66 |
register_type(SQLDatabase)
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Literal, NewType, Optional, Type, TypedDict, Union
|
2 |
|
3 |
from .type_utils import register_type
|
4 |
|
|
|
51 |
dbms: Optional[str]
|
52 |
data: Optional[Dict[str, Dict]]
|
53 |
|
54 |
+
class Parameter(TypedDict):
|
55 |
+
name: str
|
56 |
+
type: Optional[Type] # Using actual Python type objects
|
57 |
+
|
58 |
+
class Tool(TypedDict):
|
59 |
+
name: str
|
60 |
+
description: str
|
61 |
+
parameters: List[Parameter]
|
62 |
+
|
63 |
+
class ToolCall(TypedDict):
|
64 |
+
name: str
|
65 |
+
arguments: Dict[str, Any]
|
66 |
|
67 |
register_type(Text)
|
68 |
register_type(Number)
|
|
|
76 |
register_type(MultiDocument)
|
77 |
register_type(RagResponse)
|
78 |
register_type(SQLDatabase)
|
79 |
+
register_type(Parameter)
|
80 |
+
register_type(Tool)
|
81 |
+
register_type(ToolCall)
|
82 |
+
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.22.
|
|
|
1 |
+
version = "1.22.4"
|