Elron commited on
Commit
66630b0
·
verified ·
1 Parent(s): 43c8216

Upload folder using huggingface_hub

Browse files
Files changed (13) hide show
  1. collections_operators.py +5 -1
  2. dataset.py +1 -0
  3. inference.py +21 -1
  4. metric.py +1 -0
  5. metrics.py +73 -0
  6. operators.py +1 -1
  7. schema.py +3 -0
  8. serializers.py +31 -0
  9. templates.py +4 -0
  10. tool_calling.py +119 -0
  11. type_utils.py +5 -0
  12. types.py +17 -1
  13. version.py +1 -1
collections_operators.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Generator, List, Optional
2
 
3
  from .dict_utils import dict_get, dict_set
4
  from .operators import FieldOperator, StreamOperator
@@ -12,6 +12,10 @@ class Dictify(FieldOperator):
12
  def process_value(self, tup: Any) -> Any:
13
  return dict(zip(self.with_keys, tup))
14
 
 
 
 
 
15
 
16
  class Wrap(FieldOperator):
17
  inside: str
 
1
+ from typing import Any, Dict, Generator, List, Optional
2
 
3
  from .dict_utils import dict_get, dict_set
4
  from .operators import FieldOperator, StreamOperator
 
12
  def process_value(self, tup: Any) -> Any:
13
  return dict(zip(self.with_keys, tup))
14
 
15
+ class DictToTuplesList(FieldOperator):
16
+
17
+ def process_value(self, dic: Dict) -> Any:
18
+ return list(dic.items())
19
 
20
  class Wrap(FieldOperator):
21
  inside: str
dataset.py CHANGED
@@ -68,6 +68,7 @@ from .system_prompts import __file__ as _
68
  from .task import __file__ as _
69
  from .templates import __file__ as _
70
  from .text_utils import __file__ as _
 
71
  from .type_utils import __file__ as _
72
  from .types import __file__ as _
73
  from .utils import __file__ as _
 
68
  from .task import __file__ as _
69
  from .templates import __file__ as _
70
  from .text_utils import __file__ as _
71
+ from .tool_calling import __file__ as _
72
  from .type_utils import __file__ as _
73
  from .types import __file__ as _
74
  from .utils import __file__ as _
inference.py CHANGED
@@ -342,6 +342,14 @@ class InferenceEngine(Artifact):
342
  }
343
  ]
344
 
 
 
 
 
 
 
 
 
345
 
346
  class LogProbInferenceEngine(abc.ABC, Artifact):
347
  """Abstract base class for inference with log probs."""
@@ -3164,12 +3172,14 @@ class LiteLLMInferenceEngine(
3164
  # Introduce a slight delay to prevent burstiness
3165
  await asyncio.sleep(0.01)
3166
  messages = self.to_messages(instance)
 
3167
  kwargs = self.to_dict([StandardAPIParamsMixin])
3168
  kwargs = {k: v for k, v in kwargs.items() if v is not None}
3169
  del kwargs["credentials"]
3170
  try:
3171
  response = await self._completion(
3172
  messages=messages,
 
3173
  max_retries=self.max_retries,
3174
  drop_params=False,
3175
  **self.credentials,
@@ -3181,8 +3191,17 @@ class LiteLLMInferenceEngine(
3181
  ) from e
3182
 
3183
  usage = response.get("usage", {})
 
 
 
 
 
 
 
 
 
3184
  return TextGenerationInferenceOutput(
3185
- prediction=response["choices"][0]["message"]["content"],
3186
  input_tokens=usage.get("prompt_tokens"),
3187
  output_tokens=usage.get("completion_tokens"),
3188
  model_name=response.get("model", self.model),
@@ -3267,6 +3286,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3267
  "watsonx-sdk": { # checked from ibm_watsonx_ai.APIClient().foundation_models.ChatModels
3268
  "granite-20b-code-instruct": "ibm/granite-20b-code-instruct",
3269
  "granite-3-2-8b-instruct": "ibm/granite-3-2-8b-instruct",
 
3270
  "granite-3-2b-instruct": "ibm/granite-3-2b-instruct",
3271
  "granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
3272
  "granite-34b-code-instruct": "ibm/granite-34b-code-instruct",
 
342
  }
343
  ]
344
 
345
+ def to_tools(self, instance):
346
+ task_data = instance.get("task_data")
347
+ if isinstance(task_data, str):
348
+ task_data = json.loads(task_data)
349
+ if "__tools__" in task_data:
350
+ return task_data["__tools__"]
351
+ return None
352
+
353
 
354
  class LogProbInferenceEngine(abc.ABC, Artifact):
355
  """Abstract base class for inference with log probs."""
 
3172
  # Introduce a slight delay to prevent burstiness
3173
  await asyncio.sleep(0.01)
3174
  messages = self.to_messages(instance)
3175
+ tools = self.to_tools(instance)
3176
  kwargs = self.to_dict([StandardAPIParamsMixin])
3177
  kwargs = {k: v for k, v in kwargs.items() if v is not None}
3178
  del kwargs["credentials"]
3179
  try:
3180
  response = await self._completion(
3181
  messages=messages,
3182
+ tools=tools,
3183
  max_retries=self.max_retries,
3184
  drop_params=False,
3185
  **self.credentials,
 
3191
  ) from e
3192
 
3193
  usage = response.get("usage", {})
3194
+
3195
+ if tools is None:
3196
+ prediction = response["choices"][0]["message"]["content"]
3197
+ else:
3198
+ try:
3199
+ func_call = response["choices"][0]["message"]["tool_calls"][0]["function"]
3200
+ prediction = f'{{"name": "{func_call.name}", "arguments": {func_call.arguments}}}'
3201
+ except:
3202
+ prediction = response["choices"][0]["message"]["content"] or ""
3203
  return TextGenerationInferenceOutput(
3204
+ prediction=prediction,
3205
  input_tokens=usage.get("prompt_tokens"),
3206
  output_tokens=usage.get("completion_tokens"),
3207
  model_name=response.get("model", self.model),
 
3286
  "watsonx-sdk": { # checked from ibm_watsonx_ai.APIClient().foundation_models.ChatModels
3287
  "granite-20b-code-instruct": "ibm/granite-20b-code-instruct",
3288
  "granite-3-2-8b-instruct": "ibm/granite-3-2-8b-instruct",
3289
+ "granite-3-3-8b-instruct": "ibm/granite-3-3-8b-instruct",
3290
  "granite-3-2b-instruct": "ibm/granite-3-2b-instruct",
3291
  "granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
3292
  "granite-34b-code-instruct": "ibm/granite-34b-code-instruct",
metric.py CHANGED
@@ -65,6 +65,7 @@ from .system_prompts import __file__ as _
65
  from .task import __file__ as _
66
  from .templates import __file__ as _
67
  from .text_utils import __file__ as _
 
68
  from .type_utils import __file__ as _
69
  from .types import __file__ as _
70
  from .utils import __file__ as _
 
65
  from .task import __file__ as _
66
  from .templates import __file__ as _
67
  from .text_utils import __file__ as _
68
+ from .tool_calling import __file__ as _
69
  from .type_utils import __file__ as _
70
  from .types import __file__ as _
71
  from .utils import __file__ as _
metrics.py CHANGED
@@ -63,7 +63,9 @@ from .operators import ArtifactFetcherMixin, Copy, Set
63
  from .random_utils import get_seed
64
  from .settings_utils import get_settings
65
  from .stream import MultiStream, Stream
 
66
  from .type_utils import Type, isoftype, parse_type_string, to_type_string
 
67
  from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
68
 
69
  logger = get_logger()
@@ -786,6 +788,77 @@ class F1Fast(MapReduceMetric[str, Tuple[int, int]]):
786
 
787
  return result
788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789
 
790
  class MetricWithConfidenceInterval(Metric):
791
  # The number of resamples used to estimate the confidence intervals of this metric.
 
63
  from .random_utils import get_seed
64
  from .settings_utils import get_settings
65
  from .stream import MultiStream, Stream
66
+ from .tool_calling import convert_chat_api_format_to_tool
67
  from .type_utils import Type, isoftype, parse_type_string, to_type_string
68
+ from .types import ToolCall
69
  from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
70
 
71
  logger = get_logger()
 
788
 
789
  return result
790
 
791
+ class ToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
792
+ main_score = "exact_match"
793
+ reduction = MeanReduction()
794
+ prediction_type = ToolCall
795
+
796
+ def map(
797
+ self, prediction: ToolCall, references: List[ToolCall], task_data: Dict[str, Any]
798
+ ) -> Dict[str, float]:
799
+
800
+
801
+ exact_match = float(
802
+ str(prediction) in [str(reference) for reference in references]
803
+ )
804
+
805
+ tool_choice = float(
806
+ str(prediction["name"]) in [str(reference["name"]) for reference in references]
807
+ )
808
+
809
+ parameter_choice = 0.0
810
+ for reference in references:
811
+ if len(prediction["arguments"]) > 0:
812
+
813
+ score = len(set(prediction["arguments"]).intersection(set(reference["arguments"]))) / len(set(prediction["arguments"]))
814
+ else:
815
+ score = 1.0
816
+ if score > parameter_choice:
817
+ parameter_choice = score
818
+
819
+
820
+ parameter_values = 0.0
821
+ for reference in references:
822
+ value_matches = 0
823
+ for key, val in prediction["arguments"].items():
824
+ try:
825
+ if val in reference["arguments"][key] or reference["arguments"][key] in val:
826
+ value_matches += 1
827
+ except:
828
+ pass
829
+
830
+ if len(prediction["arguments"]) > 0:
831
+
832
+ score = value_matches / len(prediction["arguments"])
833
+ else:
834
+ score = 1.0
835
+ if score > parameter_values:
836
+ parameter_values = score
837
+
838
+ for tool in task_data["__tools__"]:
839
+ tool = convert_chat_api_format_to_tool(tool)
840
+ tool_params_types = {}
841
+ for param in tool["parameters"]:
842
+ tool_params_types[param["name"]] = param["type"]
843
+ correct_parameters_types = 0
844
+ for key, value in prediction["arguments"].items():
845
+ typing_type = tool_params_types.get(key, Any)
846
+ if isoftype(value, typing_type):
847
+ correct_parameters_types += 1
848
+ if len(prediction["arguments"]) > 0:
849
+ parameters_types = correct_parameters_types / len(prediction["arguments"])
850
+ else:
851
+ parameters_types = 1.0
852
+
853
+
854
+ return {
855
+ self.main_score: exact_match,
856
+ "tool_choice": tool_choice,
857
+ "parameter_choice": parameter_choice,
858
+ "parameters_types": parameters_types,
859
+ "parameter_values": parameter_values
860
+ }
861
+
862
 
863
  class MetricWithConfidenceInterval(Metric):
864
  # The number of resamples used to estimate the confidence intervals of this metric.
operators.py CHANGED
@@ -930,7 +930,7 @@ class Cast(FieldOperator):
930
  failure_default: Optional[Any] = "__UNDEFINED__"
931
 
932
  def prepare(self):
933
- self.types = {"int": int, "float": float, "str": str, "bool": bool}
934
 
935
  def process_value(self, value):
936
  try:
 
930
  failure_default: Optional[Any] = "__UNDEFINED__"
931
 
932
  def prepare(self):
933
+ self.types = {"int": int, "float": float, "str": str, "bool": bool, "tuple": tuple}
934
 
935
  def process_value(self, value):
936
  try:
schema.py CHANGED
@@ -141,6 +141,9 @@ class FinalizeDataset(InstanceOperatorValidator):
141
  }
142
  if use_reference_fields:
143
  task_data = {**task_data, **instance["reference_fields"]}
 
 
 
144
  return task_data
145
 
146
  def serialize_instance_fields(self, instance, task_data):
 
141
  }
142
  if use_reference_fields:
143
  task_data = {**task_data, **instance["reference_fields"]}
144
+
145
+ if "__tools__" in instance:
146
+ task_data["__tools__"] = instance["__tools__"]
147
  return task_data
148
 
149
  def serialize_instance_fields(self, instance, task_data):
serializers.py CHANGED
@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Union
7
  from .dataclass import AbstractField, Field
8
  from .operators import InstanceFieldOperator
9
  from .settings_utils import get_constants
 
10
  from .type_utils import isoftype, to_type_string
11
  from .types import (
12
  Dialog,
@@ -16,6 +17,8 @@ from .types import (
16
  Number,
17
  SQLDatabase,
18
  Table,
 
 
19
  Video,
20
  )
21
 
@@ -161,15 +164,43 @@ class MultiDocumentSerializer(DocumentSerializer):
161
  return "\n\n".join(documents)
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  class MultiTypeSerializer(Serializer):
165
  serializers: List[SingleTypeSerializer] = Field(
166
  default_factory=lambda: [
167
  DocumentSerializer(),
 
168
  DialogSerializer(),
169
  MultiDocumentSerializer(),
170
  ImageSerializer(),
171
  VideoSerializer(),
172
  TableSerializer(),
 
173
  DialogSerializer(),
174
  ]
175
  )
 
7
  from .dataclass import AbstractField, Field
8
  from .operators import InstanceFieldOperator
9
  from .settings_utils import get_constants
10
+ from .tool_calling import convert_to_chat_api_format
11
  from .type_utils import isoftype, to_type_string
12
  from .types import (
13
  Dialog,
 
17
  Number,
18
  SQLDatabase,
19
  Table,
20
+ Tool,
21
+ ToolCall,
22
  Video,
23
  )
24
 
 
164
  return "\n\n".join(documents)
165
 
166
 
167
+
168
+ class ToolsSerializer(SingleTypeSerializer):
169
+
170
+ serialized_type = List[Tool]
171
+ _requirements_list: List[str] = ["pydantic"]
172
+
173
+ def serialize(self, value: List[Tool], instance: Dict[str, Any]) -> str:
174
+ if "__tools__" not in instance:
175
+ instance["__tools__"] = []
176
+ tool = []
177
+ for tool in value:
178
+ chat_api_tool = convert_to_chat_api_format(tool=tool)
179
+ instance["__tools__"].append(
180
+ chat_api_tool
181
+ )
182
+ tool["parameters"] = chat_api_tool["function"]["parameters"]
183
+ return json.dumps(instance["__tools__"], indent=4)
184
+
185
+ class ToolCallSerializer(SingleTypeSerializer):
186
+
187
+ serialized_type = ToolCall
188
+ _requirements_list: List[str] = ["pydantic"]
189
+
190
+ def serialize(self, value: ToolCall, instance: Dict[str, Any]) -> str:
191
+ return json.dumps(value)
192
+
193
  class MultiTypeSerializer(Serializer):
194
  serializers: List[SingleTypeSerializer] = Field(
195
  default_factory=lambda: [
196
  DocumentSerializer(),
197
+ ToolCallSerializer(),
198
  DialogSerializer(),
199
  MultiDocumentSerializer(),
200
  ImageSerializer(),
201
  VideoSerializer(),
202
  TableSerializer(),
203
+ ToolsSerializer(),
204
  DialogSerializer(),
205
  ]
206
  )
templates.py CHANGED
@@ -19,6 +19,8 @@ from .serializers import (
19
  Serializer,
20
  SQLDatabaseAsSchemaSerializer,
21
  TableSerializer,
 
 
22
  VideoSerializer,
23
  )
24
  from .settings_utils import get_constants
@@ -63,6 +65,8 @@ class Template(InstanceOperator):
63
  ImageSerializer(),
64
  VideoSerializer(),
65
  TableSerializer(),
 
 
66
  DialogSerializer(),
67
  ListSerializer(),
68
  SQLDatabaseAsSchemaSerializer(),
 
19
  Serializer,
20
  SQLDatabaseAsSchemaSerializer,
21
  TableSerializer,
22
+ ToolCallSerializer,
23
+ ToolsSerializer,
24
  VideoSerializer,
25
  )
26
  from .settings_utils import get_constants
 
65
  ImageSerializer(),
66
  VideoSerializer(),
67
  TableSerializer(),
68
+ ToolCallSerializer(),
69
+ ToolsSerializer(),
70
  DialogSerializer(),
71
  ListSerializer(),
72
  SQLDatabaseAsSchemaSerializer(),
tool_calling.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Type
2
+
3
+ from .operators import FieldOperator
4
+ from .types import Parameter, Tool
5
+
6
+
7
+ def convert_to_chat_api_format(tool: Tool) -> Dict[str, Any]:
8
+
9
+ from pydantic import create_model
10
+
11
+ field_definitions = {}
12
+ for param in tool["parameters"]:
13
+ param_name = param["name"]
14
+ param_type = param.get("type", Any)
15
+ field_definitions[param_name] = (param_type, ...) # ... means required in Pydantic
16
+
17
+ model = create_model(f"{tool['name']}Params", **field_definitions)
18
+
19
+ schema = model.model_json_schema()
20
+
21
+ return {
22
+ "type": "function",
23
+ "function": {
24
+ "name": tool["name"],
25
+ "description": tool["description"],
26
+ "parameters": schema
27
+ }
28
+ }
29
+
30
+
31
+ def convert_chat_api_format_to_tool(chat_api_tool: Dict[str, Any]) -> Tool:
32
+ """Convert a Chat API formatted tool back to the original Tool structure.
33
+
34
+ Args:
35
+ chat_api_tool: A dictionary representing a tool in Chat API format
36
+
37
+ Returns:
38
+ A Tool dictionary with name, description, and parameters
39
+ """
40
+ # Extract function information
41
+ function_info = chat_api_tool.get("function", {})
42
+ name = function_info.get("name", chat_api_tool.get("name", ""))
43
+ description = function_info.get("description", chat_api_tool.get("description", ""))
44
+
45
+ # Extract parameters from schema
46
+ parameters: List[Parameter] = []
47
+ schema = function_info.get("parameters", chat_api_tool.get("parameters", ""))
48
+ properties = schema.get("properties", {})
49
+
50
+ for param_name, param_schema in properties.items():
51
+ # Map JSON schema type to Python type
52
+ param_type = json_schema_to_python_type(param_schema)
53
+
54
+ parameter: Parameter = {
55
+ "name": param_name,
56
+ "type": param_type
57
+ }
58
+ parameters.append(parameter)
59
+
60
+ # Construct and return the Tool
61
+ tool: Tool = {
62
+ "name": name,
63
+ "description": description,
64
+ "parameters": parameters
65
+ }
66
+
67
+ return tool
68
+
69
+ def json_schema_to_python_type(schema: Dict[str, Any]) -> Type:
70
+ """Convert JSON schema type to Python type."""
71
+ from typing import Any, Dict, List, Union
72
+
73
+ schema_type = schema.get("type")
74
+
75
+ # Handle simple types
76
+ simple_types = {
77
+ "string": str,
78
+ "integer": int,
79
+ "number": float,
80
+ "boolean": bool,
81
+ "null": type(None)
82
+ }
83
+
84
+ if schema_type in simple_types:
85
+ return simple_types[schema_type]
86
+
87
+ # Handle arrays
88
+ if schema_type == "array":
89
+ items = schema.get("items", {})
90
+ if not items:
91
+ return List[Any]
92
+
93
+ item_type = json_schema_to_python_type(items)
94
+ return List[item_type]
95
+
96
+ # Handle objects
97
+ if schema_type == "object":
98
+ return Dict[str, Any]
99
+
100
+ # Handle unions with anyOf/oneOf
101
+ if "anyOf" in schema or "oneOf" in schema:
102
+ union_schemas = schema.get("anyOf", []) or schema.get("oneOf", [])
103
+ union_types = [json_schema_to_python_type(s) for s in union_schemas]
104
+ # Use Union for Python 3.9+ or create Union using typing module
105
+ return Union[tuple(union_types)] if union_types else Any
106
+
107
+ # Handle references (simplified)
108
+ if "$ref" in schema:
109
+ # In a real implementation, you'd resolve references
110
+ return Any
111
+
112
+ # Default to Any for unrecognized schema types
113
+ return Any
114
+
115
+
116
+ class ToTool(FieldOperator):
117
+
118
+ def process_value(self, value: Dict[str, Any]) -> Tool:
119
+ return convert_chat_api_format_to_tool(value)
type_utils.py CHANGED
@@ -69,6 +69,8 @@ def is_typed_dict(object):
69
 
70
  def is_type(object):
71
  """Checks if the provided object is a type, including generics, Literal, TypedDict, and NewType."""
 
 
72
  return (
73
  isinstance(object, (type, *_generics_types))
74
  or is_new_type(object)
@@ -487,6 +489,9 @@ def isoftype(object, typing_type):
487
  if not is_type(typing_type):
488
  raise UnsupportedTypeError(typing_type)
489
 
 
 
 
490
  if is_new_type(typing_type):
491
  typing_type = typing_type.__supertype__
492
 
 
69
 
70
  def is_type(object):
71
  """Checks if the provided object is a type, including generics, Literal, TypedDict, and NewType."""
72
+ if object is typing.Type:
73
+ return True
74
  return (
75
  isinstance(object, (type, *_generics_types))
76
  or is_new_type(object)
 
489
  if not is_type(typing_type):
490
  raise UnsupportedTypeError(typing_type)
491
 
492
+ if typing_type is typing.Type:
493
+ return is_type(object)
494
+
495
  if is_new_type(typing_type):
496
  typing_type = typing_type.__supertype__
497
 
types.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Literal, NewType, Optional, TypedDict, Union
2
 
3
  from .type_utils import register_type
4
 
@@ -51,6 +51,18 @@ class SQLDatabase(TypedDict):
51
  dbms: Optional[str]
52
  data: Optional[Dict[str, Dict]]
53
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  register_type(Text)
56
  register_type(Number)
@@ -64,3 +76,7 @@ register_type(Document)
64
  register_type(MultiDocument)
65
  register_type(RagResponse)
66
  register_type(SQLDatabase)
 
 
 
 
 
1
+ from typing import Any, Dict, List, Literal, NewType, Optional, Type, TypedDict, Union
2
 
3
  from .type_utils import register_type
4
 
 
51
  dbms: Optional[str]
52
  data: Optional[Dict[str, Dict]]
53
 
54
+ class Parameter(TypedDict):
55
+ name: str
56
+ type: Optional[Type] # Using actual Python type objects
57
+
58
+ class Tool(TypedDict):
59
+ name: str
60
+ description: str
61
+ parameters: List[Parameter]
62
+
63
+ class ToolCall(TypedDict):
64
+ name: str
65
+ arguments: Dict[str, Any]
66
 
67
  register_type(Text)
68
  register_type(Number)
 
76
  register_type(MultiDocument)
77
  register_type(RagResponse)
78
  register_type(SQLDatabase)
79
+ register_type(Parameter)
80
+ register_type(Tool)
81
+ register_type(ToolCall)
82
+
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.22.3"
 
1
+ version = "1.22.4"