fffiloni commited on
Commit
28a5c9e
·
verified ·
1 Parent(s): d9cb349

add custom seed parameter

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -57,7 +57,7 @@ def start_over(gallery_state, loaded_model_setup):
57
  loaded_model_setup = None # Reset loaded model setup to prevent re-triggering old state
58
  return gallery_state, None, None, gr.update(visible=False), loaded_model_setup
59
 
60
- def setup_model(prompt, model, num_iterations, learning_rate, hps_w, imgrw_w, pcks_w, clip_w, progress=gr.Progress(track_tqdm=True)):
61
 
62
  """Clear CUDA memory before starting the training."""
63
  torch.cuda.empty_cache() # Free up cached memory
@@ -67,6 +67,7 @@ def setup_model(prompt, model, num_iterations, learning_rate, hps_w, imgrw_w, pc
67
  args.task = "single"
68
  args.prompt = prompt
69
  args.model = model
 
70
  args.n_iters = num_iterations
71
  args.lr = learning_rate
72
  args.cache_dir = "./HF_model_cache"
@@ -209,6 +210,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
209
  prompt = gr.Textbox(label="Prompt")
210
  with gr.Row():
211
  chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd"], label="Model", value="sd-turbo")
 
212
  model_status = gr.Textbox(label="model status", visible=False)
213
 
214
  with gr.Row():
@@ -247,7 +249,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
247
  outputs = [gallery_state, output_image, status, iter_gallery, loaded_model_setup] # Ensure loaded_model_setup is reset
248
  ).then(
249
  fn = setup_model,
250
- inputs = [prompt, chosen_model, n_iter, hps_w, imgrw_w, pcks_w, clip_w, learning_rate],
251
  outputs = [output_image, loaded_model_setup] # Load the new setup into the state
252
  ).then(
253
  fn = generate_image,
 
57
  loaded_model_setup = None # Reset loaded model setup to prevent re-triggering old state
58
  return gallery_state, None, None, gr.update(visible=False), loaded_model_setup
59
 
60
+ def setup_model(prompt, model, seed, num_iterations, learning_rate, hps_w, imgrw_w, pcks_w, clip_w, progress=gr.Progress(track_tqdm=True)):
61
 
62
  """Clear CUDA memory before starting the training."""
63
  torch.cuda.empty_cache() # Free up cached memory
 
67
  args.task = "single"
68
  args.prompt = prompt
69
  args.model = model
70
+ args.seed = seed
71
  args.n_iters = num_iterations
72
  args.lr = learning_rate
73
  args.cache_dir = "./HF_model_cache"
 
210
  prompt = gr.Textbox(label="Prompt")
211
  with gr.Row():
212
  chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd"], label="Model", value="sd-turbo")
213
+ seed = gr.Number(label="seed", value=0)
214
  model_status = gr.Textbox(label="model status", visible=False)
215
 
216
  with gr.Row():
 
249
  outputs = [gallery_state, output_image, status, iter_gallery, loaded_model_setup] # Ensure loaded_model_setup is reset
250
  ).then(
251
  fn = setup_model,
252
+ inputs = [prompt, chosen_model, seed, n_iter, hps_w, imgrw_w, pcks_w, clip_w, learning_rate],
253
  outputs = [output_image, loaded_model_setup] # Load the new setup into the state
254
  ).then(
255
  fn = generate_image,