Added type annotations for pipeline init args
Browse filesRemoving warnings as part of [Comprehensive type checking for `from_pretrained `kwargs](https://github.com/huggingface/diffusers/pull/10758)
- pipeline.py +2 -2
pipeline.py
CHANGED
@@ -18,7 +18,7 @@ from typing import Optional, Tuple, Union
|
|
18 |
|
19 |
import torch
|
20 |
|
21 |
-
from diffusers import DiffusionPipeline, ImagePipelineOutput
|
22 |
|
23 |
|
24 |
class CustomPipeline(DiffusionPipeline):
|
@@ -33,7 +33,7 @@ class CustomPipeline(DiffusionPipeline):
|
|
33 |
[`DDPMScheduler`], or [`DDIMScheduler`].
|
34 |
"""
|
35 |
|
36 |
-
def __init__(self, unet, scheduler):
|
37 |
super().__init__()
|
38 |
self.register_modules(unet=unet, scheduler=scheduler)
|
39 |
|
|
|
18 |
|
19 |
import torch
|
20 |
|
21 |
+
from diffusers import DiffusionPipeline, ImagePipelineOutput, SchedulerMixin, UNet2DModel
|
22 |
|
23 |
|
24 |
class CustomPipeline(DiffusionPipeline):
|
|
|
33 |
[`DDPMScheduler`], or [`DDIMScheduler`].
|
34 |
"""
|
35 |
|
36 |
+
def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
|
37 |
super().__init__()
|
38 |
self.register_modules(unet=unet, scheduler=scheduler)
|
39 |
|