Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,6 @@ from diffusers.utils import load_image
|
|
15 |
from pipeline_flux_control_removal import FluxControlRemovalPipeline
|
16 |
|
17 |
torch.set_grad_enabled(False)
|
18 |
-
os.environ['GRADIO_TEMP_DIR'] = './tmp'
|
19 |
device = "cuda"
|
20 |
print(device)
|
21 |
image_path = mask_path = None
|
@@ -52,7 +51,7 @@ image_examples = [
|
|
52 |
]
|
53 |
|
54 |
]
|
55 |
-
@spaces.GPU
|
56 |
def load_model(base_model_path, lora_path):
|
57 |
global pipe
|
58 |
transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
|
@@ -87,7 +86,7 @@ def load_model(base_model_path, lora_path):
|
|
87 |
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
|
88 |
gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
|
89 |
|
90 |
-
@spaces.GPU
|
91 |
def set_seed(seed):
|
92 |
torch.manual_seed(seed)
|
93 |
torch.cuda.manual_seed(seed)
|
@@ -95,7 +94,7 @@ def set_seed(seed):
|
|
95 |
np.random.seed(seed)
|
96 |
random.seed(seed)
|
97 |
|
98 |
-
@spaces.GPU
|
99 |
def predict(
|
100 |
input_image,
|
101 |
prompt,
|
|
|
15 |
from pipeline_flux_control_removal import FluxControlRemovalPipeline
|
16 |
|
17 |
torch.set_grad_enabled(False)
|
|
|
18 |
device = "cuda"
|
19 |
print(device)
|
20 |
image_path = mask_path = None
|
|
|
51 |
]
|
52 |
|
53 |
]
|
54 |
+
@spaces.GPU(duration=120)
|
55 |
def load_model(base_model_path, lora_path):
|
56 |
global pipe
|
57 |
transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
|
|
|
86 |
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
|
87 |
gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
|
88 |
|
89 |
+
@spaces.GPU(duration=120)
|
90 |
def set_seed(seed):
|
91 |
torch.manual_seed(seed)
|
92 |
torch.cuda.manual_seed(seed)
|
|
|
94 |
np.random.seed(seed)
|
95 |
random.seed(seed)
|
96 |
|
97 |
+
@spaces.GPU(duration=120)
|
98 |
def predict(
|
99 |
input_image,
|
100 |
prompt,
|