therealcyberlord commited on
Commit
33921d3
·
1 Parent(s): 3e99287

visual changes to app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -7
app.py CHANGED
@@ -12,12 +12,9 @@ if torch.cuda.is_available():
12
  device = torch.device("cuda")
13
 
14
  latent_size = 100
15
- display_width = 450
16
  checkpoint_path = "Checkpoints/150epochs.chkpt"
17
 
18
  st.title("Generating Abstract Art")
19
- st.text("start generating (left side bar)")
20
- st.text("Made by Xingyu B.")
21
 
22
  st.sidebar.subheader("Configurations")
23
  seed = st.sidebar.slider('Seed', -10000, 10000, 0)
@@ -59,7 +56,6 @@ if generate:
59
  # use srgan for super resolution
60
  if use_srgan == "Yes":
61
  # restore to the checkpoint
62
- st.write("Using DCGAN then ESRGAN upscale...")
63
  esrgan_generator = SRGAN.GeneratorRRDB(channels=3, filters=64, num_res_blocks=23).to(device)
64
  esrgan_checkpoint = load_esrgan()
65
  esrgan_generator.load_state_dict(esrgan_checkpoint)
@@ -71,15 +67,14 @@ if generate:
71
 
72
  for i in range(len(color_match)):
73
  # denormalize and permute to correct color channel
74
- st.image(denormalize_images(color_match[i]).permute(1, 2, 0).numpy(), width=display_width)
75
 
76
 
77
  # default setting -> vanilla dcgan generation
78
  if use_srgan == "No":
79
  fakes = fakes.cpu()
80
- st.write("Using DCGAN Model...")
81
  for i in range(len(fakes)):
82
- st.image(denormalize_images(fakes[i]).permute(1, 2, 0).numpy(), width=display_width)
83
 
84
 
85
 
 
12
  device = torch.device("cuda")
13
 
14
  latent_size = 100
 
15
  checkpoint_path = "Checkpoints/150epochs.chkpt"
16
 
17
  st.title("Generating Abstract Art")
 
 
18
 
19
  st.sidebar.subheader("Configurations")
20
  seed = st.sidebar.slider('Seed', -10000, 10000, 0)
 
56
  # use srgan for super resolution
57
  if use_srgan == "Yes":
58
  # restore to the checkpoint
 
59
  esrgan_generator = SRGAN.GeneratorRRDB(channels=3, filters=64, num_res_blocks=23).to(device)
60
  esrgan_checkpoint = load_esrgan()
61
  esrgan_generator.load_state_dict(esrgan_checkpoint)
 
67
 
68
  for i in range(len(color_match)):
69
  # denormalize and permute to correct color channel
70
+ st.image(denormalize_images(color_match[i]).permute(1, 2, 0).numpy())
71
 
72
 
73
  # default setting -> vanilla dcgan generation
74
  if use_srgan == "No":
75
  fakes = fakes.cpu()
 
76
  for i in range(len(fakes)):
77
+ st.image(denormalize_images(fakes[i]).permute(1, 2, 0).numpy())
78
 
79
 
80