|
from imgen3flip import weights_path, Model, ImageBatch, OPTS |
|
import torch |
|
import torchvision as TV |
|
import torchvision.transforms.functional as VF |
|
import sys |
|
|
|
|
|
assert weights_path.exists(), "Model weights do not exist" |
|
|
|
assert len(sys.argv) == 3, f"Usage: { |
|
sys.argv[0]} <input-filename> <output-filename>" |
|
|
|
input_filename = sys.argv[1] |
|
output_filename = sys.argv[2] |
|
|
|
assert input_filename != output_filename, f"Use different file names" |
|
|
|
print("Loading the model") |
|
model = Model() |
|
model.load_state_dict(torch.load(weights_path)) |
|
|
|
print(f"Loading 8x8 input image from {input_filename}") |
|
|
|
image = TV.io.read_image(input_filename)[:3] |
|
|
|
image = image / 255.0 |
|
assert image.shape[0] == 3, "RGB image expected" |
|
|
|
image = image.permute(1, 2, 0) |
|
|
|
|
|
image = image.view(1, 8, 8, 3) |
|
|
|
|
|
|
|
dummy_target = torch.zeros(1, 64, 64, 3, **OPTS) |
|
dummy_loss = torch.tensor(-1, **OPTS) |
|
inference_batch = ImageBatch( |
|
im8=image.to(**OPTS), |
|
im64=dummy_target, |
|
loss=dummy_loss) |
|
result = model(inference_batch) |
|
|
|
|
|
new_image = result.im64.detach().float().cpu() |
|
|
|
new_image = new_image[0] |
|
|
|
new_image = new_image.permute(2, 0, 1) |
|
assert new_image.shape == (3, 64, 64) |
|
img = VF.to_pil_image(new_image) |
|
|
|
print(f"Writing {img.height}x{img.width} image to {output_filename}") |
|
img.save(output_filename) |
|
|