Spaces:
Sleeping
Sleeping
import ast | |
import builtins | |
import dis | |
import enum | |
import inspect | |
import re | |
import typing | |
import warnings | |
from textwrap import dedent | |
from typing import Type | |
import torch | |
from torch._C import ( | |
_GeneratorType, | |
AnyType, | |
AwaitType, | |
BoolType, | |
ComplexType, | |
DeviceObjType, | |
DictType, | |
EnumType, | |
FloatType, | |
FutureType, | |
InterfaceType, | |
IntType, | |
ListType, | |
NoneType, | |
NumberType, | |
OptionalType, | |
StreamObjType, | |
StringType, | |
TensorType, | |
TupleType, | |
UnionType, | |
) | |
from torch._sources import get_source_lines_and_file | |
from .._jit_internal import ( # type: ignore[attr-defined] | |
_Await, | |
_qualified_name, | |
Any, | |
BroadcastingList1, | |
BroadcastingList2, | |
BroadcastingList3, | |
Dict, | |
Future, | |
is_await, | |
is_dict, | |
is_future, | |
is_ignored_fn, | |
is_list, | |
is_optional, | |
is_tuple, | |
is_union, | |
List, | |
Optional, | |
Tuple, | |
Union, | |
) | |
from ._state import _get_script_class | |
if torch.distributed.rpc.is_available(): | |
from torch._C import RRefType | |
from .._jit_internal import is_rref, RRef | |
from torch._ops import OpOverloadPacket | |
class Module: | |
def __init__(self, name, members): | |
self.name = name | |
self.members = members | |
def __getattr__(self, name): | |
try: | |
return self.members[name] | |
except KeyError: | |
raise RuntimeError( | |
f"Module {self.name} has no member called {name}" | |
) from None | |
class EvalEnv: | |
env = { | |
"torch": Module("torch", {"Tensor": torch.Tensor}), | |
"Tensor": torch.Tensor, | |
"typing": Module("typing", {"Tuple": Tuple}), | |
"Tuple": Tuple, | |
"List": List, | |
"Dict": Dict, | |
"Optional": Optional, | |
"Union": Union, | |
"Future": Future, | |
"Await": _Await, | |
} | |
def __init__(self, rcb): | |
self.rcb = rcb | |
if torch.distributed.rpc.is_available(): | |
self.env["RRef"] = RRef | |
def __getitem__(self, name): | |
if name in self.env: | |
return self.env[name] | |
if self.rcb is not None: | |
return self.rcb(name) | |
return getattr(builtins, name, None) | |
def get_signature(fn, rcb, loc, is_method): | |
if isinstance(fn, OpOverloadPacket): | |
signature = try_real_annotations(fn.op, loc) | |
else: | |
signature = try_real_annotations(fn, loc) | |
if signature is not None and is_method: | |
# If this is a method, then the signature will include a type for | |
# `self`, but type comments do not contain a `self`. So strip it | |
# away here so everything is consistent (`inspect.ismethod` does | |
# not work here since `fn` is unbound at this point) | |
param_types, return_type = signature | |
param_types = param_types[1:] | |
signature = (param_types, return_type) | |
if signature is None: | |
type_line, source = None, None | |
try: | |
source = dedent("".join(get_source_lines_and_file(fn)[0])) | |
type_line = get_type_line(source) | |
except TypeError: | |
pass | |
# This might happen both because we failed to get the source of fn, or | |
# because it didn't have any annotations. | |
if type_line is not None: | |
signature = parse_type_line(type_line, rcb, loc) | |
return signature | |
def is_function_or_method(the_callable): | |
# A stricter version of `inspect.isroutine` that does not pass for built-in | |
# functions | |
return inspect.isfunction(the_callable) or inspect.ismethod(the_callable) | |
def is_vararg(the_callable): | |
if not is_function_or_method(the_callable) and callable(the_callable): # noqa: B004 | |
# If `the_callable` is a class, de-sugar the call so we can still get | |
# the signature | |
the_callable = the_callable.__call__ | |
if is_function_or_method(the_callable): | |
return inspect.getfullargspec(the_callable).varargs is not None | |
else: | |
return False | |
def get_param_names(fn, n_args): | |
if isinstance(fn, OpOverloadPacket): | |
fn = fn.op | |
if ( | |
not is_function_or_method(fn) | |
and callable(fn) | |
and is_function_or_method(fn.__call__) | |
): # noqa: B004 | |
# De-sugar calls to classes | |
fn = fn.__call__ | |
if is_function_or_method(fn): | |
if is_ignored_fn(fn): | |
fn = inspect.unwrap(fn) | |
return inspect.getfullargspec(fn).args | |
else: | |
# The `fn` was not a method or function (maybe a class with a __call__ | |
# method, so use a default param name list) | |
return [str(i) for i in range(n_args)] | |
def check_fn(fn, loc): | |
# Make sure the function definition is not a class instantiation | |
try: | |
source = dedent("".join(get_source_lines_and_file(fn)[0])) | |
except (OSError, TypeError): | |
return | |
if source is None: | |
return | |
py_ast = ast.parse(source) | |
if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef): | |
raise torch.jit.frontend.FrontendError( | |
loc, | |
f"Cannot instantiate class '{py_ast.body[0].name}' in a script function", | |
) | |
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): | |
raise torch.jit.frontend.FrontendError( | |
loc, "Expected a single top-level function" | |
) | |
def _eval_no_call(stmt, glob, loc): | |
"""Evaluate statement as long as it does not contain any method/function calls.""" | |
bytecode = compile(stmt, "", mode="eval") | |
for insn in dis.get_instructions(bytecode): | |
if "CALL" in insn.opname: | |
raise RuntimeError( | |
f"Type annotation should not contain calls, but '{stmt}' does" | |
) | |
return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204 | |
def parse_type_line(type_line, rcb, loc): | |
"""Parse a type annotation specified as a comment. | |
Example inputs: | |
# type: (Tensor, torch.Tensor) -> Tuple[Tensor] | |
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor | |
""" | |
arg_ann_str, ret_ann_str = split_type_line(type_line) | |
try: | |
arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb)) | |
except (NameError, SyntaxError) as e: | |
raise RuntimeError( | |
"Failed to parse the argument list of a type annotation" | |
) from e | |
if not isinstance(arg_ann, tuple): | |
arg_ann = (arg_ann,) | |
try: | |
ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb)) | |
except (NameError, SyntaxError) as e: | |
raise RuntimeError( | |
"Failed to parse the return type of a type annotation" | |
) from e | |
arg_types = [ann_to_type(ann, loc) for ann in arg_ann] | |
return arg_types, ann_to_type(ret_ann, loc) | |
def get_type_line(source): | |
"""Try to find the line containing a comment with the type annotation.""" | |
type_comment = "# type:" | |
lines = source.split("\n") | |
lines = list(enumerate(lines)) | |
type_lines = list(filter(lambda line: type_comment in line[1], lines)) | |
# `type: ignore` comments may be needed in JIT'ed functions for mypy, due | |
# to the hack in torch/_VF.py. | |
# An ignore type comment can be of following format: | |
# 1) type: ignore | |
# 2) type: ignore[rule-code] | |
# This ignore statement must be at the end of the line | |
# adding an extra backslash before the space, to avoid triggering | |
# one of the checks in .github/workflows/lint.yml | |
type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$") | |
type_lines = list(filter(lambda line: not type_pattern.search(line[1]), type_lines)) | |
if len(type_lines) == 0: | |
# Catch common typo patterns like extra spaces, typo in 'ignore', etc. | |
wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):") | |
wrong_type_lines = list( | |
filter(lambda line: wrong_type_pattern.search(line[1]), lines) | |
) | |
if len(wrong_type_lines) > 0: | |
raise RuntimeError( | |
"The annotation prefix in line " | |
+ str(wrong_type_lines[0][0]) | |
+ " is probably invalid.\nIt must be '# type:'" | |
+ "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950 | |
+ "\nfor examples" | |
) | |
return None | |
elif len(type_lines) == 1: | |
# Only 1 type line, quit now | |
return type_lines[0][1].strip() | |
# Parse split up argument types according to PEP 484 | |
# https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code | |
return_line = None | |
parameter_type_lines = [] | |
for line_num, line in type_lines: | |
if "# type: (...) -> " in line: | |
return_line = (line_num, line) | |
break | |
elif type_comment in line: | |
parameter_type_lines.append(line) | |
if return_line is None: | |
raise RuntimeError( | |
"Return type line '# type: (...) -> ...' not found on multiline " | |
"type annotation\nfor type lines:\n" | |
+ "\n".join([line[1] for line in type_lines]) | |
+ "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" | |
) | |
def get_parameter_type(line): | |
item_type = line[line.find(type_comment) + len(type_comment) :] | |
return item_type.strip() | |
types = map(get_parameter_type, parameter_type_lines) | |
parameter_types = ", ".join(types) | |
return return_line[1].replace("...", parameter_types) | |
def split_type_line(type_line): | |
"""Split the comment with the type annotation into parts for argument and return types. | |
For example, for an input of: | |
# type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor] | |
This function will return: | |
("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]") | |
""" | |
start_offset = len("# type:") | |
try: | |
arrow_pos = type_line.index("->") | |
except ValueError: | |
raise RuntimeError( | |
"Syntax error in type annotation (cound't find `->`)" | |
) from None | |
return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2 :].strip() | |
def try_real_annotations(fn, loc): | |
"""Try to use the Py3.5+ annotation syntax to get the type.""" | |
try: | |
# Note: anything annotated as `Optional[T]` will automatically | |
# be returned as `Union[T, None]` per | |
# https://github.com/python/typing/blob/master/src/typing.py#L850 | |
sig = inspect.signature(fn) | |
except ValueError: | |
return None | |
all_annots = [sig.return_annotation] + [ | |
p.annotation for p in sig.parameters.values() | |
] | |
if all(ann is sig.empty for ann in all_annots): | |
return None | |
arg_types = [ann_to_type(p.annotation, loc) for p in sig.parameters.values()] | |
return_type = ann_to_type(sig.return_annotation, loc) | |
return arg_types, return_type | |
# Finds common type for enum values belonging to an Enum class. If not all | |
# values have the same type, AnyType is returned. | |
def get_enum_value_type(e: Type[enum.Enum], loc): | |
enum_values: List[enum.Enum] = list(e) | |
if not enum_values: | |
raise ValueError(f"No enum values defined for: '{e.__class__}'") | |
types = {type(v.value) for v in enum_values} | |
ir_types = [try_ann_to_type(t, loc) for t in types] | |
# If Enum values are of different types, an exception will be raised here. | |
# Even though Python supports this case, we chose to not implement it to | |
# avoid overcomplicate logic here for a rare use case. Please report a | |
# feature request if you find it necessary. | |
res = torch._C.unify_type_list(ir_types) | |
if not res: | |
return AnyType.get() | |
return res | |
def is_tensor(ann): | |
if issubclass(ann, torch.Tensor): | |
return True | |
if issubclass( | |
ann, | |
( | |
torch.LongTensor, | |
torch.DoubleTensor, | |
torch.FloatTensor, | |
torch.IntTensor, | |
torch.ShortTensor, | |
torch.HalfTensor, | |
torch.CharTensor, | |
torch.ByteTensor, | |
torch.BoolTensor, | |
), | |
): | |
warnings.warn( | |
"TorchScript will treat type annotations of Tensor " | |
"dtype-specific subtypes as if they are normal Tensors. " | |
"dtype constraints are not enforced in compilation either." | |
) | |
return True | |
return False | |
def _fake_rcb(inp): | |
return None | |
def try_ann_to_type(ann, loc, rcb=None): | |
ann_args = typing.get_args(ann) # always returns a tuple! | |
if ann is inspect.Signature.empty: | |
return TensorType.getInferred() | |
if ann is None: | |
return NoneType.get() | |
if inspect.isclass(ann) and is_tensor(ann): | |
return TensorType.get() | |
if is_tuple(ann): | |
# Special case for the empty Tuple type annotation `Tuple[()]` | |
if len(ann_args) == 1 and ann_args[0] == (): | |
return TupleType([]) | |
return TupleType([try_ann_to_type(a, loc) for a in ann_args]) | |
if is_list(ann): | |
elem_type = try_ann_to_type(ann_args[0], loc) | |
if elem_type: | |
return ListType(elem_type) | |
if is_dict(ann): | |
key = try_ann_to_type(ann_args[0], loc) | |
value = try_ann_to_type(ann_args[1], loc) | |
# Raise error if key or value is None | |
if key is None: | |
raise ValueError( | |
f"Unknown type annotation: '{ann_args[0]}' at {loc.highlight()}" | |
) | |
if value is None: | |
raise ValueError( | |
f"Unknown type annotation: '{ann_args[1]}' at {loc.highlight()}" | |
) | |
return DictType(key, value) | |
if is_optional(ann): | |
if issubclass(ann_args[1], type(None)): | |
contained = ann_args[0] | |
else: | |
contained = ann_args[1] | |
valid_type = try_ann_to_type(contained, loc) | |
msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}" | |
assert valid_type, msg.format(repr(ann), repr(contained), repr(loc)) | |
return OptionalType(valid_type) | |
if is_union(ann): | |
# TODO: this is hack to recognize NumberType | |
if set(ann_args) == {int, float, complex}: | |
return NumberType.get() | |
inner: List = [] | |
# We need these extra checks because both `None` and invalid | |
# values will return `None` | |
# TODO: Determine if the other cases need to be fixed as well | |
for a in typing.get_args(ann): | |
if a is None: | |
inner.append(NoneType.get()) | |
maybe_type = try_ann_to_type(a, loc) | |
msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}" | |
assert maybe_type, msg.format(repr(ann), repr(maybe_type), repr(loc)) | |
inner.append(maybe_type) | |
return UnionType(inner) # type: ignore[arg-type] | |
if torch.distributed.rpc.is_available() and is_rref(ann): | |
return RRefType(try_ann_to_type(ann_args[0], loc)) | |
if is_future(ann): | |
return FutureType(try_ann_to_type(ann_args[0], loc)) | |
if is_await(ann): | |
elementType = try_ann_to_type(ann_args[0], loc) if ann_args else AnyType.get() | |
return AwaitType(elementType) | |
if ann is float: | |
return FloatType.get() | |
if ann is complex: | |
return ComplexType.get() | |
if ann is int or ann is torch.SymInt: | |
return IntType.get() | |
if ann is str: | |
return StringType.get() | |
if ann is bool: | |
return BoolType.get() | |
if ann is Any: | |
return AnyType.get() | |
if ann is type(None): | |
return NoneType.get() | |
if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): | |
return InterfaceType(ann.__torch_script_interface__) | |
if ann is torch.device: | |
return DeviceObjType.get() | |
if ann is torch.Generator: | |
return _GeneratorType.get() | |
if ann is torch.Stream: | |
return StreamObjType.get() | |
if ann is torch.dtype: | |
return IntType.get() # dtype not yet bound in as its own type | |
if inspect.isclass(ann) and issubclass(ann, enum.Enum): | |
if _get_script_class(ann) is None: | |
scripted_class = torch.jit._script._recursive_compile_class(ann, loc) | |
name = scripted_class.qualified_name() | |
else: | |
name = _qualified_name(ann) | |
return EnumType(name, get_enum_value_type(ann, loc), list(ann)) | |
if inspect.isclass(ann): | |
maybe_script_class = _get_script_class(ann) | |
if maybe_script_class is not None: | |
return maybe_script_class | |
if torch._jit_internal.can_compile_class(ann): | |
return torch.jit._script._recursive_compile_class(ann, loc) | |
# Maybe resolve a NamedTuple to a Tuple Type | |
if rcb is None: | |
rcb = _fake_rcb | |
return torch._C._resolve_type_from_object(ann, loc, rcb) | |
def ann_to_type(ann, loc, rcb=None): | |
the_type = try_ann_to_type(ann, loc, rcb) | |
if the_type is not None: | |
return the_type | |
raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}") | |
__all__ = [ | |
"Any", | |
"List", | |
"BroadcastingList1", | |
"BroadcastingList2", | |
"BroadcastingList3", | |
"Tuple", | |
"is_tuple", | |
"is_list", | |
"Dict", | |
"is_dict", | |
"is_optional", | |
"is_union", | |
"TensorType", | |
"TupleType", | |
"FloatType", | |
"ComplexType", | |
"IntType", | |
"ListType", | |
"StringType", | |
"DictType", | |
"AnyType", | |
"Module", | |
# TODO: Consider not exporting these during wildcard import (reserve | |
# that for the types; for idiomatic typing code.) | |
"get_signature", | |
"check_fn", | |
"get_param_names", | |
"parse_type_line", | |
"get_type_line", | |
"split_type_line", | |
"try_real_annotations", | |
"try_ann_to_type", | |
"ann_to_type", | |
] | |