Sakalti commited on
Commit
1cd0800
·
verified ·
1 Parent(s): 48fa0bf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
3
+ from datasets import load_dataset
4
+ import os
5
+
6
+ def train_and_deploy(write_token, repo_name, license_text):
7
+ # トークンを環境変数に設定
8
+ os.environ['HF_WRITE_TOKEN'] = write_token
9
+
10
+ # ライセンスファイルを作成
11
+ with open("LICENSE", "w") as f:
12
+ f.write(license_text)
13
+
14
+ # モデルとトークナイザーの読み込み
15
+ model_name = "HuggingfaceH4/zephyr-7b-beta" # トレーニング対象のモデル
16
+ model = AutoModelForCausalLM.from_pretrained(model_name)
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+
19
+ # 日本語データセットの読み込み
20
+ dataset = load_dataset("Sakalti/hachiwari")
21
+
22
+ # データセットのトークン化
23
+ def tokenize_function(examples):
24
+ return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
25
+
26
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
27
+
28
+ # トレーニング設定
29
+ training_args = TrainingArguments(
30
+ output_dir="./results",
31
+ per_device_train_batch_size=8,
32
+ per_device_eval_batch_size=8,
33
+ evaluation_strategy="epoch",
34
+ save_strategy="epoch",
35
+ logging_dir="./logs",
36
+ logging_steps=10,
37
+ num_train_epochs=3, # トレーニングエポック数
38
+ push_to_hub=True, # Hugging Face Hubにプッシュ
39
+ hub_token=write_token,
40
+ hub_model_id=repo_name # ユーザーが入力したリポジトリ名
41
+ )
42
+
43
+ # Trainerの設定
44
+ trainer = Trainer(
45
+ model=model,
46
+ args=training_args,
47
+ train_dataset=tokenized_datasets["train"],
48
+ eval_dataset=tokenized_datasets["test"],
49
+ )
50
+
51
+ # トレーニング実行
52
+ trainer.train()
53
+
54
+ # モデルをHugging Face Hubにプッシュ
55
+ trainer.push_to_hub()
56
+
57
+ return f"モデルが'{repo_name}'リポジトリにデプロイされました!"
58
+
59
+ # Gradio UI
60
+ with gr.Blocks() as demo:
61
+ gr.Markdown("### Zephyr-7B モデルの日本語特化トレーニングとデプロイ")
62
+ token_input = gr.Textbox(label="Hugging Face Write Token", placeholder="トークンを入力してください...")
63
+ repo_input = gr.Textbox(label="リポジトリ名", placeholder="デプロイするリポジトリ名を入力してください...")
64
+ license_input = gr.Textbox(label="ライセンス", placeholder="ライセンス情報を入力してください...")
65
+ output = gr.Textbox(label="出力")
66
+ train_button = gr.Button("デプロイ")
67
+
68
+ train_button.click(fn=train_and_deploy, inputs=[token_input, repo_input, license_input], outputs=output)
69
+
70
+ demo.launch()