File size: 2,217 Bytes
4de66a5
 
c80d1cd
 
3d254a3
 
 
 
4de66a5
 
 
 
 
0d7c941
 
 
 
 
 
 
 
e25f345
00aa8e0
5ff57a4
e25f345
cde57f0
43ab04d
601c77e
4564715
0d7c941
 
80f43e3
45378bd
 
 
 
43ab04d
45378bd
3d254a3
4c5fefe
3ae8213
3d254a3
2a90a2f
 
559a972
0d7c941
00aa8e0
10e80f6
2a90a2f
 
dcf7046
4de66a5
a91b851
00aa8e0
b332329
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import gradio as gr
import logging

LOG_FORMAT = '%(asctime)s - %(levelname)s  [%(name)s] %(funcName)s -> %(message)s'
logging.basicConfig(level = logging.DEBUG, format = LOG_FORMAT)
logger = logging.getLogger(__name__)
logger.debug('Logging started')

API_KEY=os.environ.get('HUGGING_FACE_HUB_TOKEN', None)
article = """---
This space was created using [SD Space Creator](https://huggingface.co/spaces/anzorq/sd-space-creator)."""

class Model:
    def __init__(self, name, path="", prefix=""):
        self.name = name
        self.path = path
        self.prefix = prefix
        self.pipe_t2i = None
        self.pipe_i2i = None
models = [
   Model("Marvel","models/ItsJayQz/Marvel_WhatIf_Diffusion", "whatif style"), 
   Model("Cyberpunk Anime Diffusion", "models/DGSpitzer/Cyberpunk-Anime-Diffusion", "dgs illustration style"),
   Model("Guan Yu Diffusion", "models/DGSpitzer/Guan-Yu-Diffusion", "Guan-Yu style"),
   Model("Portrait plus", "models/wavymulder/portraitplus", "portrait+ style"),
   Model("classic Disney", "models/nitrosocke/classic-anim-diffusion", "classic disney style"),
   Model("vintedois", "models/22h/vintedois-diffusion-v0-1", "vintedois style"),
   Model("dreamlike", "models/dreamlike-art/dreamlike-diffusion-1.0","dreamlike style"),
   Model("SD21","models/stabilityai/stable-diffusion-2-1", "sd21 default style")
]

custom_model = "models/dreamlike-art/dreamlike-diffusion-1.0"

def selectModel(message):
    message = message.lower()
    for i in range(len(models)):
        if message.find(models[i].prefix)!=-1:
            c_model=models[i].path
            logging.warning('model selected = '+c_model)
            return c_model
    c_model=models[i].path
    logging.warning('model selected = '+c_model)
    return c_model, prompt

inputs_prompt = (gr.inputs.Textbox(label="Model"), gr.inputs.Textbox(label="Prompt"))

sandbox = gr.Interface.load(
    fn= selectModel,
    inputs=inputs_prompt
    title=("""AlStable sandbox""","selectModel")
    description="""Demo for <a href="https://huggingface.co/stabilityai/stable-diffusion-2-1">AlStable</a> Stable Diffusion model.""",
    article=article,
    api_key=API_KEY
)

sandbox.queue(concurrency_count=20).launch()