DonImages commited on
Commit
375075f
·
verified ·
1 Parent(s): 4b8e1f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import torch
2
  from torch import nn, optim
3
  from torch.utils.data import DataLoader, Dataset
4
- from torchvision import transforms, datasets, models
5
  from PIL import Image
6
  import json
7
  import os
8
  import gradio as gr
9
- import shutil
10
 
11
  # Paths
12
  image_folder = "Images/"
@@ -96,27 +95,26 @@ def train_lora(image_folder, metadata):
96
  if batch_idx % 10 == 0: # Log every 10 batches
97
  print(f"Batch {batch_idx}, Loss: {loss.item()}")
98
 
99
- # Save the trained model to /mnt/data/ for Hugging Face Space to access
100
- save_path = '/mnt/data/lora_model.pth'
101
- torch.save(model.state_dict(), save_path)
102
- print(f"Model saved at {save_path}")
103
 
104
- # Move the file to a location where we can access it for download
105
- # Here, /mnt/data is directly accessible from the Hugging Face Space interface
106
- print(f"Training completed. The model is saved and ready for download at {save_path}.")
107
-
108
- return f"Training completed. Download the model from: [Download Model](sandbox:/mnt/data/lora_model.pth)"
109
 
110
  # Gradio App
111
  def start_training_gradio():
112
  print("Loading metadata and preparing dataset...")
113
  metadata = load_metadata(metadata_file)
114
- return train_lora(image_folder, metadata)
 
115
 
 
116
  demo = gr.Interface(
117
  fn=start_training_gradio,
118
  inputs=None,
119
- outputs="text",
120
  title="Train LoRA Model",
121
  description="Fine-tune a model using LoRA for consistent image generation."
122
  )
 
1
  import torch
2
  from torch import nn, optim
3
  from torch.utils.data import DataLoader, Dataset
4
+ from torch torchvision import transforms, datasets, models
5
  from PIL import Image
6
  import json
7
  import os
8
  import gradio as gr
 
9
 
10
  # Paths
11
  image_folder = "Images/"
 
95
  if batch_idx % 10 == 0: # Log every 10 batches
96
  print(f"Batch {batch_idx}, Loss: {loss.item()}")
97
 
98
+ # Save the trained model
99
+ model_path = "lora_model.pth"
100
+ torch.save(model.state_dict(), model_path)
101
+ print(f"Model saved as {model_path}")
102
 
103
+ print("Training completed.")
104
+ return model_path # Return the path of the saved model
 
 
 
105
 
106
  # Gradio App
107
  def start_training_gradio():
108
  print("Loading metadata and preparing dataset...")
109
  metadata = load_metadata(metadata_file)
110
+ model_path = train_lora(image_folder, metadata)
111
+ return model_path # This will return the model file path for download
112
 
113
+ # Gradio interface
114
  demo = gr.Interface(
115
  fn=start_training_gradio,
116
  inputs=None,
117
+ outputs=gr.File(),
118
  title="Train LoRA Model",
119
  description="Fine-tune a model using LoRA for consistent image generation."
120
  )