Spaces:
Running
Running
import os | |
import re | |
import sys | |
sys.path.insert(0, '.') | |
sys.path.insert(0, '..') | |
import argparse | |
import gradio as gr | |
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), 'tmp') | |
import copy | |
import time | |
import shutil | |
import requests | |
from PIL import Image, ImageFile | |
import torch | |
import transformers | |
from transformers import StoppingCriteriaList, AutoTokenizer, AutoModel | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
from demo_asset.assets.css_html_js import custom_css | |
from demo_asset.gradio_patch import Chatbot as grChatbot | |
from demo_asset.serve_utils import Stream, Iteratorize | |
from demo_asset.conversation import CONV_VISION_7132_v2, StoppingCriteriaSub | |
from demo_asset.download import download_image_thread | |
max_section = 60 | |
no_change_btn = gr.Button.update() | |
disable_btn = gr.Button.update(interactive=False) | |
enable_btn = gr.Button.update(interactive=True) | |
chat_stream_output = True | |
article_stream_output = True | |
def get_urls(caption, exclude): | |
headers = {'Content-Type': 'application/json'} | |
json_data = {'caption': caption, 'exclude': exclude, 'need_idxs': True} | |
response = requests.post('https://lingbi.openxlab.org.cn/image/similar', | |
headers=headers, | |
json=json_data) | |
urls = response.json()['data']['image_urls'] | |
idx = response.json()['data']['indices'] | |
return urls, idx | |
class Demo_UI: | |
def __init__(self, folder): | |
self.llm_model = AutoModel.from_pretrained(folder, trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained(folder, trust_remote_code=True) | |
self.llm_model.internlm_tokenizer = tokenizer | |
self.llm_model.tokenizer = tokenizer | |
self.llm_model.eval().to('cuda') | |
self.device = 'cuda' | |
print(f" load model done: ", type(self.llm_model)) | |
self.eoh = self.llm_model.internlm_tokenizer.decode( | |
torch.Tensor([103027]), skip_special_tokens=True) | |
self.eoa = self.llm_model.internlm_tokenizer.decode( | |
torch.Tensor([103028]), skip_special_tokens=True) | |
self.soi_id = len(tokenizer) - 1 | |
self.soi_token = '<SOI_TOKEN>' | |
self.vis_processor = self.llm_model.vis_processor | |
self.device = 'cuda' | |
stop_words_ids = [ | |
torch.tensor([943]).to(self.device), | |
torch.tensor([2917, 44930]).to(self.device), | |
torch.tensor([45623]).to(self.device), ### new setting | |
torch.tensor([46323]).to(self.device), ### new setting | |
torch.tensor([103027]).to(self.device), ### new setting | |
torch.tensor([103028]).to(self.device), ### new setting | |
] | |
self.stopping_criteria = StoppingCriteriaList( | |
[StoppingCriteriaSub(stops=stop_words_ids)]) | |
self.r2 = re.compile(r'<Seg[0-9]*>') | |
self.max_txt_len = 1680 | |
def reset(self): | |
self.output_text = '' | |
self.caps = {} | |
self.show_caps = False | |
self.show_ids = {} | |
def get_images_xlab(self, caption, loc, exclude): | |
urls, idxs = get_urls(caption.strip()[:53], exclude) | |
print(urls[0]) | |
print('download image with url') | |
download_image_thread(urls, | |
folder='articles/' + self.title, | |
index=self.show_ids[loc] * 1000 + loc, | |
num_processes=4) | |
print('image downloaded') | |
return idxs | |
def generate(self, text, random, beam, max_length, repetition): | |
input_tokens = self.llm_model.internlm_tokenizer( | |
text, return_tensors="pt", | |
add_special_tokens=True).to(self.llm_model.device) | |
img_embeds = self.llm_model.internlm_model.model.embed_tokens( | |
input_tokens.input_ids) | |
with torch.no_grad(): | |
with self.llm_model.maybe_autocast(): | |
outputs = self.llm_model.internlm_model.generate( | |
inputs_embeds=img_embeds, | |
stopping_criteria=self.stopping_criteria, | |
do_sample=random, | |
num_beams=beam, | |
max_length=max_length, | |
repetition_penalty=float(repetition), | |
) | |
output_text = self.llm_model.internlm_tokenizer.decode( | |
outputs[0][1:], add_special_tokens=False) | |
output_text = output_text.split('<TOKENS_UNUSED_1>')[0] | |
return output_text | |
def generate_text(self, title, beam, repetition, text_num, random): | |
text = ' <|User|>:根据给定标题写一个图文并茂,不重复的文章:{}\n'.format( | |
title) + self.eoh + ' <|Bot|>:' | |
print('random generate:{}'.format(random)) | |
output_text = self.generate(text, random, beam, text_num, repetition) | |
return output_text | |
def generate_loc(self, text_sections, image_num, progress): | |
full_txt = ''.join(text_sections) | |
input_text = f' <|User|>:给定文章"{full_txt}" 根据上述文章,选择适合插入图像的{image_num}行' + ' \n<TOKENS_UNUSED_0> <|Bot|>:适合插入图像的行是' | |
for _ in progress.tqdm([1], desc="image spotting"): | |
output_text = self.generate(input_text, | |
random=False, | |
beam=5, | |
max_length=300, | |
repetition=1.) | |
inject_text = '适合插入图像的行是' + output_text | |
print(inject_text) | |
locs = [] | |
for m in self.r2.findall(inject_text): | |
locs.append(int(m[4:-1])) | |
print(locs) | |
return inject_text, locs | |
def generate_cap(self, text_sections, pos, progress): | |
pasts = '' | |
caps = {} | |
for idx, po in progress.tqdm(enumerate(pos), desc="image captioning"): | |
full_txt = ''.join(text_sections[:po + 2]) | |
if idx > 0: | |
past = pasts[:-2] + '。' | |
else: | |
past = pasts | |
input_text = f' <|User|>: 给定文章"{full_txt}" {past}给出适合在<Seg{po}>后插入的图像对应的标题。' + ' \n<TOKENS_UNUSED_0> <|Bot|>: 标题是"' | |
cap_text = self.generate(input_text, | |
random=False, | |
beam=1, | |
max_length=100, | |
repetition=5.) | |
cap_text = cap_text.split('"')[0].strip() | |
print(cap_text) | |
caps[po] = cap_text | |
if idx == 0: | |
pasts = f'现在<Seg{po}>后插入图像对应的标题是"{cap_text}", ' | |
else: | |
pasts += f'<Seg{po}>后插入图像对应的标题是"{cap_text}", ' | |
print(caps) | |
return caps | |
def generate_loc_cap(self, text_sections, image_num, progress): | |
inject_text, locs = self.generate_loc(text_sections, image_num, | |
progress) | |
caps = self.generate_cap(text_sections, locs, progress) | |
return caps | |
def interleav_wrap(self, img_embeds, text): | |
batch_size = img_embeds.shape[0] | |
im_len = img_embeds.shape[1] | |
text = text[0] | |
text = text.replace('<Img>', '') | |
text = text.replace('</Img>', '') | |
parts = text.split('<ImageHere>') | |
assert batch_size + 1 == len(parts) | |
warp_tokens = [] | |
warp_embeds = [] | |
warp_attns = [] | |
soi = (torch.ones([1, 1]) * self.soi_id).long().to(img_embeds.device) | |
soi_embeds = self.llm_model.internlm_model.model.embed_tokens(soi) | |
temp_len = 0 | |
for idx, part in enumerate(parts): | |
if len(part) > 0: | |
part_tokens = self.llm_model.internlm_tokenizer( | |
part, return_tensors="pt", | |
add_special_tokens=False).to(img_embeds.device) | |
part_embeds = self.llm_model.internlm_model.model.embed_tokens( | |
part_tokens.input_ids) | |
warp_tokens.append(part_tokens.input_ids) | |
warp_embeds.append(part_embeds) | |
temp_len += part_embeds.shape[1] | |
if idx < batch_size: | |
warp_tokens.append(soi.expand(-1, img_embeds[idx].shape[0])) | |
# warp_tokens.append(soi.expand(-1, img_embeds[idx].shape[0] + 1)) | |
# warp_embeds.append(soi_embeds) ### 1, 1, C | |
warp_embeds.append(img_embeds[idx].unsqueeze(0)) ### 1, 34, C | |
temp_len += im_len | |
if temp_len > self.max_txt_len: | |
break | |
warp_embeds = torch.cat(warp_embeds, dim=1) | |
return warp_embeds[:, :self.max_txt_len].to(img_embeds.device) | |
def align_text(self, samples): | |
text_new = [] | |
text = [t + self.eoa + ' </s>' for t in samples["text_input"]] | |
for i in range(len(text)): | |
temp = text[i] | |
temp = temp.replace('###Human', '<|User|>') | |
temp = temp.replace('### Human', '<|User|>') | |
temp = temp.replace('<|User|> :', '<|User|>:') | |
temp = temp.replace('<|User|>: ', '<|User|>:') | |
temp = temp.replace('<|User|>', ' <|User|>') | |
temp = temp.replace('###Assistant', '<|Bot|>') | |
temp = temp.replace('### Assistant', '<|Bot|>') | |
temp = temp.replace('<|Bot|> :', '<|Bot|>:') | |
temp = temp.replace('<|Bot|>: ', '<|Bot|>:') | |
temp = temp.replace('<|Bot|>', self.eoh + ' <|Bot|>') | |
if temp.find('<|User|>') > temp.find('<|Bot|>'): | |
temp = temp.replace(' <|User|>', self.eoa + ' <|User|>') | |
text_new.append(temp) | |
#print (temp) | |
return text_new | |
def model_select_image(self, output_text, caps, root, progress): | |
print('model_select_image') | |
pre_text = '' | |
pre_img = [] | |
pre_text_list = [] | |
ans2idx = {'A': 0, 'B': 1, 'C': 2, 'D': 3} | |
selected = {k: 0 for k in caps.keys()} | |
for i, text in enumerate(output_text.split('\n')): | |
pre_text += text + '\n' | |
if i in caps: | |
images = copy.deepcopy(pre_img) | |
for j in range(4): | |
image = Image.open( | |
os.path.join( | |
root, f'temp_{self.show_ids[i] * 1000 + i}_{j}.png' | |
)).convert("RGB") | |
image = self.vis_processor(image) | |
images.append(image) | |
images = torch.stack(images, dim=0) | |
pre_text_list.append(pre_text) | |
pre_text = '' | |
images = images.cuda() | |
instruct = ' <|User|>:根据给定上下文和候选图像,选择合适的配图:' | |
input_text = '<ImageHere>'.join( | |
pre_text_list | |
) + '\n\n候选图像包括: A.<ImageHere>\nB.<ImageHere>\nC.<ImageHere>\nD.<ImageHere>\n\n<TOKENS_UNUSED_0> <|Bot|>:最合适的图是' | |
input_text = instruct + input_text | |
samples = {} | |
samples['text_input'] = [input_text] | |
self.llm_model.debug_flag = 0 | |
with torch.no_grad(): | |
with torch.cuda.amp.autocast(): | |
img_embeds = self.llm_model.encode_img(images) | |
input_text = self.align_text(samples) | |
img_embeds = self.interleav_wrap( | |
img_embeds, input_text) | |
bos = torch.ones( | |
[1, 1]) * self.llm_model.internlm_tokenizer.bos_token_id | |
bos = bos.long().to(images.device) | |
meta_embeds = self.llm_model.internlm_model.model.embed_tokens( | |
bos) | |
inputs_embeds = torch.cat([meta_embeds, img_embeds], dim=1) | |
with torch.cuda.amp.autocast(): | |
outputs = self.llm_model.internlm_model.generate( | |
inputs_embeds=inputs_embeds[:, :-2], | |
do_sample=False, | |
num_beams=5, | |
max_length=10, | |
repetition_penalty=1., | |
) | |
out_text = self.llm_model.internlm_tokenizer.decode( | |
outputs[0][1:], add_special_tokens=False) | |
try: | |
answer = out_text[1] if out_text[0] == ' ' else out_text[0] | |
pre_img.append(images[len(pre_img) + ans2idx[answer]].cpu()) | |
except: | |
print('Select fail, use first image') | |
answer = 'A' | |
pre_img.append(images[len(pre_img) + ans2idx[answer]].cpu()) | |
selected[i] = ans2idx[answer] | |
return selected | |
def show_md(self, text_sections, title, caps, selected, show_cap=False): | |
md_shows = [] | |
ga_shows = [] | |
btn_shows = [] | |
cap_textboxs, cap_searchs = [], [] | |
editers = [] | |
for i in range(len(text_sections)): | |
if i in caps: | |
if show_cap: | |
md = text_sections[ | |
i] + '\n' + '<div align="center"> <img src="file/articles/{}/temp_{}_{}.png" width = 500/> {} </div>'.format( | |
title, self.show_ids[i] * 1000 + i, selected[i], | |
caps[i]) | |
else: | |
md = text_sections[ | |
i] + '\n' + '<div align="center"> <img src="file=articles/{}/temp_{}_{}.png" width = 500/> </div>'.format( | |
title, self.show_ids[i] * 1000 + i, selected[i]) | |
img_list = [('articles/{}/temp_{}_{}.png'.format( | |
title, self.show_ids[i] * 1000 + i, | |
j), 'articles/{}/temp_{}_{}.png'.format( | |
title, self.show_ids[i] * 1000 + i, j)) | |
for j in range(4)] | |
ga_show = gr.Gallery.update(visible=True, value=img_list) | |
ga_shows.append(ga_show) | |
btn_show = gr.Button.update(visible=True, | |
value='\U0001f5d1\uFE0F') | |
cap_textboxs.append( | |
gr.Textbox.update(visible=True, value=caps[i])) | |
cap_searchs.append(gr.Button.update(visible=True)) | |
else: | |
md = text_sections[i] | |
ga_show = gr.Gallery.update(visible=False, value=[]) | |
ga_shows.append(ga_show) | |
btn_show = gr.Button.update(visible=True, value='\u2795') | |
cap_textboxs.append(gr.Textbox.update(visible=False)) | |
cap_searchs.append(gr.Button.update(visible=False)) | |
md_show = gr.Markdown.update(visible=True, value=md) | |
md_shows.append(md_show) | |
btn_shows.append(btn_show) | |
editers.append(gr.update(visible=True)) | |
print(i, md) | |
md_hides = [] | |
ga_hides = [] | |
btn_hides = [] | |
for i in range(max_section - len(text_sections)): | |
md_hide = gr.Markdown.update(visible=False, value='') | |
md_hides.append(md_hide) | |
btn_hide = gr.Button.update(visible=False) | |
btn_hides.append(btn_hide) | |
editers.append(gr.update(visible=False)) | |
for i in range(max_section - len(ga_shows)): | |
ga_hide = gr.Gallery.update(visible=False, value=[]) | |
ga_hides.append(ga_hide) | |
cap_textboxs.append(gr.Textbox.update(visible=False)) | |
cap_searchs.append(gr.Button.update(visible=False)) | |
return md_shows + md_hides + ga_shows + ga_hides + btn_shows + btn_hides + cap_textboxs + cap_searchs + editers, md_shows | |
def generate_article(self, | |
title, | |
beam, | |
repetition, | |
text_num, | |
msi, | |
random, | |
progress=gr.Progress()): | |
self.reset() | |
self.title = title | |
if article_stream_output: | |
text = ' <|User|>:根据给定标题写一个图文并茂,不重复的文章:{}\n'.format( | |
title) + self.eoh + ' <|Bot|>:' | |
input_tokens = self.llm_model.internlm_tokenizer( | |
text, return_tensors="pt", | |
add_special_tokens=True).to(self.llm_model.device) | |
img_embeds = self.llm_model.internlm_model.model.embed_tokens( | |
input_tokens.input_ids) | |
generate_params = dict( | |
inputs_embeds=img_embeds, | |
num_beams=beam, | |
do_sample=random, | |
stopping_criteria=self.stopping_criteria, | |
repetition_penalty=float(repetition), | |
max_length=text_num, | |
bos_token_id=self.llm_model.internlm_tokenizer.bos_token_id, | |
eos_token_id=self.llm_model.internlm_tokenizer.eos_token_id, | |
pad_token_id=self.llm_model.internlm_tokenizer.pad_token_id, | |
) | |
output_text = "▌" | |
with self.generate_with_streaming(**generate_params) as generator: | |
for output in generator: | |
decoded_output = self.llm_model.internlm_tokenizer.decode( | |
output[1:]) | |
if output[-1] in [ | |
self.llm_model.internlm_tokenizer.eos_token_id | |
]: | |
break | |
output_text = decoded_output.replace('\n', '\n\n') + "▌" | |
yield (output_text,) + (gr.Markdown.update(visible=False),) * (max_section - 1) + (gr.Gallery.update(visible=False),) * max_section + \ | |
(gr.Button.update(visible=False),) * max_section + (gr.Textbox.update(visible=False),) * max_section + (gr.Button.update(visible=False),) * max_section + \ | |
(gr.update(visible=False),) * max_section + (disable_btn,) * 2 | |
time.sleep(0.03) | |
output_text = output_text[:-1] | |
yield (output_text,) + (gr.Markdown.update(visible=False),) * (max_section - 1) + (gr.Gallery.update(visible=False),) * max_section + \ | |
(gr.Button.update(visible=False),) * max_section + (gr.Textbox.update(visible=False),) * max_section + (gr.Button.update(visible=False),) * max_section +\ | |
(gr.update(visible=False),) * max_section + (disable_btn,) * 2 | |
else: | |
output_text = self.generate_text(title, beam, repetition, text_num, | |
random) | |
print(output_text) | |
output_text = re.sub(r'(\n[ \t]*)+', '\n', output_text) | |
if output_text[-1] == '\n': | |
output_text = output_text[:-1] | |
print(output_text) | |
output_text = '\n'.join(output_text.split('\n')[:max_section]) | |
text_sections = output_text.split('\n') | |
idx_text_sections = [ | |
f'<Seg{i}>' + ' ' + it + '\n' for i, it in enumerate(text_sections) | |
] | |
caps = self.generate_loc_cap(idx_text_sections, '', progress) | |
#caps = {0: '成都的三日游路线图,包括春熙路、太古里、IFS国金中心、大慈寺、宽窄巷子、奎星楼街、九眼桥(酒吧一条街)、武侯祠、锦里、杜甫草堂、浣花溪公园、青羊宫、金沙遗址博物馆、文殊院、人民公园、熊猫基地、望江楼公园、东郊记忆、建设路小吃街、电子科大清水河校区、三圣乡万福花卉市场、龙湖滨江天街购物广场和返程。', 2: '春熙路的繁华景象,各种时尚潮流的品牌店和美食餐厅鳞次栉比。', 4: 'IFS国金中心的豪华购物中心,拥有众多国际知名品牌的旗舰店和专卖店,同时还有电影院、健身房 配套设施。', 6: '春熙路上的著名景点——太古里,是一个集购物、餐饮、娱乐于一体的高端时尚街区,也是成都著名的网红打卡地之一。', 8: '大慈寺的外观,是一座历史悠久的佛教寺庙,始建于唐朝,有着深厚的文化底蕴和历史价值。'} | |
#self.show_ids = {k:0 for k in caps.keys()} | |
self.show_ids = {k: 1 for k in caps.keys()} | |
print(caps) | |
self.ex_idxs = [] | |
for loc, cap in progress.tqdm(caps.items(), desc="download image"): | |
#self.show_ids[loc] += 1 | |
idxs = self.get_images_xlab(cap, loc, self.ex_idxs) | |
self.ex_idxs.extend(idxs) | |
if msi: | |
self.selected = self.model_select_image(output_text, caps, | |
'articles/' + title, | |
progress) | |
else: | |
self.selected = {k: 0 for k in caps.keys()} | |
components, md_shows = self.show_md(text_sections, title, caps, | |
self.selected) | |
self.show_caps = False | |
self.output_text = output_text | |
self.caps = caps | |
if article_stream_output: | |
yield components + [enable_btn] * 2 | |
else: | |
return components + [enable_btn] * 2 | |
def adjust_img(self, img_num, progress=gr.Progress()): | |
text_sections = self.output_text.split('\n') | |
idx_text_sections = [ | |
f'<Seg{i}>' + ' ' + it + '\n' for i, it in enumerate(text_sections) | |
] | |
img_num = min(img_num, len(text_sections)) | |
caps = self.generate_loc_cap(idx_text_sections, int(img_num), progress) | |
#caps = {1:'318川藏线沿途的风景照片', 4:'泸定桥的全景照片', 6:'折多山垭口的全景照片', 8:'稻城亚丁机场的全景照片', 10:'姊妹湖的全景照片'} | |
print(caps) | |
sidxs = [] | |
for loc, cap in caps.items(): | |
if loc in self.show_ids: | |
self.show_ids[loc] += 1 | |
else: | |
self.show_ids[loc] = 1 | |
idxs = self.get_images_xlab(cap, loc, sidxs) | |
sidxs.extend(idxs) | |
self.sidxs = sidxs | |
self.selected = {k: 0 for k in caps.keys()} | |
components, md_shows = self.show_md(text_sections, self.title, caps, | |
self.selected) | |
self.caps = caps | |
return components | |
def add_delete_image(self, text, status, index): | |
index = int(index) | |
if status == '\U0001f5d1\uFE0F': | |
if index in self.caps: | |
self.caps.pop(index) | |
self.selected.pop(index) | |
md_show = gr.Markdown.update(value=text.split('\n')[0]) | |
gallery = gr.Gallery.update(visible=False, value=[]) | |
btn_show = gr.Button.update(value='\u2795') | |
cap_textbox = gr.Textbox.update(visible=False) | |
cap_search = gr.Button.update(visible=False) | |
else: | |
md_show = gr.Markdown.update() | |
gallery = gr.Gallery.update(visible=True, value=[]) | |
btn_show = gr.Button.update(value='\U0001f5d1\uFE0F') | |
cap_textbox = gr.Textbox.update(visible=True) | |
cap_search = gr.Button.update(visible=True) | |
return md_show, gallery, btn_show, cap_textbox, cap_search | |
def search_image(self, text, index): | |
index = int(index) | |
if text == '': | |
return gr.Gallery.update() | |
if index in self.show_ids: | |
self.show_ids[index] += 1 | |
else: | |
self.show_ids[index] = 1 | |
self.caps[index] = text | |
idxs = self.get_images_xlab(text, index, self.ex_idxs) | |
self.ex_idxs.extend(idxs) | |
img_list = [('articles/{}/temp_{}_{}.png'.format( | |
self.title, self.show_ids[index] * 1000 + index, | |
j), 'articles/{}/temp_{}_{}.png'.format( | |
self.title, self.show_ids[index] * 1000 + index, j)) | |
for j in range(4)] | |
ga_show = gr.Gallery.update(visible=True, value=img_list) | |
return ga_show | |
def replace_image(self, article, index, evt: gr.SelectData): | |
index = int(index) | |
self.selected[index] = evt.index | |
if '<div align="center">' in article: | |
return re.sub(r'file=.*.png', 'file={}'.format(evt.value), article) | |
else: | |
return article + '\n' + '<div align="center"> <img src="file={}" width = 500/> </div>'.format( | |
evt.value) | |
def add_delete_caption(self): | |
self.show_caps = False if self.show_caps else True | |
text_sections = self.output_text.split('\n') | |
components, _ = self.show_md(text_sections, | |
self.title, | |
self.caps, | |
selected=self.selected, | |
show_cap=self.show_caps) | |
return components | |
def save(self): | |
folder = 'save_articles/' + self.title | |
if os.path.exists(folder): | |
for item in os.listdir(folder): | |
os.remove(os.path.join(folder, item)) | |
os.makedirs(folder, exist_ok=True) | |
save_text = '' | |
count = 0 | |
if len(self.output_text) > 0: | |
text_sections = self.output_text.split('\n') | |
for i in range(len(text_sections)): | |
if i in self.caps: | |
if self.show_caps: | |
md = text_sections[ | |
i] + '\n' + '<div align="center"> <img src="temp_{}_{}.png" width = 500/> {} </div>'.format( | |
self.show_ids[i] * 1000 + i, self.selected[i], | |
self.caps[i]) | |
else: | |
md = text_sections[ | |
i] + '\n' + '<div align="center"> <img src="temp_{}_{}.png" width = 500/> </div>'.format( | |
self.show_ids[i] * 1000 + i, self.selected[i]) | |
count += 1 | |
else: | |
md = text_sections[i] | |
save_text += md + '\n\n' | |
save_text = save_text[:-2] | |
with open(os.path.join(folder, 'io.MD'), 'w') as f: | |
f.writelines(save_text) | |
for k in self.caps.keys(): | |
shutil.copy( | |
os.path.join( | |
'articles', self.title, | |
f'temp_{self.show_ids[k] * 1000 + k}_{self.selected[k]}.png' | |
), folder) | |
archived = shutil.make_archive(folder, 'zip', folder) | |
return archived | |
def get_context_emb(self, state, img_list): | |
prompt = state.get_prompt() | |
print(prompt) | |
prompt_segs = prompt.split('<Img><ImageHere></Img>') | |
assert len(prompt_segs) == len( | |
img_list | |
) + 1, "Unmatched numbers of image placeholders and images." | |
seg_tokens = [ | |
self.llm_model.internlm_tokenizer(seg, | |
return_tensors="pt", | |
add_special_tokens=i == 0).to( | |
self.device).input_ids | |
for i, seg in enumerate(prompt_segs) | |
] | |
seg_embs = [ | |
self.llm_model.internlm_model.model.embed_tokens(seg_t) | |
for seg_t in seg_tokens | |
] | |
mixed_embs = [ | |
emb for pair in zip(seg_embs[:-1], img_list) for emb in pair | |
] + [seg_embs[-1]] | |
mixed_embs = torch.cat(mixed_embs, dim=1) | |
return mixed_embs | |
def chat_ask(self, state, img_list, text, image): | |
print(1111) | |
state.skip_next = False | |
if len(text) <= 0 and image is None: | |
state.skip_next = True | |
return (state, img_list, state.to_gradio_chatbot(), "", | |
None) + (no_change_btn, ) * 2 | |
if image is not None: | |
image_pt = self.vis_processor(image).unsqueeze(0).to(0) | |
image_emb = self.llm_model.encode_img(image_pt) | |
img_list.append(image_emb) | |
state.append_message(state.roles[0], | |
["<Img><ImageHere></Img>", image]) | |
if len(state.messages) > 0 and state.messages[-1][0] == state.roles[ | |
0] and isinstance(state.messages[-1][1], list): | |
#state.messages[-1][1] = ' '.join([state.messages[-1][1], text]) | |
state.messages[-1][1][0] = ' '.join( | |
[state.messages[-1][1][0], text]) | |
else: | |
state.append_message(state.roles[0], text) | |
print(state.messages) | |
state.append_message(state.roles[1], None) | |
return (state, img_list, state.to_gradio_chatbot(), "", | |
None) + (disable_btn, ) * 2 | |
def generate_with_callback(self, callback=None, **kwargs): | |
kwargs.setdefault("stopping_criteria", | |
transformers.StoppingCriteriaList()) | |
kwargs["stopping_criteria"].append(Stream(callback_func=callback)) | |
with torch.no_grad(): | |
with self.llm_model.maybe_autocast(): | |
self.llm_model.internlm_model.generate(**kwargs) | |
def generate_with_streaming(self, **kwargs): | |
return Iteratorize(self.generate_with_callback, kwargs, callback=None) | |
def chat_answer(self, state, img_list, max_output_tokens, | |
repetition_penalty, num_beams, do_sample): | |
# text = '图片中是一幅油画,描绘了红军长征的场景。画面中,一群红军战士正在穿过一片草地,他们身后的旗帜在风中飘扬。' | |
# for i in range(len(text)): | |
# state.messages[-1][-1] = text[:i+1] + "▌" | |
# yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2 | |
# state.messages[-1][-1] = text[:i + 1] | |
# yield (state, state.to_gradio_chatbot()) + (enable_btn, ) * 2 | |
# return | |
if state.skip_next: | |
return (state, state.to_gradio_chatbot()) + (no_change_btn, ) * 2 | |
embs = self.get_context_emb(state, img_list) | |
if chat_stream_output: | |
generate_params = dict( | |
inputs_embeds=embs, | |
num_beams=num_beams, | |
do_sample=do_sample, | |
stopping_criteria=self.stopping_criteria, | |
repetition_penalty=float(repetition_penalty), | |
max_length=max_output_tokens, | |
bos_token_id=self.llm_model.internlm_tokenizer.bos_token_id, | |
eos_token_id=self.llm_model.internlm_tokenizer.eos_token_id, | |
pad_token_id=self.llm_model.internlm_tokenizer.pad_token_id, | |
) | |
state.messages[-1][-1] = "▌" | |
with self.generate_with_streaming(**generate_params) as generator: | |
for output in generator: | |
decoded_output = self.llm_model.internlm_tokenizer.decode( | |
output[1:]) | |
if output[-1] in [ | |
self.llm_model.internlm_tokenizer.eos_token_id, 333, 497 | |
]: | |
break | |
state.messages[-1][-1] = decoded_output + "▌" | |
yield (state, | |
state.to_gradio_chatbot()) + (disable_btn, ) * 2 | |
time.sleep(0.03) | |
state.messages[-1][-1] = state.messages[-1][-1][:-1] | |
yield (state, state.to_gradio_chatbot()) + (enable_btn, ) * 2 | |
return | |
else: | |
outputs = self.llm_model.internlm_model.generate( | |
inputs_embeds=embs, | |
max_new_tokens=max_output_tokens, | |
stopping_criteria=self.stopping_criteria, | |
num_beams=num_beams, | |
#temperature=float(temperature), | |
do_sample=do_sample, | |
repetition_penalty=float(repetition_penalty), | |
bos_token_id=self.llm_model.internlm_tokenizer.bos_token_id, | |
eos_token_id=self.llm_model.internlm_tokenizer.eos_token_id, | |
pad_token_id=self.llm_model.internlm_tokenizer.pad_token_id, | |
) | |
output_token = outputs[0] | |
if output_token[0] == 0: | |
output_token = output_token[1:] | |
output_text = self.llm_model.internlm_tokenizer.decode( | |
output_token, add_special_tokens=False) | |
print(output_text) | |
output_text = output_text.split('<TOKENS_UNUSED_1>')[ | |
0] # remove the stop sign '###' | |
output_text = output_text.split('Assistant:')[-1].strip() | |
output_text = output_text.replace("<s>", "") | |
state.messages[-1][1] = output_text | |
return (state, state.to_gradio_chatbot()) + (enable_btn, ) * 2 | |
def clear_answer(self, state): | |
state.messages[-1][-1] = None | |
return (state, state.to_gradio_chatbot()) | |
def chat_clear_history(self): | |
state = CONV_VISION_7132_v2.copy() | |
return (state, [], state.to_gradio_chatbot(), "", | |
None) + (disable_btn, ) * 2 | |
def load_demo(): | |
state = CONV_VISION_7132_v2.copy() | |
return (state, [], gr.Chatbot.update(visible=True), | |
gr.Textbox.update(visible=True), gr.Button.update(visible=True), | |
gr.Row.update(visible=True), gr.Accordion.update(visible=True)) | |
def change_language(lang): | |
if lang == '中文': | |
lang_btn = gr.update(value='English') | |
title = gr.update(label='根据给定标题写一个图文并茂的文章:') | |
btn = gr.update(value='生成') | |
parameter_article = gr.update(label='高级设置') | |
beam = gr.update(label='集束大小') | |
repetition = gr.update(label='重复惩罚') | |
text_num = gr.update(label='最多输出字数') | |
msi = gr.update(label='模型选图') | |
random = gr.update(label='采样') | |
img_num = gr.update(label='生成文章后,可选择全文配图数量') | |
adjust_btn = gr.update(value='固定数量配图') | |
cap_searchs, editers = [], [] | |
for _ in range(max_section): | |
cap_searchs.append(gr.update(value='搜索')) | |
editers.append(gr.update(label='编辑')) | |
save_btn = gr.update(value='文章下载') | |
save_file = gr.update(label='文章下载') | |
parameter_chat = gr.update(label='参数') | |
chat_text_num = gr.update(label='最多输出字数') | |
chat_beam = gr.update(label='集束大小') | |
chat_repetition = gr.update(label='重复惩罚') | |
chat_random = gr.update(label='采样') | |
chat_textbox = gr.update(placeholder='输入聊天内容并回车') | |
submit_btn = gr.update(value='提交') | |
regenerate_btn = gr.update(value='🔄 重新生成') | |
clear_btn = gr.update(value='🗑️ 清空聊天框') | |
elif lang == 'English': | |
lang_btn = gr.update(value='中文') | |
title = gr.update( | |
label='Write an illustrated article based on the given title:') | |
btn = gr.update(value='Submit') | |
parameter_article = gr.update(label='Advanced Settings') | |
beam = gr.update(label='Beam Size') | |
repetition = gr.update(label='Repetition_penalty') | |
text_num = gr.update(label='Max output tokens') | |
msi = gr.update(label='Model selects images') | |
random = gr.update(label='Do_sample') | |
img_num = gr.update( | |
label= | |
'Select the number of the inserted image after article generation.' | |
) | |
adjust_btn = gr.update(value='Insert a fixed number of images') | |
cap_searchs, editers = [], [] | |
for _ in range(max_section): | |
cap_searchs.append(gr.update(value='Search')) | |
editers.append(gr.update(label='edit')) | |
save_btn = gr.update(value='Save article') | |
save_file = gr.update(label='Save article') | |
parameter_chat = gr.update(label='Parameters') | |
chat_text_num = gr.update(label='Max output tokens') | |
chat_beam = gr.update(label='Beam Size') | |
chat_repetition = gr.update(label='Repetition_penalty') | |
chat_random = gr.update(label='Do_sample') | |
chat_textbox = gr.update(placeholder='Enter text and press ENTER') | |
submit_btn = gr.update(value='Submit') | |
regenerate_btn = gr.update(value='🔄 Regenerate') | |
clear_btn = gr.update(value='🗑️ Clear history') | |
return [lang_btn, title, btn, parameter_article, beam, repetition, text_num, msi, random, img_num, adjust_btn] +\ | |
cap_searchs + editers + [save_btn, save_file] +[parameter_chat, chat_text_num, chat_beam, chat_repetition, chat_random] + \ | |
[chat_textbox, submit_btn, regenerate_btn, clear_btn] | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--folder", default='internlm/internlm-xcomposer-7b') | |
parser.add_argument("--private", default=False, action='store_true') | |
args = parser.parse_args() | |
demo_ui = Demo_UI(args.folder) | |
with gr.Blocks(css=custom_css, title='浦语·灵笔 (InternLM-XComposer)') as demo: | |
with gr.Row(): | |
with gr.Column(scale=20): | |
#gr.HTML("""<h1 align="center" id="space-title" style="font-size:35px;">🤗 浦语·灵笔 (InternLM-XComposer)</h1>""") | |
gr.HTML( | |
"""<h1 align="center"><img src="https://raw.githubusercontent.com/panzhang0212/interleaved_io/main/logo.png", alt="InternLM-XComposer" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>""" | |
) | |
with gr.Column(scale=1, min_width=100): | |
lang_btn = gr.Button("中文") | |
with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
with gr.TabItem("📝 创作图文并茂文章 (Write Interleaved-text-image Article)"): | |
with gr.Row(): | |
title = gr.Textbox( | |
label= | |
'Write an illustrated article based on the given title:', | |
scale=2) | |
btn = gr.Button("Submit", scale=1) | |
with gr.Row(): | |
img_num = gr.Slider( | |
minimum=1.0, | |
maximum=30.0, | |
value=5.0, | |
step=1.0, | |
scale=2, | |
label= | |
'Select the number of the inserted image after article generation.' | |
) | |
adjust_btn = gr.Button('Insert a fixed number of images', | |
interactive=False, | |
scale=1) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Accordion("Advanced Settings", | |
open=False, | |
visible=True) as parameter_article: | |
beam = gr.Slider(minimum=1.0, | |
maximum=6.0, | |
value=5.0, | |
step=1.0, | |
label='Beam Size') | |
repetition = gr.Slider(minimum=0.0, | |
maximum=10.0, | |
value=5.0, | |
step=0.1, | |
label='Repetition_penalty') | |
text_num = gr.Slider(minimum=100.0, | |
maximum=2000.0, | |
value=1000.0, | |
step=1.0, | |
label='Max output tokens') | |
msi = gr.Checkbox(value=True, | |
label='Model selects images') | |
random = gr.Checkbox(label='Do_sample') | |
with gr.Column(scale=1): | |
gr.Examples( | |
examples=[["又见敦煌"], ["星链新闻稿"], ["如何养好一只宠物"], | |
["Shanghai Travel Guide in English"], ["Travel guidance of London in English"], ["Advertising for Genshin Impact in English"]], | |
inputs=[title], | |
) | |
articles = [] | |
gallerys = [] | |
add_delete_btns = [] | |
cap_textboxs = [] | |
cap_searchs = [] | |
editers = [] | |
with gr.Column(): | |
for i in range(max_section): | |
with gr.Row(): | |
visible = True if i == 0 else False | |
with gr.Column(scale=2): | |
article = gr.Markdown(visible=visible, | |
elem_classes='feedback') | |
articles.append(article) | |
with gr.Column(scale=1): | |
with gr.Accordion('edit', | |
open=False, | |
visible=False) as editer: | |
with gr.Row(): | |
cap_textbox = gr.Textbox(show_label=False, | |
interactive=True, | |
scale=6, | |
visible=False) | |
cap_search = gr.Button(value="Search", | |
visible=False, | |
scale=1) | |
with gr.Row(): | |
gallery = gr.Gallery(visible=False, | |
columns=2, | |
height='auto') | |
add_delete_btn = gr.Button(visible=False) | |
gallery.select(demo_ui.replace_image, [ | |
articles[i], | |
gr.Number(value=i, visible=False) | |
], articles[i]) | |
gallerys.append(gallery) | |
add_delete_btns.append(add_delete_btn) | |
cap_textboxs.append(cap_textbox) | |
cap_searchs.append(cap_search) | |
editers.append(editer) | |
save_btn = gr.Button("Save article") | |
save_file = gr.File(label="Save article") | |
for i in range(max_section): | |
add_delete_btns[i].click(demo_ui.add_delete_image, | |
inputs=[ | |
articles[i], | |
add_delete_btns[i], | |
gr.Number(value=i, | |
visible=False) | |
], | |
outputs=[ | |
articles[i], gallerys[i], | |
add_delete_btns[i], | |
cap_textboxs[i], | |
cap_searchs[i] | |
]) | |
cap_searchs[i].click(demo_ui.search_image, | |
inputs=[ | |
cap_textboxs[i], | |
gr.Number(value=i, visible=False) | |
], | |
outputs=gallerys[i]) | |
btn.click( | |
demo_ui.generate_article, | |
inputs=[title, beam, repetition, text_num, msi, random], | |
outputs=articles + gallerys + add_delete_btns + | |
cap_textboxs + cap_searchs + editers + [btn, adjust_btn]) | |
# cap_btn.click(demo_ui.add_delete_caption, inputs=None, outputs=articles) | |
save_btn.click(demo_ui.save, inputs=None, outputs=save_file) | |
adjust_btn.click(demo_ui.adjust_img, | |
inputs=img_num, | |
outputs=articles + gallerys + | |
add_delete_btns + cap_textboxs + cap_searchs + | |
editers) | |
with gr.TabItem("💬 多模态对话 (Multimodal Chat)", elem_id="chat", id=0): | |
chat_state = gr.State() | |
img_list = gr.State() | |
with gr.Row(): | |
with gr.Column(scale=3): | |
imagebox = gr.Image(type="pil") | |
with gr.Accordion("Parameters", open=True, | |
visible=False) as parameter_row: | |
chat_max_output_tokens = gr.Slider( | |
minimum=0, | |
maximum=1024, | |
value=512, | |
step=64, | |
interactive=True, | |
label="Max output tokens", | |
) | |
chat_num_beams = gr.Slider( | |
minimum=1, | |
maximum=5, | |
value=3, | |
step=1, | |
interactive=True, | |
label="Beam Size", | |
) | |
chat_repetition_penalty = gr.Slider( | |
minimum=1, | |
maximum=5, | |
value=1, | |
step=0.1, | |
interactive=True, | |
label="Repetition_penalty", | |
) | |
# chat_temperature = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, interactive=True, | |
# label="Temperature", ) | |
chat_do_sample = gr.Checkbox(interactive=True, | |
value=True, | |
label="Do_sample") | |
with gr.Column(scale=6): | |
chatbot = grChatbot(elem_id="chatbot", | |
visible=False, | |
height=750) | |
with gr.Row(): | |
with gr.Column(scale=8): | |
chat_textbox = gr.Textbox( | |
show_label=False, | |
placeholder="Enter text and press ENTER", | |
visible=False).style(container=False) | |
with gr.Column(scale=1, min_width=60): | |
submit_btn = gr.Button(value="Submit", | |
visible=False) | |
with gr.Row(visible=True) as button_row: | |
regenerate_btn = gr.Button(value="🔄 Regenerate", | |
interactive=False) | |
clear_btn = gr.Button(value="🗑️ Clear history", | |
interactive=False) | |
btn_list = [regenerate_btn, clear_btn] | |
parameter_list = [ | |
chat_max_output_tokens, chat_repetition_penalty, | |
chat_num_beams, chat_do_sample | |
] | |
chat_textbox.submit( | |
demo_ui.chat_ask, | |
[chat_state, img_list, chat_textbox, imagebox], | |
[chat_state, img_list, chatbot, chat_textbox, imagebox] + | |
btn_list).then(demo_ui.chat_answer, | |
[chat_state, img_list] + parameter_list, | |
[chat_state, chatbot] + btn_list) | |
submit_btn.click( | |
demo_ui.chat_ask, | |
[chat_state, img_list, chat_textbox, imagebox], | |
[chat_state, img_list, chatbot, chat_textbox, imagebox] + | |
btn_list).then(demo_ui.chat_answer, | |
[chat_state, img_list] + parameter_list, | |
[chat_state, chatbot] + btn_list) | |
regenerate_btn.click(demo_ui.clear_answer, chat_state, | |
[chat_state, chatbot]).then( | |
demo_ui.chat_answer, | |
[chat_state, img_list] + parameter_list, | |
[chat_state, chatbot] + btn_list) | |
clear_btn.click( | |
demo_ui.chat_clear_history, None, | |
[chat_state, img_list, chatbot, chat_textbox, imagebox] + | |
btn_list) | |
demo.load(load_demo, None, [ | |
chat_state, img_list, chatbot, chat_textbox, submit_btn, | |
parameter_row | |
]) | |
lang_btn.click(change_language, inputs=lang_btn, outputs=[lang_btn, title, btn, parameter_article] +\ | |
[beam, repetition, text_num, msi, random, img_num, adjust_btn] + cap_searchs + editers +\ | |
[save_btn, save_file] + [parameter_row, chat_max_output_tokens, chat_num_beams, chat_repetition_penalty, chat_do_sample] +\ | |
[chat_textbox, submit_btn, regenerate_btn, clear_btn]) | |
demo.queue(concurrency_count=8, status_update_rate=10, api_open=False) | |
if __name__ == "__main__": | |
demo.launch() | |