Elron commited on
Commit
371536d
·
verified ·
1 Parent(s): db06ad0

Upload templates.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. templates.py +15 -7
templates.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
  from abc import abstractmethod
3
- from typing import Any, Dict, List, Optional, Tuple
4
 
5
  from .collections import ListCollection
6
  from .dataclass import NonPositionalField
@@ -393,14 +393,22 @@ class KeyValTemplate(Template):
393
 
394
 
395
  class OutputQuantizingTemplate(InputOutputTemplate):
396
- quantum: float = 0.1
397
 
398
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
399
- quantum_str = f"{self.quantum:.10f}".rstrip("0").rstrip(".")
400
- quantized_outputs = {
401
- key: f"{round(value / self.quantum) * self.quantum:{quantum_str}}"
402
- for key, value in outputs.items()
403
- }
 
 
 
 
 
 
 
 
404
  return super().outputs_to_target_and_references(quantized_outputs)
405
 
406
 
 
1
  import json
2
  from abc import abstractmethod
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
 
5
  from .collections import ListCollection
6
  from .dataclass import NonPositionalField
 
393
 
394
 
395
  class OutputQuantizingTemplate(InputOutputTemplate):
396
+ quantum: Union[float, int] = 0.1 # Now supports both int and float
397
 
398
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
399
+ if isinstance(self.quantum, int):
400
+ # When quantum is an int, format quantized values as ints
401
+ quantized_outputs = {
402
+ key: f"{int(round(value / self.quantum) * self.quantum)}"
403
+ for key, value in outputs.items()
404
+ }
405
+ else:
406
+ # When quantum is a float, format quantized values with precision based on quantum
407
+ quantum_str = f"{self.quantum:.10f}".rstrip("0").rstrip(".")
408
+ quantized_outputs = {
409
+ key: f"{round(value / self.quantum) * self.quantum:{quantum_str}}"
410
+ for key, value in outputs.items()
411
+ }
412
  return super().outputs_to_target_and_references(quantized_outputs)
413
 
414