File size: 3,793 Bytes
f271aef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import gradio as gr
from transformers import AutoTokenizer
from model_smol2 import LlamaForCausalLM, config_model

# Instantiate the model
model = LlamaForCausalLM(config_model)

# Load the checkpoint
checkpoint_path = "/Users/shriti/Downloads/Assign13_ERAV3/deply/final_checkpoint.pt"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Load tokenizer (replace with the appropriate tokenizer if you're using a custom one)
# Load the tokenizer
TOKENIZER_PATH = "HuggingFaceTB/cosmo2-tokenizer"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "[PAD]"


# Text generation function
def generate_text(
    prompt, max_length=50, temperature=0.7, top_k=50, repetition_penalty=1.2, n_gram_block=2
):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    generated_tokens = input_ids[0].tolist()

    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(input_ids)  # model outputs
            
            # Check if the output is a dictionary with logits
            if isinstance(outputs, dict) and 'logits' in outputs:
                logits = outputs['logits'][:, -1, :]
            else:
                # If not, treat the output as a plain tensor
                logits = outputs[:, -1, :]

            # Repetition penalty
            for token_id in set(generated_tokens):
                logits[:, token_id] /= repetition_penalty

            # n-gram blocking
            if len(generated_tokens) >= n_gram_block:
                n_gram = tuple(generated_tokens[-n_gram_block:])
                for token_id in set(generated_tokens):
                    if generated_tokens[-n_gram_block:] == list(n_gram):
                        logits[:, token_id] -= 1e9

            logits /= temperature
            top_k_logits, top_k_indices = torch.topk(logits, top_k, dim=-1)
            probs = torch.softmax(top_k_logits, dim=-1)

            next_token_idx = torch.multinomial(probs, num_samples=1)
            next_token = top_k_indices[0, next_token_idx[0]]

            generated_tokens.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

            if next_token.item() == tokenizer.eos_token_id:
                break

    return tokenizer.decode(generated_tokens, skip_special_tokens=True)


# Gradio UI
def generate_response(prompt, max_length, temperature, top_k, repetition_penalty, n_gram_block):
    return generate_text(prompt, max_length, temperature, top_k, repetition_penalty, n_gram_block)

with gr.Blocks() as demo:
    gr.Markdown("# Smol2 Text Generator")
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Input Prompt", placeholder="Enter your text prompt here...")
            max_length = gr.Slider(label="Max Length", minimum=10, maximum=200, value=50)
            temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, value=0.7, step=0.1)
            top_k = gr.Slider(label="Top K", minimum=10, maximum=100, value=50, step=1)
            repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.1)
            n_gram_block = gr.Slider(label="N-Gram Blocking", minimum=1, maximum=5, value=2, step=1)
            generate_button = gr.Button("Generate Text")
        with gr.Column():
            output_text = gr.Textbox(label="Generated Text", lines=10)

    generate_button.click(
        generate_response,
        inputs=[prompt_input, max_length, temperature, top_k, repetition_penalty, n_gram_block],
        outputs=[output_text],
    )

demo.launch()