StableDiffusion_test / model_t.py
Nguyendo1999's picture
Update model_t.py
5e9c22e
raw
history blame
466 Bytes
import time
import keras_cv
from tensorflow import keras
import matplotlib.pyplot as plt
model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)
images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)
def plot_images(images):
plt.figure(figsize=(20, 20))
for i in range(len(images)):
ax = plt.subplot(1, len(images), i + 1)
plt.imshow(images[i])
plt.axis("off")
plot_images(images)