GeminiFan207 commited on
Commit
eb813a3
·
verified ·
1 Parent(s): 6ad04d7

Rename model to generate.py

Browse files
Files changed (2) hide show
  1. generate.py +181 -0
  2. model +0 -0
generate.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import argparse
4
+ import logging
5
+ from typing import List, Optional
6
+
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Load model and tokenizer
12
+ def load_model_and_tokenizer(model_name: str) -> tuple:
13
+ """
14
+ Load the pre-trained model and tokenizer.
15
+
16
+ Args:
17
+ model_name (str): Name or path of the pre-trained model.
18
+
19
+ Returns:
20
+ tuple: (model, tokenizer)
21
+ """
22
+ logger.info(f"Loading model: {model_name}...")
23
+ try:
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_name,
27
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
28
+ )
29
+ logger.info("Model and tokenizer loaded successfully.")
30
+ return model, tokenizer
31
+ except Exception as e:
32
+ logger.error(f"Error loading model: {e}")
33
+ raise
34
+
35
+ # Generate text
36
+ def generate_text(
37
+ model,
38
+ tokenizer,
39
+ prompt: str,
40
+ max_length: int = 100,
41
+ temperature: float = 1.0,
42
+ top_k: int = 50,
43
+ top_p: float = 0.95,
44
+ ) -> str:
45
+ """
46
+ Generate text based on the given prompt.
47
+
48
+ Args:
49
+ model: Pre-trained language model.
50
+ tokenizer: Tokenizer for the model.
51
+ prompt (str): Input prompt for text generation.
52
+ max_length (int): Maximum length of the generated text.
53
+ temperature (float): Sampling temperature (higher = more random).
54
+ top_k (int): Top-k sampling (0 = no sampling).
55
+ top_p (float): Top-p (nucleus) sampling (1.0 = no sampling).
56
+
57
+ Returns:
58
+ str: Generated text.
59
+ """
60
+ try:
61
+ inputs = tokenizer(prompt, return_tensors="pt")
62
+ if torch.cuda.is_available():
63
+ inputs = {key: value.to("cuda") for key, value in inputs.items()}
64
+ model.to("cuda")
65
+
66
+ with torch.no_grad():
67
+ outputs = model.generate(
68
+ inputs.input_ids,
69
+ max_length=max_length,
70
+ temperature=temperature,
71
+ top_k=top_k,
72
+ top_p=top_p,
73
+ do_sample=True,
74
+ )
75
+
76
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
77
+ logger.info("Text generation completed successfully.")
78
+ return generated_text
79
+ except Exception as e:
80
+ logger.error(f"Error generating text: {e}")
81
+ raise
82
+
83
+ # Save generated text to a file
84
+ def save_to_file(text: str, filename: str) -> None:
85
+ """
86
+ Save the generated text to a file.
87
+
88
+ Args:
89
+ text (str): Generated text.
90
+ filename (str): Name of the output file.
91
+ """
92
+ try:
93
+ with open(filename, "w") as file:
94
+ file.write(text)
95
+ logger.info(f"Generated text saved to {filename}.")
96
+ except Exception as e:
97
+ logger.error(f"Error saving to file: {e}")
98
+ raise
99
+
100
+ # Main function
101
+ def main():
102
+ # Parse command-line arguments
103
+ parser = argparse.ArgumentParser(
104
+ description="Generate text using a pre-trained language model.",
105
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
106
+ )
107
+ parser.add_argument(
108
+ "--model",
109
+ type=str,
110
+ default="mistralai/Mistral-8x7B",
111
+ help="Name or path of the pre-trained model.",
112
+ )
113
+ parser.add_argument(
114
+ "--prompt",
115
+ type=str,
116
+ required=True,
117
+ help="Input prompt for text generation.",
118
+ )
119
+ parser.add_argument(
120
+ "--max_length",
121
+ type=int,
122
+ default=100,
123
+ help="Maximum length of the generated text.",
124
+ )
125
+ parser.add_argument(
126
+ "--temperature",
127
+ type=float,
128
+ default=1.0,
129
+ help="Sampling temperature (higher = more random).",
130
+ )
131
+ parser.add_argument(
132
+ "--top_k",
133
+ type=int,
134
+ default=50,
135
+ help="Top-k sampling (0 = no sampling).",
136
+ )
137
+ parser.add_argument(
138
+ "--top_p",
139
+ type=float,
140
+ default=0.95,
141
+ help="Top-p (nucleus) sampling (1.0 = no sampling).",
142
+ )
143
+ parser.add_argument(
144
+ "--output_file",
145
+ type=str,
146
+ help="File to save the generated text.",
147
+ )
148
+ args = parser.parse_args()
149
+
150
+ # Load model and tokenizer
151
+ try:
152
+ model, tokenizer = load_model_and_tokenizer(args.model)
153
+ except Exception as e:
154
+ logger.error(f"Failed to load model: {e}")
155
+ return
156
+
157
+ # Generate text
158
+ try:
159
+ logger.info("Generating text...")
160
+ generated_text = generate_text(
161
+ model,
162
+ tokenizer,
163
+ args.prompt,
164
+ max_length=args.max_length,
165
+ temperature=args.temperature,
166
+ top_k=args.top_k,
167
+ top_p=args.top_p,
168
+ )
169
+
170
+ # Print the generated text
171
+ print("\nGenerated Text:")
172
+ print(generated_text)
173
+
174
+ # Save to file if specified
175
+ if args.output_file:
176
+ save_to_file(generated_text, args.output_file)
177
+ except Exception as e:
178
+ logger.error(f"Failed to generate text: {e}")
179
+
180
+ if __name__ == "__main__":
181
+ main()
model DELETED
File without changes