My-AI-Projects commited on
Commit
5e69d3c
Β·
verified Β·
1 Parent(s): 68cc003

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -39
app.py CHANGED
@@ -1,46 +1,30 @@
1
- import gradio as gr
2
  import torch
3
- from transformers import DalleBartTokenizer, DalleBartForConditionalGeneration
4
- from PIL import Image
5
- import io
6
 
7
- # Load model and tokenizer
8
- model_id = "dalle-mini/dalle-mini" # Example model id; adjust if needed
9
- model = DalleBartForConditionalGeneration.from_pretrained(model_id)
10
- tokenizer = DalleBartTokenizer.from_pretrained(model_id)
11
 
12
- # Function to generate image
13
- def generate_image(prompt, num_inference_steps=50):
14
  inputs = tokenizer(prompt, return_tensors="pt")
15
 
16
- # Generate images
17
  with torch.no_grad():
18
- outputs = model.generate(**inputs, num_beams=num_inference_steps)
19
-
20
- # Convert tensor to PIL image
21
- image = Image.fromarray(outputs[0].cpu().numpy().astype('uint8'))
22
 
23
- return image
24
-
25
- # Define the Gradio interface
26
- with gr.Blocks() as demo:
27
- gr.Markdown("# Text to Image Generation")
28
-
29
- with gr.Row():
30
- prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt here...")
31
- num_inference_steps = gr.Slider(minimum=1, maximum=50, step=1, value=28, label="Number of Inference Steps")
32
-
33
- with gr.Row():
34
- generate_button = gr.Button("Generate Image")
35
-
36
- result = gr.Image(label="Generated Image")
37
-
38
- # Connect the function to the button
39
- generate_button.click(
40
- fn=generate_image,
41
- inputs=[prompt, num_inference_steps],
42
- outputs=result
43
- )
44
-
45
- # Launch the Gradio app
46
- demo.launch()
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
+ import gradio as gr
 
 
4
 
5
+ # Load the model and tokenizer
6
+ tokenizer = AutoTokenizer.from_pretrained("dalle-mini/dalle-mega")
7
+ model = AutoModelForCausalLM.from_pretrained("dalle-mini/dalle-mega")
 
8
 
9
+ # Define the function for Gradio interface
10
+ def generate_image(prompt):
11
  inputs = tokenizer(prompt, return_tensors="pt")
12
 
13
+ # Generate image (or output) using the model
14
  with torch.no_grad():
15
+ outputs = model.generate(**inputs)
 
 
 
16
 
17
+ # Convert output to a format suitable for Gradio
18
+ # This part may need to be adapted based on actual output format
19
+ return outputs
20
+
21
+ # Set up Gradio interface
22
+ iface = gr.Interface(
23
+ fn=generate_image,
24
+ inputs=gr.Textbox(label="Enter prompt"),
25
+ outputs=gr.Image(type="pil", label="Generated Image"),
26
+ live=True
27
+ )
28
+
29
+ # Launch the app
30
+ iface.launch()