File size: 2,109 Bytes
4de66a5
 
c80d1cd
 
3d254a3
 
 
 
4de66a5
 
 
 
 
0d7c941
 
 
 
 
 
 
 
 
e25f345
00aa8e0
5ff57a4
e25f345
cde57f0
4564715
0d7c941
 
4c5fefe
45378bd
 
 
 
 
 
3d254a3
4c5fefe
3ae8213
3d254a3
45378bd
0d7c941
00aa8e0
10e80f6
f25ca48
060702a
4c5fefe
dcf7046
4de66a5
a91b851
00aa8e0
3d254a3
c907be6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import gradio as gr
import logging

LOG_FORMAT = '%(asctime)s - %(levelname)s  [%(name)s] %(funcName)s -> %(message)s'
logging.basicConfig(level = logging.DEBUG, format = LOG_FORMAT)
logger = logging.getLogger(__name__)
logger.debug('Logging started')

API_KEY=os.environ.get('HUGGING_FACE_HUB_TOKEN', None)
article = """---
This space was created using [SD Space Creator](https://huggingface.co/spaces/anzorq/sd-space-creator)."""

class Model:
    def __init__(self, name, path="", prefix=""):
        self.name = name
        self.path = path
        self.prefix = prefix
        self.pipe_t2i = None
        self.pipe_i2i = None

models = [
   Model("Marvel","models/ItsJayQz/Marvel_WhatIf_Diffusion", "whatif style"), 
   Model("Cyberpunk Anime Diffusion", "models/DGSpitzer/Cyberpunk-Anime-Diffusion", "dgs illustration style"),
   Model("Guan Yu Diffusion", "models/DGSpitzer/Guan-Yu-Diffusion", "Guan-Yu style"),
   Model("Portrait plus", "models/wavymulder/portraitplus", "portrait+ style"),
   Model("classic Disney", "models/nitrosocke/classic-anim-diffusion", "classic disney style"),
   Model("SD21","models/stabilityai/stable-diffusion-2-1", "sd21 default style")
]

custom_model = "models/stabilityai/stable-diffusion-2-1"

def selectModel(message):
    message = message.lower()
    for i in range(len(models)):
        if message.find(models[i].prefix)!=-1:
            c_model=models[i].path
            logging.warning('model selected = '+c_model)
            return c_model
    c_model=models[i].path
    logging.warning('model selected = '+c_model)
    return c_model

sandbox = gr.Interface.load(
    fn= selectModel,
    name= custom_model, 
    title="""AlStable sandbox""",
    inputs = gr.Textbox(label="Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt", elem_id="input-prompt"), 
    description="""Demo for <a href="https://huggingface.co/stabilityai/stable-diffusion-2-1">AlStable</a> Stable Diffusion model.""",
    article=article,
    api_key=API_KEY
)
logging.warning('model chosen = '+custom_model)
sandbox.queue(concurrency_count=20).launch()