guiyrt commited on
Commit
bad185b
·
verified ·
1 Parent(s): 278ec4f

Added type annotations for pipeline init args

Browse files

Removing warnings as part of [Comprehensive type checking for `from_pretrained `kwargs](https://github.com/huggingface/diffusers/pull/10758)

Files changed (1) hide show
  1. pipeline.py +2 -2
pipeline.py CHANGED
@@ -18,7 +18,7 @@ from typing import Optional, Tuple, Union
18
 
19
  import torch
20
 
21
- from diffusers import DiffusionPipeline, ImagePipelineOutput
22
 
23
 
24
  class CustomPipeline(DiffusionPipeline):
@@ -33,7 +33,7 @@ class CustomPipeline(DiffusionPipeline):
33
  [`DDPMScheduler`], or [`DDIMScheduler`].
34
  """
35
 
36
- def __init__(self, unet, scheduler):
37
  super().__init__()
38
  self.register_modules(unet=unet, scheduler=scheduler)
39
 
 
18
 
19
  import torch
20
 
21
+ from diffusers import DiffusionPipeline, ImagePipelineOutput, SchedulerMixin, UNet2DModel
22
 
23
 
24
  class CustomPipeline(DiffusionPipeline):
 
33
  [`DDPMScheduler`], or [`DDIMScheduler`].
34
  """
35
 
36
+ def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
37
  super().__init__()
38
  self.register_modules(unet=unet, scheduler=scheduler)
39