Tousifahamed commited on
Commit
57a4ca3
·
verified ·
1 Parent(s): a222a4b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -26
app.py CHANGED
@@ -1,21 +1,19 @@
1
  import torch
2
- torch.backends.quantized.engine = 'fbgemm' # ensure we use fbgemm
3
 
4
  print("PyTorch version:", torch.__version__)
5
  print("Supported quantized engines:", torch.backends.quantized.supported_engines)
6
 
7
  import torch.nn as nn
8
- import torch.quantization # <--- Use the older namespace for default_qconfig
9
  from transformers import AutoTokenizer
10
  from model import TransformerModel
11
  import gradio as gr
12
 
13
- from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig
14
-
15
  # Load the tokenizer
16
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
17
 
18
  def load_quantized_model(checkpoint_path):
 
19
  model = TransformerModel(
20
  vocab_size=49152,
21
  hidden_size=576,
@@ -29,30 +27,29 @@ def load_quantized_model(checkpoint_path):
29
  tie_word_embeddings=True,
30
  )
31
 
32
- # This qconfig is typically for your other layers
33
- default_qconfig = torch.quantization.get_default_qconfig("fbgemm")
34
- model.qconfig = default_qconfig
35
-
36
- # For embeddings, force the specialized config:
37
- model.embed_tokens.qconfig = float_qparams_weight_only_qconfig
38
- model.embed_positions.qconfig = float_qparams_weight_only_qconfig
39
-
40
- # Then prepare, calibrate, and convert
41
- model = torch.quantization.prepare(model, inplace=False)
42
-
43
- # Calibration pass here...
44
- model = torch.quantization.convert(model, inplace=False)
45
 
46
- return model
 
 
 
 
 
 
 
47
 
 
48
 
49
- # Load the quantized model
50
  model = load_quantized_model("quantized_model.pt")
51
 
52
- # Function to generate text
53
  def generate_text(prompt, max_length=50, temperature=1.0, top_k=50):
54
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
55
-
56
  with torch.no_grad():
57
  output_ids = model.generate(
58
  input_ids,
@@ -61,11 +58,10 @@ def generate_text(prompt, max_length=50, temperature=1.0, top_k=50):
61
  top_k=top_k,
62
  do_sample=True,
63
  )
64
-
65
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
66
  return generated_text
67
 
68
- # Gradio Interface
69
  interface = gr.Interface(
70
  fn=generate_text,
71
  inputs=[
@@ -76,8 +72,7 @@ interface = gr.Interface(
76
  ],
77
  outputs=gr.Textbox(label="Generated Text"),
78
  title="Text Generation with Quantized SMOL-LM2",
79
- description="Generate text using a quantized version of the SMOL-LM2 model.",
80
  )
81
 
82
- # Launch the app
83
- interface.launch()
 
1
  import torch
2
+ torch.backends.quantized.engine = 'fbgemm'
3
 
4
  print("PyTorch version:", torch.__version__)
5
  print("Supported quantized engines:", torch.backends.quantized.supported_engines)
6
 
7
  import torch.nn as nn
 
8
  from transformers import AutoTokenizer
9
  from model import TransformerModel
10
  import gradio as gr
11
 
 
 
12
  # Load the tokenizer
13
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
14
 
15
  def load_quantized_model(checkpoint_path):
16
+ # 1. Create the float model
17
  model = TransformerModel(
18
  vocab_size=49152,
19
  hidden_size=576,
 
27
  tie_word_embeddings=True,
28
  )
29
 
30
+ # 2. Load the actual checkpoint weights
31
+ # If "quantized_model.pt" is a state_dict, do:
32
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
33
+ model.load_state_dict(checkpoint) # or checkpoint["model_state_dict"] if saved that way
34
+ model.eval()
 
 
 
 
 
 
 
 
35
 
36
+ # 3. Dynamically quantize relevant layers
37
+ # For embeddings, we typically use torch.quint8
38
+ # so we don't run into any embedding dtype errors
39
+ quantized_model = torch.quantization.quantize_dynamic(
40
+ model,
41
+ {nn.Linear, nn.Embedding},
42
+ dtype=torch.quint8
43
+ )
44
 
45
+ return quantized_model
46
 
47
+ # 4. Load the quantized model
48
  model = load_quantized_model("quantized_model.pt")
49
 
50
+ # 5. Inference function
51
  def generate_text(prompt, max_length=50, temperature=1.0, top_k=50):
52
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
 
53
  with torch.no_grad():
54
  output_ids = model.generate(
55
  input_ids,
 
58
  top_k=top_k,
59
  do_sample=True,
60
  )
 
61
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
62
  return generated_text
63
 
64
+ # 6. Gradio interface
65
  interface = gr.Interface(
66
  fn=generate_text,
67
  inputs=[
 
72
  ],
73
  outputs=gr.Textbox(label="Generated Text"),
74
  title="Text Generation with Quantized SMOL-LM2",
75
+ description="Generate text using a dynamically quantized SMOL-LM2 model.",
76
  )
77
 
78
+ interface.launch()