Spaces:
Running
Running
| import os | |
| import math | |
| import argparse | |
| import glob | |
| import gradio | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from transformers import ( | |
| PreTrainedTokenizerBase, | |
| DataCollatorForSeq2Seq, | |
| ) | |
| from model import load_model_for_inference | |
| from dataset import DatasetReader, count_lines | |
| from accelerate import Accelerator, DistributedType, find_executable_batch_size | |
| from typing import Optional | |
| def encode_string(text): | |
| return text.replace("\r", r"\r").replace("\n", r"\n").replace("\t", r"\t") | |
| def get_dataloader( | |
| accelerator: Accelerator, | |
| filename: str, | |
| tokenizer: PreTrainedTokenizerBase, | |
| batch_size: int, | |
| max_length: int, | |
| prompt: str, | |
| ) -> DataLoader: | |
| dataset = DatasetReader( | |
| filename=filename, | |
| tokenizer=tokenizer, | |
| max_length=max_length, | |
| prompt=prompt, | |
| ) | |
| if accelerator.distributed_type == DistributedType.TPU: | |
| data_collator = DataCollatorForSeq2Seq( | |
| tokenizer, | |
| padding="max_length", | |
| max_length=max_length, | |
| label_pad_token_id=tokenizer.pad_token_id, | |
| return_tensors="pt", | |
| ) | |
| else: | |
| data_collator = DataCollatorForSeq2Seq( | |
| tokenizer, | |
| padding=True, | |
| label_pad_token_id=tokenizer.pad_token_id, | |
| # max_length=max_length, No need to set max_length here, we already truncate in the preprocess function | |
| pad_to_multiple_of=8, | |
| return_tensors="pt", | |
| ) | |
| return DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| collate_fn=data_collator, | |
| num_workers=0, # Disable multiprocessing | |
| ) | |
| def main( | |
| sentences_path: Optional[str], | |
| sentences_dir: Optional[str], | |
| files_extension: str, | |
| output_path: str, | |
| source_lang: Optional[str], | |
| target_lang: Optional[str], | |
| starting_batch_size: int, | |
| model_name: str = "facebook/m2m100_1.2B", | |
| lora_weights_name_or_path: str = None, | |
| force_auto_device_map: bool = False, | |
| precision: str = None, | |
| max_length: int = 256, | |
| num_beams: int = 4, | |
| num_return_sequences: int = 1, | |
| do_sample: bool = False, | |
| temperature: float = 1.0, | |
| top_k: int = 50, | |
| top_p: float = 1.0, | |
| keep_special_tokens: bool = False, | |
| keep_tokenization_spaces: bool = False, | |
| repetition_penalty: float = None, | |
| prompt: str = None, | |
| trust_remote_code: bool = False, | |
| ): | |
| accelerator = Accelerator() | |
| if force_auto_device_map and starting_batch_size >= 64: | |
| print( | |
| f"WARNING: You are using a very large batch size ({starting_batch_size}) and the auto_device_map flag. " | |
| f"auto_device_map will offload model parameters to the CPU when they don't fit on the GPU VRAM. " | |
| f"If you use a very large batch size, it will offload a lot of parameters to the CPU and slow down the " | |
| f"inference. You should consider using a smaller batch size, i.e '--starting_batch_size 8'" | |
| ) | |
| if sentences_path is None and sentences_dir is None: | |
| raise ValueError( | |
| "You must specify either --sentences_path or --sentences_dir. Use --help for more details." | |
| ) | |
| if sentences_path is not None and sentences_dir is not None: | |
| raise ValueError( | |
| "You must specify either --sentences_path or --sentences_dir, not both. Use --help for more details." | |
| ) | |
| if precision is None: | |
| quantization = None | |
| dtype = None | |
| elif precision == "8" or precision == "4": | |
| quantization = int(precision) | |
| dtype = None | |
| elif precision == "fp16": | |
| quantization = None | |
| dtype = "float16" | |
| elif precision == "bf16": | |
| quantization = None | |
| dtype = "bfloat16" | |
| elif precision == "32": | |
| quantization = None | |
| dtype = "float32" | |
| else: | |
| raise ValueError( | |
| f"Precision {precision} not supported. Please choose between 8, 4, fp16, bf16, 32 or None." | |
| ) | |
| model, tokenizer = load_model_for_inference( | |
| weights_path=model_name, | |
| quantization=quantization, | |
| lora_weights_name_or_path=lora_weights_name_or_path, | |
| torch_dtype=dtype, | |
| force_auto_device_map=force_auto_device_map, | |
| trust_remote_code=trust_remote_code, | |
| ) | |
| is_translation_model = hasattr(tokenizer, "lang_code_to_id") | |
| lang_code_to_idx = None | |
| if ( | |
| is_translation_model | |
| and (source_lang is None or target_lang is None) | |
| and "small100" not in model_name | |
| ): | |
| raise ValueError( | |
| f"The model you are using requires a source and target language. " | |
| f"Please specify them with --source-lang and --target-lang. " | |
| f"The supported languages are: {tokenizer.lang_code_to_id.keys()}" | |
| ) | |
| if not is_translation_model and ( | |
| source_lang is not None or target_lang is not None | |
| ): | |
| if prompt is None: | |
| print( | |
| "WARNING: You are using a model that does not support source and target languages parameters " | |
| "but you specified them. You probably want to use m2m100/nllb200 for translation or " | |
| "set --prompt to define the task for you model. " | |
| ) | |
| else: | |
| print( | |
| "WARNING: You are using a model that does not support source and target languages parameters " | |
| "but you specified them." | |
| ) | |
| if prompt is not None and "%%SENTENCE%%" not in prompt: | |
| raise ValueError( | |
| f"The prompt must contain the %%SENTENCE%% token to indicate where the sentence should be inserted. " | |
| f"Your prompt: {prompt}" | |
| ) | |
| if is_translation_model: | |
| try: | |
| _ = tokenizer.lang_code_to_id[source_lang] | |
| except KeyError: | |
| raise KeyError( | |
| f"Language {source_lang} not found in tokenizer. Available languages: {tokenizer.lang_code_to_id.keys()}" | |
| ) | |
| tokenizer.src_lang = source_lang | |
| try: | |
| lang_code_to_idx = tokenizer.lang_code_to_id[target_lang] | |
| except KeyError: | |
| raise KeyError( | |
| f"Language {target_lang} not found in tokenizer. Available languages: {tokenizer.lang_code_to_id.keys()}" | |
| ) | |
| if "small100" in model_name: | |
| tokenizer.tgt_lang = target_lang | |
| # We don't need to force the BOS token, so we set is_translation_model to False | |
| is_translation_model = False | |
| if model.config.model_type == "seamless_m4t": | |
| # Loading a seamless_m4t model, we need to set a few things to ensure compatibility | |
| supported_langs = tokenizer.additional_special_tokens | |
| supported_langs = [lang.replace("__", "") for lang in supported_langs] | |
| if source_lang is None or target_lang is None: | |
| raise ValueError( | |
| f"The model you are using requires a source and target language. " | |
| f"Please specify them with --source-lang and --target-lang. " | |
| f"The supported languages are: {supported_langs}" | |
| ) | |
| if source_lang not in supported_langs: | |
| raise ValueError( | |
| f"Language {source_lang} not found in tokenizer. Available languages: {supported_langs}" | |
| ) | |
| if target_lang not in supported_langs: | |
| raise ValueError( | |
| f"Language {target_lang} not found in tokenizer. Available languages: {supported_langs}" | |
| ) | |
| tokenizer.src_lang = source_lang | |
| gen_kwargs = { | |
| "max_new_tokens": max_length, | |
| "num_beams": num_beams, | |
| "num_return_sequences": num_return_sequences, | |
| "do_sample": do_sample, | |
| "temperature": temperature, | |
| "top_k": top_k, | |
| "top_p": top_p, | |
| } | |
| if repetition_penalty is not None: | |
| gen_kwargs["repetition_penalty"] = repetition_penalty | |
| if is_translation_model: | |
| gen_kwargs["forced_bos_token_id"] = lang_code_to_idx | |
| if model.config.model_type == "seamless_m4t": | |
| gen_kwargs["tgt_lang"] = target_lang | |
| if accelerator.is_main_process: | |
| print( | |
| f"** Translation **\n" | |
| f"Input file: {sentences_path}\n" | |
| f"Sentences dir: {sentences_dir}\n" | |
| f"Output file: {output_path}\n" | |
| f"Source language: {source_lang}\n" | |
| f"Target language: {target_lang}\n" | |
| f"Force target lang as BOS token: {is_translation_model}\n" | |
| f"Prompt: {prompt}\n" | |
| f"Starting batch size: {starting_batch_size}\n" | |
| f"Device: {str(accelerator.device).split(':')[0]}\n" | |
| f"Num. Devices: {accelerator.num_processes}\n" | |
| f"Distributed_type: {accelerator.distributed_type}\n" | |
| f"Max length: {max_length}\n" | |
| f"Quantization: {quantization}\n" | |
| f"Precision: {dtype}\n" | |
| f"Model: {model_name}\n" | |
| f"LoRA weights: {lora_weights_name_or_path}\n" | |
| f"Force auto device map: {force_auto_device_map}\n" | |
| f"Keep special tokens: {keep_special_tokens}\n" | |
| f"Keep tokenization spaces: {keep_tokenization_spaces}\n" | |
| ) | |
| print("** Generation parameters **") | |
| print("\n".join(f"{k}: {v}" for k, v in gen_kwargs.items())) | |
| print("\n") | |
| def inference(batch_size, sentences_path, output_path): | |
| nonlocal model, tokenizer, max_length, gen_kwargs, precision, prompt, is_translation_model | |
| print(f"Translating {sentences_path} with batch size {batch_size}") | |
| total_lines: int = count_lines(sentences_path) | |
| data_loader = get_dataloader( | |
| accelerator=accelerator, | |
| filename=sentences_path, | |
| tokenizer=tokenizer, | |
| batch_size=batch_size, | |
| max_length=max_length, | |
| prompt=prompt, | |
| ) | |
| model, data_loader = accelerator.prepare(model, data_loader) | |
| samples_seen: int = 0 | |
| with tqdm( | |
| total=total_lines, | |
| desc="Dataset translation", | |
| leave=True, | |
| ascii=True, | |
| disable=(not accelerator.is_main_process), | |
| ) as pbar, open(output_path, "w", encoding="utf-8") as output_file: | |
| with torch.no_grad(): | |
| for step, batch in enumerate(data_loader): | |
| batch["input_ids"] = batch["input_ids"] | |
| batch["attention_mask"] = batch["attention_mask"] | |
| generated_tokens = accelerator.unwrap_model(model).generate( | |
| **batch, | |
| **gen_kwargs, | |
| ) | |
| generated_tokens = accelerator.pad_across_processes( | |
| generated_tokens, dim=1, pad_index=tokenizer.pad_token_id | |
| ) | |
| generated_tokens = ( | |
| accelerator.gather(generated_tokens).cpu().numpy() | |
| ) | |
| tgt_text = tokenizer.batch_decode( | |
| generated_tokens, | |
| skip_special_tokens=not keep_special_tokens, | |
| clean_up_tokenization_spaces=not keep_tokenization_spaces, | |
| ) | |
| if accelerator.is_main_process: | |
| if ( | |
| step | |
| == math.ceil( | |
| math.ceil(total_lines / batch_size) | |
| / accelerator.num_processes | |
| ) | |
| - 1 | |
| ): | |
| tgt_text = tgt_text[ | |
| : (total_lines * num_return_sequences) - samples_seen | |
| ] | |
| else: | |
| samples_seen += len(tgt_text) | |
| print( | |
| "\n".join( | |
| [encode_string(sentence) for sentence in tgt_text] | |
| ), | |
| file=output_file, | |
| ) | |
| pbar.update(len(tgt_text) // gen_kwargs["num_return_sequences"]) | |
| print(f"Translation done. Output written to {output_path}\n") | |
| if sentences_path is not None: | |
| os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True) | |
| inference(sentences_path=sentences_path, output_path=output_path) | |
| if sentences_dir is not None: | |
| print( | |
| f"Translating all files in {sentences_dir}, with extension {files_extension}" | |
| ) | |
| os.makedirs(os.path.abspath(output_path), exist_ok=True) | |
| for filename in glob.glob( | |
| os.path.join( | |
| sentences_dir, f"*.{files_extension}" if files_extension else "*" | |
| ) | |
| ): | |
| output_filename = os.path.join(output_path, os.path.basename(filename)) | |
| inference(sentences_path=filename, output_path=output_filename) | |
| print(f"Translation done.\n") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Run the translation experiments") | |
| input_group = parser.add_mutually_exclusive_group(required=True) | |
| input_group.add_argument( | |
| "--sentences_path", | |
| default=None, | |
| type=str, | |
| help="Path to a txt file containing the sentences to translate. One sentence per line.", | |
| ) | |
| input_group.add_argument( | |
| "--sentences_dir", | |
| type=str, | |
| default=None, | |
| help="Path to a directory containing the sentences to translate. " | |
| "Sentences must be in .txt files containing containing one sentence per line.", | |
| ) | |
| parser.add_argument( | |
| "--files_extension", | |
| type=str, | |
| default="txt", | |
| help="If sentences_dir is specified, extension of the files to translate. Defaults to txt. " | |
| "If set to an empty string, we will translate all files in the directory.", | |
| ) | |
| parser.add_argument( | |
| "--output_path", | |
| type=str, | |
| required=True, | |
| help="Path to a txt file where the translated sentences will be written. If the input is a directory, " | |
| "the output will be a directory with the same structure.", | |
| ) | |
| parser.add_argument( | |
| "--source_lang", | |
| type=str, | |
| default=None, | |
| required=False, | |
| help="Source language id. See: supported_languages.md. Required for m2m100 and nllb200", | |
| ) | |
| parser.add_argument( | |
| "--target_lang", | |
| type=str, | |
| default=None, | |
| required=False, | |
| help="Source language id. See: supported_languages.md. Required for m2m100 and nllb200", | |
| ) | |
| parser.add_argument( | |
| "--starting_batch_size", | |
| type=int, | |
| default=128, | |
| help="Starting batch size, we will automatically reduce it if we find an OOM error." | |
| "If you use multiple devices, we will divide this number by the number of devices.", | |
| ) | |
| parser.add_argument( | |
| "--model_name", | |
| type=str, | |
| default="facebook/m2m100_1.2B", | |
| help="Path to the model to use. See: https://huggingface.co/models", | |
| ) | |
| parser.add_argument( | |
| "--lora_weights_name_or_path", | |
| type=str, | |
| default=None, | |
| help="If the model uses LoRA weights, path to those weights. See: https://github.com/huggingface/peft", | |
| ) | |
| parser.add_argument( | |
| "--force_auto_device_map", | |
| action="store_true", | |
| help=" Whether to force the use of the auto device map. If set to True, " | |
| "the model will be split across GPUs and CPU to fit the model in memory. " | |
| "If set to False, a full copy of the model will be loaded into each GPU. Defaults to False.", | |
| ) | |
| parser.add_argument( | |
| "--max_length", | |
| type=int, | |
| default=256, | |
| help="Maximum number of tokens in the source sentence and generated sentence. " | |
| "Increase this value to translate longer sentences, at the cost of increasing memory usage.", | |
| ) | |
| parser.add_argument( | |
| "--num_beams", | |
| type=int, | |
| default=5, | |
| help="Number of beams for beam search, m2m10 author recommends 5, but it might use too much memory", | |
| ) | |
| parser.add_argument( | |
| "--num_return_sequences", | |
| type=int, | |
| default=1, | |
| help="Number of possible translation to return for each sentence (num_return_sequences<=num_beams).", | |
| ) | |
| parser.add_argument( | |
| "--precision", | |
| type=str, | |
| default=None, | |
| choices=["bf16", "fp16", "32", "4", "8"], | |
| help="Precision of the model. bf16, fp16 or 32, 8 , 4 " | |
| "(4bits/8bits quantification, requires bitsandbytes library: https://github.com/TimDettmers/bitsandbytes). " | |
| "If None, we will use the torch.dtype of the model weights.", | |
| ) | |
| parser.add_argument( | |
| "--do_sample", | |
| action="store_true", | |
| help="Use sampling instead of beam search.", | |
| ) | |
| parser.add_argument( | |
| "--temperature", | |
| type=float, | |
| default=0.8, | |
| help="Temperature for sampling, value used only if do_sample is True.", | |
| ) | |
| parser.add_argument( | |
| "--top_k", | |
| type=int, | |
| default=100, | |
| help="If do_sample is True, will sample from the top k most likely tokens.", | |
| ) | |
| parser.add_argument( | |
| "--top_p", | |
| type=float, | |
| default=0.75, | |
| help="If do_sample is True, will sample from the top k most likely tokens.", | |
| ) | |
| parser.add_argument( | |
| "--keep_special_tokens", | |
| action="store_true", | |
| help="Keep special tokens in the decoded text.", | |
| ) | |
| parser.add_argument( | |
| "--keep_tokenization_spaces", | |
| action="store_true", | |
| help="Do not clean spaces in the decoded text.", | |
| ) | |
| parser.add_argument( | |
| "--repetition_penalty", | |
| type=float, | |
| default=None, | |
| help="Repetition penalty.", | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| default=None, | |
| help="Prompt to use for generation. " | |
| "It must include the special token %%SENTENCE%% which will be replaced by the sentence to translate.", | |
| ) | |
| parser.add_argument( | |
| "--trust_remote_code", | |
| action="store_true", | |
| help="If set we will trust remote code in HuggingFace models. This is required for some models.", | |
| ) | |
| args = parser.parse_args() | |
| main( | |
| sentences_path=args.sentences_path, | |
| sentences_dir=args.sentences_dir, | |
| files_extension=args.files_extension, | |
| output_path=args.output_path, | |
| source_lang=args.source_lang, | |
| target_lang=args.target_lang, | |
| starting_batch_size=args.starting_batch_size, | |
| model_name=args.model_name, | |
| max_length=args.max_length, | |
| num_beams=args.num_beams, | |
| num_return_sequences=args.num_return_sequences, | |
| precision=args.precision, | |
| do_sample=args.do_sample, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| top_p=args.top_p, | |
| keep_special_tokens=args.keep_special_tokens, | |
| keep_tokenization_spaces=args.keep_tokenization_spaces, | |
| repetition_penalty=args.repetition_penalty, | |
| prompt=args.prompt, | |
| trust_remote_code=args.trust_remote_code, | |
| ) | |
| demo = gradio.Interface(fn=main, inputs="textbox", outputs="textbox") | |
| demo.launch(share=True) | |