pranavajay commited on
Commit
bf08430
·
verified ·
1 Parent(s): 8a49fba

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +3 -4
api.py CHANGED
@@ -9,9 +9,8 @@ import datetime
9
  from fastapi import FastAPI, HTTPException, Request, Response
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel, constr, conint
12
- from diffusers import (FluxPipeline, FluxControlNetPipeline,
13
- FluxControlNetModel, FluxImg2ImgPipeline,
14
- FluxInpaintPipeline, CogVideoXImageToVideoPipeline)
15
  from diffusers.utils import load_image
16
  from PIL import Image
17
  from collections import defaultdict
@@ -151,7 +150,7 @@ async def set_controlnet_adapter(adapter: str, is_inpainting: bool = False):
151
 
152
  controlnet_model_path = adapter_controlnet_mapping[adapter]
153
  controlnet = FluxControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16)
154
- pipeline_cls = FluxControlNetPipeline if not is_inpainting else FluxInpaintPipeline
155
  flux_controlnet_pipe = pipeline_cls.from_pretrained(
156
  "pranavajay/flow", controlnet=controlnet, torch_dtype=torch.bfloat16
157
  )
 
9
  from fastapi import FastAPI, HTTPException, Request, Response
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel, constr, conint
12
+ from diffusers import FluxPipeline, FluxControlNetModel, FluxImg2ImgPipeline, FluxInpaintPipeline, CogVideoXImageToVideoPipeline
13
+ from diffusers.pipelines import FluxControlNetPipeline, FluxControlNetInpaintPipeline
 
14
  from diffusers.utils import load_image
15
  from PIL import Image
16
  from collections import defaultdict
 
150
 
151
  controlnet_model_path = adapter_controlnet_mapping[adapter]
152
  controlnet = FluxControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16)
153
+ pipeline_cls = FluxControlNetPipeline if not is_inpainting else FluxControlNetInpaintPipeline
154
  flux_controlnet_pipe = pipeline_cls.from_pretrained(
155
  "pranavajay/flow", controlnet=controlnet, torch_dtype=torch.bfloat16
156
  )