Spaces:
Sleeping
Sleeping
File size: 2,659 Bytes
3afe7b3 57a4ca3 a222a4b c7c3e26 825827f 3afe7b3 632a181 3afe7b3 b789c6c 57a4ca3 3afe7b3 57a4ca3 825827f 57a4ca3 3afe7b3 57a4ca3 b789c6c 57a4ca3 825827f 3afe7b3 57a4ca3 3afe7b3 57a4ca3 3afe7b3 b789c6c 3afe7b3 b789c6c 57a4ca3 3afe7b3 57a4ca3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import torch
torch.backends.quantized.engine = 'fbgemm'
print("PyTorch version:", torch.__version__)
print("Supported quantized engines:", torch.backends.quantized.supported_engines)
import torch.nn as nn
from transformers import AutoTokenizer
from model import TransformerModel
import gradio as gr
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
def load_quantized_model(checkpoint_path):
# 1. Create the float model
model = TransformerModel(
vocab_size=49152,
hidden_size=576,
num_hidden_layers=30,
num_attention_heads=9,
intermediate_size=1536,
num_key_value_heads=3,
max_position_embeddings=2048,
rms_norm_eps=1e-5,
hidden_act="silu",
tie_word_embeddings=True,
)
# 2. Load the actual checkpoint weights
# If "quantized_model.pt" is a state_dict, do:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint) # or checkpoint["model_state_dict"] if saved that way
model.eval()
# 3. Dynamically quantize relevant layers
# For embeddings, we typically use torch.quint8
# so we don't run into any embedding dtype errors
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear, nn.Embedding},
dtype=torch.quint8
)
return quantized_model
# 4. Load the quantized model
model = load_quantized_model("quantized_model.pt")
# 5. Inference function
def generate_text(prompt, max_length=50, temperature=1.0, top_k=50):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_length=max_length,
temperature=temperature,
top_k=top_k,
do_sample=True,
)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return generated_text
# 6. Gradio interface
interface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
gr.Slider(minimum=10, maximum=200, value=50, label="Max Length"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Temperature"),
gr.Slider(minimum=1, maximum=100, value=50, label="Top-k Sampling"),
],
outputs=gr.Textbox(label="Generated Text"),
title="Text Generation with Quantized SMOL-LM2",
description="Generate text using a dynamically quantized SMOL-LM2 model.",
)
interface.launch()
|