Upload 2 files
Browse files- app.py +11 -11
- generate.py +28 -35
app.py
CHANGED
|
@@ -91,17 +91,17 @@ def instruct_generate(
|
|
| 91 |
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
|
| 92 |
# prompt_length = encoded.size(0)
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
|
| 105 |
|
| 106 |
output = tokenizer.decode(y)
|
| 107 |
output = output.split("### Response:")[1].strip()
|
|
|
|
| 91 |
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
|
| 92 |
# prompt_length = encoded.size(0)
|
| 93 |
|
| 94 |
+
y = generate(
|
| 95 |
+
model,
|
| 96 |
+
idx=encoded,
|
| 97 |
+
max_seq_length=max_new_tokens,
|
| 98 |
+
max_new_tokens=max_new_tokens,
|
| 99 |
+
temperature=temperature,
|
| 100 |
+
top_k=top_k,
|
| 101 |
+
eos_id=tokenizer.eos_id
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
|
| 105 |
|
| 106 |
output = tokenizer.decode(y)
|
| 107 |
output = output.split("### Response:")[1].strip()
|
generate.py
CHANGED
|
@@ -12,16 +12,15 @@ wd = Path(__file__).parent.parent.resolve()
|
|
| 12 |
sys.path.append(str(wd))
|
| 13 |
|
| 14 |
from lit_llama import LLaMA, Tokenizer
|
| 15 |
-
from lit_llama.utils import lazy_load, llama_model_lookup
|
| 16 |
|
| 17 |
|
| 18 |
@torch.no_grad()
|
| 19 |
def generate(
|
| 20 |
-
model:
|
| 21 |
idx: torch.Tensor,
|
| 22 |
max_new_tokens: int,
|
| 23 |
-
|
| 24 |
-
max_seq_length: Optional[int] = None,
|
| 25 |
temperature: float = 1.0,
|
| 26 |
top_k: Optional[int] = None,
|
| 27 |
eos_id: Optional[int] = None,
|
|
@@ -42,49 +41,35 @@ def generate(
|
|
| 42 |
# create an empty tensor of the expected final shape and fill in the current tokens
|
| 43 |
T = idx.size(0)
|
| 44 |
T_new = T + max_new_tokens
|
| 45 |
-
|
| 46 |
-
max_seq_length = min(T_new, model.config.block_size)
|
| 47 |
-
|
| 48 |
-
device, dtype = idx.device, idx.dtype
|
| 49 |
-
# create an empty tensor of the expected final shape and fill in the current tokens
|
| 50 |
-
empty = torch.empty(T_new, dtype=dtype, device=device)
|
| 51 |
empty[:T] = idx
|
| 52 |
idx = empty
|
| 53 |
-
input_pos = torch.arange(0, T, device=device)
|
| 54 |
-
|
| 55 |
-
if idx.device.type == "xla":
|
| 56 |
-
import torch_xla.core.xla_model as xm
|
| 57 |
-
|
| 58 |
-
xm.mark_step()
|
| 59 |
|
| 60 |
# generate max_new_tokens tokens
|
| 61 |
-
for
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
# forward
|
| 65 |
-
logits = model(
|
| 66 |
logits = logits[0, -1] / temperature
|
| 67 |
|
| 68 |
# optionally crop the logits to only the top k options
|
| 69 |
if top_k is not None:
|
| 70 |
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 71 |
-
logits
|
| 72 |
|
| 73 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 74 |
-
idx_next = torch.multinomial(probs, num_samples=1)
|
| 75 |
-
|
| 76 |
-
# advance
|
| 77 |
-
input_pos = input_pos[-1:] + 1
|
| 78 |
-
|
| 79 |
-
if idx.device.type == "xla":
|
| 80 |
-
xm.mark_step()
|
| 81 |
|
| 82 |
# concatenate the new generation
|
| 83 |
-
idx =
|
| 84 |
|
| 85 |
# if <eos> token is triggered, return the output (stop generation)
|
| 86 |
if idx_next == eos_id:
|
| 87 |
-
return idx[:
|
| 88 |
|
| 89 |
return idx
|
| 90 |
|
|
@@ -118,22 +103,24 @@ def main(
|
|
| 118 |
assert checkpoint_path.is_file(), checkpoint_path
|
| 119 |
assert tokenizer_path.is_file(), tokenizer_path
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
print("Loading model ...", file=sys.stderr)
|
| 125 |
t0 = time.time()
|
| 126 |
with lazy_load(checkpoint_path) as checkpoint:
|
| 127 |
name = llama_model_lookup(checkpoint)
|
| 128 |
|
| 129 |
-
with
|
|
|
|
|
|
|
| 130 |
model = LLaMA.from_name(name)
|
| 131 |
|
| 132 |
model.load_state_dict(checkpoint)
|
| 133 |
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
| 134 |
|
| 135 |
model.eval()
|
| 136 |
-
model = fabric.
|
| 137 |
|
| 138 |
tokenizer = Tokenizer(tokenizer_path)
|
| 139 |
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
|
|
@@ -142,10 +129,16 @@ def main(
|
|
| 142 |
L.seed_everything(1234)
|
| 143 |
for i in range(num_samples):
|
| 144 |
t0 = time.perf_counter()
|
| 145 |
-
y = generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
t = time.perf_counter() - t0
|
| 147 |
|
| 148 |
-
model.reset_cache()
|
| 149 |
print(tokenizer.decode(y))
|
| 150 |
tokens_generated = y.size(0) - prompt_length
|
| 151 |
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
|
|
|
|
| 12 |
sys.path.append(str(wd))
|
| 13 |
|
| 14 |
from lit_llama import LLaMA, Tokenizer
|
| 15 |
+
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
|
| 16 |
|
| 17 |
|
| 18 |
@torch.no_grad()
|
| 19 |
def generate(
|
| 20 |
+
model: torch.nn.Module,
|
| 21 |
idx: torch.Tensor,
|
| 22 |
max_new_tokens: int,
|
| 23 |
+
max_seq_length: int,
|
|
|
|
| 24 |
temperature: float = 1.0,
|
| 25 |
top_k: Optional[int] = None,
|
| 26 |
eos_id: Optional[int] = None,
|
|
|
|
| 41 |
# create an empty tensor of the expected final shape and fill in the current tokens
|
| 42 |
T = idx.size(0)
|
| 43 |
T_new = T + max_new_tokens
|
| 44 |
+
empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
empty[:T] = idx
|
| 46 |
idx = empty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# generate max_new_tokens tokens
|
| 49 |
+
for t in range(T, T_new):
|
| 50 |
+
# ignore the not-filled-yet tokens
|
| 51 |
+
idx_cond = idx[:t]
|
| 52 |
+
# if the sequence context is growing too long we must crop it at max_seq_length
|
| 53 |
+
idx_cond = idx_cond if t <= max_seq_length else idx_cond[-max_seq_length:]
|
| 54 |
|
| 55 |
# forward
|
| 56 |
+
logits = model(idx_cond.view(1, -1))
|
| 57 |
logits = logits[0, -1] / temperature
|
| 58 |
|
| 59 |
# optionally crop the logits to only the top k options
|
| 60 |
if top_k is not None:
|
| 61 |
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 62 |
+
logits[logits < v[[-1]]] = -float("Inf")
|
| 63 |
|
| 64 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 65 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
# concatenate the new generation
|
| 68 |
+
idx[t] = idx_next
|
| 69 |
|
| 70 |
# if <eos> token is triggered, return the output (stop generation)
|
| 71 |
if idx_next == eos_id:
|
| 72 |
+
return idx[:t + 1] # include the EOS token
|
| 73 |
|
| 74 |
return idx
|
| 75 |
|
|
|
|
| 103 |
assert checkpoint_path.is_file(), checkpoint_path
|
| 104 |
assert tokenizer_path.is_file(), tokenizer_path
|
| 105 |
|
| 106 |
+
fabric = L.Fabric(devices=1)
|
| 107 |
+
dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
|
| 108 |
|
| 109 |
print("Loading model ...", file=sys.stderr)
|
| 110 |
t0 = time.time()
|
| 111 |
with lazy_load(checkpoint_path) as checkpoint:
|
| 112 |
name = llama_model_lookup(checkpoint)
|
| 113 |
|
| 114 |
+
with EmptyInitOnDevice(
|
| 115 |
+
device=fabric.device, dtype=dtype, quantization_mode=quantize
|
| 116 |
+
):
|
| 117 |
model = LLaMA.from_name(name)
|
| 118 |
|
| 119 |
model.load_state_dict(checkpoint)
|
| 120 |
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
|
| 121 |
|
| 122 |
model.eval()
|
| 123 |
+
model = fabric.setup_module(model)
|
| 124 |
|
| 125 |
tokenizer = Tokenizer(tokenizer_path)
|
| 126 |
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
|
|
|
|
| 129 |
L.seed_everything(1234)
|
| 130 |
for i in range(num_samples):
|
| 131 |
t0 = time.perf_counter()
|
| 132 |
+
y = generate(
|
| 133 |
+
model,
|
| 134 |
+
encoded,
|
| 135 |
+
max_new_tokens,
|
| 136 |
+
model.config.block_size, # type: ignore[union-attr,arg-type]
|
| 137 |
+
temperature=temperature,
|
| 138 |
+
top_k=top_k,
|
| 139 |
+
)
|
| 140 |
t = time.perf_counter() - t0
|
| 141 |
|
|
|
|
| 142 |
print(tokenizer.decode(y))
|
| 143 |
tokens_generated = y.size(0) - prompt_length
|
| 144 |
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
|