Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer | |
from model_smol2 import LlamaForCausalLM, config_model | |
# Instantiate the model | |
model = LlamaForCausalLM(config_model) | |
# Load the checkpoint | |
checkpoint_path = "/Users/shriti/Downloads/Assign13_ERAV3/deply/final_checkpoint.pt" | |
checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.eval() | |
# Load tokenizer (replace with the appropriate tokenizer if you're using a custom one) | |
# Load the tokenizer | |
TOKENIZER_PATH = "HuggingFaceTB/cosmo2-tokenizer" | |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "[PAD]" | |
# Text generation function | |
def generate_text( | |
prompt, max_length=50, temperature=0.7, top_k=50, repetition_penalty=1.2, n_gram_block=2 | |
): | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
generated_tokens = input_ids[0].tolist() | |
with torch.no_grad(): | |
for _ in range(max_length): | |
outputs = model(input_ids) # model outputs | |
# Check if the output is a dictionary with logits | |
if isinstance(outputs, dict) and 'logits' in outputs: | |
logits = outputs['logits'][:, -1, :] | |
else: | |
# If not, treat the output as a plain tensor | |
logits = outputs[:, -1, :] | |
# Repetition penalty | |
for token_id in set(generated_tokens): | |
logits[:, token_id] /= repetition_penalty | |
# n-gram blocking | |
if len(generated_tokens) >= n_gram_block: | |
n_gram = tuple(generated_tokens[-n_gram_block:]) | |
for token_id in set(generated_tokens): | |
if generated_tokens[-n_gram_block:] == list(n_gram): | |
logits[:, token_id] -= 1e9 | |
logits /= temperature | |
top_k_logits, top_k_indices = torch.topk(logits, top_k, dim=-1) | |
probs = torch.softmax(top_k_logits, dim=-1) | |
next_token_idx = torch.multinomial(probs, num_samples=1) | |
next_token = top_k_indices[0, next_token_idx[0]] | |
generated_tokens.append(next_token.item()) | |
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) | |
if next_token.item() == tokenizer.eos_token_id: | |
break | |
return tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
# Gradio UI | |
def generate_response(prompt, max_length, temperature, top_k, repetition_penalty, n_gram_block): | |
return generate_text(prompt, max_length, temperature, top_k, repetition_penalty, n_gram_block) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Smol2 Text Generator") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Input Prompt", placeholder="Enter your text prompt here...") | |
max_length = gr.Slider(label="Max Length", minimum=10, maximum=200, value=50) | |
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, value=0.7, step=0.1) | |
top_k = gr.Slider(label="Top K", minimum=10, maximum=100, value=50, step=1) | |
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.1) | |
n_gram_block = gr.Slider(label="N-Gram Blocking", minimum=1, maximum=5, value=2, step=1) | |
generate_button = gr.Button("Generate Text") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Generated Text", lines=10) | |
generate_button.click( | |
generate_response, | |
inputs=[prompt_input, max_length, temperature, top_k, repetition_penalty, n_gram_block], | |
outputs=[output_text], | |
) | |
demo.launch() |