DonImages commited on
Commit
4b8e1f8
·
verified ·
1 Parent(s): d8ec44f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  import json
7
  import os
8
  import gradio as gr
 
9
 
10
  # Paths
11
  image_folder = "Images/"
@@ -95,23 +96,22 @@ def train_lora(image_folder, metadata):
95
  if batch_idx % 10 == 0: # Log every 10 batches
96
  print(f"Batch {batch_idx}, Loss: {loss.item()}")
97
 
98
- # Define the folder where the model will be saved
99
- save_folder = "models"
100
- os.makedirs(save_folder, exist_ok=True) # Create the folder if it doesn't exist
 
101
 
102
- # Save the trained model in the specified folder
103
- model_save_path = os.path.join(save_folder, "lora_model.pth")
104
- torch.save(model.state_dict(), model_save_path)
105
- print(f"Model saved at {model_save_path}")
106
 
107
- print("Training completed.")
108
 
109
  # Gradio App
110
  def start_training_gradio():
111
  print("Loading metadata and preparing dataset...")
112
  metadata = load_metadata(metadata_file)
113
- train_lora(image_folder, metadata)
114
- return "Training completed. Check the model outputs!"
115
 
116
  demo = gr.Interface(
117
  fn=start_training_gradio,
 
6
  import json
7
  import os
8
  import gradio as gr
9
+ import shutil
10
 
11
  # Paths
12
  image_folder = "Images/"
 
96
  if batch_idx % 10 == 0: # Log every 10 batches
97
  print(f"Batch {batch_idx}, Loss: {loss.item()}")
98
 
99
+ # Save the trained model to /mnt/data/ for Hugging Face Space to access
100
+ save_path = '/mnt/data/lora_model.pth'
101
+ torch.save(model.state_dict(), save_path)
102
+ print(f"Model saved at {save_path}")
103
 
104
+ # Move the file to a location where we can access it for download
105
+ # Here, /mnt/data is directly accessible from the Hugging Face Space interface
106
+ print(f"Training completed. The model is saved and ready for download at {save_path}.")
 
107
 
108
+ return f"Training completed. Download the model from: [Download Model](sandbox:/mnt/data/lora_model.pth)"
109
 
110
  # Gradio App
111
  def start_training_gradio():
112
  print("Loading metadata and preparing dataset...")
113
  metadata = load_metadata(metadata_file)
114
+ return train_lora(image_folder, metadata)
 
115
 
116
  demo = gr.Interface(
117
  fn=start_training_gradio,