bhlewis's picture
Update app.py
ba4060a verified
raw
history blame
2.59 kB
import gradio as gr
import json
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer
from datasets import Dataset
import shutil
import os
# Load tokenized data
def load_data(tokenized_file):
with open(tokenized_file.name, 'r') as f:
tokenized_data = json.load(f)
return tokenized_data
# Fine-tune the model
def fine_tune_model(tokenized_file, progress=gr.Progress()):
tokenized_data = load_data(tokenized_file)
# Convert tokenized data to Dataset
dataset = Dataset.from_dict(tokenized_data)
# Split the dataset into train and validation sets
tokenized_datasets = dataset.train_test_split(test_size=0.2)
model = AutoModelForSequenceClassification.from_pretrained('anferico/bert-for-patents', num_labels=2)
tokenizer = AutoTokenizer.from_pretrained('anferico/bert-for-patents')
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=8, # Reduce batch size
per_device_eval_batch_size=16,
gradient_accumulation_steps=2, # Use gradient accumulation
fp16=True, # Enable mixed precision training
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['test'],
)
progress(0.5, "Fine-tuning the model...")
trainer.train()
progress(1.0, "Fine-tuning complete.")
model.save_pretrained('./fine_tuned_patentbert')
tokenizer.save_pretrained('./fine_tuned_patentbert')
# Create a zip file of the fine-tuned model
shutil.make_archive('fine_tuned_patentbert', 'zip', './fine_tuned_patentbert')
return "Model fine-tuned and saved successfully. Download the model using the link below.", "fine_tuned_patentbert.zip"
# Create Gradio interface
iface = gr.Interface(
fn=fine_tune_model,
inputs=[
gr.File(label="Upload Tokenized Data JSON")
],
outputs=[
gr.Textbox(label="Processing Information"),
gr.File(label="Download Fine-Tuned Model")
],
title="Fine-Tune Patent BERT Model",
description="Upload tokenized JSON file to fine-tune the BERT model.",
live=True # Enable live updates for progress
)
# Launch the interface
iface.launch()