import gradio as gr import jax import jax.numpy as jnp from flax.jax_utils import replicate from transformers import DalleBartProcessor, FlaxDalleBartForConditionalGeneration # Load the model and processor processor = DalleBartProcessor.from_pretrained("dalle-mini/dalle-mega") model = FlaxDalleBartForConditionalGeneration.from_pretrained("dalle-mini/dalle-mega") # Function to generate an image from a text prompt def generate_image(prompt): inputs = processor([prompt], return_tensors="jax", padding="max_length", truncation=True, max_length=128) # Generate images images = model.generate(**replicate(inputs.data), do_sample=True, num_beams=1, num_return_sequences=1) # Post-process image for display (convert to PIL image format) image = images[0] # assuming single image output return image # return the generated image # Create Gradio interface iface = gr.Interface( fn=generate_image, # Function to generate image inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your prompt"), # Textbox input outputs="image", # Output as an image title="DALL-E Mini Image Generator", description="Generate images from text prompts using DALL-E Mini model." ) # Launch the app iface.launch()