thefreeham commited on
Commit
c4c34a2
·
1 Parent(s): 73203ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -71
app.py CHANGED
@@ -1,77 +1,7 @@
1
- import argparse
2
- import base64
3
- import os
4
- from pathlib import Path
5
- from io import BytesIO
6
- import time
7
-
8
- from flask import Flask, request, jsonify
9
- from flask_cors import CORS, cross_origin
10
- from consts import IMAGES_OUTPUT_DIR
11
- from utils import parse_arg_boolean, parse_arg_dalle_version
12
- from consts import ModelSize
13
-
14
-
15
  import gradio as gr
16
 
17
  def greet(name):
18
  return "Hello " + name + "!!"
19
 
20
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
21
- iface.launch()
22
-
23
-
24
- app = Flask(__name__)
25
- CORS(app)
26
- print("--> Starting DALL-E Server. This might take up to two minutes.")
27
-
28
- from dalle_model import DalleModel
29
- dalle_model = None
30
-
31
- parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights")
32
- parser.add_argument("--port", type=int, default=8000, help = "backend port")
33
- parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full")
34
- parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk")
35
- args = parser.parse_args()
36
-
37
- @app.route("/dalle", methods=["POST"])
38
- @cross_origin()
39
- def generate_images_api():
40
- json_data = request.get_json(force=True)
41
- text_prompt = json_data["text"]
42
- num_images = json_data["num_images"]
43
- generated_imgs = dalle_model.generate_images(text_prompt, num_images)
44
-
45
- generated_images = []
46
- if args.save_to_disk:
47
- dir_name = os.path.join(IMAGES_OUTPUT_DIR,f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}")
48
- Path(dir_name).mkdir(parents=True, exist_ok=True)
49
-
50
- for idx, img in enumerate(generated_imgs):
51
- if args.save_to_disk:
52
- img.save(os.path.join(dir_name, f'{idx}.jpeg'), format="JPEG")
53
-
54
- buffered = BytesIO()
55
- img.save(buffered, format="JPEG")
56
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
57
- generated_images.append(img_str)
58
-
59
- print(f"Created {num_images} images from text prompt [{text_prompt}]")
60
- return jsonify(generated_images)
61
-
62
-
63
- @app.route("/", methods=["GET"])
64
- @cross_origin()
65
- def health_check():
66
- return jsonify(success=True)
67
-
68
-
69
- with app.app_context():
70
- dalle_model = DalleModel(args.model_version)
71
- dalle_model.generate_images("warm-up", 1)
72
- print("--> DALL-E Server is up and running!")
73
- print(f"--> Model selected - DALL-E {args.model_version}")
74
-
75
-
76
- if __name__ == "__main__":
77
- app.run(host="0.0.0.0", port=args.port, debug=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
  def greet(name):
4
  return "Hello " + name + "!!"
5
 
6
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ iface.launch()