File size: 2,049 Bytes
5b2cc7a d8f7287 5b2cc7a d8f7287 5b2cc7a d8f7287 5b2cc7a d8f7287 5b2cc7a d8f7287 5b2cc7a d8f7287 5b2cc7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import argparse
import os
import yaml
from PIL import Image
import numpy as np
import torch
from cdim.noise import get_noise
from cdim.operators import get_operator
from cdim.image_utils import save_to_image
def load_image(path):
"""
Load the image and normalize to [-1, 1]
"""
original_image = Image.open(path)
# Resize if needed
original_image = np.array(original_image.resize((256, 256), Image.BICUBIC))
original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
return (original_image / 127.5 - 1.0).to(torch.float)
def load_yaml(file_path: str) -> dict:
with open(file_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def main(args):
device_str = f"cuda" if args.cuda and torch.cuda.is_available() else 'cpu'
print(f"Using device {device_str}")
device = torch.device(device_str)
os.makedirs(args.output_dir, exist_ok=True)
original_image = load_image(args.input_image).to(device)
# Load the noise function
noise_config = load_yaml(args.noise_config)
noise_function = get_noise(**noise_config)
print(noise_function)
# Load the measurement function A
operator_config = load_yaml(args.operator_config)
operator_config["device"] = device
operator = get_operator(**operator_config)
print(operator)
noisy_measurement = noise_function(operator(original_image))
save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("input_image", type=str)
parser.add_argument("T", type=int)
parser.add_argument("K", type=int)
parser.add_argument("model", type=str)
parser.add_argument("operator_config", type=str)
parser.add_argument("noise_config", type=str)
parser.add_argument("--output-dir", default=".", type=str)
parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
main(parser.parse_args()) |