Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import DalleBartProcessor, FlaxDalleBartForConditionalGeneration | |
from PIL import Image | |
import numpy as np | |
import jax | |
import jax.numpy as jnp | |
# Load the DALL-E Mega model and processor | |
processor = DalleBartProcessor.from_pretrained("dalle-mini/dalle-mega") | |
model = FlaxDalleBartForConditionalGeneration.from_pretrained("dalle-mini/dalle-mega") | |
# Function to generate an image from a text prompt | |
def generate_image(prompt): | |
# Process the prompt | |
inputs = processor([prompt], return_tensors="jax", padding="max_length", truncation=True, max_length=64) | |
# Generate the images | |
outputs = model.generate(**inputs, do_sample=True, num_beams=4, num_return_sequences=1) | |
# Decode the images and convert them to displayable format | |
images = model.decode(outputs.sequences) | |
images = jax.device_get(images) | |
# Convert to a PIL image | |
pil_img = Image.fromarray(np.asarray(images[0]).astype(np.uint8)) | |
return pil_img | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_image, # Function to generate the image | |
inputs=gr.Textbox(lines=2, placeholder="Enter your text prompt"), # Input textbox for the prompt | |
outputs="image", # Output as an image | |
title="DALL-E Mini/Mega Image Generator", | |
description="Generate images from text prompts using the DALL-E Mega model." | |
) | |
# Launch the app | |
iface.launch() | |