Spaces:
Sleeping
Sleeping
File size: 3,793 Bytes
f271aef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
import gradio as gr
from transformers import AutoTokenizer
from model_smol2 import LlamaForCausalLM, config_model
# Instantiate the model
model = LlamaForCausalLM(config_model)
# Load the checkpoint
checkpoint_path = "/Users/shriti/Downloads/Assign13_ERAV3/deply/final_checkpoint.pt"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Load tokenizer (replace with the appropriate tokenizer if you're using a custom one)
# Load the tokenizer
TOKENIZER_PATH = "HuggingFaceTB/cosmo2-tokenizer"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "[PAD]"
# Text generation function
def generate_text(
prompt, max_length=50, temperature=0.7, top_k=50, repetition_penalty=1.2, n_gram_block=2
):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
generated_tokens = input_ids[0].tolist()
with torch.no_grad():
for _ in range(max_length):
outputs = model(input_ids) # model outputs
# Check if the output is a dictionary with logits
if isinstance(outputs, dict) and 'logits' in outputs:
logits = outputs['logits'][:, -1, :]
else:
# If not, treat the output as a plain tensor
logits = outputs[:, -1, :]
# Repetition penalty
for token_id in set(generated_tokens):
logits[:, token_id] /= repetition_penalty
# n-gram blocking
if len(generated_tokens) >= n_gram_block:
n_gram = tuple(generated_tokens[-n_gram_block:])
for token_id in set(generated_tokens):
if generated_tokens[-n_gram_block:] == list(n_gram):
logits[:, token_id] -= 1e9
logits /= temperature
top_k_logits, top_k_indices = torch.topk(logits, top_k, dim=-1)
probs = torch.softmax(top_k_logits, dim=-1)
next_token_idx = torch.multinomial(probs, num_samples=1)
next_token = top_k_indices[0, next_token_idx[0]]
generated_tokens.append(next_token.item())
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Gradio UI
def generate_response(prompt, max_length, temperature, top_k, repetition_penalty, n_gram_block):
return generate_text(prompt, max_length, temperature, top_k, repetition_penalty, n_gram_block)
with gr.Blocks() as demo:
gr.Markdown("# Smol2 Text Generator")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Input Prompt", placeholder="Enter your text prompt here...")
max_length = gr.Slider(label="Max Length", minimum=10, maximum=200, value=50)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, value=0.7, step=0.1)
top_k = gr.Slider(label="Top K", minimum=10, maximum=100, value=50, step=1)
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.1)
n_gram_block = gr.Slider(label="N-Gram Blocking", minimum=1, maximum=5, value=2, step=1)
generate_button = gr.Button("Generate Text")
with gr.Column():
output_text = gr.Textbox(label="Generated Text", lines=10)
generate_button.click(
generate_response,
inputs=[prompt_input, max_length, temperature, top_k, repetition_penalty, n_gram_block],
outputs=[output_text],
)
demo.launch() |