Tousifahamed commited on
Commit
55e33aa
·
verified ·
1 Parent(s): 825827f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -22
app.py CHANGED
@@ -5,6 +5,8 @@ from transformers import AutoTokenizer
5
  from model import TransformerModel
6
  import gradio as gr
7
 
 
 
8
  # Load the tokenizer
9
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
10
 
@@ -22,32 +24,20 @@ def load_quantized_model(checkpoint_path):
22
  tie_word_embeddings=True,
23
  )
24
 
25
- # Dynamic quant for embeddings
26
- model.embed_tokens = torch.quantization.quantize_dynamic(
27
- model.embed_tokens, {nn.Embedding}, dtype=torch.qint8
28
- )
29
- model.embed_positions = torch.quantization.quantize_dynamic(
30
- model.embed_positions, {nn.Embedding}, dtype=torch.qint8
31
- )
32
-
33
- # Static quant config for the rest of the model
34
- model.qconfig = torch.quantization.get_default_qconfig("fbgemm") # CPU
35
- model = torch.quantization.prepare(model, inplace=False)
36
 
37
- #
38
- # >>> RUN CALIBRATION HERE (forward pass with sample data) <<<
39
- # e.g. with torch.no_grad():
40
- # for input_ids in some_calibration_loader:
41
- # outputs = model(input_ids)
42
- #
43
 
 
44
  model = torch.quantization.convert(model, inplace=False)
45
-
46
- # Load checkpoint
47
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
48
- model.load_state_dict(checkpoint)
49
 
50
- model.eval()
51
  return model
52
 
53
 
 
5
  from model import TransformerModel
6
  import gradio as gr
7
 
8
+ from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig
9
+
10
  # Load the tokenizer
11
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
12
 
 
24
  tie_word_embeddings=True,
25
  )
26
 
27
+ # This qconfig is typically for your other layers
28
+ default_qconfig = torch.quantization.get_default_qconfig("fbgemm")
29
+ model.qconfig = default_qconfig
30
+
31
+ # For embeddings, force the specialized config:
32
+ model.embed_tokens.qconfig = float_qparams_weight_only_qconfig
33
+ model.embed_positions.qconfig = float_qparams_weight_only_qconfig
 
 
 
 
34
 
35
+ # Then prepare, calibrate, and convert
36
+ model = torch.quantization.prepare(model, inplace=False)
 
 
 
 
37
 
38
+ # Calibration pass here...
39
  model = torch.quantization.convert(model, inplace=False)
 
 
 
 
40
 
 
41
  return model
42
 
43