Spaces:
Runtime error
Runtime error
File size: 1,438 Bytes
5e69d3c 60c6128 afb2266 d13afd7 60c6128 afb2266 d13afd7 afb2266 ddc7e76 60c6128 afb2266 60c6128 5e69d3c ddc7e76 60c6128 afb2266 60c6128 5e69d3c ddc7e76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import gradio as gr
from transformers import DalleBartProcessor, FlaxDalleBartForConditionalGeneration
from PIL import Image
import numpy as np
import jax
import jax.numpy as jnp
# Load the DALL-E Mega model and processor
processor = DalleBartProcessor.from_pretrained("dalle-mini/dalle-mega")
model = FlaxDalleBartForConditionalGeneration.from_pretrained("dalle-mini/dalle-mega")
# Function to generate an image from a text prompt
def generate_image(prompt):
# Process the prompt
inputs = processor([prompt], return_tensors="jax", padding="max_length", truncation=True, max_length=64)
# Generate the images
outputs = model.generate(**inputs, do_sample=True, num_beams=4, num_return_sequences=1)
# Decode the images and convert them to displayable format
images = model.decode(outputs.sequences)
images = jax.device_get(images)
# Convert to a PIL image
pil_img = Image.fromarray(np.asarray(images[0]).astype(np.uint8))
return pil_img
# Create Gradio interface
iface = gr.Interface(
fn=generate_image, # Function to generate the image
inputs=gr.Textbox(lines=2, placeholder="Enter your text prompt"), # Input textbox for the prompt
outputs="image", # Output as an image
title="DALL-E Mini/Mega Image Generator",
description="Generate images from text prompts using the DALL-E Mega model."
)
# Launch the app
iface.launch()
|