DonImages commited on
Commit
d8ec44f
·
verified ·
1 Parent(s): 8894709

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -95,9 +95,14 @@ def train_lora(image_folder, metadata):
95
  if batch_idx % 10 == 0: # Log every 10 batches
96
  print(f"Batch {batch_idx}, Loss: {loss.item()}")
97
 
98
- # Save the trained model
99
- torch.save(model.state_dict(), "lora_model.pth")
100
- print("Model saved as lora_model.pth")
 
 
 
 
 
101
 
102
  print("Training completed.")
103
 
 
95
  if batch_idx % 10 == 0: # Log every 10 batches
96
  print(f"Batch {batch_idx}, Loss: {loss.item()}")
97
 
98
+ # Define the folder where the model will be saved
99
+ save_folder = "models"
100
+ os.makedirs(save_folder, exist_ok=True) # Create the folder if it doesn't exist
101
+
102
+ # Save the trained model in the specified folder
103
+ model_save_path = os.path.join(save_folder, "lora_model.pth")
104
+ torch.save(model.state_dict(), model_save_path)
105
+ print(f"Model saved at {model_save_path}")
106
 
107
  print("Training completed.")
108