Spaces:
Running
Running
| import os | |
| from collections import OrderedDict | |
| from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig | |
| from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm | |
| from toolkit.sd_device_states_presets import get_train_sd_device_state_preset | |
| from toolkit.stable_diffusion_model import StableDiffusion | |
| import gc | |
| import torch | |
| from jobs.process import BaseExtensionProcess | |
| from toolkit.train_tools import get_torch_dtype | |
| def flush(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| class PureLoraGenerator(BaseExtensionProcess): | |
| def __init__(self, process_id: int, job, config: OrderedDict): | |
| super().__init__(process_id, job, config) | |
| self.output_folder = self.get_conf('output_folder', required=True) | |
| self.device = self.get_conf('device', 'cuda') | |
| self.device_torch = torch.device(self.device) | |
| self.model_config = ModelConfig(**self.get_conf('model', required=True)) | |
| self.generate_config = SampleConfig(**self.get_conf('sample', required=True)) | |
| self.dtype = self.get_conf('dtype', 'float16') | |
| self.torch_dtype = get_torch_dtype(self.dtype) | |
| lorm_config = self.get_conf('lorm', None) | |
| self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None | |
| self.device_state_preset = get_train_sd_device_state_preset( | |
| device=torch.device(self.device), | |
| ) | |
| self.progress_bar = None | |
| self.sd = StableDiffusion( | |
| device=self.device, | |
| model_config=self.model_config, | |
| dtype=self.dtype, | |
| ) | |
| def run(self): | |
| super().run() | |
| print("Loading model...") | |
| with torch.no_grad(): | |
| self.sd.load_model() | |
| self.sd.unet.eval() | |
| self.sd.unet.to(self.device_torch) | |
| if isinstance(self.sd.text_encoder, list): | |
| for te in self.sd.text_encoder: | |
| te.eval() | |
| te.to(self.device_torch) | |
| else: | |
| self.sd.text_encoder.eval() | |
| self.sd.to(self.device_torch) | |
| print(f"Converting to LoRM UNet") | |
| # replace the unet with LoRMUnet | |
| convert_diffusers_unet_to_lorm( | |
| self.sd.unet, | |
| config=self.lorm_config, | |
| ) | |
| sample_folder = os.path.join(self.output_folder) | |
| gen_img_config_list = [] | |
| sample_config = self.generate_config | |
| start_seed = sample_config.seed | |
| current_seed = start_seed | |
| for i in range(len(sample_config.prompts)): | |
| if sample_config.walk_seed: | |
| current_seed = start_seed + i | |
| filename = f"[time]_[count].{self.generate_config.ext}" | |
| output_path = os.path.join(sample_folder, filename) | |
| prompt = sample_config.prompts[i] | |
| extra_args = {} | |
| gen_img_config_list.append(GenerateImageConfig( | |
| prompt=prompt, # it will autoparse the prompt | |
| width=sample_config.width, | |
| height=sample_config.height, | |
| negative_prompt=sample_config.neg, | |
| seed=current_seed, | |
| guidance_scale=sample_config.guidance_scale, | |
| guidance_rescale=sample_config.guidance_rescale, | |
| num_inference_steps=sample_config.sample_steps, | |
| network_multiplier=sample_config.network_multiplier, | |
| output_path=output_path, | |
| output_ext=sample_config.ext, | |
| adapter_conditioning_scale=sample_config.adapter_conditioning_scale, | |
| **extra_args | |
| )) | |
| # send to be generated | |
| self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) | |
| print("Done generating images") | |
| # cleanup | |
| del self.sd | |
| gc.collect() | |
| torch.cuda.empty_cache() | |