Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,8 @@ import gradio as gr
|
|
2 |
import json
|
3 |
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer
|
4 |
from datasets import Dataset
|
5 |
-
import
|
|
|
6 |
|
7 |
# Load tokenized data
|
8 |
def load_data(tokenized_file):
|
@@ -51,7 +52,10 @@ def fine_tune_model(tokenized_file, progress=gr.Progress()):
|
|
51 |
model.save_pretrained('./fine_tuned_patentbert')
|
52 |
tokenizer.save_pretrained('./fine_tuned_patentbert')
|
53 |
|
54 |
-
|
|
|
|
|
|
|
55 |
|
56 |
# Create Gradio interface
|
57 |
iface = gr.Interface(
|
@@ -59,7 +63,10 @@ iface = gr.Interface(
|
|
59 |
inputs=[
|
60 |
gr.File(label="Upload Tokenized Data JSON")
|
61 |
],
|
62 |
-
outputs=
|
|
|
|
|
|
|
63 |
title="Fine-Tune Patent BERT Model",
|
64 |
description="Upload tokenized JSON file to fine-tune the BERT model.",
|
65 |
live=True # Enable live updates for progress
|
|
|
2 |
import json
|
3 |
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer
|
4 |
from datasets import Dataset
|
5 |
+
import shutil
|
6 |
+
import os
|
7 |
|
8 |
# Load tokenized data
|
9 |
def load_data(tokenized_file):
|
|
|
52 |
model.save_pretrained('./fine_tuned_patentbert')
|
53 |
tokenizer.save_pretrained('./fine_tuned_patentbert')
|
54 |
|
55 |
+
# Create a zip file of the fine-tuned model
|
56 |
+
shutil.make_archive('fine_tuned_patentbert', 'zip', './fine_tuned_patentbert')
|
57 |
+
|
58 |
+
return "Model fine-tuned and saved successfully. Download the model using the link below.", "fine_tuned_patentbert.zip"
|
59 |
|
60 |
# Create Gradio interface
|
61 |
iface = gr.Interface(
|
|
|
63 |
inputs=[
|
64 |
gr.File(label="Upload Tokenized Data JSON")
|
65 |
],
|
66 |
+
outputs=[
|
67 |
+
gr.Textbox(label="Processing Information"),
|
68 |
+
gr.File(label="Download Fine-Tuned Model")
|
69 |
+
],
|
70 |
title="Fine-Tune Patent BERT Model",
|
71 |
description="Upload tokenized JSON file to fine-tune the BERT model.",
|
72 |
live=True # Enable live updates for progress
|