Spaces:
Running
on
Zero
Running
on
Zero
Update model.py
Browse files
model.py
CHANGED
|
@@ -57,7 +57,7 @@ class Model:
|
|
| 57 |
beta_schedule="scaled_linear",
|
| 58 |
num_train_timesteps=1000,
|
| 59 |
steps_offset=1
|
| 60 |
-
)
|
| 61 |
# pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7)
|
| 62 |
pipe.enable_xformers_memory_efficient_attention()
|
| 63 |
pipe.force_zeros_for_empty_prompt = False
|
|
@@ -70,34 +70,34 @@ class Model:
|
|
| 70 |
print(f'Loaded {model_id}...')
|
| 71 |
return pipe
|
| 72 |
|
| 73 |
-
def set_base_model(self, base_model_id: str) -> str:
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
|
| 85 |
def load_controlnet_weight(self, task_name: str) -> None:
|
| 86 |
print('Entered load_controlnet_weight....')
|
| 87 |
-
if task_name == self.task_name:
|
| 88 |
-
|
| 89 |
-
if self.pipe is not None and hasattr(self.pipe, "controlnet"):
|
| 90 |
-
|
| 91 |
-
torch.cuda.empty_cache()
|
| 92 |
-
gc.collect()
|
| 93 |
-
model_id = CONTROLNET_MODEL_IDS[task_name]
|
| 94 |
-
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
|
| 95 |
-
print(f'Loaded {model_id}...')
|
| 96 |
-
controlnet.to(self.device)
|
| 97 |
-
torch.cuda.empty_cache()
|
| 98 |
-
gc.collect()
|
| 99 |
-
self.pipe.controlnet = controlnet
|
| 100 |
-
self.task_name = task_name
|
| 101 |
|
| 102 |
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
|
| 103 |
if not prompt:
|
|
|
|
| 57 |
beta_schedule="scaled_linear",
|
| 58 |
num_train_timesteps=1000,
|
| 59 |
steps_offset=1
|
| 60 |
+
).to('cuda')
|
| 61 |
# pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7)
|
| 62 |
pipe.enable_xformers_memory_efficient_attention()
|
| 63 |
pipe.force_zeros_for_empty_prompt = False
|
|
|
|
| 70 |
print(f'Loaded {model_id}...')
|
| 71 |
return pipe
|
| 72 |
|
| 73 |
+
# def set_base_model(self, base_model_id: str) -> str:
|
| 74 |
+
# if not base_model_id or base_model_id == self.base_model_id:
|
| 75 |
+
# return self.base_model_id
|
| 76 |
+
# del self.pipe
|
| 77 |
+
# torch.cuda.empty_cache()
|
| 78 |
+
# gc.collect()
|
| 79 |
+
# try:
|
| 80 |
+
# self.pipe = self.load_pipe(base_model_id, self.task_name)
|
| 81 |
+
# except Exception:
|
| 82 |
+
# self.pipe = self.load_pipe(self.base_model_id, self.task_name)
|
| 83 |
+
# return self.base_model_id
|
| 84 |
|
| 85 |
def load_controlnet_weight(self, task_name: str) -> None:
|
| 86 |
print('Entered load_controlnet_weight....')
|
| 87 |
+
# if task_name == self.task_name:
|
| 88 |
+
# return
|
| 89 |
+
# if self.pipe is not None and hasattr(self.pipe, "controlnet"):
|
| 90 |
+
# del self.pipe.controlnet
|
| 91 |
+
# torch.cuda.empty_cache()
|
| 92 |
+
# gc.collect()
|
| 93 |
+
# model_id = CONTROLNET_MODEL_IDS[task_name]
|
| 94 |
+
# controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
|
| 95 |
+
# print(f'Loaded {model_id}...')
|
| 96 |
+
# controlnet.to(self.device)
|
| 97 |
+
# torch.cuda.empty_cache()
|
| 98 |
+
# gc.collect()
|
| 99 |
+
# self.pipe.controlnet = controlnet
|
| 100 |
+
# self.task_name = task_name
|
| 101 |
|
| 102 |
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
|
| 103 |
if not prompt:
|