Tousifahamed commited on
Commit
a7f53d7
·
verified ·
1 Parent(s): 7d3f5e9

Upload 2 files

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -22,15 +22,17 @@ def load_quantized_model(checkpoint_path):
22
  tie_word_embeddings=True,
23
  )
24
 
25
- # Set the quantization configuration for the embedding layer
26
- model.embed_tokens.qconfig = quantization.float_qparams_weight_only_qconfig
 
 
27
 
28
  # Apply static quantization to the rest of the model
29
  model.qconfig = quantization.default_qconfig
30
  model = quantization.prepare(model, inplace=False)
31
  model = quantization.convert(model, inplace=False)
32
 
33
- # Load the quantized checkpoint
34
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
35
  model.load_state_dict(checkpoint["model_state_dict"])
36
 
 
22
  tie_word_embeddings=True,
23
  )
24
 
25
+ # Set quantization config for ALL embedding layers
26
+ for name, module in model.named_modules():
27
+ if isinstance(module, nn.Embedding):
28
+ module.qconfig = quantization.float_qparams_weight_only_qconfig
29
 
30
  # Apply static quantization to the rest of the model
31
  model.qconfig = quantization.default_qconfig
32
  model = quantization.prepare(model, inplace=False)
33
  model = quantization.convert(model, inplace=False)
34
 
35
+ # Load the checkpoint
36
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
37
  model.load_state_dict(checkpoint["model_state_dict"])
38