johnsu6616's picture
Duplicate from hahahafofo/prompt_generator
7a666f0
import random
import re
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoProcessor
from transformers import pipeline, set_seed
device = "cuda" if torch.cuda.is_available() else "cpu"
big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator')
zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
def load_prompter():
prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
return prompter_model, tokenizer
prompter_model, prompter_tokenizer = load_prompter()
def generate_prompter(plain_text, max_new_tokens=75, num_beams=8, num_return_sequences=8, length_penalty=-1.0):
input_ids = prompter_tokenizer(plain_text.strip() + " Rephrase:", return_tensors="pt").input_ids
eos_id = prompter_tokenizer.eos_token_id
outputs = prompter_model.generate(
input_ids,
do_sample=False,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
eos_token_id=eos_id,
pad_token_id=eos_id,
length_penalty=length_penalty
)
output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
result = []
for output_text in output_texts:
result.append(output_text.replace(plain_text + " Rephrase:", "").strip())
return "\n".join(result)
def translate_zh2en(text):
with torch.no_grad():
encoded = zh2en_tokenizer([text], return_tensors='pt')
sequences = zh2en_model.generate(**encoded)
return zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
def translate_en2zh(text):
with torch.no_grad():
encoded = en2zh_tokenizer([text], return_tensors="pt")
sequences = en2zh_model.generate(**encoded)
return en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
def text_generate(text_in_english):
seed = random.randint(100, 1000000)
set_seed(seed)
result = ""
for _ in range(6):
sequences = text_pipe(text_in_english, max_length=random.randint(60, 90), num_return_sequences=8)
list = []
for sequence in sequences:
line = sequence['generated_text'].strip()
if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith(
(':', '-', '—')) is False:
list.append(line)
result = "\n".join(list)
result = re.sub('[^ ]+\.[^ ]+', '', result)
result = result.replace('<', '').replace('>', '')
if result != '':
break
return result, "\n".join(translate_en2zh(line) for line in result.split("\n") if len(line) > 0)
def get_prompt_from_image(input_image):
image = input_image.convert('RGB')
pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values
generated_ids = big_model.to(device).generate(pixel_values=pixel_values, max_length=50)
generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)
return generated_caption
with gr.Blocks() as block:
with gr.Column():
with gr.Tab('文本生成'):
with gr.Row():
input_text = gr.Textbox(lines=6, label='你的想法', placeholder='在此输入内容...')
translate_output = gr.Textbox(lines=6, label='翻译结果(Prompt输入)')
with gr.Accordion('SD优化参数设置', open=False):
max_new_tokens = gr.Slider(1, 255, 75, label='max_new_tokens', step=1)
nub_beams = gr.Slider(1, 30, 8, label='num_beams', step=1)
num_return_sequences = gr.Slider(1, 30, 8, label='num_return_sequences', step=1)
length_penalty = gr.Slider(-1.0, 1.0, -1.0, label='length_penalty')
generate_prompter_output = gr.Textbox(lines=6, label='SD优化的 Prompt')
output = gr.Textbox(lines=6, label='瞎编的 Prompt')
output_zh = gr.Textbox(lines=6, label='瞎编的 Prompt(zh)')
with gr.Row():
translate_btn = gr.Button('翻译')
generate_prompter_btn = gr.Button('SD优化')
gpt_btn = gr.Button('瞎编')
with gr.Tab('从图片中生成'):
with gr.Row():
input_image = gr.Image(type='pil')
img_btn = gr.Button('提交')
output_image = gr.Textbox(lines=6, label='生成的 Prompt')
translate_btn.click(
fn=translate_zh2en,
inputs=input_text,
outputs=translate_output
)
generate_prompter_btn.click(
fn=generate_prompter,
inputs=[translate_output, max_new_tokens, nub_beams, num_return_sequences, length_penalty],
outputs=generate_prompter_output
)
gpt_btn.click(
fn=text_generate,
inputs=translate_output,
outputs=[output, output_zh]
)
img_btn.click(
fn=get_prompt_from_image,
inputs=input_image,
outputs=output_image
)
block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=False, server_name='0.0.0.0')