VinitT commited on
Commit
d2190eb
·
verified ·
1 Parent(s): 6368cfb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -2
app.py CHANGED
@@ -6,6 +6,8 @@ from diffusers import StableDiffusionPipeline
6
  from transformers import CLIPTokenizer
7
  import os
8
  import zipfile
 
 
9
  import gradio as gr
10
 
11
  # Define the device
@@ -104,9 +106,24 @@ def zip_model(model_path):
104
 
105
  # Gradio interface functions
106
  def start_fine_tuning(uploaded_files, prompts, num_epochs):
107
- images = [Image.open(file).convert("RGB") for file in uploaded_files]
 
 
 
 
 
 
 
 
 
 
 
108
  model_save_path = "fine_tuned_model"
109
  fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
 
 
 
 
110
  return "Fine-tuning completed! Model is ready for download."
111
 
112
  def download_model():
@@ -156,4 +173,4 @@ with gr.Blocks() as demo:
156
 
157
  generate_button.click(generate_new_image, [prompt_input], generated_image)
158
 
159
- demo.launch()
 
6
  from transformers import CLIPTokenizer
7
  import os
8
  import zipfile
9
+ import tempfile
10
+ import shutil
11
  import gradio as gr
12
 
13
  # Define the device
 
106
 
107
  # Gradio interface functions
108
  def start_fine_tuning(uploaded_files, prompts, num_epochs):
109
+ # Create a temporary directory for storing files
110
+ temp_dir = tempfile.mkdtemp()
111
+ print("Temporary directory:", temp_dir)
112
+
113
+ images = []
114
+ for file in uploaded_files:
115
+ # Store the uploaded file in the temp directory
116
+ image_path = os.path.join(temp_dir, file.name)
117
+ with open(image_path, 'wb') as f:
118
+ f.write(file.read()) # Save file content
119
+ images.append(Image.open(image_path).convert("RGB"))
120
+
121
  model_save_path = "fine_tuned_model"
122
  fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
123
+
124
+ # Clean up the temporary directory after fine-tuning
125
+ shutil.rmtree(temp_dir)
126
+
127
  return "Fine-tuning completed! Model is ready for download."
128
 
129
  def download_model():
 
173
 
174
  generate_button.click(generate_new_image, [prompt_input], generated_image)
175
 
176
+ demo.launch()