File size: 2,658 Bytes
9b6e78f
 
 
 
 
 
 
 
 
 
 
 
 
 
c82c2e4
 
 
68040ae
bfe9080
68040ae
5e22df1
 
 
c82c2e4
68040ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b6e78f
68040ae
 
 
 
 
 
9b6e78f
68040ae
 
 
 
9b6e78f
68040ae
 
9b6e78f
 
68040ae
 
 
 
9b6e78f
 
68040ae
 
 
 
 
9b6e78f
bfe9080
9b6e78f
68040ae
 
 
bfe9080
68040ae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import base64
import os
from pathlib import Path
from io import BytesIO
import time

from flask import Flask, request, jsonify
from flask_cors import CORS, cross_origin
from consts import IMAGES_OUTPUT_DIR
from utils import parse_arg_boolean, parse_arg_dalle_version
from consts import ModelSize


import gradio as gr

def greet(name):
   


    app = Flask(__name__)
    CORS(app)
    print("--> Starting DALL-E Server. This might take up to two minutes.")

    from dalle_model import DalleModel
    dalle_model = None

    parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights")
    parser.add_argument("--port", type=int, default=8000, help = "backend port")
    parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full")
    parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk")
    args = parser.parse_args()

    @app.route("/dalle", methods=["POST"])
    @cross_origin()
    def generate_images_api():
        json_data = request.get_json(force=True)
        text_prompt = json_data["text"]
        num_images = json_data["num_images"]
        generated_imgs = dalle_model.generate_images(text_prompt, num_images)

        generated_images = []
        if args.save_to_disk: 
            dir_name = os.path.join(IMAGES_OUTPUT_DIR,f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}")
            Path(dir_name).mkdir(parents=True, exist_ok=True)
    
        for idx, img in enumerate(generated_imgs):
            if args.save_to_disk: 
                img.save(os.path.join(dir_name, f'{idx}.jpeg'), format="JPEG")

            buffered = BytesIO()
            img.save(buffered, format="JPEG")
            img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
            generated_images.append(img_str)

        print(f"Created {num_images} images from text prompt [{text_prompt}]")
        return jsonify(generated_images)


    @app.route("/", methods=["GET"])
    @cross_origin()
    def health_check():
        return jsonify(success=True)


    with app.app_context():
        dalle_model = DalleModel(args.model_version)
        dalle_model.generate_images("warm-up", 1)
        print("--> DALL-E Server is up and running!")
        print(f"--> Model selected - DALL-E {args.model_version}")

    

    if __name__ == "__main__":
        app.run(host="0.0.0.0", port=args.port, debug=False)
    
    return "Hello " + name + "!!"
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
iface.launch()