Update app.py
Browse files
app.py
CHANGED
|
@@ -95,9 +95,14 @@ def train_lora(image_folder, metadata):
|
|
| 95 |
if batch_idx % 10 == 0: # Log every 10 batches
|
| 96 |
print(f"Batch {batch_idx}, Loss: {loss.item()}")
|
| 97 |
|
| 98 |
-
#
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
print("Training completed.")
|
| 103 |
|
|
|
|
| 95 |
if batch_idx % 10 == 0: # Log every 10 batches
|
| 96 |
print(f"Batch {batch_idx}, Loss: {loss.item()}")
|
| 97 |
|
| 98 |
+
# Define the folder where the model will be saved
|
| 99 |
+
save_folder = "models"
|
| 100 |
+
os.makedirs(save_folder, exist_ok=True) # Create the folder if it doesn't exist
|
| 101 |
+
|
| 102 |
+
# Save the trained model in the specified folder
|
| 103 |
+
model_save_path = os.path.join(save_folder, "lora_model.pth")
|
| 104 |
+
torch.save(model.state_dict(), model_save_path)
|
| 105 |
+
print(f"Model saved at {model_save_path}")
|
| 106 |
|
| 107 |
print("Training completed.")
|
| 108 |
|