Sakalti commited on
Commit
b7e4c73
·
verified ·
1 Parent(s): d865a70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import torch
3
  from dotenv import load_dotenv
4
  from datasets import load_dataset
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
@@ -8,15 +7,13 @@ from huggingface_hub import login
8
  # === トークン読み込み ===
9
  load_dotenv()
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
-
12
  if not HF_TOKEN:
13
  raise ValueError("Hugging Faceのトークンが見つかりません。`.env`ファイルまたは環境変数を確認してください。")
14
-
15
  login(HF_TOKEN)
16
 
17
  # === 設定 ===
18
- BASE_MODEL = "Sakalti/Template-4"
19
- HF_REPO = "Sakalti/Template-16"
20
 
21
  # === データ読み込み ===
22
  dataset = load_dataset("Verah/JParaCrawl-Filtered-English-Japanese-Parallel-Corpus", split="train")
@@ -25,12 +22,12 @@ dataset = load_dataset("Verah/JParaCrawl-Filtered-English-Japanese-Parallel-Corp
25
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
26
  model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
27
 
28
- # === データ前処理 ===
29
  def preprocess(examples):
30
- texts = [f"英語: {ex['en']}\n日本語:" for ex in examples]
31
- model_inputs = tokenizer(texts, max_length=256, truncation=True)
32
- model_inputs["labels"] = model_inputs["input_ids"]
33
- return model_inputs
34
 
35
  tokenized_dataset = dataset.map(preprocess, batched=True, remove_columns=dataset.column_names)
36
 
@@ -42,7 +39,7 @@ training_args = TrainingArguments(
42
  per_device_train_batch_size=2,
43
  num_train_epochs=3,
44
  save_total_limit=2,
45
- save_steps=500, # 500ステップごとに保存(ご要望通り)
46
  push_to_hub=True,
47
  hub_model_id=HF_REPO,
48
  hub_token=HF_TOKEN,
 
1
  import os
 
2
  from dotenv import load_dotenv
3
  from datasets import load_dataset
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
 
7
  # === トークン読み込み ===
8
  load_dotenv()
9
  HF_TOKEN = os.getenv("HF_TOKEN")
 
10
  if not HF_TOKEN:
11
  raise ValueError("Hugging Faceのトークンが見つかりません。`.env`ファイルまたは環境変数を確認してください。")
 
12
  login(HF_TOKEN)
13
 
14
  # === 設定 ===
15
+ BASE_MODEL = "Sakalti/template-4" # 修正対象モデル名
16
+ HF_REPO = "Sakalti/template-16"
17
 
18
  # === データ読み込み ===
19
  dataset = load_dataset("Verah/JParaCrawl-Filtered-English-Japanese-Parallel-Corpus", split="train")
 
22
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
23
  model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
24
 
25
+ # === 超簡素トークナイズ関数 ===
26
  def preprocess(examples):
27
+ texts = [ex["en"] + " " + ex["ja"] for ex in examples]
28
+ tokenized = tokenizer(texts, max_length=256, truncation=True)
29
+ tokenized["labels"] = tokenized["input_ids"].copy()
30
+ return tokenized
31
 
32
  tokenized_dataset = dataset.map(preprocess, batched=True, remove_columns=dataset.column_names)
33
 
 
39
  per_device_train_batch_size=2,
40
  num_train_epochs=3,
41
  save_total_limit=2,
42
+ save_steps=500,
43
  push_to_hub=True,
44
  hub_model_id=HF_REPO,
45
  hub_token=HF_TOKEN,