File size: 2,659 Bytes
3afe7b3
57a4ca3
a222a4b
c7c3e26
 
 
825827f
3afe7b3
632a181
3afe7b3
 
 
 
 
b789c6c
57a4ca3
3afe7b3
 
 
 
 
 
 
 
 
 
 
 
 
57a4ca3
 
 
 
 
825827f
57a4ca3
 
 
 
 
 
 
 
3afe7b3
57a4ca3
b789c6c
57a4ca3
825827f
3afe7b3
57a4ca3
3afe7b3
 
 
 
 
 
 
 
 
 
 
 
 
57a4ca3
3afe7b3
b789c6c
3afe7b3
 
 
 
 
 
 
b789c6c
57a4ca3
3afe7b3
 
57a4ca3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
torch.backends.quantized.engine = 'fbgemm'

print("PyTorch version:", torch.__version__)
print("Supported quantized engines:", torch.backends.quantized.supported_engines)

import torch.nn as nn
from transformers import AutoTokenizer
from model import TransformerModel
import gradio as gr

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")

def load_quantized_model(checkpoint_path):
    # 1. Create the float model
    model = TransformerModel(
        vocab_size=49152,
        hidden_size=576,
        num_hidden_layers=30,
        num_attention_heads=9,
        intermediate_size=1536,
        num_key_value_heads=3,
        max_position_embeddings=2048,
        rms_norm_eps=1e-5,
        hidden_act="silu",
        tie_word_embeddings=True,
    )
    
    # 2. Load the actual checkpoint weights
    #    If "quantized_model.pt" is a state_dict, do:
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint)  # or checkpoint["model_state_dict"] if saved that way
    model.eval()

    # 3. Dynamically quantize relevant layers
    #    For embeddings, we typically use torch.quint8
    #    so we don't run into any embedding dtype errors
    quantized_model = torch.quantization.quantize_dynamic(
        model,
        {nn.Linear, nn.Embedding},
        dtype=torch.quint8
    )

    return quantized_model

# 4. Load the quantized model
model = load_quantized_model("quantized_model.pt")

# 5. Inference function
def generate_text(prompt, max_length=50, temperature=1.0, top_k=50):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_length=max_length,
            temperature=temperature,
            top_k=top_k,
            do_sample=True,
        )
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return generated_text

# 6. Gradio interface
interface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
        gr.Slider(minimum=10, maximum=200, value=50, label="Max Length"),
        gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Temperature"),
        gr.Slider(minimum=1, maximum=100, value=50, label="Top-k Sampling"),
    ],
    outputs=gr.Textbox(label="Generated Text"),
    title="Text Generation with Quantized SMOL-LM2",
    description="Generate text using a dynamically quantized SMOL-LM2 model.",
)

interface.launch()