Spaces:
Runtime error
Runtime error
add inference script features
Browse files- climategan_wrapper.py +82 -1
climategan_wrapper.py
CHANGED
|
@@ -5,7 +5,7 @@ import os
|
|
| 5 |
import re
|
| 6 |
from pathlib import Path
|
| 7 |
from uuid import uuid4
|
| 8 |
-
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
from diffusers import StableDiffusionInpaintPipeline
|
|
@@ -541,3 +541,84 @@ class ClimateGAN:
|
|
| 541 |
im = Image.fromarray(uint8(im))
|
| 542 |
imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
|
| 543 |
im.save(im_path.parent / (imstem + im_path.suffix))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import re
|
| 6 |
from pathlib import Path
|
| 7 |
from uuid import uuid4
|
| 8 |
+
from minydra import resolved_args
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
from diffusers import StableDiffusionInpaintPipeline
|
|
|
|
| 541 |
im = Image.fromarray(uint8(im))
|
| 542 |
imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
|
| 543 |
im.save(im_path.parent / (imstem + im_path.suffix))
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
if __name__ == "__main__":
|
| 547 |
+
print("Run `$ python climategan_wrapper.py help` for usage instructions\n")
|
| 548 |
+
|
| 549 |
+
# parse arguments
|
| 550 |
+
args = resolved_args(
|
| 551 |
+
defaults={
|
| 552 |
+
"input_folder": None,
|
| 553 |
+
"output_folder": None,
|
| 554 |
+
"painter": "both",
|
| 555 |
+
"help": False,
|
| 556 |
+
}
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
# print help
|
| 560 |
+
if args.help:
|
| 561 |
+
print(
|
| 562 |
+
"Usage: python inference.py input_folder=/path/to/folder\n"
|
| 563 |
+
+ "By default inferences will be stored in the input folder.\n"
|
| 564 |
+
+ "Add `output_folder=/path/to/folder` for a different output folder.\n"
|
| 565 |
+
+ "By default, both ClimateGAN and Stable Diffusion will be used."
|
| 566 |
+
+ "Change this by adding `painter=climategan` or"
|
| 567 |
+
+ " `painter=stable_diffusion`.\n"
|
| 568 |
+
+ "Make sure you have agreed to the terms of use for the models."
|
| 569 |
+
+ "In particular, visit SD's model card to agree to the terms of use:"
|
| 570 |
+
+ " https://huggingface.co/runwayml/stable-diffusion-inpainting"
|
| 571 |
+
)
|
| 572 |
+
# print args
|
| 573 |
+
args.pretty_print()
|
| 574 |
+
|
| 575 |
+
# load models
|
| 576 |
+
cg = ClimateGAN("models/climategan")
|
| 577 |
+
|
| 578 |
+
# check painter type
|
| 579 |
+
assert args.painter in {"climategan", "stable_diffusion", "both",}, (
|
| 580 |
+
f"Unknown painter {args.painter}. "
|
| 581 |
+
+ "Allowed values are 'climategan', 'stable_diffusion' and 'both'."
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# load SD pipeline if need be
|
| 585 |
+
if args.painter != "climate_gan":
|
| 586 |
+
cg._setup_stable_diffusion()
|
| 587 |
+
|
| 588 |
+
# resolve input folder path
|
| 589 |
+
in_path = Path(args.input_folder).expanduser().resolve()
|
| 590 |
+
assert in_path.exists(), f"Folder {str(in_path)} does not exist"
|
| 591 |
+
|
| 592 |
+
# output is input if not specified
|
| 593 |
+
if args.output_folder is None:
|
| 594 |
+
out_path = in_path
|
| 595 |
+
|
| 596 |
+
# find images in input folder
|
| 597 |
+
im_paths = [
|
| 598 |
+
p
|
| 599 |
+
for p in in_path.iterdir()
|
| 600 |
+
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
|
| 601 |
+
]
|
| 602 |
+
assert im_paths, f"No images found in {str(im_paths)}"
|
| 603 |
+
|
| 604 |
+
print(f"\nFound {len(im_paths)} images in {str(in_path)}\n")
|
| 605 |
+
|
| 606 |
+
# infer and write
|
| 607 |
+
for i, im_path in enumerate(im_paths):
|
| 608 |
+
print(">>> Processing", f"{i}/{len(im_paths)}", im_path.name)
|
| 609 |
+
outs = cg.infer_single(
|
| 610 |
+
np.array(Image.open(im_path)),
|
| 611 |
+
args.painter,
|
| 612 |
+
as_pil_image=True,
|
| 613 |
+
concats=[
|
| 614 |
+
"input",
|
| 615 |
+
"masked_input",
|
| 616 |
+
"climategan_flood",
|
| 617 |
+
"stable_copy_flood",
|
| 618 |
+
],
|
| 619 |
+
)
|
| 620 |
+
for k, v in outs.items():
|
| 621 |
+
name = f"{im_path.stem}---{k}{im_path.suffix}"
|
| 622 |
+
im = Image.fromarray(uint8(v))
|
| 623 |
+
im.save(out_path / name)
|
| 624 |
+
print(">>> Done", f"{i}/{len(im_paths)}", im_path.name, end="\n\n")
|