Divyanshu04 commited on
Commit
5200ca1
·
1 Parent(s): deefcce
Files changed (1) hide show
  1. Text2image-api.py +81 -0
Text2image-api.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, jsonify, request
2
+ from pathlib import Path
3
+ import sys
4
+ import torch
5
+ import os
6
+ from torch import autocast
7
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
8
+ import streamlit as st
9
+
10
+ # model_path = WEIGHTS_DIR # If you want to use previously trained model saved in gdrive, replace this with the full path of model in gdrive
11
+
12
+ # pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float32).to("cuda")
13
+ # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
14
+ # pipe.enable_xformers_memory_efficient_attention()
15
+ # g_cuda = None
16
+
17
+ FILE = Path(__file__).resolve()
18
+ ROOT = FILE.parents[0] # YOLOv5 root directory
19
+ if str(ROOT) not in sys.path:
20
+ sys.path.append(str(ROOT)) # add ROOT to PATH
21
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd()))
22
+
23
+ app = Flask(__name__)
24
+
25
+ # @app.route('/', methods = ['GET', 'POST'])
26
+ # def home():
27
+ # if(request.method == 'GET'):
28
+
29
+ # data = "Text2Image"
30
+ # return jsonify({'service': data})
31
+
32
+
33
+ # @app.route("/", methods=["POST"])
34
+ def generate():
35
+
36
+ # prompt = request.form['prompt']
37
+ # negative_prompt = request.form['Negative prompt']
38
+ # num_samples = request.form['No. of samples']
39
+
40
+ prompt = st.text_area(label = "prompt", key="pmpt")
41
+ negative_prompt = st.text_area(label = "Negative prompt", key="ng_pmpt")
42
+ num_samples = st.number_input("No. of samples")
43
+
44
+ res = st.button("Reset", type="primary")
45
+
46
+ if res:
47
+
48
+ guidance_scale = 7.5
49
+ num_inference_steps = 24
50
+ height = 512
51
+ width = 512
52
+
53
+ g_cuda = torch.Generator(device='cuda')
54
+ seed = 52362
55
+ g_cuda.manual_seed(seed)
56
+
57
+ # commandline_args = os.environ.get('COMMANDLINE_ARGS', "--skip-torch-cuda-test --no-half")
58
+
59
+ with autocast("cuda"), torch.inference_mode():
60
+ images = pipe(
61
+ prompt,
62
+ height=height,
63
+ width=width,
64
+ negative_prompt=negative_prompt,
65
+ num_images_per_prompt=num_samples,
66
+ num_inference_steps=num_inference_steps,
67
+ guidance_scale=guidance_scale,
68
+ generator=g_cuda
69
+ ).images
70
+
71
+ return {"message": "successful"}
72
+
73
+ else:
74
+ return {"message": "Running.."}
75
+
76
+
77
+
78
+
79
+ # driver function
80
+ if __name__ == '__main__':
81
+ generate()