import onnx import onnxruntime import torch from transformers import BertForTokenClassification from .config_train import model_load_path, onnx_path, tokenizer # Convert Model to ONNX def convert_to_onnx(model_path, tokenizer): """Convert the fine-tuned BERT token classification model to ONNX.""" model = BertForTokenClassification.from_pretrained(model_path) model.eval() # Dummy input dummy_sentence = "Tôi muốn đi cắm trại ngắm hoàng hôn trên biển cùng gia đình" inputs = tokenizer(dummy_sentence, return_tensors="pt", padding=True, truncation=True) dummy_input_ids = inputs["input_ids"] dummy_attention_mask = inputs["attention_mask"] # Export ONNX model torch.onnx.export( model, (inputs["input_ids"], inputs["attention_mask"]), # Tuple of model inputs onnx_path, export_params=True, opset_version=14, # Use Opset 14 or higher input_names=["input_ids", "attention_mask"], output_names=["logits"], dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, "logits": {0: "batch_size", 1: "sequence_length"}}, ) print(f"✅ ONNX model saved to {onnx_path}") convert_to_onnx(model_load_path, tokenizer)