Spaces:
Running
Running
File size: 5,110 Bytes
fb26382 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import gradio as gr
import torch
from transformers import AutoTokenizer
import yaml
from SmolLm3 import LlamaModel
def generate_helper(model, idx, max_new_tokens, context_length, temperature=1.0, top_k=None, eos_token=None, device=None):
model = model.to(device)
idx = idx.to(device)
model.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -context_length:]
with torch.no_grad():
logits, _ = model(idx_cond) # Unpack both logits and loss (ignore loss)
logits = logits.view(idx_cond.shape[0], -1, model.config['vocab_size']) # Reshape to [batch, seq, vocab]
# Get the logits for the last token only
logits = logits[:, -1, :] # Shape: [batch_size, vocab_size]
if top_k is not None:
# top k sampling
top_logits, top_pos = torch.topk(logits, top_k)
min_logit = top_logits[:, -1].unsqueeze(-1)
logits = torch.where(logits < min_logit,
torch.tensor(float('-inf')).to(logits.device),
logits)
# temperature scaling
if temperature > 0.0:
logits /= temperature
probs = torch.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
else:
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
if idx_next.item() == eos_token:
break
idx = torch.cat((idx, idx_next), dim=1)
model.train()
return idx
def get_config(config_path):
config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
return config
def load_model_from_checkpoint(config_path, checkpoint_path, device):
config = get_config(config_path)
model = LlamaModel(config['model'])
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
state_dict = checkpoint['model_state_dict']
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
return model
def load_weights(config, weights_path, device):
model = LlamaModel(config['model'])
model.load_state_dict(torch.load(weights_path, map_location=torch.device(device)))
return model
def get_tokenizer(config):
tokenizer_path = config['tokenizer']['tokenizer_name_or_path']
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
vocab_size = tokenizer.vocab_size
return tokenizer, vocab_size
def generate_text(model, tokenizer, input_text, max_new_tokens, context_length, temperature, top_k, eos_token, device):
encoded_text = tokenizer.encode(input_text, return_tensors="pt").to(device)
generated_text = generate_helper(model,
idx=encoded_text,
max_new_tokens=max_new_tokens,
context_length=context_length,
temperature=temperature,
top_k=top_k,
eos_token=eos_token,
device=device)
return tokenizer.decode(generated_text.squeeze(0))
# Initialize model and tokenizer
def initialize_model():
config_path = "config_smollm2_135M.yaml"
checkpoint_path = "/Users/chiragtagadiya/Documents/Final_training_before_stop_smolllm3/checkpoints/model_37000_steps_avg_loss_2.85920_optimizer_lr_0.00000003.pth" # Update this path
weights_path = "model_weights_35000_step.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load configuration
config = get_config(config_path)
# Load model
# model = load_model_from_checkpoint(config_path, checkpoint_path, device)
model = load_weights(config, weights_path, device)
model.to(device)
model.eval()
# Load tokenizer
tokenizer, vocab_size = get_tokenizer(config)
return model, tokenizer, device
def generate_response(prompt, max_new_tokens):
generated_text = generate_text(
model=model,
tokenizer=tokenizer,
input_text=prompt,
max_new_tokens=max_new_tokens,
context_length=256,
temperature=0.9,
top_k=2,
eos_token=tokenizer.eos_token_id,
device=device
)
return generated_text
# Initialize global variables
model, tokenizer, device = initialize_model()
# Create Gradio interface
iface = gr.Interface(
fn=generate_response,
inputs=[
gr.Textbox(
lines=3,
placeholder="Enter your prompt here...",
label="Input Prompt"
),
gr.Slider(
minimum=50,
maximum=256,
value=100,
step=10,
label="Max New Tokens"
)
],
outputs=gr.Textbox(
lines=5,
label="Generated Text"
),
title="SmolLM Text Generator",
description="Enter a prompt and adjust the maximum number of tokens to generate text with SmolLM model."
)
if __name__ == "__main__":
iface.launch()
|