File size: 1,274 Bytes
5e69d3c
afb2266
 
 
 
d13afd7
afb2266
 
 
d13afd7
afb2266
ddc7e76
afb2266
 
 
 
 
 
 
5e69d3c
ddc7e76
 
afb2266
 
 
 
 
5e69d3c
 
ddc7e76
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import gradio as gr
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from transformers import DalleBartProcessor, FlaxDalleBartForConditionalGeneration

# Load the model and processor
processor = DalleBartProcessor.from_pretrained("dalle-mini/dalle-mega")
model = FlaxDalleBartForConditionalGeneration.from_pretrained("dalle-mini/dalle-mega")

# Function to generate an image from a text prompt
def generate_image(prompt):
    inputs = processor([prompt], return_tensors="jax", padding="max_length", truncation=True, max_length=128)
    # Generate images
    images = model.generate(**replicate(inputs.data), do_sample=True, num_beams=1, num_return_sequences=1)
    
    # Post-process image for display (convert to PIL image format)
    image = images[0]  # assuming single image output
    return image  # return the generated image

# Create Gradio interface
iface = gr.Interface(
    fn=generate_image,               # Function to generate image
    inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your prompt"),   # Textbox input
    outputs="image",                 # Output as an image
    title="DALL-E Mini Image Generator",
    description="Generate images from text prompts using DALL-E Mini model."
)

# Launch the app
iface.launch()