Charm_15 / generate.py
GeminiFan207's picture
Rename model to generate.py
eb813a3 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
import logging
from typing import List, Optional
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Load model and tokenizer
def load_model_and_tokenizer(model_name: str) -> tuple:
"""
Load the pre-trained model and tokenizer.
Args:
model_name (str): Name or path of the pre-trained model.
Returns:
tuple: (model, tokenizer)
"""
logger.info(f"Loading model: {model_name}...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
logger.info("Model and tokenizer loaded successfully.")
return model, tokenizer
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
# Generate text
def generate_text(
model,
tokenizer,
prompt: str,
max_length: int = 100,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.95,
) -> str:
"""
Generate text based on the given prompt.
Args:
model: Pre-trained language model.
tokenizer: Tokenizer for the model.
prompt (str): Input prompt for text generation.
max_length (int): Maximum length of the generated text.
temperature (float): Sampling temperature (higher = more random).
top_k (int): Top-k sampling (0 = no sampling).
top_p (float): Top-p (nucleus) sampling (1.0 = no sampling).
Returns:
str: Generated text.
"""
try:
inputs = tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {key: value.to("cuda") for key, value in inputs.items()}
model.to("cuda")
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info("Text generation completed successfully.")
return generated_text
except Exception as e:
logger.error(f"Error generating text: {e}")
raise
# Save generated text to a file
def save_to_file(text: str, filename: str) -> None:
"""
Save the generated text to a file.
Args:
text (str): Generated text.
filename (str): Name of the output file.
"""
try:
with open(filename, "w") as file:
file.write(text)
logger.info(f"Generated text saved to {filename}.")
except Exception as e:
logger.error(f"Error saving to file: {e}")
raise
# Main function
def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(
description="Generate text using a pre-trained language model.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--model",
type=str,
default="mistralai/Mistral-8x7B",
help="Name or path of the pre-trained model.",
)
parser.add_argument(
"--prompt",
type=str,
required=True,
help="Input prompt for text generation.",
)
parser.add_argument(
"--max_length",
type=int,
default=100,
help="Maximum length of the generated text.",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature (higher = more random).",
)
parser.add_argument(
"--top_k",
type=int,
default=50,
help="Top-k sampling (0 = no sampling).",
)
parser.add_argument(
"--top_p",
type=float,
default=0.95,
help="Top-p (nucleus) sampling (1.0 = no sampling).",
)
parser.add_argument(
"--output_file",
type=str,
help="File to save the generated text.",
)
args = parser.parse_args()
# Load model and tokenizer
try:
model, tokenizer = load_model_and_tokenizer(args.model)
except Exception as e:
logger.error(f"Failed to load model: {e}")
return
# Generate text
try:
logger.info("Generating text...")
generated_text = generate_text(
model,
tokenizer,
args.prompt,
max_length=args.max_length,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
# Print the generated text
print("\nGenerated Text:")
print(generated_text)
# Save to file if specified
if args.output_file:
save_to_file(generated_text, args.output_file)
except Exception as e:
logger.error(f"Failed to generate text: {e}")
if __name__ == "__main__":
main()