File size: 1,568 Bytes
cd267d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from __future__ import annotations
from functools import cached_property
from diffusers import (
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
)
from asdff.base import AdPipelineBase
class AdPipeline(AdPipelineBase, StableDiffusionPipeline):
@cached_property
def inpaint_pipeline(self):
return StableDiffusionInpaintPipeline(
vae=self.vae,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
unet=self.unet,
scheduler=self.scheduler,
safety_checker=self.safety_checker,
feature_extractor=self.feature_extractor,
requires_safety_checker=self.config.requires_safety_checker,
)
@property
def txt2img_class(self):
return StableDiffusionPipeline
class AdCnPipeline(AdPipelineBase, StableDiffusionControlNetPipeline):
@cached_property
def inpaint_pipeline(self):
return StableDiffusionControlNetInpaintPipeline(
vae=self.vae,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
unet=self.unet,
controlnet=self.controlnet,
scheduler=self.scheduler,
safety_checker=self.safety_checker,
feature_extractor=self.feature_extractor,
requires_safety_checker=self.config.requires_safety_checker,
)
@property
def txt2img_class(self):
return StableDiffusionControlNetPipeline
|