Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments | |
from datasets import load_dataset | |
import torch | |
import os | |
# Load Dataset | |
dataset = load_dataset('csv', data_files={'train': './data/raw_data.csv'}, delimiter=",") | |
# Load Pretrained Tokenizer and Model | |
model_name = "xlm-roberta-base" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) | |
# Tokenization | |
def preprocess_function(examples): | |
return tokenizer(examples['text'], truncation=True, padding=True) | |
encoded_dataset = dataset.map(preprocess_function, batched=True) | |
# Training Arguments | |
training_args = TrainingArguments( | |
output_dir="./checkpoints", | |
num_train_epochs=3, | |
per_device_train_batch_size=8, | |
save_steps=100, | |
save_total_limit=1, | |
logging_dir="./logs", | |
logging_steps=10, | |
evaluation_strategy="no", | |
push_to_hub=False, | |
load_best_model_at_end=False | |
) | |
# Trainer Setup | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=encoded_dataset['train'] | |
) | |
# Start Training | |
trainer.train() | |
# Save Final Fine-tuned Model | |
save_directory = "./models/fine_tuned_xlm_roberta" | |
os.makedirs(save_directory, exist_ok=True) | |
model.save_pretrained(save_directory) | |
tokenizer.save_pretrained(save_directory) | |
# Quantize Model (Make Lightweight) | |
def quantize_model(model_path): | |
model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
model.to(torch.device('cpu')) | |
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) | |
quantized_model_path = model_path + "_quantized" | |
os.makedirs(quantized_model_path, exist_ok=True) | |
model.save_pretrained(quantized_model_path) | |
tokenizer.save_pretrained(quantized_model_path) | |
print(f"Quantized model saved to {quantized_model_path}") | |
quantize_model(save_directory) | |