rasbt commited on
Commit
3e575a3
·
verified ·
1 Parent(s): a01bfd8

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -80
main.py DELETED
@@ -1,80 +0,0 @@
1
- # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
2
- # Source for "Build a Large Language Model From Scratch"
3
- # https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/07_gpt_to_llama/standalone-llama32.ipynb
4
-
5
- import time
6
- import torch
7
-
8
- from model import Llama3Model, generate, text_to_token_ids, token_ids_to_text
9
- from tokenizer import Llama3Tokenizer, ChatFormat, clean_text
10
-
11
- #######################################
12
- # Model settings
13
-
14
- MODEL_FILE = "llama3.2-1B-instruct.pth"
15
- # MODEL_FILE = "llama3.2-1B-base.pth"
16
- # MODEL_FILE = "llama3.2-3B-instruct.pth"
17
- # MODEL_FILE = "llama3.2-3B-base.pth"
18
-
19
- MODEL_CONTEXT_LENGTH = 8192 # Supports up to 131_072
20
-
21
- # Text generation settings
22
- if "instruct" in MODEL_FILE:
23
- PROMPT = "What do llamas eat?"
24
- else:
25
- PROMPT = "Llamas eat"
26
-
27
- MAX_NEW_TOKENS = 150
28
- TEMPERATURE = 0.
29
- TOP_K = 1
30
- #######################################
31
-
32
- if "1B" in MODEL_FILE:
33
- from model import LLAMA32_CONFIG_1B as LLAMA32_CONFIG
34
- elif "3B" in MODEL_FILE:
35
- from model import LLAMA32_CONFIG_3B as LLAMA32_CONFIG
36
- else:
37
- raise ValueError("Incorrect model file name")
38
-
39
- model = Llama3Model(LLAMA32_CONFIG)
40
-
41
- tokenizer = Tokenizer("tokenizer.model")
42
-
43
- if "instruct" in MODEL_FILE:
44
- tokenizer = ChatFormat(tokenizer)
45
-
46
- model.load_state_dict(torch.load(MODEL_FILE, weights_only=True))
47
-
48
- device = (
49
- torch.device("cuda") if torch.cuda.is_available() else
50
- torch.device("mps") if torch.backends.mps.is_available() else
51
- torch.device("cpu")
52
- )
53
- model.to(device)
54
-
55
- torch.manual_seed(123)
56
-
57
- start = time.time()
58
-
59
- token_ids = generate(
60
- model=model,
61
- idx=text_to_token_ids(PROMPT, tokenizer).to(device),
62
- max_new_tokens=MAX_NEW_TOKENS,
63
- context_size=LLAMA32_CONFIG["context_length"],
64
- top_k=TOP_K,
65
- temperature=TEMPERATURE
66
- )
67
-
68
- print(f"Time: {time.time() - start:.2f} sec")
69
-
70
- if torch.cuda.is_available():
71
- max_mem_bytes = torch.cuda.max_memory_allocated()
72
- max_mem_gb = max_mem_bytes / (1024 ** 3)
73
- print(f"Max memory allocated: {max_mem_gb:.2f} GB")
74
-
75
- output_text = token_ids_to_text(token_ids, tokenizer)
76
-
77
- if "instruct" in MODEL_FILE:
78
- output_text = clean_text(output_text)
79
-
80
- print("\n\nOutput text:\n\n", output_text)