Spaces:
Runtime error
Runtime error
stakelovelace
commited on
Commit
·
acc7015
1
Parent(s):
2094fe7
commit from tesla
Browse files
app.py
CHANGED
|
@@ -42,8 +42,9 @@ def train_model(model, tokenizer, data, device):
|
|
| 42 |
training_args = TrainingArguments(
|
| 43 |
output_dir='./results',
|
| 44 |
num_train_epochs=3,
|
| 45 |
-
per_device_train_batch_size=
|
| 46 |
-
gradient_accumulation_steps=
|
|
|
|
| 47 |
warmup_steps=500,
|
| 48 |
weight_decay=0.01,
|
| 49 |
logging_dir='./logs',
|
|
@@ -89,7 +90,7 @@ def generate_api_query(model, tokenizer, prompt, desired_output, api_name, base_
|
|
| 89 |
input_ids = input_ids.to(model.device)
|
| 90 |
|
| 91 |
# Generate query using model with temperature for randomness
|
| 92 |
-
output = model.generate(input_ids, max_length=
|
| 93 |
|
| 94 |
# Decode the generated query tokens
|
| 95 |
query = tokenizer.decode(output[0], skip_special_tokens=True)
|
|
|
|
| 42 |
training_args = TrainingArguments(
|
| 43 |
output_dir='./results',
|
| 44 |
num_train_epochs=3,
|
| 45 |
+
per_device_train_batch_size=8,
|
| 46 |
+
gradient_accumulation_steps=4,
|
| 47 |
+
fp16=True, # Enable mixed precision
|
| 48 |
warmup_steps=500,
|
| 49 |
weight_decay=0.01,
|
| 50 |
logging_dir='./logs',
|
|
|
|
| 90 |
input_ids = input_ids.to(model.device)
|
| 91 |
|
| 92 |
# Generate query using model with temperature for randomness
|
| 93 |
+
output = model.generate(input_ids, max_length=128, truncation=True, padding='max_length', temperature=0.1, do_sample=True)
|
| 94 |
|
| 95 |
# Decode the generated query tokens
|
| 96 |
query = tokenizer.decode(output[0], skip_special_tokens=True)
|