My-AI-Projects commited on
Commit
60c6128
Β·
verified Β·
1 Parent(s): 032de9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -13
app.py CHANGED
@@ -1,30 +1,38 @@
1
  import gradio as gr
 
 
 
2
  import jax
3
  import jax.numpy as jnp
4
- from flax.jax_utils import replicate
5
- from transformers import DalleBartProcessor, FlaxDalleBartForConditionalGeneration
6
 
7
- # Load the model and processor
8
  processor = DalleBartProcessor.from_pretrained("dalle-mini/dalle-mega")
9
  model = FlaxDalleBartForConditionalGeneration.from_pretrained("dalle-mini/dalle-mega")
10
 
11
  # Function to generate an image from a text prompt
12
  def generate_image(prompt):
13
- inputs = processor([prompt], return_tensors="jax", padding="max_length", truncation=True, max_length=128)
14
- # Generate images
15
- images = model.generate(**replicate(inputs.data), do_sample=True, num_beams=1, num_return_sequences=1)
 
 
 
 
 
 
 
 
 
16
 
17
- # Post-process image for display (convert to PIL image format)
18
- image = images[0] # assuming single image output
19
- return image # return the generated image
20
 
21
  # Create Gradio interface
22
  iface = gr.Interface(
23
- fn=generate_image, # Function to generate image
24
- inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your prompt"), # Textbox input
25
  outputs="image", # Output as an image
26
- title="DALL-E Mini Image Generator",
27
- description="Generate images from text prompts using DALL-E Mini model."
28
  )
29
 
30
  # Launch the app
 
1
  import gradio as gr
2
+ from transformers import DalleBartProcessor, FlaxDalleBartForConditionalGeneration
3
+ from PIL import Image
4
+ import numpy as np
5
  import jax
6
  import jax.numpy as jnp
 
 
7
 
8
+ # Load the DALL-E Mega model and processor
9
  processor = DalleBartProcessor.from_pretrained("dalle-mini/dalle-mega")
10
  model = FlaxDalleBartForConditionalGeneration.from_pretrained("dalle-mini/dalle-mega")
11
 
12
  # Function to generate an image from a text prompt
13
  def generate_image(prompt):
14
+ # Process the prompt
15
+ inputs = processor([prompt], return_tensors="jax", padding="max_length", truncation=True, max_length=64)
16
+
17
+ # Generate the images
18
+ outputs = model.generate(**inputs, do_sample=True, num_beams=4, num_return_sequences=1)
19
+
20
+ # Decode the images and convert them to displayable format
21
+ images = model.decode(outputs.sequences)
22
+ images = jax.device_get(images)
23
+
24
+ # Convert to a PIL image
25
+ pil_img = Image.fromarray(np.asarray(images[0]).astype(np.uint8))
26
 
27
+ return pil_img
 
 
28
 
29
  # Create Gradio interface
30
  iface = gr.Interface(
31
+ fn=generate_image, # Function to generate the image
32
+ inputs=gr.Textbox(lines=2, placeholder="Enter your text prompt"), # Input textbox for the prompt
33
  outputs="image", # Output as an image
34
+ title="DALL-E Mini/Mega Image Generator",
35
+ description="Generate images from text prompts using the DALL-E Mega model."
36
  )
37
 
38
  # Launch the app