My-AI-Projects commited on
Commit
afb2266
Β·
verified Β·
1 Parent(s): b7c099e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -13
app.py CHANGED
@@ -1,23 +1,30 @@
1
  import gradio as gr
2
- from transformers import eBart
3
- import torch
 
 
4
 
5
- # Load the eBart model
6
- model = eBart.from_pretrained("dalle-mini/dalle-mega")
 
7
 
8
- # Define a function to generate an image from text
9
  def generate_image(prompt):
10
- inputs = model.prepare_inputs_for_generation(prompt)
11
- outputs = model.generate(inputs)
12
- return outputs # You may need to convert it to a displayable image format depending on the model output
 
 
 
 
13
 
14
  # Create Gradio interface
15
  iface = gr.Interface(
16
- fn=generate_image, # Function that takes a prompt and returns an image
17
- inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your prompt"), # Textbox input for the prompt
18
- outputs="image", # Output is an image
19
- title="eBart DALL-E Mega Image Generator",
20
- description="Generate images from text prompts using the DALL-E Mega model."
21
  )
22
 
23
  # Launch the app
 
1
  import gradio as gr
2
+ import jax
3
+ import jax.numpy as jnp
4
+ from flax.jax_utils import replicate
5
+ from transformers import DalleBartProcessor, FlaxDalleBartForConditionalGeneration
6
 
7
+ # Load the model and processor
8
+ processor = DalleBartProcessor.from_pretrained("dalle-mini/dalle-mega")
9
+ model = FlaxDalleBartForConditionalGeneration.from_pretrained("dalle-mini/dalle-mega")
10
 
11
+ # Function to generate an image from a text prompt
12
  def generate_image(prompt):
13
+ inputs = processor([prompt], return_tensors="jax", padding="max_length", truncation=True, max_length=128)
14
+ # Generate images
15
+ images = model.generate(**replicate(inputs.data), do_sample=True, num_beams=1, num_return_sequences=1)
16
+
17
+ # Post-process image for display (convert to PIL image format)
18
+ image = images[0] # assuming single image output
19
+ return image # return the generated image
20
 
21
  # Create Gradio interface
22
  iface = gr.Interface(
23
+ fn=generate_image, # Function to generate image
24
+ inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your prompt"), # Textbox input
25
+ outputs="image", # Output as an image
26
+ title="DALL-E Mini Image Generator",
27
+ description="Generate images from text prompts using DALL-E Mini model."
28
  )
29
 
30
  # Launch the app