djrana commited on
Commit
03a8354
·
verified ·
1 Parent(s): 9647f84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -1,28 +1,31 @@
1
  import gradio as gr
2
  from transformers import pipeline
 
 
3
 
4
- pipe = pipeline('text-generation', model_id='Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator')
5
-
6
- def get_valid_prompt(text: str) -> str:
7
- dot_split = text.split('.')[0]
8
- n_split = text.split('\n')[0]
9
 
10
- return {
11
- len(dot_split) < len(n_split): dot_split,
12
- len(n_split) > len(dot_split): n_split,
13
- len(n_split) == len(dot_split): dot_split
14
- }[True]
15
 
16
- def generate_prompt(prompt):
17
- valid_prompt = get_valid_prompt(pipe(prompt, max_length=77)[0]['generated_text'])
18
- return valid_prompt
19
 
20
  iface = gr.Interface(
21
- fn=generate_prompt,
22
- inputs="text",
23
- outputs="text",
24
- title="Prompt Generator",
25
- description="Enter a prompt and get the valid prompt generated by the script."
 
 
 
 
 
 
 
26
  )
27
 
28
- iface.launch()
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ import os
4
+ os.system("pip install -r requirements.txt")
5
 
6
+ pipe = pipeline(
7
+ "text-generation",
8
+ model="Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator",
9
+ tokenizer="gpt2"
10
+ )
11
 
 
 
 
 
 
12
 
13
+ def generate_text(prompt):
14
+ return pipe(prompt, max_length=77)[0]["generated_text"]
 
15
 
16
  iface = gr.Interface(
17
+ fn=generate_text,
18
+
19
+ #input is a text box
20
+ inputs=gr.Textbox(lines=5, label="Prompt"),
21
+
22
+
23
+ # output is a text box with copy button
24
+ outputs=gr.Textbox(label="Output", show_copy_button=True),
25
+
26
+ title="GPT-2 650k Stable Diffusion Prompt Generator",
27
+ description="GPT-2 650k Stable Diffusion Prompt Generator",
28
+ api_name="predict"
29
  )
30
 
31
+ iface.launch(show_api=True)