Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -23,19 +23,21 @@ def load_quantized_model(checkpoint_path):
|
|
| 23 |
tie_word_embeddings=True,
|
| 24 |
)
|
| 25 |
|
| 26 |
-
#
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
#
|
| 32 |
-
model.qconfig = quantization.default_qconfig
|
| 33 |
-
model = quantization.prepare(model, inplace=False)
|
| 34 |
-
model = quantization.convert(model, inplace=False)
|
| 35 |
|
| 36 |
-
# Load
|
| 37 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 38 |
-
# model.load_state_dict(checkpoint["model_state_dict"])
|
| 39 |
model.load_state_dict(checkpoint)
|
| 40 |
|
| 41 |
model.eval()
|
|
|
|
| 23 |
tie_word_embeddings=True,
|
| 24 |
)
|
| 25 |
|
| 26 |
+
# Dynamic quantization for embeddings
|
| 27 |
+
model.embed_tokens = torch.ao.quantization.quantize_dynamic(
|
| 28 |
+
model.embed_tokens, {nn.Embedding}, dtype=torch.qint8
|
| 29 |
+
)
|
| 30 |
+
model.embed_positions = torch.ao.quantization.quantize_dynamic(
|
| 31 |
+
model.embed_positions, {nn.Embedding}, dtype=torch.qint8
|
| 32 |
+
)
|
| 33 |
|
| 34 |
+
# Static quantization for other layers
|
| 35 |
+
model.qconfig = torch.ao.quantization.default_qconfig
|
| 36 |
+
model = torch.ao.quantization.prepare(model, inplace=False)
|
| 37 |
+
model = torch.ao.quantization.convert(model, inplace=False)
|
| 38 |
|
| 39 |
+
# Load checkpoint
|
| 40 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
|
|
| 41 |
model.load_state_dict(checkpoint)
|
| 42 |
|
| 43 |
model.eval()
|