pszemraj commited on
Commit
8621629
·
1 Parent(s): 23625d3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +130 -0
README.md CHANGED
@@ -134,3 +134,133 @@ Below is a quick script that can be used as a reference/starting point for writi
134
 
135
 
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
 
136
 
137
+ <details>
138
+ <summary>🔥 Unleash the Power of Code Generation! Click to Reveal the Magic! 🔮</summary>
139
+
140
+ Are you ready to witness the incredible possibilities of code generation? 🚀. Brace yourself for an exceptional journey into the world of artificial intelligence and programming. Observe a script that will change the way you create and finalize code.
141
+
142
+ This script provides entry to a planet where machines can write code with remarkable precision and imagination.
143
+
144
+ ```python
145
+ """
146
+ simple script for testing model(s) designed to generate/complete code
147
+
148
+ See details/args with the below.
149
+ python textgen_inference_code.py --help
150
+ """
151
+ import logging
152
+ import random
153
+ import time
154
+ from pathlib import Path
155
+
156
+ import fire
157
+ import torch
158
+ from transformers import AutoModelForCausalLM, AutoTokenizer
159
+
160
+ logging.basicConfig(format="%(levelname)s - %(message)s", level=logging.INFO)
161
+
162
+
163
+ class Timer:
164
+ """
165
+ Basic timer utility.
166
+ """
167
+
168
+ def __enter__(self):
169
+
170
+ self.start_time = time.perf_counter()
171
+ return self
172
+
173
+ def __exit__(self, exc_type, exc_value, traceback):
174
+
175
+ self.end_time = time.perf_counter()
176
+ self.elapsed_time = self.end_time - self.start_time
177
+ logging.info(f"Elapsed time: {self.elapsed_time:.4f} seconds")
178
+
179
+
180
+ def load_model(model_name, use_fast=False):
181
+ """ util for loading model and tokenizer"""
182
+ logging.info(f"Loading model: {model_name}")
183
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=use_fast)
184
+ model = AutoModelForCausalLM.from_pretrained(
185
+ model_name, torch_dtype="auto", device_map="auto"
186
+ )
187
+ model = torch.compile(model)
188
+ return tokenizer, model
189
+
190
+
191
+ def run_inference(prompt, model, tokenizer, max_new_tokens: int = 256):
192
+ """
193
+ run_inference
194
+
195
+ Args:
196
+ prompt (TYPE): Description
197
+ model (TYPE): Description
198
+ tokenizer (TYPE): Description
199
+ max_new_tokens (int, optional): Description
200
+
201
+ Returns:
202
+ TYPE: Description
203
+ """
204
+ logging.info(f"Running inference with max_new_tokens={max_new_tokens} ...")
205
+ with Timer() as timer:
206
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
207
+ outputs = model.generate(
208
+ **inputs,
209
+ max_new_tokens=max_new_tokens,
210
+ min_new_tokens=8,
211
+ renormalize_logits=True,
212
+ no_repeat_ngram_size=8,
213
+ repetition_penalty=1.04,
214
+ num_beams=4,
215
+ early_stopping=True,
216
+ )
217
+ text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
218
+ logging.info(f"Output text:\n\n{text}")
219
+ return text
220
+
221
+
222
+ def main(
223
+ model_name="BEE-spoke-data/smol_llama-101M-GQA-python",
224
+ prompt:str=None,
225
+ use_fast=False,
226
+ n_tokens: int = 256,
227
+ ):
228
+ """Summary
229
+
230
+ Args:
231
+ model_name (str, optional): Description
232
+ prompt (None, optional): specify the prompt directly (default: random choice from list)
233
+ n_tokens (int, optional): max new tokens to generate
234
+ """
235
+ logging.info(f"Inference with:\t{model_name}, max_new_tokens:{n_tokens}")
236
+
237
+ if prompt is None:
238
+ prompt_list = [
239
+ '''
240
+ def print_primes(n: int):
241
+ """
242
+ Print all primes between 1 and n
243
+ """''',
244
+ "def quantum_analysis(",
245
+ "def sanitize_filenames(target_dir:str, recursive:False, extension",
246
+ ]
247
+ prompt = random.SystemRandom().choice(prompt_list)
248
+
249
+ logging.info(f"Using prompt:\t{prompt}")
250
+
251
+ tokenizer, model = load_model(model_name, use_fast=use_fast)
252
+
253
+ run_inference(prompt, model, tokenizer, n_tokens)
254
+
255
+
256
+ if __name__ == "__main__":
257
+ fire.Fire(main)
258
+ ```
259
+
260
+ Wowoweewa!! It can create some file cleaning utilities.
261
+
262
+
263
+ </details>
264
+
265
+
266
+ ---