Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,16 +12,14 @@ from image_datasets.dataset import image_resize
|
|
| 12 |
from src.flux.util import load_ae, load_clip, load_flow_model2, load_t5, tensor_to_pil_image
|
| 13 |
from src.flux.xflux_pipeline import XFluxSampler
|
| 14 |
args = OmegaConf.load("inference_configs/inference.yaml")
|
| 15 |
-
is_schnell = args.model_name == "flux-schnell"
|
| 16 |
-
'/home/user/app/assets/0_camera_zoom/20486354.png'
|
| 17 |
-
'/home/user/app/assets/0_camera_zoom/20486354.png'
|
| 18 |
# sampler = None
|
| 19 |
-
device = torch.device("cuda")
|
| 20 |
-
dtype = torch.bfloat16
|
| 21 |
-
dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
|
| 22 |
-
vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
|
| 23 |
-
t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
|
| 24 |
-
clip = load_clip("cpu").to(device, dtype=dtype)
|
| 25 |
#test push
|
| 26 |
@spaces.GPU
|
| 27 |
def generate(image: Image.Image, edit_prompt: str):
|
|
@@ -29,26 +27,21 @@ def generate(image: Image.Image, edit_prompt: str):
|
|
| 29 |
|
| 30 |
|
| 31 |
|
| 32 |
-
vae.requires_grad_(False)
|
| 33 |
-
t5.requires_grad_(False)
|
| 34 |
-
clip.requires_grad_(False)
|
| 35 |
|
| 36 |
-
model_path = hf_hub_download(
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
)
|
| 41 |
-
state_dict = load_file(model_path)
|
| 42 |
-
dit.load_state_dict(state_dict)
|
| 43 |
-
dit.eval()
|
| 44 |
-
dit.to(device, dtype=dtype)
|
| 45 |
|
| 46 |
sampler = XFluxSampler(
|
| 47 |
-
clip=clip,
|
| 48 |
-
t5=t5,
|
| 49 |
-
ae=vae,
|
| 50 |
-
model=dit,
|
| 51 |
-
device=device,
|
| 52 |
ip_loaded=False,
|
| 53 |
spatial_condition=False,
|
| 54 |
clip_image_processor=None,
|
|
|
|
| 12 |
from src.flux.util import load_ae, load_clip, load_flow_model2, load_t5, tensor_to_pil_image
|
| 13 |
from src.flux.xflux_pipeline import XFluxSampler
|
| 14 |
args = OmegaConf.load("inference_configs/inference.yaml")
|
| 15 |
+
# is_schnell = args.model_name == "flux-schnell"
|
|
|
|
|
|
|
| 16 |
# sampler = None
|
| 17 |
+
# device = torch.device("cuda")
|
| 18 |
+
# dtype = torch.bfloat16
|
| 19 |
+
# dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
|
| 20 |
+
# vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
|
| 21 |
+
# t5 = load_t5(device="cpu", max_length=256 if is_schnell else 512).to(device, dtype=dtype)
|
| 22 |
+
# clip = load_clip("cpu").to(device, dtype=dtype)
|
| 23 |
#test push
|
| 24 |
@spaces.GPU
|
| 25 |
def generate(image: Image.Image, edit_prompt: str):
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
|
| 30 |
+
# vae.requires_grad_(False)
|
| 31 |
+
# t5.requires_grad_(False)
|
| 32 |
+
# clip.requires_grad_(False)
|
| 33 |
|
| 34 |
+
# model_path = hf_hub_download(
|
| 35 |
+
# repo_id="Boese0601/ByteMorpher",
|
| 36 |
+
# filename="dit.safetensors",
|
| 37 |
+
# use_auth_token=os.getenv("HF_TOKEN")
|
| 38 |
+
# )
|
| 39 |
+
# state_dict = load_file(model_path)
|
| 40 |
+
# dit.load_state_dict(state_dict)
|
| 41 |
+
# dit.eval()
|
| 42 |
+
# dit.to(device, dtype=dtype)
|
| 43 |
|
| 44 |
sampler = XFluxSampler(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
ip_loaded=False,
|
| 46 |
spatial_condition=False,
|
| 47 |
clip_image_processor=None,
|