rectified onnx inference code
Browse files
README.md
CHANGED
@@ -106,7 +106,7 @@ predicted_class_id = logits.argmax().item()
|
|
106 |
loaded_model.config.id2label[predicted_class_id]
|
107 |
```
|
108 |
|
109 |
-
Optimum with ONNX
|
110 |
|
111 |
Loading the model requires the 🤗 Optimum library installed.
|
112 |
```shell
|
@@ -115,12 +115,12 @@ pip install transformers optimum[onnxruntime] optimum
|
|
115 |
|
116 |
```python
|
117 |
model_path = "philomath-1209/programming-language-identification"
|
118 |
-
|
119 |
from transformers import pipeline, AutoTokenizer
|
120 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
121 |
|
122 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
123 |
-
model = ORTModelForSequenceClassification.from_pretrained(model_path, export=
|
124 |
|
125 |
text = """
|
126 |
PROGRAM Triangle
|
@@ -141,9 +141,10 @@ text = """
|
|
141 |
END FUNCTION Area
|
142 |
|
143 |
"""
|
144 |
-
inputs =
|
145 |
with torch.no_grad():
|
146 |
-
logits =
|
147 |
predicted_class_id = logits.argmax().item()
|
148 |
-
|
|
|
149 |
```
|
|
|
106 |
loaded_model.config.id2label[predicted_class_id]
|
107 |
```
|
108 |
|
109 |
+
### Optimum with ONNX inference
|
110 |
|
111 |
Loading the model requires the 🤗 Optimum library installed.
|
112 |
```shell
|
|
|
115 |
|
116 |
```python
|
117 |
model_path = "philomath-1209/programming-language-identification"
|
118 |
+
import torch
|
119 |
from transformers import pipeline, AutoTokenizer
|
120 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
121 |
|
122 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="onnx")
|
123 |
+
model = ORTModelForSequenceClassification.from_pretrained(model_path, export=False, subfolder="onnx")
|
124 |
|
125 |
text = """
|
126 |
PROGRAM Triangle
|
|
|
141 |
END FUNCTION Area
|
142 |
|
143 |
"""
|
144 |
+
inputs = tokenizer(text, return_tensors="pt",truncation=True)
|
145 |
with torch.no_grad():
|
146 |
+
logits = model(**inputs).logits
|
147 |
predicted_class_id = logits.argmax().item()
|
148 |
+
model.config.id2label[predicted_class_id]
|
149 |
+
|
150 |
```
|