Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,6 +16,7 @@ from pipeline_flux_control_removal import FluxControlRemovalPipeline
|
|
| 16 |
|
| 17 |
torch.set_grad_enabled(False)
|
| 18 |
os.environ['GRADIO_TEMP_DIR'] = './tmp'
|
|
|
|
| 19 |
image_path = mask_path = None
|
| 20 |
image_examples = [...]
|
| 21 |
image_path = mask_path =None
|
|
@@ -78,7 +79,7 @@ def load_model(base_model_path, lora_path):
|
|
| 78 |
base_model_path,
|
| 79 |
transformer=transformer,
|
| 80 |
torch_dtype=torch.bfloat16
|
| 81 |
-
).to(
|
| 82 |
pipe.transformer.to(torch.bfloat16)
|
| 83 |
gr.Info(str(f"Model loading: {int((80 / 100) * 100)}%"))
|
| 84 |
gr.Info(str(f"Inject LoRA: {lora_path}"))
|
|
@@ -146,7 +147,7 @@ def predict(
|
|
| 146 |
width=H,
|
| 147 |
height=W,
|
| 148 |
num_inference_steps=ddim_steps,
|
| 149 |
-
generator=torch.Generator(
|
| 150 |
guidance_scale=scale,
|
| 151 |
max_sequence_length=512,
|
| 152 |
).images[0]
|
|
|
|
| 16 |
|
| 17 |
torch.set_grad_enabled(False)
|
| 18 |
os.environ['GRADIO_TEMP_DIR'] = './tmp'
|
| 19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
image_path = mask_path = None
|
| 21 |
image_examples = [...]
|
| 22 |
image_path = mask_path =None
|
|
|
|
| 79 |
base_model_path,
|
| 80 |
transformer=transformer,
|
| 81 |
torch_dtype=torch.bfloat16
|
| 82 |
+
).to(device)
|
| 83 |
pipe.transformer.to(torch.bfloat16)
|
| 84 |
gr.Info(str(f"Model loading: {int((80 / 100) * 100)}%"))
|
| 85 |
gr.Info(str(f"Inject LoRA: {lora_path}"))
|
|
|
|
| 147 |
width=H,
|
| 148 |
height=W,
|
| 149 |
num_inference_steps=ddim_steps,
|
| 150 |
+
generator=torch.Generator(device).manual_seed(seed),
|
| 151 |
guidance_scale=scale,
|
| 152 |
max_sequence_length=512,
|
| 153 |
).images[0]
|