Tousifahamed commited on
Commit
1f2619d
·
verified ·
1 Parent(s): 7b27885

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -23,19 +23,21 @@ def load_quantized_model(checkpoint_path):
23
  tie_word_embeddings=True,
24
  )
25
 
26
- # Set quantization config for ALL embedding layers
27
- for name, module in model.named_modules():
28
- if isinstance(module, nn.Embedding): # Now works because `nn` is imported
29
- module.qconfig = quantization.float_qparams_weight_only_qconfig
 
 
 
30
 
31
- # Apply static quantization to the rest of the model
32
- model.qconfig = quantization.default_qconfig
33
- model = quantization.prepare(model, inplace=False)
34
- model = quantization.convert(model, inplace=False)
35
 
36
- # Load the checkpoint
37
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
38
- # model.load_state_dict(checkpoint["model_state_dict"])
39
  model.load_state_dict(checkpoint)
40
 
41
  model.eval()
 
23
  tie_word_embeddings=True,
24
  )
25
 
26
+ # Dynamic quantization for embeddings
27
+ model.embed_tokens = torch.ao.quantization.quantize_dynamic(
28
+ model.embed_tokens, {nn.Embedding}, dtype=torch.qint8
29
+ )
30
+ model.embed_positions = torch.ao.quantization.quantize_dynamic(
31
+ model.embed_positions, {nn.Embedding}, dtype=torch.qint8
32
+ )
33
 
34
+ # Static quantization for other layers
35
+ model.qconfig = torch.ao.quantization.default_qconfig
36
+ model = torch.ao.quantization.prepare(model, inplace=False)
37
+ model = torch.ao.quantization.convert(model, inplace=False)
38
 
39
+ # Load checkpoint
40
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
 
41
  model.load_state_dict(checkpoint)
42
 
43
  model.eval()