|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Once a model has been fine-pruned, the weights that are masked during the forward pass can be pruned once for all. |
|
For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceClassification` is trained, it can be saved (and then loaded) |
|
as a standard :class:`~transformers.BertForSequenceClassification`. |
|
""" |
|
|
|
import argparse |
|
import os |
|
import shutil |
|
|
|
import torch |
|
from emmental.modules import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer |
|
|
|
|
|
def main(args): |
|
pruning_method = args.pruning_method |
|
threshold = args.threshold |
|
|
|
model_name_or_path = args.model_name_or_path.rstrip("/") |
|
target_model_path = args.target_model_path |
|
|
|
print(f"Load fine-pruned model from {model_name_or_path}") |
|
model = torch.load(os.path.join(model_name_or_path, "pytorch_model.bin")) |
|
pruned_model = {} |
|
|
|
for name, tensor in model.items(): |
|
if "embeddings" in name or "LayerNorm" in name or "pooler" in name: |
|
pruned_model[name] = tensor |
|
print(f"Copied layer {name}") |
|
elif "classifier" in name or "qa_output" in name: |
|
pruned_model[name] = tensor |
|
print(f"Copied layer {name}") |
|
elif "bias" in name: |
|
pruned_model[name] = tensor |
|
print(f"Copied layer {name}") |
|
else: |
|
if pruning_method == "magnitude": |
|
mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold) |
|
pruned_model[name] = tensor * mask |
|
print(f"Pruned layer {name}") |
|
elif pruning_method == "topK": |
|
if "mask_scores" in name: |
|
continue |
|
prefix_ = name[:-6] |
|
scores = model[f"{prefix_}mask_scores"] |
|
mask = TopKBinarizer.apply(scores, threshold) |
|
pruned_model[name] = tensor * mask |
|
print(f"Pruned layer {name}") |
|
elif pruning_method == "sigmoied_threshold": |
|
if "mask_scores" in name: |
|
continue |
|
prefix_ = name[:-6] |
|
scores = model[f"{prefix_}mask_scores"] |
|
mask = ThresholdBinarizer.apply(scores, threshold, True) |
|
pruned_model[name] = tensor * mask |
|
print(f"Pruned layer {name}") |
|
elif pruning_method == "l0": |
|
if "mask_scores" in name: |
|
continue |
|
prefix_ = name[:-6] |
|
scores = model[f"{prefix_}mask_scores"] |
|
l, r = -0.1, 1.1 |
|
s = torch.sigmoid(scores) |
|
s_bar = s * (r - l) + l |
|
mask = s_bar.clamp(min=0.0, max=1.0) |
|
pruned_model[name] = tensor * mask |
|
print(f"Pruned layer {name}") |
|
else: |
|
raise ValueError("Unknown pruning method") |
|
|
|
if target_model_path is None: |
|
target_model_path = os.path.join( |
|
os.path.dirname(model_name_or_path), f"bertarized_{os.path.basename(model_name_or_path)}" |
|
) |
|
|
|
if not os.path.isdir(target_model_path): |
|
shutil.copytree(model_name_or_path, target_model_path) |
|
print(f"\nCreated folder {target_model_path}") |
|
|
|
torch.save(pruned_model, os.path.join(target_model_path, "pytorch_model.bin")) |
|
print("\nPruned model saved! See you later!") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--pruning_method", |
|
choices=["l0", "magnitude", "topK", "sigmoied_threshold"], |
|
type=str, |
|
required=True, |
|
help=( |
|
"Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning," |
|
" sigmoied_threshold = Soft movement pruning)" |
|
), |
|
) |
|
parser.add_argument( |
|
"--threshold", |
|
type=float, |
|
required=False, |
|
help=( |
|
"For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model. " |
|
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared. " |
|
"Not needed for `l0`" |
|
), |
|
) |
|
parser.add_argument( |
|
"--model_name_or_path", |
|
type=str, |
|
required=True, |
|
help="Folder containing the model that was previously fine-pruned", |
|
) |
|
parser.add_argument( |
|
"--target_model_path", |
|
default=None, |
|
type=str, |
|
required=False, |
|
help="Folder containing the model that was previously fine-pruned", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|