# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. from typing import Optional import fire from llama import Llama def main( ckpt_dir: str, tokenizer_path: str, temperature: float = 0.2, top_p: float = 0.9, max_seq_len: int = 256, max_batch_size: int = 4, max_gen_len: Optional[int] = None, ): generator = Llama.build( ckpt_dir=ckpt_dir, tokenizer_path=tokenizer_path, max_seq_len=max_seq_len, max_batch_size=max_batch_size, ) prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt """\ def fizzbuzz(n: int):""", """\ import argparse def main(string: str): print(string) print(string[::-1]) if __name__ == "__main__":""" ] results = generator.text_completion( prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, ) for prompt, result in zip(prompts, results): print(prompt) print(f"> {result['generation']}") print("\n==================================\n") if __name__ == "__main__": fire.Fire(main)