Spaces:
Running
Running
File size: 6,714 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
from typing import List, Optional
import torch
from torch.backends._nnapi.serializer import _NnapiSerializer
ANEURALNETWORKS_PREFER_LOW_POWER = 0
ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2
class NnapiModule(torch.nn.Module):
"""Torch Module that wraps an NNAPI Compilation.
This module handles preparing the weights, initializing the
NNAPI TorchBind object, and adjusting the memory formats
of all inputs and outputs.
"""
# _nnapi.Compilation is defined
comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined]
weights: List[torch.Tensor]
out_templates: List[torch.Tensor]
def __init__(
self,
shape_compute_module: torch.nn.Module,
ser_model: torch.Tensor,
weights: List[torch.Tensor],
inp_mem_fmts: List[int],
out_mem_fmts: List[int],
compilation_preference: int,
relax_f32_to_f16: bool,
):
super().__init__()
self.shape_compute_module = shape_compute_module
self.ser_model = ser_model
self.weights = weights
self.inp_mem_fmts = inp_mem_fmts
self.out_mem_fmts = out_mem_fmts
self.out_templates = []
self.comp = None
self.compilation_preference = compilation_preference
self.relax_f32_to_f16 = relax_f32_to_f16
@torch.jit.export
def init(self, args: List[torch.Tensor]):
assert self.comp is None
self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator]
self.weights = [w.contiguous() for w in self.weights]
comp = torch.classes._nnapi.Compilation()
comp.init2(
self.ser_model,
self.weights,
self.compilation_preference,
self.relax_f32_to_f16,
)
self.comp = comp
def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]:
if self.comp is None:
self.init(args)
comp = self.comp
assert comp is not None
outs = [torch.empty_like(out) for out in self.out_templates]
assert len(args) == len(self.inp_mem_fmts)
fixed_args = []
for idx in range(len(args)):
fmt = self.inp_mem_fmts[idx]
# These constants match the values in DimOrder in serializer.py
# TODO: See if it's possible to use those directly.
if fmt == 0:
fixed_args.append(args[idx].contiguous())
elif fmt == 1:
fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous())
else:
raise Exception("Invalid mem_fmt")
comp.run(fixed_args, outs)
assert len(outs) == len(self.out_mem_fmts)
for idx in range(len(self.out_templates)):
fmt = self.out_mem_fmts[idx]
# These constants match the values in DimOrder in serializer.py
# TODO: See if it's possible to use those directly.
if fmt in (0, 2):
pass
elif fmt == 1:
outs[idx] = outs[idx].permute(0, 3, 1, 2)
else:
raise Exception("Invalid mem_fmt")
return outs
def convert_model_to_nnapi(
model,
inputs,
serializer=None,
return_shapes=None,
use_int16_for_qint16=False,
compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED,
relax_f32_to_f16=False,
):
(
shape_compute_module,
ser_model_tensor,
used_weights,
inp_mem_fmts,
out_mem_fmts,
retval_count,
) = process_for_nnapi(
model, inputs, serializer, return_shapes, use_int16_for_qint16
)
nnapi_model = NnapiModule(
shape_compute_module,
ser_model_tensor,
used_weights,
inp_mem_fmts,
out_mem_fmts,
compilation_preference,
relax_f32_to_f16,
)
class NnapiInterfaceWrapper(torch.nn.Module):
"""NNAPI list-ifying and de-list-ifying wrapper.
NNAPI always expects a list of inputs and provides a list of outputs.
This module allows us to accept inputs as separate arguments.
It returns results as either a single tensor or tuple,
matching the original module.
"""
def __init__(self, mod):
super().__init__()
self.mod = mod
wrapper_model_py = NnapiInterfaceWrapper(nnapi_model)
wrapper_model = torch.jit.script(wrapper_model_py)
# TODO: Maybe make these names match the original.
arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs)))
if retval_count < 0:
ret_expr = "retvals[0]"
else:
ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count))
wrapper_model.define(
f"def forward(self, {arg_list}):\n"
f" retvals = self.mod([{arg_list}])\n"
f" return {ret_expr}\n"
)
return wrapper_model
def process_for_nnapi(
model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False
):
model = torch.jit.freeze(model)
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
serializer = serializer or _NnapiSerializer(
config=None, use_int16_for_qint16=use_int16_for_qint16
)
(
ser_model,
used_weights,
inp_mem_fmts,
out_mem_fmts,
shape_compute_lines,
retval_count,
) = serializer.serialize_model(model, inputs, return_shapes)
ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
# We have to create a new class here every time this function is called
# because module.define adds a method to the *class*, not the instance.
class ShapeComputeModule(torch.nn.Module):
"""Code-gen-ed module for tensor shape computation.
module.prepare will mutate ser_model according to the computed operand
shapes, based on the shapes of args. Returns a list of output templates.
"""
pass
shape_compute_module = torch.jit.script(ShapeComputeModule())
real_shape_compute_lines = [
"def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n",
] + [f" {line}\n" for line in shape_compute_lines]
shape_compute_module.define("".join(real_shape_compute_lines))
return (
shape_compute_module,
ser_model_tensor,
used_weights,
inp_mem_fmts,
out_mem_fmts,
retval_count,
)
|