File size: 1,662 Bytes
cf8bafe
4115131
 
65ad1db
 
c9e21fd
a97c5ef
3faa194
 
 
 
 
cf8bafe
65ad1db
 
cf8bafe
65ad1db
 
 
 
 
cb460c0
 
 
 
c249291
 
3faa194
c249291
 
3faa194
384d32a
c249291
36c015f
c249291
3faa194
b9e2ccf
3faa194
c249291
 
b2e5b52
c249291
 
b2e5b52
c249291
384d32a
 
3faa194
c249291
 
36c015f
c249291
384d32a
36c015f
3faa194
 
 
cb460c0
65ad1db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
from gradio.inputs import Textbox

import torch
from diffusers import StableDiffusionPipeline
import boto3
from io import BytesIO
import os

AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
S3_BUCKET_NAME = os.getenv("BUCKET_NAME")

model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = StableDiffusionPipeline.from_pretrained(
    model_id, torch_dtype=torch.float32)

pipe = pipe.to(device)

def text_to_image(prompt, save_as, key_id):

    if AWS_ACCESS_KEY_ID != key_id:
        return "not permition"

    # Create an instance of the S3 client
    s3 = boto3.client('s3',
                      aws_access_key_id=AWS_ACCESS_KEY_ID,
                      aws_secret_access_key=AWS_SECRET_ACCESS_KEY)

    image_name = '-'.join(save_as.split()) + ".webp"

    def save_image_to_s3(image):
        # Create a BytesIO object to store the image.
        image_buffer = BytesIO()
        image.save(image_buffer, format='WEBP')
        image_buffer.seek(0)

        # Full path of the file in the bucket
        s3_key = "public/" + image_name

        # Upload the image to the S3 bucket
        s3.upload_fileobj(image_buffer, S3_BUCKET_NAME, s3_key)

    def generator_image(prompt):
        prompt = prompt
        image = pipe(prompt).images[0]

        # Save the image in S3
        save_image_to_s3(image)

    generator_image(prompt)
    return image_name



iface = gr.Interface(fn=text_to_image, inputs=[Textbox(label="prompt"), Textbox(label="s3_save_as"), Textbox(label="aws_key_id")], outputs="text")
iface.launch()