Update app.py
Browse files
app.py
CHANGED
@@ -95,9 +95,14 @@ def train_lora(image_folder, metadata):
|
|
95 |
if batch_idx % 10 == 0: # Log every 10 batches
|
96 |
print(f"Batch {batch_idx}, Loss: {loss.item()}")
|
97 |
|
98 |
-
#
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
print("Training completed.")
|
103 |
|
|
|
95 |
if batch_idx % 10 == 0: # Log every 10 batches
|
96 |
print(f"Batch {batch_idx}, Loss: {loss.item()}")
|
97 |
|
98 |
+
# Define the folder where the model will be saved
|
99 |
+
save_folder = "models"
|
100 |
+
os.makedirs(save_folder, exist_ok=True) # Create the folder if it doesn't exist
|
101 |
+
|
102 |
+
# Save the trained model in the specified folder
|
103 |
+
model_save_path = os.path.join(save_folder, "lora_model.pth")
|
104 |
+
torch.save(model.state_dict(), model_save_path)
|
105 |
+
print(f"Model saved at {model_save_path}")
|
106 |
|
107 |
print("Training completed.")
|
108 |
|