Spaces:
Sleeping
Sleeping
File size: 1,937 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def emulate_int(w, bits, method, scale=None, zero_point=None):
q = globals()[f"emulate_int8_{method}"]
return q(w, scale=scale, zero_point=zero_point, bits=bits)
def quantize(w, scale, zero_point, bits=8):
# In the default behavior, max_val = 255.
max_val = 2 ** bits - 1
return (
torch.clamp(torch.round(w / scale + zero_point), 0, max_val) - zero_point
) * scale
def emulate_int8_histogram(w, scale=None, zero_point=None, bits=8):
if scale is None:
obs = torch.ao.quantization.observer.HistogramObserver()
obs.to(device=w.device)
_ = obs(w.float())
scale, zero_point = obs.calculate_qparams()
scale = scale.cuda().type_as(w)
zero_point = zero_point.cuda().type_as(w)
return quantize(w, scale, zero_point, bits=bits), scale, zero_point
def emulate_int8_channel(w, scale=None, zero_point=None, bits=8):
if scale is None:
obs = torch.ao.quantization.observer.PerChannelMinMaxObserver(
ch_axis=-1, qscheme=torch.per_channel_symmetric
)
obs.to(device=w.device)
_ = obs(w)
scale, zero_point, ch_axis = obs.get_qparams()
scale = scale.cuda().type_as(w)
zero_point = zero_point.cuda().type_as(w)
return quantize(w, scale, zero_point, bits=bits), scale, zero_point
def emulate_int8_tensor(w, scale=None, zero_point=None, bits=8):
if scale is None:
obs = torch.ao.quantization.observer.MinMaxObserver()
obs.to(device=w.device)
_ = obs(w)
scale, zero_point = obs.calculate_qparams()
scale = scale.cuda().type_as(w)
zero_point = zero_point.cuda().type_as(w)
return quantize(w, scale, zero_point, bits=bits), scale, zero_point
|