|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from configuration_indictrans import IndicTransConfig
|
|
from modeling_indictrans import IndicTransForConditionalGeneration
|
|
|
|
|
|
def remove_ignore_keys_(state_dict):
|
|
ignore_keys = [
|
|
"encoder.version",
|
|
"decoder.version",
|
|
"model.encoder.version",
|
|
"model.decoder.version",
|
|
"_float_tensor",
|
|
"encoder.embed_positions._float_tensor",
|
|
"decoder.embed_positions._float_tensor",
|
|
]
|
|
for k in ignore_keys:
|
|
state_dict.pop(k, None)
|
|
|
|
|
|
def make_linear_from_emb(emb):
|
|
vocab_size, emb_size = emb.shape
|
|
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
|
lin_layer.weight.data = emb.data
|
|
return lin_layer
|
|
|
|
|
|
def convert_fairseq_IT2_checkpoint_from_disk(checkpoint_path):
|
|
model = torch.load(checkpoint_path, map_location="cpu")
|
|
args = model["args"] or model["cfg"]["model"]
|
|
state_dict = model["model"]
|
|
remove_ignore_keys_(state_dict)
|
|
encoder_vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
|
|
decoder_vocab_size = state_dict["decoder.embed_tokens.weight"].shape[0]
|
|
|
|
config = IndicTransConfig(
|
|
encoder_vocab_size=encoder_vocab_size,
|
|
decoder_vocab_size=decoder_vocab_size,
|
|
max_source_positions=args.max_source_positions,
|
|
max_target_positions=args.max_target_positions,
|
|
encoder_layers=args.encoder_layers,
|
|
decoder_layers=args.decoder_layers,
|
|
layernorm_embedding=args.layernorm_embedding,
|
|
encoder_normalize_before=args.encoder_normalize_before,
|
|
decoder_normalize_before=args.decoder_normalize_before,
|
|
encoder_attention_heads=args.encoder_attention_heads,
|
|
decoder_attention_heads=args.decoder_attention_heads,
|
|
encoder_ffn_dim=args.encoder_ffn_embed_dim,
|
|
decoder_ffn_dim=args.decoder_ffn_embed_dim,
|
|
encoder_embed_dim=args.encoder_embed_dim,
|
|
decoder_embed_dim=args.decoder_embed_dim,
|
|
encoder_layerdrop=args.encoder_layerdrop,
|
|
decoder_layerdrop=args.decoder_layerdrop,
|
|
dropout=args.dropout,
|
|
attention_dropout=args.attention_dropout,
|
|
activation_dropout=args.activation_dropout,
|
|
activation_function=args.activation_fn,
|
|
share_decoder_input_output_embed=args.share_decoder_input_output_embed,
|
|
scale_embedding=not args.no_scale_embedding,
|
|
)
|
|
|
|
model = IndicTransForConditionalGeneration(config)
|
|
model.model.load_state_dict(state_dict, strict=False)
|
|
if not args.share_decoder_input_output_embed:
|
|
model.lm_head = make_linear_from_emb(
|
|
state_dict["decoder.output_projection.weight"]
|
|
)
|
|
print(model)
|
|
return model
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--fairseq_path",
|
|
default="indic-en/model/checkpoint_best.pt",
|
|
type=str,
|
|
help="path to a model.pt on local filesystem.",
|
|
)
|
|
parser.add_argument(
|
|
"--pytorch_dump_folder_path",
|
|
default="indic-en/hf_model",
|
|
type=str,
|
|
help="Path to the output PyTorch model.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
model = convert_fairseq_IT2_checkpoint_from_disk(args.fairseq_path)
|
|
model.save_pretrained(args.pytorch_dump_folder_path)
|
|
|