|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet) |
|
""" |
|
|
|
|
|
import argparse |
|
import logging |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from transformers import ( |
|
CTRLLMHeadModel, |
|
CTRLTokenizer, |
|
GenerationMixin, |
|
GPT2LMHeadModel, |
|
GPT2Tokenizer, |
|
OpenAIGPTLMHeadModel, |
|
OpenAIGPTTokenizer, |
|
TransfoXLLMHeadModel, |
|
TransfoXLTokenizer, |
|
XLMTokenizer, |
|
XLMWithLMHeadModel, |
|
XLNetLMHeadModel, |
|
XLNetTokenizer, |
|
) |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO, |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
MAX_LENGTH = int(10000) |
|
|
|
MODEL_CLASSES = { |
|
"gpt2": (GPT2LMHeadModel, GPT2Tokenizer), |
|
"ctrl": (CTRLLMHeadModel, CTRLTokenizer), |
|
"openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), |
|
"xlnet": (XLNetLMHeadModel, XLNetTokenizer), |
|
"transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), |
|
"xlm": (XLMWithLMHeadModel, XLMTokenizer), |
|
} |
|
|
|
|
|
|
|
|
|
PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family |
|
(except for Alexei and Maria) are discovered. |
|
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the |
|
remainder of the story. 1883 Western Siberia, |
|
a young Grigori Rasputin is asked by his father and a group of men to perform magic. |
|
Rasputin has a vision and denounces one of the men as a horse thief. Although his |
|
father initially slaps him for making such an accusation, Rasputin watches as the |
|
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of |
|
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, |
|
with people, even a bishop, begging for his blessing. <eod> </s> <eos>""" |
|
|
|
|
|
def set_seed(args): |
|
np.random.seed(args.seed) |
|
torch.manual_seed(args.seed) |
|
if args.n_gpu > 0: |
|
torch.cuda.manual_seed_all(args.seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_ctrl_input(args, _, tokenizer, prompt_text): |
|
if args.temperature > 0.7: |
|
logger.info("CTRL typically works better with lower temperatures (and lower top_k).") |
|
|
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) |
|
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): |
|
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") |
|
return prompt_text |
|
|
|
|
|
def prepare_xlm_input(args, model, tokenizer, prompt_text): |
|
|
|
|
|
|
|
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb |
|
if hasattr(model.config, "lang2id") and use_lang_emb: |
|
available_languages = model.config.lang2id.keys() |
|
if args.xlm_language in available_languages: |
|
language = args.xlm_language |
|
else: |
|
language = None |
|
while language not in available_languages: |
|
language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") |
|
|
|
model.config.lang_id = model.config.lang2id[language] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return prompt_text |
|
|
|
|
|
def prepare_xlnet_input(args, _, tokenizer, prompt_text): |
|
prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX |
|
prompt_text = prefix + prompt_text |
|
return prompt_text |
|
|
|
|
|
def prepare_transfoxl_input(args, _, tokenizer, prompt_text): |
|
prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX |
|
prompt_text = prefix + prompt_text |
|
return prompt_text |
|
|
|
|
|
PREPROCESSING_FUNCTIONS = { |
|
"ctrl": prepare_ctrl_input, |
|
"xlm": prepare_xlm_input, |
|
"xlnet": prepare_xlnet_input, |
|
"transfo-xl": prepare_transfoxl_input, |
|
} |
|
|
|
|
|
def adjust_length_to_model(length, max_sequence_length): |
|
if length < 0 and max_sequence_length > 0: |
|
length = max_sequence_length |
|
elif 0 < max_sequence_length < length: |
|
length = max_sequence_length |
|
elif length < 0: |
|
length = MAX_LENGTH |
|
return length |
|
|
|
|
|
def sparse_model_config(model_config): |
|
embedding_size = None |
|
if hasattr(model_config, "hidden_size"): |
|
embedding_size = model_config.hidden_size |
|
elif hasattr(model_config, "n_embed"): |
|
embedding_size = model_config.n_embed |
|
elif hasattr(model_config, "n_embd"): |
|
embedding_size = model_config.n_embd |
|
|
|
num_head = None |
|
if hasattr(model_config, "num_attention_heads"): |
|
num_head = model_config.num_attention_heads |
|
elif hasattr(model_config, "n_head"): |
|
num_head = model_config.n_head |
|
|
|
if embedding_size is None or num_head is None or num_head == 0: |
|
raise ValueError("Check the model config") |
|
|
|
num_embedding_size_per_head = int(embedding_size / num_head) |
|
num_layer = model_config.n_layer |
|
|
|
return num_layer, num_head, num_embedding_size_per_head |
|
|
|
|
|
def prepare_jit_inputs(inputs, model, tokenizer): |
|
num_batch = len(inputs) |
|
dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True) |
|
num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config) |
|
if model.config.model_type == "bloom": |
|
past_key_values = tuple( |
|
( |
|
torch.zeros(int(num_attention_heads * num_batch), num_embedding_size_per_head, 1) |
|
.to(model.config.torch_dtype) |
|
.to(model.device), |
|
torch.zeros(int(num_attention_heads * num_batch), 1, num_embedding_size_per_head) |
|
.to(model.config.torch_dtype) |
|
.to(model.device), |
|
) |
|
for _ in range(num_block_layers) |
|
) |
|
else: |
|
past_key_values = tuple( |
|
( |
|
torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head) |
|
.to(model.config.torch_dtype) |
|
.to(model.device), |
|
torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head) |
|
.to(model.config.torch_dtype) |
|
.to(model.device), |
|
) |
|
for _ in range(num_block_layers) |
|
) |
|
|
|
dummy_input["attention_mask"] = torch.cat( |
|
[ |
|
torch.zeros(dummy_input["attention_mask"].shape[0], 1).to(dummy_input["attention_mask"].dtype), |
|
dummy_input["attention_mask"], |
|
], |
|
-1, |
|
) |
|
|
|
if model.config.use_cache: |
|
jit_inputs = ( |
|
dummy_input["input_ids"].to(model.device), |
|
past_key_values, |
|
dummy_input["attention_mask"].to(model.device), |
|
) |
|
else: |
|
jit_inputs = ( |
|
dummy_input["input_ids"].to(model.device), |
|
dummy_input["attention_mask"].to(model.device), |
|
) |
|
|
|
return jit_inputs |
|
|
|
|
|
class _ModelFallbackWrapper(GenerationMixin): |
|
__slots__ = ("_optimized", "_default") |
|
|
|
def __init__(self, optimized, default): |
|
self._optimized = optimized |
|
self._default = default |
|
|
|
def __call__(self, *args, **kwargs): |
|
if kwargs["past_key_values"] is None: |
|
return self._default(*args, **kwargs) |
|
trace_graph_inputs = [] |
|
kwargs.pop("position_ids", None) |
|
for k, v in kwargs.items(): |
|
if v is not None and not isinstance(v, bool): |
|
trace_graph_inputs.append(v) |
|
trace_graph_inputs = tuple(trace_graph_inputs) |
|
outputs = self._optimized(*trace_graph_inputs) |
|
lm_logits = outputs[0] |
|
past_key_values = outputs[1] |
|
fixed_output = CausalLMOutputWithPast( |
|
loss=None, |
|
logits=lm_logits, |
|
past_key_values=past_key_values, |
|
hidden_states=None, |
|
attentions=None, |
|
) |
|
return fixed_output |
|
|
|
def __getattr__(self, item): |
|
return getattr(self._default, item) |
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs |
|
): |
|
return self._default.prepare_inputs_for_generation( |
|
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs |
|
) |
|
|
|
def _reorder_cache( |
|
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor |
|
) -> Tuple[Tuple[torch.Tensor]]: |
|
""" |
|
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or |
|
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct |
|
beam_idx at every generation step. |
|
""" |
|
return self._default._reorder_cache(past_key_values, beam_idx) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model_type", |
|
default=None, |
|
type=str, |
|
required=True, |
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), |
|
) |
|
parser.add_argument( |
|
"--model_name_or_path", |
|
default=None, |
|
type=str, |
|
required=True, |
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()), |
|
) |
|
|
|
parser.add_argument("--prompt", type=str, default="") |
|
parser.add_argument("--length", type=int, default=20) |
|
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") |
|
|
|
parser.add_argument( |
|
"--temperature", |
|
type=float, |
|
default=1.0, |
|
help="temperature of 1.0 has no effect, lower tend toward greedy sampling", |
|
) |
|
parser.add_argument( |
|
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" |
|
) |
|
parser.add_argument("--k", type=int, default=0) |
|
parser.add_argument("--p", type=float, default=0.9) |
|
|
|
parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.") |
|
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.") |
|
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") |
|
|
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") |
|
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") |
|
parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") |
|
parser.add_argument( |
|
"--fp16", |
|
action="store_true", |
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", |
|
) |
|
parser.add_argument( |
|
"--jit", type=bool, default=False, help="Whether or not to use jit trace to accelerate inference" |
|
) |
|
args = parser.parse_args() |
|
|
|
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") |
|
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() |
|
|
|
logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}") |
|
|
|
set_seed(args) |
|
|
|
|
|
try: |
|
args.model_type = args.model_type.lower() |
|
model_class, tokenizer_class = MODEL_CLASSES[args.model_type] |
|
except KeyError: |
|
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") |
|
|
|
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
model = model_class.from_pretrained(args.model_name_or_path) |
|
model.to(args.device) |
|
|
|
if args.fp16: |
|
model.half() |
|
|
|
args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings) |
|
logger.info(args) |
|
|
|
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") |
|
|
|
|
|
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() |
|
if requires_preprocessing: |
|
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) |
|
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) |
|
|
|
if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: |
|
tokenizer_kwargs = {"add_space_before_punct_symbol": True} |
|
else: |
|
tokenizer_kwargs = {} |
|
|
|
encoded_prompt = tokenizer.encode( |
|
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs |
|
) |
|
else: |
|
prefix = args.prefix if args.prefix else args.padding_text |
|
encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt") |
|
encoded_prompt = encoded_prompt.to(args.device) |
|
|
|
if encoded_prompt.size()[-1] == 0: |
|
input_ids = None |
|
else: |
|
input_ids = encoded_prompt |
|
|
|
if args.jit: |
|
jit_input_texts = ["jit"] |
|
jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer) |
|
torch._C._jit_set_texpr_fuser_enabled(False) |
|
model.config.return_dict = False |
|
traced_model = torch.jit.trace(model, jit_inputs, strict=False) |
|
traced_model = torch.jit.freeze(traced_model.eval()) |
|
traced_model(*jit_inputs) |
|
traced_model(*jit_inputs) |
|
|
|
model = _ModelFallbackWrapper(traced_model, model) |
|
|
|
output_sequences = model.generate( |
|
input_ids=input_ids, |
|
max_length=args.length + len(encoded_prompt[0]), |
|
temperature=args.temperature, |
|
top_k=args.k, |
|
top_p=args.p, |
|
repetition_penalty=args.repetition_penalty, |
|
do_sample=True, |
|
num_return_sequences=args.num_return_sequences, |
|
) |
|
|
|
|
|
if len(output_sequences.shape) > 2: |
|
output_sequences.squeeze_() |
|
|
|
generated_sequences = [] |
|
|
|
for generated_sequence_idx, generated_sequence in enumerate(output_sequences): |
|
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") |
|
generated_sequence = generated_sequence.tolist() |
|
|
|
|
|
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) |
|
|
|
|
|
text = text[: text.find(args.stop_token) if args.stop_token else None] |
|
|
|
|
|
total_sequence = ( |
|
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] |
|
) |
|
|
|
generated_sequences.append(total_sequence) |
|
print(total_sequence) |
|
|
|
return generated_sequences |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|