from flask import Flask, jsonify, request from pathlib import Path import sys import torch import os from torch import autocast from diffusers import StableDiffusionPipeline, DDIMScheduler, DiffusionPipeline import streamlit as st from huggingface_hub import login login() # model_path = WEIGHTS_DIR # If you want to use previously trained model saved in gdrive, replace this with the full path of model in gdrive # headers = {"Authorization": "Bearer xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"} # pipe = StableDiffusionPipeline.from_pretrained("Divyanshu04/Finetuned-sd-vae", safety_checker=None, torch_dtype=torch.float32).to("cuda") pipe = DiffusionPipeline.from_pretrained("Divyanshu04/Finetuned-sd-vae", torch_dtype=torch.float32, use_auth_token = "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") pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.enable_xformers_memory_efficient_attention() g_cuda = None FILE = Path(__file__).resolve() ROOT = FILE.parents[0] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) app = Flask(__name__) # @app.route("/", methods=["POST"]) def generate(): # prompt = request.form['prompt'] # negative_prompt = request.form['Negative prompt'] # num_samples = request.form['No. of samples'] prompt = st.text_area(label = "prompt", key="pmpt") negative_prompt = st.text_area(label = "Negative prompt", key="ng_pmpt") num_samples = st.number_input("No. of samples") res = st.button("Generate", type="primary") if res: guidance_scale = 7.5 num_inference_steps = 24 height = 512 width = 512 g_cuda = torch.Generator(device='cuda') seed = 52362 g_cuda.manual_seed(seed) # commandline_args = os.environ.get('COMMANDLINE_ARGS', "--skip-torch-cuda-test --no-half") with autocast("cuda"), torch.inference_mode(): images = pipe( prompt, height=height, width=width, negative_prompt=negative_prompt, num_images_per_prompt=num_samples, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=g_cuda ).images return {"message": "successful"} else: st.write('') # driver function if __name__ == '__main__': generate()