bhlewis's picture
Update app.py
54d96e0 verified
raw
history blame
2.44 kB
import gradio as gr
import json
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer
from datasets import Dataset
import shutil
import os
# Load tokenized data
def load_data(tokenized_file):
with open(tokenized_file.name, 'r') as f:
tokenized_data = json.load(f)
return tokenized_data
# Fine-tune the model
def fine_tune_model(tokenized_file, progress=gr.Progress()):
tokenized_data = load_data(tokenized_file)
# Convert tokenized data to Dataset
dataset = Dataset.from_dict(tokenized_data)
# Split the dataset into train and validation sets
tokenized_datasets = dataset.train_test_split(test_size=0.2)
model = AutoModelForSequenceClassification.from_pretrained('anferico/bert-for-patents', num_labels=2)
tokenizer = AutoTokenizer.from_pretrained('anferico/bert-for-patents')
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['test'],
)
progress(0.5, "Fine-tuning the model...")
trainer.train()
progress(1.0, "Fine-tuning complete.")
model.save_pretrained('./fine_tuned_patentbert')
tokenizer.save_pretrained('./fine_tuned_patentbert')
# Create a zip file of the fine-tuned model
shutil.make_archive('fine_tuned_patentbert', 'zip', './fine_tuned_patentbert')
return "Model fine-tuned and saved successfully. Download the model using the link below.", "fine_tuned_patentbert.zip"
# Create Gradio interface
iface = gr.Interface(
fn=fine_tune_model,
inputs=[
gr.File(label="Upload Tokenized Data JSON")
],
outputs=[
gr.Textbox(label="Processing Information"),
gr.File(label="Download Fine-Tuned Model")
],
title="Fine-Tune Patent BERT Model",
description="Upload tokenized JSON file to fine-tune the BERT model.",
live=True # Enable live updates for progress
)
# Launch the interface
iface.launch()