Tousifahamed commited on
Commit
632a181
·
verified ·
1 Parent(s): a7f53d7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import torch
 
2
  import torch.ao.quantization as quantization
3
  from transformers import AutoTokenizer
4
- from model import TransformerModel # Replace with your model class
5
  import gradio as gr
6
 
7
  # Load the tokenizer
@@ -24,7 +25,7 @@ def load_quantized_model(checkpoint_path):
24
 
25
  # Set quantization config for ALL embedding layers
26
  for name, module in model.named_modules():
27
- if isinstance(module, nn.Embedding):
28
  module.qconfig = quantization.float_qparams_weight_only_qconfig
29
 
30
  # Apply static quantization to the rest of the model
@@ -39,7 +40,6 @@ def load_quantized_model(checkpoint_path):
39
  model.eval()
40
  return model
41
 
42
- import gradio as gr
43
 
44
  # Load the quantized model
45
  model = load_quantized_model("checkpoint_quantized.pt")
 
1
  import torch
2
+ import torch.nn as nn # Added missing import
3
  import torch.ao.quantization as quantization
4
  from transformers import AutoTokenizer
5
+ from model import TransformerModel
6
  import gradio as gr
7
 
8
  # Load the tokenizer
 
25
 
26
  # Set quantization config for ALL embedding layers
27
  for name, module in model.named_modules():
28
+ if isinstance(module, nn.Embedding): # Now works because `nn` is imported
29
  module.qconfig = quantization.float_qparams_weight_only_qconfig
30
 
31
  # Apply static quantization to the rest of the model
 
40
  model.eval()
41
  return model
42
 
 
43
 
44
  # Load the quantized model
45
  model = load_quantized_model("checkpoint_quantized.pt")