Tousifahamed commited on
Commit
ad95929
·
verified ·
1 Parent(s): b789c6c

Upload 2 files

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from transformers import AutoTokenizer
3
  from model import TransformerModel # Replace with your model class
4
  import gradio as gr
@@ -21,15 +22,13 @@ def load_quantized_model(checkpoint_path):
21
  tie_word_embeddings=True,
22
  )
23
 
24
- # Apply dynamic quantization to the embedding layer
25
- model.embed_tokens = torch.quantization.quantize_dynamic(
26
- model.embed_tokens, {torch.nn.Embedding}, dtype=torch.qint8
27
- )
28
 
29
  # Apply static quantization to the rest of the model
30
- model.qconfig = torch.quantization.default_qconfig
31
- model = torch.quantization.prepare(model, inplace=False)
32
- model = torch.quantization.convert(model, inplace=False)
33
 
34
  # Load the quantized checkpoint
35
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
@@ -38,12 +37,19 @@ def load_quantized_model(checkpoint_path):
38
  model.eval()
39
  return model
40
 
41
-
42
  import gradio as gr
43
 
44
  # Load the quantized model
45
  model = load_quantized_model("checkpoint_quantized.pt")
46
 
 
 
 
 
 
 
 
 
47
  # Function to generate text
48
  def generate_text(prompt, max_length=50, temperature=1.0, top_k=50):
49
  input_ids = tokenizer.encode(prompt, return_tensors="pt")
 
1
  import torch
2
+ import torch.ao.quantization as quantization
3
  from transformers import AutoTokenizer
4
  from model import TransformerModel # Replace with your model class
5
  import gradio as gr
 
22
  tie_word_embeddings=True,
23
  )
24
 
25
+ # Set the quantization configuration for the embedding layer
26
+ model.embed_tokens.qconfig = quantization.float_qparams_weight_only_qconfig
 
 
27
 
28
  # Apply static quantization to the rest of the model
29
+ model.qconfig = quantization.default_qconfig
30
+ model = quantization.prepare(model, inplace=False)
31
+ model = quantization.convert(model, inplace=False)
32
 
33
  # Load the quantized checkpoint
34
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
 
37
  model.eval()
38
  return model
39
 
 
40
  import gradio as gr
41
 
42
  # Load the quantized model
43
  model = load_quantized_model("checkpoint_quantized.pt")
44
 
45
+ # Set the quantization configuration for the embedding layer
46
+ model.embed_tokens.qconfig = quantization.float_qparams_weight_only_qconfig
47
+
48
+ # Apply static quantization to the rest of the model
49
+ model.qconfig = quantization.default_qconfig
50
+ model = quantization.prepare(model, inplace=False)
51
+ model = quantization.convert(model, inplace=False)
52
+
53
  # Function to generate text
54
  def generate_text(prompt, max_length=50, temperature=1.0, top_k=50):
55
  input_ids = tokenizer.encode(prompt, return_tensors="pt")