Spaces:
Running
Running
File size: 31,867 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 |
# mypy: ignore-errors
import collections
import dataclasses
import functools
import inspect
import sys
from typing import Dict, List, Optional
from torch._subclasses.fake_tensor import is_fake
from .. import variables
from ..bytecode_transformation import (
create_call_function,
create_call_method,
create_instruction,
)
from ..eval_frame import skip_code
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GetItemSource
from ..utils import dict_keys, dict_values, istype, specialize_symnode
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
# [Adding a new supported class within the keys of ConstDictVarialble]
# - Add its tracker type to is_hashable
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
def is_hashable(x):
if isinstance(x, variables.TensorVariable):
# Tensors are hashable if they have an example_value (a fake tensor)
# Most VT's should have one.
# It'd be nice if at some point we could assert that they all have one
return x.as_proxy().node.meta.get("example_value") is not None
elif isinstance(x, variables.TupleVariable):
return all(is_hashable(e) for e in x.items)
else:
return isinstance(
x,
(
variables.BuiltinVariable,
variables.SymNodeVariable,
variables.ConstantVariable,
variables.EnumVariable,
variables.user_defined.UserDefinedClassVariable,
variables.UserFunctionVariable,
variables.SkipFunctionVariable,
variables.misc.NumpyVariable,
variables.NNModuleVariable,
variables.MethodWrapperVariable,
variables.TorchInGraphFunctionVariable,
variables.TypingVariable,
variables.FunctoolsPartialVariable,
),
)
class ConstDictVariable(VariableTracker):
class _HashableTracker:
"""
Auxiliary opaque internal class that wraps a VariableTracker and makes it hashable
This should not be seen or touched by anything outside of ConstDictVariable and its children
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
"""
def __init__(self, vt):
# We specialize SymNodes
vt = specialize_symnode(vt)
# TODO Temorarily remove to figure out what keys are we breaking on
# and add proper support for them
if not is_hashable(vt):
unimplemented(f"Dict key of type {type(vt)}. Key: {vt}")
self.vt = vt
@property
def underlying_value(self):
if isinstance(self.vt, variables.TensorVariable):
x = self.vt.as_proxy().node.meta["example_value"]
elif isinstance(self.vt, variables.TupleVariable):
Hashable = ConstDictVariable._HashableTracker
x = tuple(Hashable(e).underlying_value for e in self.vt.items)
elif isinstance(self.vt, variables.NNModuleVariable):
return self.vt.module
elif isinstance(self.vt, variables.UserFunctionVariable):
return self.vt.get_function()
else:
x = self.vt.as_python_constant()
return x
def __hash__(self):
return hash(self.underlying_value)
@staticmethod
def _eq_impl(a, b):
# TODO: Put this in utils and share it between variables/builtin.py and here
if type(a) != type(b):
return False
elif isinstance(a, tuple):
Hashable = ConstDictVariable._HashableTracker
return len(a) == len(b) and all(
Hashable._eq_impl(u, v) for u, v in zip(a, b)
)
elif is_fake(a):
return a is b
else:
return a == b
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
Hashable = ConstDictVariable._HashableTracker
assert isinstance(other, Hashable) or ConstantVariable.is_literal(
other
), type(other)
if isinstance(other, Hashable):
return Hashable._eq_impl(self.underlying_value, other.underlying_value)
# constant
return Hashable._eq_impl(self.underlying_value, other)
def __init__(
self, items: Dict[VariableTracker, VariableTracker], user_cls=dict, **kwargs
):
super().__init__(**kwargs)
Hashable = ConstDictVariable._HashableTracker
# Keys will just be HashableTrackers when cloning, in any other case they'll be VariableTrackers
assert all(
isinstance(x, (VariableTracker, Hashable))
and isinstance(v, VariableTracker)
for x, v in items.items()
)
def make_hashable(key):
return key if isinstance(key, Hashable) else Hashable(key)
self.items = {make_hashable(x): v for x, v in items.items()}
self.user_cls = user_cls
def as_proxy(self):
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
def as_python_constant(self):
return {
k.vt.as_python_constant(): v.as_python_constant()
for k, v in self.items.items()
}
def keys_as_python_constant(self):
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
def python_type(self):
return self.user_cls
def __contains__(self, vt):
assert isinstance(vt, VariableTracker)
Hashable = ConstDictVariable._HashableTracker
return is_hashable(vt) and Hashable(vt) in self.items
def reconstruct(self, codegen):
# instructions to load collections.OrderedDict if necessary
if self.user_cls is collections.OrderedDict:
codegen.extend_output(
[
codegen.create_load_python_module(collections, True),
codegen.create_load_attr("OrderedDict"),
]
)
# instructions to build the dict keys and values
for key, value in self.items.items():
codegen(key.vt)
codegen(value)
# BUILD_MAP and calling collections.OrderedDict if necessary
if self.user_cls is collections.OrderedDict:
codegen.extend_output(
[
create_instruction("BUILD_MAP", arg=len(self.items)),
*create_call_function(1, False),
]
)
# BUILD_MAP only if user_cls is dict
else:
codegen.append_output(create_instruction("BUILD_MAP", arg=len(self.items)))
def getitem_const(self, arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
raise KeyError(arg.value)
return self.items[key]
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from . import (
BuiltinVariable,
ConstantVariable,
ListIteratorVariable,
ListVariable,
TupleVariable,
)
Hashable = ConstDictVariable._HashableTracker
arg_hashable = args and is_hashable(args[0])
if name == "__getitem__":
assert len(args) == 1
return self.getitem_const(args[0])
elif name == "items":
assert not (args or kwargs)
return TupleVariable(
[TupleVariable([k.vt, v]) for k, v in self.items.items()]
)
elif name == "keys":
assert not (args or kwargs)
return DictKeys(self)
elif name == "values":
assert not (args or kwargs)
return DictValues(self)
elif name == "copy":
assert not (args or kwargs)
return self.clone(items=self.items.copy(), mutable_local=MutableLocal())
elif name == "__len__":
assert not (args or kwargs)
return ConstantVariable.create(len(self.items))
elif name == "__setitem__" and arg_hashable and self.mutable_local:
assert not kwargs and len(args) == 2
tx.output.side_effects.mutation(self)
self.items[Hashable(args[0])] = args[1]
return ConstantVariable.create(None)
elif name in ("pop", "get") and len(args) in (1, 2) and args[0] not in self:
# missing item, return the default value
if len(args) == 1:
return ConstantVariable(None)
else:
return args[1]
elif name == "pop" and arg_hashable and self.mutable_local:
tx.output.side_effects.mutation(self)
return self.items.pop(Hashable(args[0]))
elif name == "clear":
tx.output.side_effects.mutation(self)
self.items.clear()
return ConstantVariable.create(None)
elif (
name == "update"
and len(args) == 1
and isinstance(
args[0],
(
ConstDictVariable,
ListVariable,
TupleVariable,
ListIteratorVariable,
),
)
and self.mutable_local
):
tx.output.side_effects.mutation(self)
if isinstance(args[0], ConstDictVariable):
dict_vt = args[0]
else:
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
self.items.update(dict_vt.items)
# Wrap strings
kwargs = {
Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items()
}
self.items.update(kwargs)
return ConstantVariable.create(None)
elif name in ("get", "__getattr__") and args[0] in self:
return self.getitem_const(args[0])
elif name == "__contains__" and len(args) == 1:
return ConstantVariable.create(args[0] in self)
else:
return super().call_method(tx, name, args, kwargs)
def unpack_var_sequence(self, tx):
return [x.vt for x in self.items.keys()]
class DefaultDictVariable(ConstDictVariable):
def __init__(self, items, user_cls, default_factory=None, **kwargs):
super().__init__(items, user_cls, **kwargs)
assert user_cls is collections.defaultdict
self.default_factory = default_factory
def is_python_constant(self):
# Return false for unsupported defaults. This ensures that a bad handler
# path is not taken in BuiltinVariable for getitem.
if self.default_factory not in [list, tuple, dict] and not self.items:
return False
return super().is_python_constant()
@staticmethod
def is_supported_arg(arg):
if isinstance(arg, variables.BuiltinVariable):
return arg.fn in [list, tuple, dict]
else:
return isinstance(arg, variables.functions.BaseUserFunctionVariable)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
assert len(args) == 1
if args[0] in self:
return self.getitem_const(args[0])
else:
if self.default_factory is None:
raise KeyError(f"{args[0]}")
else:
default_var = self.default_factory.call_function(tx, [], {})
super().call_method(
tx, "__setitem__", (args[0], default_var), kwargs
)
return default_var
else:
return super().call_method(tx, name, args, kwargs)
class SetVariable(ConstDictVariable):
"""We model a sets as dictonary with None values"""
def __init__(
self,
items: List[VariableTracker],
**kwargs,
):
items = dict.fromkeys(items, SetVariable._default_value())
super().__init__(items, **kwargs)
@property
def set_items(self):
return set(self.items.keys())
@staticmethod
def _default_value():
# Variable to fill in he keys of the dictinary
return ConstantVariable.create(None)
def as_proxy(self):
return {k.vt.as_proxy() for k in self.set_items}
def python_type(self):
return set
def as_python_constant(self):
return {k.vt.as_python_constant() for k in self.set_items}
def reconstruct(self, codegen):
codegen.foreach([x.vt for x in self.set_items])
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
def call_method(
self,
tx,
name,
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
) -> "VariableTracker":
# We foward the calls to the dictionary model
if name == "add":
assert not kwargs
assert len(args) == 1
name = "__setitem__"
args = (args[0], SetVariable._default_value())
elif name == "pop":
assert not kwargs
assert not args
# Choose an item at random and pop it via the Dict.pop method
result = self.set_items.pop().vt
super().call_method(tx, name, (result,), kwargs)
return result
return super().call_method(tx, name, args, kwargs)
def getitem_const(self, arg: VariableTracker):
raise RuntimeError("Illegal to getitem on a set")
class DictView(VariableTracker):
"""
Models _PyDictViewObject
This is an "abstract" class. Subclasses will override kv and the items method
"""
kv: Optional[str] = None
def __init__(self, dv_dict: ConstDictVariable, **kwargs):
super().__init__(**kwargs)
assert self.kv in ("keys", "values")
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
@property
def view_items(self):
return getattr(self.dv_dict.items, self.kv)()
@property
def view_items_vt(self):
# Returns an iterable of the unpacked items
# Implement in the subclasses
raise NotImplementedError()
def unpack_var_sequence(self, tx):
def unwrap(x):
return x.vt if self.kv == "keys" else x
return [unwrap(x) for x in self.view_items]
def reconstruct(self, codegen):
codegen(self.dv_dict)
codegen.extend_output(
[
create_instruction("LOAD_METHOD", argval=self.kv),
*create_call_method(0),
]
)
def call_method(
self,
tx,
name,
args: List["VariableTracker"],
kwargs: Dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__len__":
return self.dv_dict.call_method(tx, name, args, kwargs)
return super().call_method(tx, name, args, kwargs)
class DictKeys(DictView):
kv = "keys"
@property
def set_items(self):
return set(self.view_items)
@property
def view_items_vt(self):
# Returns an iterable of the unpacked items
return [x.vt for x in self.view_items]
def python_type(self):
return dict_keys
def call_method(
self,
tx,
name,
args: List["VariableTracker"],
kwargs: Dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__contains__":
return self.dv_dict.call_method(tx, name, args, kwargs)
return super().call_method(tx, name, args, kwargs)
class DictValues(DictView):
# DictValues is an iterable but cannot be compared.
kv = "values"
@property
def view_items_vt(self):
return list(self.view_items)
def python_type(self):
return dict_values
def _is_matching_transformers_cls(cls) -> bool:
mod = sys.modules.get("transformers.file_utils")
return mod is not None and issubclass(cls, mod.ModelOutput)
def _is_matching_diffusers_cls(cls) -> bool:
mod = sys.modules.get("diffusers.utils")
return mod is not None and issubclass(cls, mod.BaseOutput)
def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker":
"""Shared method between DataClassVariable and CustomizedDictVariable where items are attrs"""
if name in self.items or hasattr(self.user_cls, name):
return ConstantVariable(True)
elif istype(self.mutable_local, MutableLocal) and self.source is None:
# Something created locally can't have any extra fields on it
return ConstantVariable(False)
elif self.mutable_local is None and self.source:
# Maybe add a guard
try:
example = tx.output.root_tx.get_example_value(self.source)
install_guard(
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
)
return ConstantVariable(hasattr(example, name))
except KeyError:
pass
unimplemented(
f"hasattr({self.__class__.__name__}, {name}) {self.mutable_local} {self.source}"
)
class DataClassVariable(ConstDictVariable):
"""
This is a bit of a hack to deal with
transformers.file_utils.ModelOutput() from huggingface.
ModelOutput causes trouble because it a a mix of a dataclass and a
OrderedDict and it calls super() methods implemented in C.
"""
# ModelOutput() excludes None, though generic datclasses don't
include_none = False
@staticmethod
@functools.lru_cache(None)
def _patch_once():
try:
from transformers.file_utils import ModelOutput
for obj in ModelOutput.__dict__.values():
if callable(obj):
skip_code(obj.__code__)
except ImportError:
pass
try:
from diffusers.utils import BaseOutput
for obj in BaseOutput.__dict__.values():
if callable(obj):
skip_code(obj.__code__)
except ImportError:
pass
@staticmethod
def is_matching_cls(cls):
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)
@classmethod
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))
@classmethod
def create(cls, user_cls, args, kwargs, options):
DataClassVariable._patch_once()
skip_code(user_cls.__init__.__code__)
keys = [f.name for f in dataclasses.fields(user_cls)]
bound = inspect.signature(user_cls).bind(*args, **kwargs)
bound.apply_defaults()
assert set(bound.arguments.keys()) == set(keys)
items = {}
for key in keys:
val = bound.arguments[key]
key = ConstantVariable.create(key)
if isinstance(val, VariableTracker):
items[key] = val
else:
if cls.include_none:
assert variables.ConstantVariable.is_literal(val)
items[key] = variables.ConstantVariable.create(val)
else:
assert val is None, f"unexpected {val}"
if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable):
unimplemented("DataClassVariable iterator constructor")
# TODO(jansel): implement unpacking logic in ModelOutput.__post_init__
return cls(items, user_cls, **options)
@classmethod
def wrap(cls, builder, obj):
user_cls = type(obj)
keys = [f.name for f in dataclasses.fields(user_cls)]
excluded = []
items = {}
for key in keys:
# __init__ function of a dataclass might not have yet defined the key
if hasattr(obj, key):
val = getattr(obj, key)
var = builder.__class__(
tx=builder.tx, source=AttrSource(builder.source, key)
)(val)
if val is not None or cls.include_none:
key = ConstantVariable.create(key)
items[key] = var
else:
excluded.append(var)
return cls(items, user_cls)
def __init__(self, items, user_cls, **options):
super().__init__(items, user_cls, **options)
assert self.is_matching_cls(user_cls)
def as_proxy(self):
raise NotImplementedError()
def reconstruct(self, codegen):
codegen.extend_output([codegen._create_load_const(self.user_cls)])
# All the keys are just wrapped strings
d = self.keys_as_python_constant()
codegen.foreach(d.values())
keys = tuple(d.keys())
codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, True))
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
assert not kwargs and len(args) == 1
val = args[0]
if val.python_type() == str:
return self.getitem_const(val)
else:
return self.call_method(tx, "to_tuple", [], {}).call_method(
tx, "__getitem__", args, kwargs
)
elif name == "to_tuple":
assert not (args or kwargs)
return variables.TupleVariable(list(self.items.values()))
elif name == "__setattr__":
name = "__setitem__"
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name: str) -> "VariableTracker":
name_vt = ConstantVariable.create(name)
if name_vt in self:
return self.call_method(tx, "__getitem__", [name_vt], {})
elif not self.include_none:
defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
if name in defaults:
assert variables.ConstantVariable.is_literal(defaults[name])
return variables.ConstantVariable.create(defaults[name])
super().var_getattr(tx, name)
call_hasattr = _call_hasattr_customobj
class CustomizedDictVariable(ConstDictVariable):
@staticmethod
def is_matching_cls(cls):
# True if using default OrderedDict.__init__ and did not implement __post_init__
if (
issubclass(cls, collections.OrderedDict)
and cls.__init__ is collections.OrderedDict.__init__
and not hasattr(cls, "__post_init__")
):
return True
# hack for HF usecase:
# assume dataclass annotation for ModelOutput subclass
# assume self.create is AA to ModelOutput.__post_init__
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)
@classmethod
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))
# called from user_defined.py
# when is_matching_cls(cls) is true
@classmethod
def create(cls, user_cls, args, kwargs, options):
# avoid tracing when returning ModelOutput from forward func
for attr_name in ("__init__", "__post_init__", "__setattr__", "__setitem__"):
if hasattr(user_cls, attr_name):
fn = getattr(user_cls, attr_name)
assert callable(fn), f"expect callable attr {attr_name}"
if hasattr(fn, "__code__"):
skip_code(fn.__code__)
if dataclasses.is_dataclass(user_cls):
# @dataclass CustomDict(a=1, b=2)
bound = inspect.signature(user_cls).bind(*args, **kwargs)
bound.apply_defaults()
def make_var(x):
if isinstance(x, VariableTracker):
return x
elif ConstantVariable.is_literal(x):
return ConstantVariable.create(x)
else:
unimplemented(
"expect VariableTracker or ConstantVariable.is_literal"
)
items = {
ConstantVariable.create(k): make_var(v)
for k, v in bound.arguments.items()
}
elif not args:
# CustomDict(a=1, b=2) in the general (non-dataclass) case.
items = {ConstantVariable.create(k): v for k, v in kwargs.items()}
elif len(args) == 1 and isinstance(args[0], ConstDictVariable) and not kwargs:
# CustomDict({'a': 1, 'b': 2})
items = args[0].items
else:
unimplemented("custom dict init with args/kwargs unimplemented")
return cls(items, user_cls, **options)
# called from builder.py
@classmethod
def wrap(cls, builder, obj):
raise NotImplementedError()
def __init__(self, items, user_cls, **options):
super().__init__(items, user_cls, **options)
assert self.is_matching_cls(user_cls)
def as_proxy(self):
raise NotImplementedError()
# 'RETURN_VALUE triggered compile'
# called from torch/_dynamo/codegen.py
def reconstruct(self, codegen):
codegen.extend_output([codegen._create_load_const(self.user_cls)])
# All the keys are just wrapped strings
d = self.keys_as_python_constant()
codegen.foreach(d.values())
keys = tuple(d.keys())
codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, True))
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
fn = getattr(self.user_cls, name)
source = None if self.source is None else AttrSource(self.source, name)
if hasattr(fn, "__objclass__") and fn.__objclass__ in (
dict,
collections.OrderedDict,
):
# for python dict method without overridden
return super().call_method(tx, name, args, kwargs)
elif name in ("__getitem__", "to_tuple", "__setitem__", "__setattr__"):
# for user overridden method
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, source=source),
[self] + list(args),
kwargs,
)
unimplemented("custom dict: call_method unimplemented name=%s", name)
def var_getattr(self, tx, name: str) -> "VariableTracker":
name_vt = ConstantVariable.create(name)
if name_vt in self:
return self.call_method(tx, "__getitem__", [name_vt], {})
super().var_getattr(tx, name)
call_hasattr = _call_hasattr_customobj
@functools.lru_cache(None)
def _install_PretrainedConfig_patch():
import transformers
# We need to monkeypatch transformers here, sadly.
# TODO(voz): Upstream to transformers lib
def _dynamo_overriden_transformers_eq(self, other):
if not hasattr(other, "__dict__"):
return False
return self.__dict__ == other.__dict__
transformers.configuration_utils.PretrainedConfig.__eq__ = (
_dynamo_overriden_transformers_eq
)
class HFPretrainedConfigVariable(VariableTracker):
"""
Hack for HuggingFace PretrainedConfig
"""
@staticmethod
def is_matching_cls(cls):
mod = sys.modules.get("transformers.configuration_utils")
is_match = mod is not None and issubclass(cls, mod.PretrainedConfig)
# Lazily install monkeypatch the first time we see it in dynamo
if is_match:
_install_PretrainedConfig_patch()
return is_match
@classmethod
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))
def __init__(self, obj, **kwargs):
super().__init__(**kwargs)
self.obj = obj
assert self.is_matching_cls(type(obj))
def var_getattr(self, tx, name: str) -> "VariableTracker":
from . import ConstantVariable
return ConstantVariable.create(getattr(self.obj, name))
def call_hasattr(self, tx, name: str) -> "VariableTracker":
return variables.ConstantVariable.create(hasattr(self.obj, name))
class PythonSysModulesVariable(VariableTracker):
"""Special case for sys.modules.
Without this we will guard on the exact set of modules imported in the
lifetime of the python program.
"""
def python_type(self):
return dict
def reconstruct(self, codegen):
codegen.extend_output(
[
codegen.create_load_python_module(sys, True),
codegen.create_load_attr("modules"),
]
)
def call_method(
self, tx, name, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
):
from .builder import VariableBuilder
if name == "__getitem__":
return self.call_getitem(tx, *args, **kwargs)
elif name == "get":
return self.call_get(tx, *args, **kwargs)
elif name == "__contains__":
return self.call_contains(tx, *args, **kwargs)
# Fallback to dict implementation
real_dict = VariableBuilder(tx, self.source)(sys.modules)
return real_dict.call_method(tx, name, args, kwargs)
def _contains_helper(self, tx, key: VariableTracker):
k = key.as_python_constant()
has_key = k in sys.modules
install_guard(
self.make_guard(
functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key)
)
)
return k, has_key
def call_contains(self, tx, key: VariableTracker):
k, has_key = self._contains_helper(tx, key)
return ConstantVariable.create(value=has_key)
def call_get(
self, tx, key: VariableTracker, default: Optional[VariableTracker] = None
):
from .builder import VariableBuilder
k, has_key = self._contains_helper(tx, key)
if has_key:
return VariableBuilder(
tx,
GetItemSource(self.source, k),
)(sys.modules[k])
if default is not None:
return default
return ConstantVariable.create(value=None)
def call_getitem(self, tx, key: VariableTracker):
from .builder import VariableBuilder
k, has_key = self._contains_helper(tx, key)
return VariableBuilder(
tx,
GetItemSource(self.source, k),
)(sys.modules[k])
|