File size: 3,735 Bytes
78e2fb5
 
f82013f
 
78e2fb5
f82013f
 
 
 
 
b04693a
7890ebe
 
 
f82013f
 
 
 
 
7890ebe
f82013f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78e2fb5
 
 
 
 
 
 
 
28d6e58
 
f82013f
ea20bfc
78e2fb5
 
 
 
 
 
 
 
28d6e58
 
 
 
ea20bfc
28d6e58
 
 
 
 
 
 
 
 
78e2fb5
 
 
 
 
 
28d6e58
90ec86f
78e2fb5
 
f82013f
78e2fb5
 
 
 
 
 
 
f82013f
aba8522
 
 
 
 
 
12c1229
78e2fb5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
from share_btn import community_icon_html, loading_icon_html, share_js
import random
import re

import torch
from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline, set_seed

import gradio as grad
from diffusers import StableDiffusionPipeline

tokenizer = AutoTokenizer.from_pretrained("shahp7575/gpt2-horoscopes")
model = AutoModelWithLMHead.from_pretrained("shahp7575/gpt2-horoscopes")

def fn(sign, cat):
    sign = "scorpio"

    prompt = f"<|category|> {cat} <|horoscope|> {sign}"

    

    prompt_encoded = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)

    sample_outputs = model.generate(
        prompt_encoded,
        do_sample=True,
        top_k=40,
        max_length=300,
        top_p=0.95,
        temperature=0.95,
        num_beams=4,
        num_return_sequences=4,
    )

    final_out = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
    starting_text = " ".join(final_out.split(" ")[4:])
    pipe = pipeline("text-generation", model="Gustavosta/MagicPrompt-Stable-Diffusion", tokenizer="gpt2")

    seed = random.randint(100, 1000000)
    set_seed(seed)

    response = pipe(starting_text, max_length=(len(starting_text) + random.randint(60, 90)), num_return_sequences=1)
    pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
    image = pipe(response[0]["generated_text"], num_inference_steps=5).images[0]
    return [image, starting_text]


block = gr.Blocks(css="./css.css")

with block:
    with gr.Group():
        with gr.Box():
            with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
                text = gr.Dropdown(
                    label="Star Sign",
                    choices=["aries", "taurus","gemini", "cancer", "leo", "virgo", "libra", "scorpio", "sagittarius", "capricorn", "aquarius", "Pisces"],
                    show_label=True,
                    max_lines=1,
                    placeholder="Enter your prompt",
                    elem_id="prompt-text-input",
                ).style(
                    border=(True, False, True, True),
                    rounded=(True, False, False, True),
                    container=False,
                )

                text2 = gr.Dropdown(
                    choices=["love", "career", "wellness"],
                    label="Category",
                    show_label=True,
                    max_lines=1,
                    placeholder="Enter your prompt",
                    elem_id="prompt-text-input",
                ).style(
                    border=(True, True, True, True),
                    rounded=(True, False, False, True),
                    container=False,
                )

                btn = gr.Button("Generate image").style(
                    margin=False,
                    rounded=(False, True, True, False),
                    full_width=False,
                )

        gallery = gr.Image(
            interactive=False,
            label="Generated images", show_label=False, elem_id="gallery"
        ).style(grid=[2], height="auto")
        text = gr.Textbox("Text")

        with gr.Group(elem_id="container-advanced-btns"):
            with gr.Group(elem_id="share-btn-container"):
                community_icon = gr.HTML(community_icon_html)
                loading_icon = gr.HTML(loading_icon_html)
                share_button = gr.Button("Share to community", elem_id="share-btn")

        btn.click(fn=fn, inputs=[text, text2], outputs=[gallery, text])
        share_button.click(
            None,
            [],
            [],
            _js=share_js,
        )
        

block.queue(concurrency_count=40, max_size=20).launch(max_threads=150)