Willow123's picture
Upload 12 files
c195a6f
raw
history blame
49.7 kB
import os
import re
import sys
sys.path.insert(0, '.')
sys.path.insert(0, '..')
import argparse
import gradio as gr
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), 'tmp')
import copy
import time
import shutil
import requests
from PIL import Image, ImageFile
import torch
import transformers
from transformers import StoppingCriteriaList, AutoTokenizer, AutoModel
ImageFile.LOAD_TRUNCATED_IMAGES = True
from demo_asset.assets.css_html_js import custom_css
from demo_asset.gradio_patch import Chatbot as grChatbot
from demo_asset.serve_utils import Stream, Iteratorize
from demo_asset.conversation import CONV_VISION_7132_v2, StoppingCriteriaSub
from demo_asset.download import download_image_thread
max_section = 60
no_change_btn = gr.Button.update()
disable_btn = gr.Button.update(interactive=False)
enable_btn = gr.Button.update(interactive=True)
chat_stream_output = True
article_stream_output = True
def get_urls(caption, exclude):
headers = {'Content-Type': 'application/json'}
json_data = {'caption': caption, 'exclude': exclude, 'need_idxs': True}
response = requests.post('https://lingbi.openxlab.org.cn/image/similar',
headers=headers,
json=json_data)
urls = response.json()['data']['image_urls']
idx = response.json()['data']['indices']
return urls, idx
class Demo_UI:
def __init__(self, folder):
self.llm_model = AutoModel.from_pretrained(folder, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(folder, trust_remote_code=True)
self.llm_model.internlm_tokenizer = tokenizer
self.llm_model.tokenizer = tokenizer
self.llm_model.eval().to('cuda')
self.device = 'cuda'
print(f" load model done: ", type(self.llm_model))
self.eoh = self.llm_model.internlm_tokenizer.decode(
torch.Tensor([103027]), skip_special_tokens=True)
self.eoa = self.llm_model.internlm_tokenizer.decode(
torch.Tensor([103028]), skip_special_tokens=True)
self.soi_id = len(tokenizer) - 1
self.soi_token = '<SOI_TOKEN>'
self.vis_processor = self.llm_model.vis_processor
self.device = 'cuda'
stop_words_ids = [
torch.tensor([943]).to(self.device),
torch.tensor([2917, 44930]).to(self.device),
torch.tensor([45623]).to(self.device), ### new setting
torch.tensor([46323]).to(self.device), ### new setting
torch.tensor([103027]).to(self.device), ### new setting
torch.tensor([103028]).to(self.device), ### new setting
]
self.stopping_criteria = StoppingCriteriaList(
[StoppingCriteriaSub(stops=stop_words_ids)])
self.r2 = re.compile(r'<Seg[0-9]*>')
self.max_txt_len = 1680
def reset(self):
self.output_text = ''
self.caps = {}
self.show_caps = False
self.show_ids = {}
def get_images_xlab(self, caption, loc, exclude):
urls, idxs = get_urls(caption.strip()[:53], exclude)
print(urls[0])
print('download image with url')
download_image_thread(urls,
folder='articles/' + self.title,
index=self.show_ids[loc] * 1000 + loc,
num_processes=4)
print('image downloaded')
return idxs
def generate(self, text, random, beam, max_length, repetition):
input_tokens = self.llm_model.internlm_tokenizer(
text, return_tensors="pt",
add_special_tokens=True).to(self.llm_model.device)
img_embeds = self.llm_model.internlm_model.model.embed_tokens(
input_tokens.input_ids)
with torch.no_grad():
with self.llm_model.maybe_autocast():
outputs = self.llm_model.internlm_model.generate(
inputs_embeds=img_embeds,
stopping_criteria=self.stopping_criteria,
do_sample=random,
num_beams=beam,
max_length=max_length,
repetition_penalty=float(repetition),
)
output_text = self.llm_model.internlm_tokenizer.decode(
outputs[0][1:], add_special_tokens=False)
output_text = output_text.split('<TOKENS_UNUSED_1>')[0]
return output_text
def generate_text(self, title, beam, repetition, text_num, random):
text = ' <|User|>:根据给定标题写一个图文并茂,不重复的文章:{}\n'.format(
title) + self.eoh + ' <|Bot|>:'
print('random generate:{}'.format(random))
output_text = self.generate(text, random, beam, text_num, repetition)
return output_text
def generate_loc(self, text_sections, image_num, progress):
full_txt = ''.join(text_sections)
input_text = f' <|User|>:给定文章"{full_txt}" 根据上述文章,选择适合插入图像的{image_num}行' + ' \n<TOKENS_UNUSED_0> <|Bot|>:适合插入图像的行是'
for _ in progress.tqdm([1], desc="image spotting"):
output_text = self.generate(input_text,
random=False,
beam=5,
max_length=300,
repetition=1.)
inject_text = '适合插入图像的行是' + output_text
print(inject_text)
locs = []
for m in self.r2.findall(inject_text):
locs.append(int(m[4:-1]))
print(locs)
return inject_text, locs
def generate_cap(self, text_sections, pos, progress):
pasts = ''
caps = {}
for idx, po in progress.tqdm(enumerate(pos), desc="image captioning"):
full_txt = ''.join(text_sections[:po + 2])
if idx > 0:
past = pasts[:-2] + '。'
else:
past = pasts
input_text = f' <|User|>: 给定文章"{full_txt}" {past}给出适合在<Seg{po}>后插入的图像对应的标题。' + ' \n<TOKENS_UNUSED_0> <|Bot|>: 标题是"'
cap_text = self.generate(input_text,
random=False,
beam=1,
max_length=100,
repetition=5.)
cap_text = cap_text.split('"')[0].strip()
print(cap_text)
caps[po] = cap_text
if idx == 0:
pasts = f'现在<Seg{po}>后插入图像对应的标题是"{cap_text}", '
else:
pasts += f'<Seg{po}>后插入图像对应的标题是"{cap_text}", '
print(caps)
return caps
def generate_loc_cap(self, text_sections, image_num, progress):
inject_text, locs = self.generate_loc(text_sections, image_num,
progress)
caps = self.generate_cap(text_sections, locs, progress)
return caps
def interleav_wrap(self, img_embeds, text):
batch_size = img_embeds.shape[0]
im_len = img_embeds.shape[1]
text = text[0]
text = text.replace('<Img>', '')
text = text.replace('</Img>', '')
parts = text.split('<ImageHere>')
assert batch_size + 1 == len(parts)
warp_tokens = []
warp_embeds = []
warp_attns = []
soi = (torch.ones([1, 1]) * self.soi_id).long().to(img_embeds.device)
soi_embeds = self.llm_model.internlm_model.model.embed_tokens(soi)
temp_len = 0
for idx, part in enumerate(parts):
if len(part) > 0:
part_tokens = self.llm_model.internlm_tokenizer(
part, return_tensors="pt",
add_special_tokens=False).to(img_embeds.device)
part_embeds = self.llm_model.internlm_model.model.embed_tokens(
part_tokens.input_ids)
warp_tokens.append(part_tokens.input_ids)
warp_embeds.append(part_embeds)
temp_len += part_embeds.shape[1]
if idx < batch_size:
warp_tokens.append(soi.expand(-1, img_embeds[idx].shape[0]))
# warp_tokens.append(soi.expand(-1, img_embeds[idx].shape[0] + 1))
# warp_embeds.append(soi_embeds) ### 1, 1, C
warp_embeds.append(img_embeds[idx].unsqueeze(0)) ### 1, 34, C
temp_len += im_len
if temp_len > self.max_txt_len:
break
warp_embeds = torch.cat(warp_embeds, dim=1)
return warp_embeds[:, :self.max_txt_len].to(img_embeds.device)
def align_text(self, samples):
text_new = []
text = [t + self.eoa + ' </s>' for t in samples["text_input"]]
for i in range(len(text)):
temp = text[i]
temp = temp.replace('###Human', '<|User|>')
temp = temp.replace('### Human', '<|User|>')
temp = temp.replace('<|User|> :', '<|User|>:')
temp = temp.replace('<|User|>: ', '<|User|>:')
temp = temp.replace('<|User|>', ' <|User|>')
temp = temp.replace('###Assistant', '<|Bot|>')
temp = temp.replace('### Assistant', '<|Bot|>')
temp = temp.replace('<|Bot|> :', '<|Bot|>:')
temp = temp.replace('<|Bot|>: ', '<|Bot|>:')
temp = temp.replace('<|Bot|>', self.eoh + ' <|Bot|>')
if temp.find('<|User|>') > temp.find('<|Bot|>'):
temp = temp.replace(' <|User|>', self.eoa + ' <|User|>')
text_new.append(temp)
#print (temp)
return text_new
def model_select_image(self, output_text, caps, root, progress):
print('model_select_image')
pre_text = ''
pre_img = []
pre_text_list = []
ans2idx = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
selected = {k: 0 for k in caps.keys()}
for i, text in enumerate(output_text.split('\n')):
pre_text += text + '\n'
if i in caps:
images = copy.deepcopy(pre_img)
for j in range(4):
image = Image.open(
os.path.join(
root, f'temp_{self.show_ids[i] * 1000 + i}_{j}.png'
)).convert("RGB")
image = self.vis_processor(image)
images.append(image)
images = torch.stack(images, dim=0)
pre_text_list.append(pre_text)
pre_text = ''
images = images.cuda()
instruct = ' <|User|>:根据给定上下文和候选图像,选择合适的配图:'
input_text = '<ImageHere>'.join(
pre_text_list
) + '\n\n候选图像包括: A.<ImageHere>\nB.<ImageHere>\nC.<ImageHere>\nD.<ImageHere>\n\n<TOKENS_UNUSED_0> <|Bot|>:最合适的图是'
input_text = instruct + input_text
samples = {}
samples['text_input'] = [input_text]
self.llm_model.debug_flag = 0
with torch.no_grad():
with torch.cuda.amp.autocast():
img_embeds = self.llm_model.encode_img(images)
input_text = self.align_text(samples)
img_embeds = self.interleav_wrap(
img_embeds, input_text)
bos = torch.ones(
[1, 1]) * self.llm_model.internlm_tokenizer.bos_token_id
bos = bos.long().to(images.device)
meta_embeds = self.llm_model.internlm_model.model.embed_tokens(
bos)
inputs_embeds = torch.cat([meta_embeds, img_embeds], dim=1)
with torch.cuda.amp.autocast():
outputs = self.llm_model.internlm_model.generate(
inputs_embeds=inputs_embeds[:, :-2],
do_sample=False,
num_beams=5,
max_length=10,
repetition_penalty=1.,
)
out_text = self.llm_model.internlm_tokenizer.decode(
outputs[0][1:], add_special_tokens=False)
try:
answer = out_text[1] if out_text[0] == ' ' else out_text[0]
pre_img.append(images[len(pre_img) + ans2idx[answer]].cpu())
except:
print('Select fail, use first image')
answer = 'A'
pre_img.append(images[len(pre_img) + ans2idx[answer]].cpu())
selected[i] = ans2idx[answer]
return selected
def show_md(self, text_sections, title, caps, selected, show_cap=False):
md_shows = []
ga_shows = []
btn_shows = []
cap_textboxs, cap_searchs = [], []
editers = []
for i in range(len(text_sections)):
if i in caps:
if show_cap:
md = text_sections[
i] + '\n' + '<div align="center"> <img src="file/articles/{}/temp_{}_{}.png" width = 500/> {} </div>'.format(
title, self.show_ids[i] * 1000 + i, selected[i],
caps[i])
else:
md = text_sections[
i] + '\n' + '<div align="center"> <img src="file=articles/{}/temp_{}_{}.png" width = 500/> </div>'.format(
title, self.show_ids[i] * 1000 + i, selected[i])
img_list = [('articles/{}/temp_{}_{}.png'.format(
title, self.show_ids[i] * 1000 + i,
j), 'articles/{}/temp_{}_{}.png'.format(
title, self.show_ids[i] * 1000 + i, j))
for j in range(4)]
ga_show = gr.Gallery.update(visible=True, value=img_list)
ga_shows.append(ga_show)
btn_show = gr.Button.update(visible=True,
value='\U0001f5d1\uFE0F')
cap_textboxs.append(
gr.Textbox.update(visible=True, value=caps[i]))
cap_searchs.append(gr.Button.update(visible=True))
else:
md = text_sections[i]
ga_show = gr.Gallery.update(visible=False, value=[])
ga_shows.append(ga_show)
btn_show = gr.Button.update(visible=True, value='\u2795')
cap_textboxs.append(gr.Textbox.update(visible=False))
cap_searchs.append(gr.Button.update(visible=False))
md_show = gr.Markdown.update(visible=True, value=md)
md_shows.append(md_show)
btn_shows.append(btn_show)
editers.append(gr.update(visible=True))
print(i, md)
md_hides = []
ga_hides = []
btn_hides = []
for i in range(max_section - len(text_sections)):
md_hide = gr.Markdown.update(visible=False, value='')
md_hides.append(md_hide)
btn_hide = gr.Button.update(visible=False)
btn_hides.append(btn_hide)
editers.append(gr.update(visible=False))
for i in range(max_section - len(ga_shows)):
ga_hide = gr.Gallery.update(visible=False, value=[])
ga_hides.append(ga_hide)
cap_textboxs.append(gr.Textbox.update(visible=False))
cap_searchs.append(gr.Button.update(visible=False))
return md_shows + md_hides + ga_shows + ga_hides + btn_shows + btn_hides + cap_textboxs + cap_searchs + editers, md_shows
def generate_article(self,
title,
beam,
repetition,
text_num,
msi,
random,
progress=gr.Progress()):
self.reset()
self.title = title
if article_stream_output:
text = ' <|User|>:根据给定标题写一个图文并茂,不重复的文章:{}\n'.format(
title) + self.eoh + ' <|Bot|>:'
input_tokens = self.llm_model.internlm_tokenizer(
text, return_tensors="pt",
add_special_tokens=True).to(self.llm_model.device)
img_embeds = self.llm_model.internlm_model.model.embed_tokens(
input_tokens.input_ids)
generate_params = dict(
inputs_embeds=img_embeds,
num_beams=beam,
do_sample=random,
stopping_criteria=self.stopping_criteria,
repetition_penalty=float(repetition),
max_length=text_num,
bos_token_id=self.llm_model.internlm_tokenizer.bos_token_id,
eos_token_id=self.llm_model.internlm_tokenizer.eos_token_id,
pad_token_id=self.llm_model.internlm_tokenizer.pad_token_id,
)
output_text = "▌"
with self.generate_with_streaming(**generate_params) as generator:
for output in generator:
decoded_output = self.llm_model.internlm_tokenizer.decode(
output[1:])
if output[-1] in [
self.llm_model.internlm_tokenizer.eos_token_id
]:
break
output_text = decoded_output.replace('\n', '\n\n') + "▌"
yield (output_text,) + (gr.Markdown.update(visible=False),) * (max_section - 1) + (gr.Gallery.update(visible=False),) * max_section + \
(gr.Button.update(visible=False),) * max_section + (gr.Textbox.update(visible=False),) * max_section + (gr.Button.update(visible=False),) * max_section + \
(gr.update(visible=False),) * max_section + (disable_btn,) * 2
time.sleep(0.03)
output_text = output_text[:-1]
yield (output_text,) + (gr.Markdown.update(visible=False),) * (max_section - 1) + (gr.Gallery.update(visible=False),) * max_section + \
(gr.Button.update(visible=False),) * max_section + (gr.Textbox.update(visible=False),) * max_section + (gr.Button.update(visible=False),) * max_section +\
(gr.update(visible=False),) * max_section + (disable_btn,) * 2
else:
output_text = self.generate_text(title, beam, repetition, text_num,
random)
print(output_text)
output_text = re.sub(r'(\n[ \t]*)+', '\n', output_text)
if output_text[-1] == '\n':
output_text = output_text[:-1]
print(output_text)
output_text = '\n'.join(output_text.split('\n')[:max_section])
text_sections = output_text.split('\n')
idx_text_sections = [
f'<Seg{i}>' + ' ' + it + '\n' for i, it in enumerate(text_sections)
]
caps = self.generate_loc_cap(idx_text_sections, '', progress)
#caps = {0: '成都的三日游路线图,包括春熙路、太古里、IFS国金中心、大慈寺、宽窄巷子、奎星楼街、九眼桥(酒吧一条街)、武侯祠、锦里、杜甫草堂、浣花溪公园、青羊宫、金沙遗址博物馆、文殊院、人民公园、熊猫基地、望江楼公园、东郊记忆、建设路小吃街、电子科大清水河校区、三圣乡万福花卉市场、龙湖滨江天街购物广场和返程。', 2: '春熙路的繁华景象,各种时尚潮流的品牌店和美食餐厅鳞次栉比。', 4: 'IFS国金中心的豪华购物中心,拥有众多国际知名品牌的旗舰店和专卖店,同时还有电影院、健身房 配套设施。', 6: '春熙路上的著名景点——太古里,是一个集购物、餐饮、娱乐于一体的高端时尚街区,也是成都著名的网红打卡地之一。', 8: '大慈寺的外观,是一座历史悠久的佛教寺庙,始建于唐朝,有着深厚的文化底蕴和历史价值。'}
#self.show_ids = {k:0 for k in caps.keys()}
self.show_ids = {k: 1 for k in caps.keys()}
print(caps)
self.ex_idxs = []
for loc, cap in progress.tqdm(caps.items(), desc="download image"):
#self.show_ids[loc] += 1
idxs = self.get_images_xlab(cap, loc, self.ex_idxs)
self.ex_idxs.extend(idxs)
if msi:
self.selected = self.model_select_image(output_text, caps,
'articles/' + title,
progress)
else:
self.selected = {k: 0 for k in caps.keys()}
components, md_shows = self.show_md(text_sections, title, caps,
self.selected)
self.show_caps = False
self.output_text = output_text
self.caps = caps
if article_stream_output:
yield components + [enable_btn] * 2
else:
return components + [enable_btn] * 2
def adjust_img(self, img_num, progress=gr.Progress()):
text_sections = self.output_text.split('\n')
idx_text_sections = [
f'<Seg{i}>' + ' ' + it + '\n' for i, it in enumerate(text_sections)
]
img_num = min(img_num, len(text_sections))
caps = self.generate_loc_cap(idx_text_sections, int(img_num), progress)
#caps = {1:'318川藏线沿途的风景照片', 4:'泸定桥的全景照片', 6:'折多山垭口的全景照片', 8:'稻城亚丁机场的全景照片', 10:'姊妹湖的全景照片'}
print(caps)
sidxs = []
for loc, cap in caps.items():
if loc in self.show_ids:
self.show_ids[loc] += 1
else:
self.show_ids[loc] = 1
idxs = self.get_images_xlab(cap, loc, sidxs)
sidxs.extend(idxs)
self.sidxs = sidxs
self.selected = {k: 0 for k in caps.keys()}
components, md_shows = self.show_md(text_sections, self.title, caps,
self.selected)
self.caps = caps
return components
def add_delete_image(self, text, status, index):
index = int(index)
if status == '\U0001f5d1\uFE0F':
if index in self.caps:
self.caps.pop(index)
self.selected.pop(index)
md_show = gr.Markdown.update(value=text.split('\n')[0])
gallery = gr.Gallery.update(visible=False, value=[])
btn_show = gr.Button.update(value='\u2795')
cap_textbox = gr.Textbox.update(visible=False)
cap_search = gr.Button.update(visible=False)
else:
md_show = gr.Markdown.update()
gallery = gr.Gallery.update(visible=True, value=[])
btn_show = gr.Button.update(value='\U0001f5d1\uFE0F')
cap_textbox = gr.Textbox.update(visible=True)
cap_search = gr.Button.update(visible=True)
return md_show, gallery, btn_show, cap_textbox, cap_search
def search_image(self, text, index):
index = int(index)
if text == '':
return gr.Gallery.update()
if index in self.show_ids:
self.show_ids[index] += 1
else:
self.show_ids[index] = 1
self.caps[index] = text
idxs = self.get_images_xlab(text, index, self.ex_idxs)
self.ex_idxs.extend(idxs)
img_list = [('articles/{}/temp_{}_{}.png'.format(
self.title, self.show_ids[index] * 1000 + index,
j), 'articles/{}/temp_{}_{}.png'.format(
self.title, self.show_ids[index] * 1000 + index, j))
for j in range(4)]
ga_show = gr.Gallery.update(visible=True, value=img_list)
return ga_show
def replace_image(self, article, index, evt: gr.SelectData):
index = int(index)
self.selected[index] = evt.index
if '<div align="center">' in article:
return re.sub(r'file=.*.png', 'file={}'.format(evt.value), article)
else:
return article + '\n' + '<div align="center"> <img src="file={}" width = 500/> </div>'.format(
evt.value)
def add_delete_caption(self):
self.show_caps = False if self.show_caps else True
text_sections = self.output_text.split('\n')
components, _ = self.show_md(text_sections,
self.title,
self.caps,
selected=self.selected,
show_cap=self.show_caps)
return components
def save(self):
folder = 'save_articles/' + self.title
if os.path.exists(folder):
for item in os.listdir(folder):
os.remove(os.path.join(folder, item))
os.makedirs(folder, exist_ok=True)
save_text = ''
count = 0
if len(self.output_text) > 0:
text_sections = self.output_text.split('\n')
for i in range(len(text_sections)):
if i in self.caps:
if self.show_caps:
md = text_sections[
i] + '\n' + '<div align="center"> <img src="temp_{}_{}.png" width = 500/> {} </div>'.format(
self.show_ids[i] * 1000 + i, self.selected[i],
self.caps[i])
else:
md = text_sections[
i] + '\n' + '<div align="center"> <img src="temp_{}_{}.png" width = 500/> </div>'.format(
self.show_ids[i] * 1000 + i, self.selected[i])
count += 1
else:
md = text_sections[i]
save_text += md + '\n\n'
save_text = save_text[:-2]
with open(os.path.join(folder, 'io.MD'), 'w') as f:
f.writelines(save_text)
for k in self.caps.keys():
shutil.copy(
os.path.join(
'articles', self.title,
f'temp_{self.show_ids[k] * 1000 + k}_{self.selected[k]}.png'
), folder)
archived = shutil.make_archive(folder, 'zip', folder)
return archived
def get_context_emb(self, state, img_list):
prompt = state.get_prompt()
print(prompt)
prompt_segs = prompt.split('<Img><ImageHere></Img>')
assert len(prompt_segs) == len(
img_list
) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.llm_model.internlm_tokenizer(seg,
return_tensors="pt",
add_special_tokens=i == 0).to(
self.device).input_ids
for i, seg in enumerate(prompt_segs)
]
seg_embs = [
self.llm_model.internlm_model.model.embed_tokens(seg_t)
for seg_t in seg_tokens
]
mixed_embs = [
emb for pair in zip(seg_embs[:-1], img_list) for emb in pair
] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
def chat_ask(self, state, img_list, text, image):
print(1111)
state.skip_next = False
if len(text) <= 0 and image is None:
state.skip_next = True
return (state, img_list, state.to_gradio_chatbot(), "",
None) + (no_change_btn, ) * 2
if image is not None:
image_pt = self.vis_processor(image).unsqueeze(0).to(0)
image_emb = self.llm_model.encode_img(image_pt)
img_list.append(image_emb)
state.append_message(state.roles[0],
["<Img><ImageHere></Img>", image])
if len(state.messages) > 0 and state.messages[-1][0] == state.roles[
0] and isinstance(state.messages[-1][1], list):
#state.messages[-1][1] = ' '.join([state.messages[-1][1], text])
state.messages[-1][1][0] = ' '.join(
[state.messages[-1][1][0], text])
else:
state.append_message(state.roles[0], text)
print(state.messages)
state.append_message(state.roles[1], None)
return (state, img_list, state.to_gradio_chatbot(), "",
None) + (disable_btn, ) * 2
def generate_with_callback(self, callback=None, **kwargs):
kwargs.setdefault("stopping_criteria",
transformers.StoppingCriteriaList())
kwargs["stopping_criteria"].append(Stream(callback_func=callback))
with torch.no_grad():
with self.llm_model.maybe_autocast():
self.llm_model.internlm_model.generate(**kwargs)
def generate_with_streaming(self, **kwargs):
return Iteratorize(self.generate_with_callback, kwargs, callback=None)
def chat_answer(self, state, img_list, max_output_tokens,
repetition_penalty, num_beams, do_sample):
# text = '图片中是一幅油画,描绘了红军长征的场景。画面中,一群红军战士正在穿过一片草地,他们身后的旗帜在风中飘扬。'
# for i in range(len(text)):
# state.messages[-1][-1] = text[:i+1] + "▌"
# yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
# state.messages[-1][-1] = text[:i + 1]
# yield (state, state.to_gradio_chatbot()) + (enable_btn, ) * 2
# return
if state.skip_next:
return (state, state.to_gradio_chatbot()) + (no_change_btn, ) * 2
embs = self.get_context_emb(state, img_list)
if chat_stream_output:
generate_params = dict(
inputs_embeds=embs,
num_beams=num_beams,
do_sample=do_sample,
stopping_criteria=self.stopping_criteria,
repetition_penalty=float(repetition_penalty),
max_length=max_output_tokens,
bos_token_id=self.llm_model.internlm_tokenizer.bos_token_id,
eos_token_id=self.llm_model.internlm_tokenizer.eos_token_id,
pad_token_id=self.llm_model.internlm_tokenizer.pad_token_id,
)
state.messages[-1][-1] = "▌"
with self.generate_with_streaming(**generate_params) as generator:
for output in generator:
decoded_output = self.llm_model.internlm_tokenizer.decode(
output[1:])
if output[-1] in [
self.llm_model.internlm_tokenizer.eos_token_id, 333, 497
]:
break
state.messages[-1][-1] = decoded_output + "▌"
yield (state,
state.to_gradio_chatbot()) + (disable_btn, ) * 2
time.sleep(0.03)
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn, ) * 2
return
else:
outputs = self.llm_model.internlm_model.generate(
inputs_embeds=embs,
max_new_tokens=max_output_tokens,
stopping_criteria=self.stopping_criteria,
num_beams=num_beams,
#temperature=float(temperature),
do_sample=do_sample,
repetition_penalty=float(repetition_penalty),
bos_token_id=self.llm_model.internlm_tokenizer.bos_token_id,
eos_token_id=self.llm_model.internlm_tokenizer.eos_token_id,
pad_token_id=self.llm_model.internlm_tokenizer.pad_token_id,
)
output_token = outputs[0]
if output_token[0] == 0:
output_token = output_token[1:]
output_text = self.llm_model.internlm_tokenizer.decode(
output_token, add_special_tokens=False)
print(output_text)
output_text = output_text.split('<TOKENS_UNUSED_1>')[
0] # remove the stop sign '###'
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.replace("<s>", "")
state.messages[-1][1] = output_text
return (state, state.to_gradio_chatbot()) + (enable_btn, ) * 2
def clear_answer(self, state):
state.messages[-1][-1] = None
return (state, state.to_gradio_chatbot())
def chat_clear_history(self):
state = CONV_VISION_7132_v2.copy()
return (state, [], state.to_gradio_chatbot(), "",
None) + (disable_btn, ) * 2
def load_demo():
state = CONV_VISION_7132_v2.copy()
return (state, [], gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True), gr.Button.update(visible=True),
gr.Row.update(visible=True), gr.Accordion.update(visible=True))
def change_language(lang):
if lang == '中文':
lang_btn = gr.update(value='English')
title = gr.update(label='根据给定标题写一个图文并茂的文章:')
btn = gr.update(value='生成')
parameter_article = gr.update(label='高级设置')
beam = gr.update(label='集束大小')
repetition = gr.update(label='重复惩罚')
text_num = gr.update(label='最多输出字数')
msi = gr.update(label='模型选图')
random = gr.update(label='采样')
img_num = gr.update(label='生成文章后,可选择全文配图数量')
adjust_btn = gr.update(value='固定数量配图')
cap_searchs, editers = [], []
for _ in range(max_section):
cap_searchs.append(gr.update(value='搜索'))
editers.append(gr.update(label='编辑'))
save_btn = gr.update(value='文章下载')
save_file = gr.update(label='文章下载')
parameter_chat = gr.update(label='参数')
chat_text_num = gr.update(label='最多输出字数')
chat_beam = gr.update(label='集束大小')
chat_repetition = gr.update(label='重复惩罚')
chat_random = gr.update(label='采样')
chat_textbox = gr.update(placeholder='输入聊天内容并回车')
submit_btn = gr.update(value='提交')
regenerate_btn = gr.update(value='🔄 重新生成')
clear_btn = gr.update(value='🗑️ 清空聊天框')
elif lang == 'English':
lang_btn = gr.update(value='中文')
title = gr.update(
label='Write an illustrated article based on the given title:')
btn = gr.update(value='Submit')
parameter_article = gr.update(label='Advanced Settings')
beam = gr.update(label='Beam Size')
repetition = gr.update(label='Repetition_penalty')
text_num = gr.update(label='Max output tokens')
msi = gr.update(label='Model selects images')
random = gr.update(label='Do_sample')
img_num = gr.update(
label=
'Select the number of the inserted image after article generation.'
)
adjust_btn = gr.update(value='Insert a fixed number of images')
cap_searchs, editers = [], []
for _ in range(max_section):
cap_searchs.append(gr.update(value='Search'))
editers.append(gr.update(label='edit'))
save_btn = gr.update(value='Save article')
save_file = gr.update(label='Save article')
parameter_chat = gr.update(label='Parameters')
chat_text_num = gr.update(label='Max output tokens')
chat_beam = gr.update(label='Beam Size')
chat_repetition = gr.update(label='Repetition_penalty')
chat_random = gr.update(label='Do_sample')
chat_textbox = gr.update(placeholder='Enter text and press ENTER')
submit_btn = gr.update(value='Submit')
regenerate_btn = gr.update(value='🔄 Regenerate')
clear_btn = gr.update(value='🗑️ Clear history')
return [lang_btn, title, btn, parameter_article, beam, repetition, text_num, msi, random, img_num, adjust_btn] +\
cap_searchs + editers + [save_btn, save_file] +[parameter_chat, chat_text_num, chat_beam, chat_repetition, chat_random] + \
[chat_textbox, submit_btn, regenerate_btn, clear_btn]
parser = argparse.ArgumentParser()
parser.add_argument("--folder", default='internlm/internlm-xcomposer-7b')
parser.add_argument("--private", default=False, action='store_true')
args = parser.parse_args()
demo_ui = Demo_UI(args.folder)
with gr.Blocks(css=custom_css, title='浦语·灵笔 (InternLM-XComposer)') as demo:
with gr.Row():
with gr.Column(scale=20):
#gr.HTML("""<h1 align="center" id="space-title" style="font-size:35px;">🤗 浦语·灵笔 (InternLM-XComposer)</h1>""")
gr.HTML(
"""<h1 align="center"><img src="https://raw.githubusercontent.com/panzhang0212/interleaved_io/main/logo.png", alt="InternLM-XComposer" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>"""
)
with gr.Column(scale=1, min_width=100):
lang_btn = gr.Button("中文")
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("📝 创作图文并茂文章 (Write Interleaved-text-image Article)"):
with gr.Row():
title = gr.Textbox(
label=
'Write an illustrated article based on the given title:',
scale=2)
btn = gr.Button("Submit", scale=1)
with gr.Row():
img_num = gr.Slider(
minimum=1.0,
maximum=30.0,
value=5.0,
step=1.0,
scale=2,
label=
'Select the number of the inserted image after article generation.'
)
adjust_btn = gr.Button('Insert a fixed number of images',
interactive=False,
scale=1)
with gr.Row():
with gr.Column(scale=1):
with gr.Accordion("Advanced Settings",
open=False,
visible=True) as parameter_article:
beam = gr.Slider(minimum=1.0,
maximum=6.0,
value=5.0,
step=1.0,
label='Beam Size')
repetition = gr.Slider(minimum=0.0,
maximum=10.0,
value=5.0,
step=0.1,
label='Repetition_penalty')
text_num = gr.Slider(minimum=100.0,
maximum=2000.0,
value=1000.0,
step=1.0,
label='Max output tokens')
msi = gr.Checkbox(value=True,
label='Model selects images')
random = gr.Checkbox(label='Do_sample')
with gr.Column(scale=1):
gr.Examples(
examples=[["又见敦煌"], ["星链新闻稿"], ["如何养好一只宠物"],
["Shanghai Travel Guide in English"], ["Travel guidance of London in English"], ["Advertising for Genshin Impact in English"]],
inputs=[title],
)
articles = []
gallerys = []
add_delete_btns = []
cap_textboxs = []
cap_searchs = []
editers = []
with gr.Column():
for i in range(max_section):
with gr.Row():
visible = True if i == 0 else False
with gr.Column(scale=2):
article = gr.Markdown(visible=visible,
elem_classes='feedback')
articles.append(article)
with gr.Column(scale=1):
with gr.Accordion('edit',
open=False,
visible=False) as editer:
with gr.Row():
cap_textbox = gr.Textbox(show_label=False,
interactive=True,
scale=6,
visible=False)
cap_search = gr.Button(value="Search",
visible=False,
scale=1)
with gr.Row():
gallery = gr.Gallery(visible=False,
columns=2,
height='auto')
add_delete_btn = gr.Button(visible=False)
gallery.select(demo_ui.replace_image, [
articles[i],
gr.Number(value=i, visible=False)
], articles[i])
gallerys.append(gallery)
add_delete_btns.append(add_delete_btn)
cap_textboxs.append(cap_textbox)
cap_searchs.append(cap_search)
editers.append(editer)
save_btn = gr.Button("Save article")
save_file = gr.File(label="Save article")
for i in range(max_section):
add_delete_btns[i].click(demo_ui.add_delete_image,
inputs=[
articles[i],
add_delete_btns[i],
gr.Number(value=i,
visible=False)
],
outputs=[
articles[i], gallerys[i],
add_delete_btns[i],
cap_textboxs[i],
cap_searchs[i]
])
cap_searchs[i].click(demo_ui.search_image,
inputs=[
cap_textboxs[i],
gr.Number(value=i, visible=False)
],
outputs=gallerys[i])
btn.click(
demo_ui.generate_article,
inputs=[title, beam, repetition, text_num, msi, random],
outputs=articles + gallerys + add_delete_btns +
cap_textboxs + cap_searchs + editers + [btn, adjust_btn])
# cap_btn.click(demo_ui.add_delete_caption, inputs=None, outputs=articles)
save_btn.click(demo_ui.save, inputs=None, outputs=save_file)
adjust_btn.click(demo_ui.adjust_img,
inputs=img_num,
outputs=articles + gallerys +
add_delete_btns + cap_textboxs + cap_searchs +
editers)
with gr.TabItem("💬 多模态对话 (Multimodal Chat)", elem_id="chat", id=0):
chat_state = gr.State()
img_list = gr.State()
with gr.Row():
with gr.Column(scale=3):
imagebox = gr.Image(type="pil")
with gr.Accordion("Parameters", open=True,
visible=False) as parameter_row:
chat_max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
step=64,
interactive=True,
label="Max output tokens",
)
chat_num_beams = gr.Slider(
minimum=1,
maximum=5,
value=3,
step=1,
interactive=True,
label="Beam Size",
)
chat_repetition_penalty = gr.Slider(
minimum=1,
maximum=5,
value=1,
step=0.1,
interactive=True,
label="Repetition_penalty",
)
# chat_temperature = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, interactive=True,
# label="Temperature", )
chat_do_sample = gr.Checkbox(interactive=True,
value=True,
label="Do_sample")
with gr.Column(scale=6):
chatbot = grChatbot(elem_id="chatbot",
visible=False,
height=750)
with gr.Row():
with gr.Column(scale=8):
chat_textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press ENTER",
visible=False).style(container=False)
with gr.Column(scale=1, min_width=60):
submit_btn = gr.Button(value="Submit",
visible=False)
with gr.Row(visible=True) as button_row:
regenerate_btn = gr.Button(value="🔄 Regenerate",
interactive=False)
clear_btn = gr.Button(value="🗑️ Clear history",
interactive=False)
btn_list = [regenerate_btn, clear_btn]
parameter_list = [
chat_max_output_tokens, chat_repetition_penalty,
chat_num_beams, chat_do_sample
]
chat_textbox.submit(
demo_ui.chat_ask,
[chat_state, img_list, chat_textbox, imagebox],
[chat_state, img_list, chatbot, chat_textbox, imagebox] +
btn_list).then(demo_ui.chat_answer,
[chat_state, img_list] + parameter_list,
[chat_state, chatbot] + btn_list)
submit_btn.click(
demo_ui.chat_ask,
[chat_state, img_list, chat_textbox, imagebox],
[chat_state, img_list, chatbot, chat_textbox, imagebox] +
btn_list).then(demo_ui.chat_answer,
[chat_state, img_list] + parameter_list,
[chat_state, chatbot] + btn_list)
regenerate_btn.click(demo_ui.clear_answer, chat_state,
[chat_state, chatbot]).then(
demo_ui.chat_answer,
[chat_state, img_list] + parameter_list,
[chat_state, chatbot] + btn_list)
clear_btn.click(
demo_ui.chat_clear_history, None,
[chat_state, img_list, chatbot, chat_textbox, imagebox] +
btn_list)
demo.load(load_demo, None, [
chat_state, img_list, chatbot, chat_textbox, submit_btn,
parameter_row
])
lang_btn.click(change_language, inputs=lang_btn, outputs=[lang_btn, title, btn, parameter_article] +\
[beam, repetition, text_num, msi, random, img_num, adjust_btn] + cap_searchs + editers +\
[save_btn, save_file] + [parameter_row, chat_max_output_tokens, chat_num_beams, chat_repetition_penalty, chat_do_sample] +\
[chat_textbox, submit_btn, regenerate_btn, clear_btn])
demo.queue(concurrency_count=8, status_update_rate=10, api_open=False)
if __name__ == "__main__":
demo.launch()