theSure commited on
Commit
e904e7b
Β·
verified Β·
1 Parent(s): ed5622a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -16,6 +16,7 @@ from pipeline_flux_control_removal import FluxControlRemovalPipeline
16
 
17
  torch.set_grad_enabled(False)
18
  os.environ['GRADIO_TEMP_DIR'] = './tmp'
 
19
  image_path = mask_path = None
20
  image_examples = [...]
21
  image_path = mask_path =None
@@ -78,7 +79,7 @@ def load_model(base_model_path, lora_path):
78
  base_model_path,
79
  transformer=transformer,
80
  torch_dtype=torch.bfloat16
81
- ).to("cuda")
82
  pipe.transformer.to(torch.bfloat16)
83
  gr.Info(str(f"Model loading: {int((80 / 100) * 100)}%"))
84
  gr.Info(str(f"Inject LoRA: {lora_path}"))
@@ -146,7 +147,7 @@ def predict(
146
  width=H,
147
  height=W,
148
  num_inference_steps=ddim_steps,
149
- generator=torch.Generator("cuda").manual_seed(seed),
150
  guidance_scale=scale,
151
  max_sequence_length=512,
152
  ).images[0]
 
16
 
17
  torch.set_grad_enabled(False)
18
  os.environ['GRADIO_TEMP_DIR'] = './tmp'
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  image_path = mask_path = None
21
  image_examples = [...]
22
  image_path = mask_path =None
 
79
  base_model_path,
80
  transformer=transformer,
81
  torch_dtype=torch.bfloat16
82
+ ).to(device)
83
  pipe.transformer.to(torch.bfloat16)
84
  gr.Info(str(f"Model loading: {int((80 / 100) * 100)}%"))
85
  gr.Info(str(f"Inject LoRA: {lora_path}"))
 
147
  width=H,
148
  height=W,
149
  num_inference_steps=ddim_steps,
150
+ generator=torch.Generator(device).manual_seed(seed),
151
  guidance_scale=scale,
152
  max_sequence_length=512,
153
  ).images[0]