Spaces:
Runtime error
Runtime error
import gradio as gr | |
import jax | |
import jax.numpy as jnp | |
from flax.jax_utils import replicate | |
from transformers import DalleBartProcessor, FlaxDalleBartForConditionalGeneration | |
# Load the model and processor | |
processor = DalleBartProcessor.from_pretrained("dalle-mini/dalle-mega") | |
model = FlaxDalleBartForConditionalGeneration.from_pretrained("dalle-mini/dalle-mega") | |
# Function to generate an image from a text prompt | |
def generate_image(prompt): | |
inputs = processor([prompt], return_tensors="jax", padding="max_length", truncation=True, max_length=128) | |
# Generate images | |
images = model.generate(**replicate(inputs.data), do_sample=True, num_beams=1, num_return_sequences=1) | |
# Post-process image for display (convert to PIL image format) | |
image = images[0] # assuming single image output | |
return image # return the generated image | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_image, # Function to generate image | |
inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your prompt"), # Textbox input | |
outputs="image", # Output as an image | |
title="DALL-E Mini Image Generator", | |
description="Generate images from text prompts using DALL-E Mini model." | |
) | |
# Launch the app | |
iface.launch() | |