parth parekh commited on
Commit
838063e
·
1 Parent(s): e43c18e

testing out torch jit

Browse files
Files changed (1) hide show
  1. predictor.py +18 -10
predictor.py CHANGED
@@ -82,21 +82,29 @@ test_sentences = [
82
  "Lets do '42069' tonight it will be really fun what do you say ?"
83
  ]
84
 
 
 
85
 
86
- # Function to predict
 
 
 
 
87
  def predict(text):
88
- with torch.no_grad():
89
- inputs = torch.tensor([text_pipeline(text)])
90
- if inputs.size(1) < max(FILTER_SIZES):
91
- # Pad the input if it's shorter than the largest filter size
92
- padding = torch.zeros(1, max(FILTER_SIZES) - inputs.size(1), dtype=torch.long)
93
- inputs = torch.cat([inputs, padding], dim=1)
94
- inputs = inputs.to(device)
95
- outputs = model(inputs)
 
 
 
96
  return torch.argmax(outputs, dim=1).item()
97
 
98
 
99
-
100
  # Test the sentences
101
  for i, sentence in enumerate(test_sentences, 1):
102
  prediction = predict(sentence)
 
82
  "Lets do '42069' tonight it will be really fun what do you say ?"
83
  ]
84
 
85
+ # JIT Script the model for faster inference
86
+ scripted_model = torch.jit.script(model)
87
 
88
+ # Preallocate padding tensor to avoid repeated memory allocation
89
+ MAX_LEN = max(FILTER_SIZES)
90
+ padding_tensor = torch.zeros(1, MAX_LEN, dtype=torch.long).to(device)
91
+
92
+ # Prediction function using JIT and inference optimizations
93
  def predict(text):
94
+ with torch.inference_mode(): # Use inference mode instead of no_grad
95
+ inputs = torch.tensor([text_pipeline(text)]).to(device)
96
+
97
+ # Perform padding if necessary
98
+ if inputs.size(1) < MAX_LEN:
99
+ inputs = torch.cat([inputs, padding_tensor[:, :MAX_LEN - inputs.size(1)]], dim=1)
100
+
101
+ # Pass inputs through the scripted model
102
+ outputs = scripted_model(inputs)
103
+
104
+ # Return predicted class
105
  return torch.argmax(outputs, dim=1).item()
106
 
107
 
 
108
  # Test the sentences
109
  for i, sentence in enumerate(test_sentences, 1):
110
  prediction = predict(sentence)