import os import re import sys sys.path.insert(0, '.') sys.path.insert(0, '..') import argparse import gradio as gr os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), 'tmp') import copy import time import shutil import requests from PIL import Image, ImageFile import torch import transformers from transformers import StoppingCriteriaList, AutoTokenizer, AutoModel ImageFile.LOAD_TRUNCATED_IMAGES = True from demo_asset.assets.css_html_js import custom_css from demo_asset.gradio_patch import Chatbot as grChatbot from demo_asset.serve_utils import Stream, Iteratorize from demo_asset.conversation import CONV_VISION_7132_v2, StoppingCriteriaSub from demo_asset.download import download_image_thread max_section = 60 no_change_btn = gr.Button.update() disable_btn = gr.Button.update(interactive=False) enable_btn = gr.Button.update(interactive=True) chat_stream_output = True article_stream_output = True def get_urls(caption, exclude): headers = {'Content-Type': 'application/json'} json_data = {'caption': caption, 'exclude': exclude, 'need_idxs': True} response = requests.post('https://lingbi.openxlab.org.cn/image/similar', headers=headers, json=json_data) urls = response.json()['data']['image_urls'] idx = response.json()['data']['indices'] return urls, idx class Demo_UI: def __init__(self, folder): self.llm_model = AutoModel.from_pretrained(folder, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(folder, trust_remote_code=True) self.llm_model.internlm_tokenizer = tokenizer self.llm_model.tokenizer = tokenizer self.llm_model.eval().to('cuda') self.device = 'cuda' print(f" load model done: ", type(self.llm_model)) self.eoh = self.llm_model.internlm_tokenizer.decode( torch.Tensor([103027]), skip_special_tokens=True) self.eoa = self.llm_model.internlm_tokenizer.decode( torch.Tensor([103028]), skip_special_tokens=True) self.soi_id = len(tokenizer) - 1 self.soi_token = '' self.vis_processor = self.llm_model.vis_processor self.device = 'cuda' stop_words_ids = [ torch.tensor([943]).to(self.device), torch.tensor([2917, 44930]).to(self.device), torch.tensor([45623]).to(self.device), ### new setting torch.tensor([46323]).to(self.device), ### new setting torch.tensor([103027]).to(self.device), ### new setting torch.tensor([103028]).to(self.device), ### new setting ] self.stopping_criteria = StoppingCriteriaList( [StoppingCriteriaSub(stops=stop_words_ids)]) self.r2 = re.compile(r'') self.max_txt_len = 1680 def reset(self): self.output_text = '' self.caps = {} self.show_caps = False self.show_ids = {} def get_images_xlab(self, caption, loc, exclude): urls, idxs = get_urls(caption.strip()[:53], exclude) print(urls[0]) print('download image with url') download_image_thread(urls, folder='articles/' + self.title, index=self.show_ids[loc] * 1000 + loc, num_processes=4) print('image downloaded') return idxs def generate(self, text, random, beam, max_length, repetition): input_tokens = self.llm_model.internlm_tokenizer( text, return_tensors="pt", add_special_tokens=True).to(self.llm_model.device) img_embeds = self.llm_model.internlm_model.model.embed_tokens( input_tokens.input_ids) with torch.no_grad(): with self.llm_model.maybe_autocast(): outputs = self.llm_model.internlm_model.generate( inputs_embeds=img_embeds, stopping_criteria=self.stopping_criteria, do_sample=random, num_beams=beam, max_length=max_length, repetition_penalty=float(repetition), ) output_text = self.llm_model.internlm_tokenizer.decode( outputs[0][1:], add_special_tokens=False) output_text = output_text.split('')[0] return output_text def generate_text(self, title, beam, repetition, text_num, random): text = ' <|User|>:根据给定标题写一个图文并茂,不重复的文章:{}\n'.format( title) + self.eoh + ' <|Bot|>:' print('random generate:{}'.format(random)) output_text = self.generate(text, random, beam, text_num, repetition) return output_text def generate_loc(self, text_sections, image_num, progress): full_txt = ''.join(text_sections) input_text = f' <|User|>:给定文章"{full_txt}" 根据上述文章,选择适合插入图像的{image_num}行' + ' \n <|Bot|>:适合插入图像的行是' for _ in progress.tqdm([1], desc="image spotting"): output_text = self.generate(input_text, random=False, beam=5, max_length=300, repetition=1.) inject_text = '适合插入图像的行是' + output_text print(inject_text) locs = [] for m in self.r2.findall(inject_text): locs.append(int(m[4:-1])) print(locs) return inject_text, locs def generate_cap(self, text_sections, pos, progress): pasts = '' caps = {} for idx, po in progress.tqdm(enumerate(pos), desc="image captioning"): full_txt = ''.join(text_sections[:po + 2]) if idx > 0: past = pasts[:-2] + '。' else: past = pasts input_text = f' <|User|>: 给定文章"{full_txt}" {past}给出适合在后插入的图像对应的标题。' + ' \n <|Bot|>: 标题是"' cap_text = self.generate(input_text, random=False, beam=1, max_length=100, repetition=5.) cap_text = cap_text.split('"')[0].strip() print(cap_text) caps[po] = cap_text if idx == 0: pasts = f'现在后插入图像对应的标题是"{cap_text}", ' else: pasts += f'后插入图像对应的标题是"{cap_text}", ' print(caps) return caps def generate_loc_cap(self, text_sections, image_num, progress): inject_text, locs = self.generate_loc(text_sections, image_num, progress) caps = self.generate_cap(text_sections, locs, progress) return caps def interleav_wrap(self, img_embeds, text): batch_size = img_embeds.shape[0] im_len = img_embeds.shape[1] text = text[0] text = text.replace('', '') text = text.replace('', '') parts = text.split('') assert batch_size + 1 == len(parts) warp_tokens = [] warp_embeds = [] warp_attns = [] soi = (torch.ones([1, 1]) * self.soi_id).long().to(img_embeds.device) soi_embeds = self.llm_model.internlm_model.model.embed_tokens(soi) temp_len = 0 for idx, part in enumerate(parts): if len(part) > 0: part_tokens = self.llm_model.internlm_tokenizer( part, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) part_embeds = self.llm_model.internlm_model.model.embed_tokens( part_tokens.input_ids) warp_tokens.append(part_tokens.input_ids) warp_embeds.append(part_embeds) temp_len += part_embeds.shape[1] if idx < batch_size: warp_tokens.append(soi.expand(-1, img_embeds[idx].shape[0])) # warp_tokens.append(soi.expand(-1, img_embeds[idx].shape[0] + 1)) # warp_embeds.append(soi_embeds) ### 1, 1, C warp_embeds.append(img_embeds[idx].unsqueeze(0)) ### 1, 34, C temp_len += im_len if temp_len > self.max_txt_len: break warp_embeds = torch.cat(warp_embeds, dim=1) return warp_embeds[:, :self.max_txt_len].to(img_embeds.device) def align_text(self, samples): text_new = [] text = [t + self.eoa + ' ' for t in samples["text_input"]] for i in range(len(text)): temp = text[i] temp = temp.replace('###Human', '<|User|>') temp = temp.replace('### Human', '<|User|>') temp = temp.replace('<|User|> :', '<|User|>:') temp = temp.replace('<|User|>: ', '<|User|>:') temp = temp.replace('<|User|>', ' <|User|>') temp = temp.replace('###Assistant', '<|Bot|>') temp = temp.replace('### Assistant', '<|Bot|>') temp = temp.replace('<|Bot|> :', '<|Bot|>:') temp = temp.replace('<|Bot|>: ', '<|Bot|>:') temp = temp.replace('<|Bot|>', self.eoh + ' <|Bot|>') if temp.find('<|User|>') > temp.find('<|Bot|>'): temp = temp.replace(' <|User|>', self.eoa + ' <|User|>') text_new.append(temp) #print (temp) return text_new def model_select_image(self, output_text, caps, root, progress): print('model_select_image') pre_text = '' pre_img = [] pre_text_list = [] ans2idx = {'A': 0, 'B': 1, 'C': 2, 'D': 3} selected = {k: 0 for k in caps.keys()} for i, text in enumerate(output_text.split('\n')): pre_text += text + '\n' if i in caps: images = copy.deepcopy(pre_img) for j in range(4): image = Image.open( os.path.join( root, f'temp_{self.show_ids[i] * 1000 + i}_{j}.png' )).convert("RGB") image = self.vis_processor(image) images.append(image) images = torch.stack(images, dim=0) pre_text_list.append(pre_text) pre_text = '' images = images.cuda() instruct = ' <|User|>:根据给定上下文和候选图像,选择合适的配图:' input_text = ''.join( pre_text_list ) + '\n\n候选图像包括: A.\nB.\nC.\nD.\n\n <|Bot|>:最合适的图是' input_text = instruct + input_text samples = {} samples['text_input'] = [input_text] self.llm_model.debug_flag = 0 with torch.no_grad(): with torch.cuda.amp.autocast(): img_embeds = self.llm_model.encode_img(images) input_text = self.align_text(samples) img_embeds = self.interleav_wrap( img_embeds, input_text) bos = torch.ones( [1, 1]) * self.llm_model.internlm_tokenizer.bos_token_id bos = bos.long().to(images.device) meta_embeds = self.llm_model.internlm_model.model.embed_tokens( bos) inputs_embeds = torch.cat([meta_embeds, img_embeds], dim=1) with torch.cuda.amp.autocast(): outputs = self.llm_model.internlm_model.generate( inputs_embeds=inputs_embeds[:, :-2], do_sample=False, num_beams=5, max_length=10, repetition_penalty=1., ) out_text = self.llm_model.internlm_tokenizer.decode( outputs[0][1:], add_special_tokens=False) try: answer = out_text[1] if out_text[0] == ' ' else out_text[0] pre_img.append(images[len(pre_img) + ans2idx[answer]].cpu()) except: print('Select fail, use first image') answer = 'A' pre_img.append(images[len(pre_img) + ans2idx[answer]].cpu()) selected[i] = ans2idx[answer] return selected def show_md(self, text_sections, title, caps, selected, show_cap=False): md_shows = [] ga_shows = [] btn_shows = [] cap_textboxs, cap_searchs = [], [] editers = [] for i in range(len(text_sections)): if i in caps: if show_cap: md = text_sections[ i] + '\n' + '
{}
'.format( title, self.show_ids[i] * 1000 + i, selected[i], caps[i]) else: md = text_sections[ i] + '\n' + '
'.format( title, self.show_ids[i] * 1000 + i, selected[i]) img_list = [('articles/{}/temp_{}_{}.png'.format( title, self.show_ids[i] * 1000 + i, j), 'articles/{}/temp_{}_{}.png'.format( title, self.show_ids[i] * 1000 + i, j)) for j in range(4)] ga_show = gr.Gallery.update(visible=True, value=img_list) ga_shows.append(ga_show) btn_show = gr.Button.update(visible=True, value='\U0001f5d1\uFE0F') cap_textboxs.append( gr.Textbox.update(visible=True, value=caps[i])) cap_searchs.append(gr.Button.update(visible=True)) else: md = text_sections[i] ga_show = gr.Gallery.update(visible=False, value=[]) ga_shows.append(ga_show) btn_show = gr.Button.update(visible=True, value='\u2795') cap_textboxs.append(gr.Textbox.update(visible=False)) cap_searchs.append(gr.Button.update(visible=False)) md_show = gr.Markdown.update(visible=True, value=md) md_shows.append(md_show) btn_shows.append(btn_show) editers.append(gr.update(visible=True)) print(i, md) md_hides = [] ga_hides = [] btn_hides = [] for i in range(max_section - len(text_sections)): md_hide = gr.Markdown.update(visible=False, value='') md_hides.append(md_hide) btn_hide = gr.Button.update(visible=False) btn_hides.append(btn_hide) editers.append(gr.update(visible=False)) for i in range(max_section - len(ga_shows)): ga_hide = gr.Gallery.update(visible=False, value=[]) ga_hides.append(ga_hide) cap_textboxs.append(gr.Textbox.update(visible=False)) cap_searchs.append(gr.Button.update(visible=False)) return md_shows + md_hides + ga_shows + ga_hides + btn_shows + btn_hides + cap_textboxs + cap_searchs + editers, md_shows def generate_article(self, title, beam, repetition, text_num, msi, random, progress=gr.Progress()): self.reset() self.title = title if article_stream_output: text = ' <|User|>:根据给定标题写一个图文并茂,不重复的文章:{}\n'.format( title) + self.eoh + ' <|Bot|>:' input_tokens = self.llm_model.internlm_tokenizer( text, return_tensors="pt", add_special_tokens=True).to(self.llm_model.device) img_embeds = self.llm_model.internlm_model.model.embed_tokens( input_tokens.input_ids) generate_params = dict( inputs_embeds=img_embeds, num_beams=beam, do_sample=random, stopping_criteria=self.stopping_criteria, repetition_penalty=float(repetition), max_length=text_num, bos_token_id=self.llm_model.internlm_tokenizer.bos_token_id, eos_token_id=self.llm_model.internlm_tokenizer.eos_token_id, pad_token_id=self.llm_model.internlm_tokenizer.pad_token_id, ) output_text = "▌" with self.generate_with_streaming(**generate_params) as generator: for output in generator: decoded_output = self.llm_model.internlm_tokenizer.decode( output[1:]) if output[-1] in [ self.llm_model.internlm_tokenizer.eos_token_id ]: break output_text = decoded_output.replace('\n', '\n\n') + "▌" yield (output_text,) + (gr.Markdown.update(visible=False),) * (max_section - 1) + (gr.Gallery.update(visible=False),) * max_section + \ (gr.Button.update(visible=False),) * max_section + (gr.Textbox.update(visible=False),) * max_section + (gr.Button.update(visible=False),) * max_section + \ (gr.update(visible=False),) * max_section + (disable_btn,) * 2 time.sleep(0.03) output_text = output_text[:-1] yield (output_text,) + (gr.Markdown.update(visible=False),) * (max_section - 1) + (gr.Gallery.update(visible=False),) * max_section + \ (gr.Button.update(visible=False),) * max_section + (gr.Textbox.update(visible=False),) * max_section + (gr.Button.update(visible=False),) * max_section +\ (gr.update(visible=False),) * max_section + (disable_btn,) * 2 else: output_text = self.generate_text(title, beam, repetition, text_num, random) print(output_text) output_text = re.sub(r'(\n[ \t]*)+', '\n', output_text) if output_text[-1] == '\n': output_text = output_text[:-1] print(output_text) output_text = '\n'.join(output_text.split('\n')[:max_section]) text_sections = output_text.split('\n') idx_text_sections = [ f'' + ' ' + it + '\n' for i, it in enumerate(text_sections) ] caps = self.generate_loc_cap(idx_text_sections, '', progress) #caps = {0: '成都的三日游路线图,包括春熙路、太古里、IFS国金中心、大慈寺、宽窄巷子、奎星楼街、九眼桥(酒吧一条街)、武侯祠、锦里、杜甫草堂、浣花溪公园、青羊宫、金沙遗址博物馆、文殊院、人民公园、熊猫基地、望江楼公园、东郊记忆、建设路小吃街、电子科大清水河校区、三圣乡万福花卉市场、龙湖滨江天街购物广场和返程。', 2: '春熙路的繁华景象,各种时尚潮流的品牌店和美食餐厅鳞次栉比。', 4: 'IFS国金中心的豪华购物中心,拥有众多国际知名品牌的旗舰店和专卖店,同时还有电影院、健身房 配套设施。', 6: '春熙路上的著名景点——太古里,是一个集购物、餐饮、娱乐于一体的高端时尚街区,也是成都著名的网红打卡地之一。', 8: '大慈寺的外观,是一座历史悠久的佛教寺庙,始建于唐朝,有着深厚的文化底蕴和历史价值。'} #self.show_ids = {k:0 for k in caps.keys()} self.show_ids = {k: 1 for k in caps.keys()} print(caps) self.ex_idxs = [] for loc, cap in progress.tqdm(caps.items(), desc="download image"): #self.show_ids[loc] += 1 idxs = self.get_images_xlab(cap, loc, self.ex_idxs) self.ex_idxs.extend(idxs) if msi: self.selected = self.model_select_image(output_text, caps, 'articles/' + title, progress) else: self.selected = {k: 0 for k in caps.keys()} components, md_shows = self.show_md(text_sections, title, caps, self.selected) self.show_caps = False self.output_text = output_text self.caps = caps if article_stream_output: yield components + [enable_btn] * 2 else: return components + [enable_btn] * 2 def adjust_img(self, img_num, progress=gr.Progress()): text_sections = self.output_text.split('\n') idx_text_sections = [ f'' + ' ' + it + '\n' for i, it in enumerate(text_sections) ] img_num = min(img_num, len(text_sections)) caps = self.generate_loc_cap(idx_text_sections, int(img_num), progress) #caps = {1:'318川藏线沿途的风景照片', 4:'泸定桥的全景照片', 6:'折多山垭口的全景照片', 8:'稻城亚丁机场的全景照片', 10:'姊妹湖的全景照片'} print(caps) sidxs = [] for loc, cap in caps.items(): if loc in self.show_ids: self.show_ids[loc] += 1 else: self.show_ids[loc] = 1 idxs = self.get_images_xlab(cap, loc, sidxs) sidxs.extend(idxs) self.sidxs = sidxs self.selected = {k: 0 for k in caps.keys()} components, md_shows = self.show_md(text_sections, self.title, caps, self.selected) self.caps = caps return components def add_delete_image(self, text, status, index): index = int(index) if status == '\U0001f5d1\uFE0F': if index in self.caps: self.caps.pop(index) self.selected.pop(index) md_show = gr.Markdown.update(value=text.split('\n')[0]) gallery = gr.Gallery.update(visible=False, value=[]) btn_show = gr.Button.update(value='\u2795') cap_textbox = gr.Textbox.update(visible=False) cap_search = gr.Button.update(visible=False) else: md_show = gr.Markdown.update() gallery = gr.Gallery.update(visible=True, value=[]) btn_show = gr.Button.update(value='\U0001f5d1\uFE0F') cap_textbox = gr.Textbox.update(visible=True) cap_search = gr.Button.update(visible=True) return md_show, gallery, btn_show, cap_textbox, cap_search def search_image(self, text, index): index = int(index) if text == '': return gr.Gallery.update() if index in self.show_ids: self.show_ids[index] += 1 else: self.show_ids[index] = 1 self.caps[index] = text idxs = self.get_images_xlab(text, index, self.ex_idxs) self.ex_idxs.extend(idxs) img_list = [('articles/{}/temp_{}_{}.png'.format( self.title, self.show_ids[index] * 1000 + index, j), 'articles/{}/temp_{}_{}.png'.format( self.title, self.show_ids[index] * 1000 + index, j)) for j in range(4)] ga_show = gr.Gallery.update(visible=True, value=img_list) return ga_show def replace_image(self, article, index, evt: gr.SelectData): index = int(index) self.selected[index] = evt.index if '
' in article: return re.sub(r'file=.*.png', 'file={}'.format(evt.value), article) else: return article + '\n' + '
'.format( evt.value) def add_delete_caption(self): self.show_caps = False if self.show_caps else True text_sections = self.output_text.split('\n') components, _ = self.show_md(text_sections, self.title, self.caps, selected=self.selected, show_cap=self.show_caps) return components def save(self): folder = 'save_articles/' + self.title if os.path.exists(folder): for item in os.listdir(folder): os.remove(os.path.join(folder, item)) os.makedirs(folder, exist_ok=True) save_text = '' count = 0 if len(self.output_text) > 0: text_sections = self.output_text.split('\n') for i in range(len(text_sections)): if i in self.caps: if self.show_caps: md = text_sections[ i] + '\n' + '
{}
'.format( self.show_ids[i] * 1000 + i, self.selected[i], self.caps[i]) else: md = text_sections[ i] + '\n' + '
'.format( self.show_ids[i] * 1000 + i, self.selected[i]) count += 1 else: md = text_sections[i] save_text += md + '\n\n' save_text = save_text[:-2] with open(os.path.join(folder, 'io.MD'), 'w') as f: f.writelines(save_text) for k in self.caps.keys(): shutil.copy( os.path.join( 'articles', self.title, f'temp_{self.show_ids[k] * 1000 + k}_{self.selected[k]}.png' ), folder) archived = shutil.make_archive(folder, 'zip', folder) return archived def get_context_emb(self, state, img_list): prompt = state.get_prompt() print(prompt) prompt_segs = prompt.split('') assert len(prompt_segs) == len( img_list ) + 1, "Unmatched numbers of image placeholders and images." seg_tokens = [ self.llm_model.internlm_tokenizer(seg, return_tensors="pt", add_special_tokens=i == 0).to( self.device).input_ids for i, seg in enumerate(prompt_segs) ] seg_embs = [ self.llm_model.internlm_model.model.embed_tokens(seg_t) for seg_t in seg_tokens ] mixed_embs = [ emb for pair in zip(seg_embs[:-1], img_list) for emb in pair ] + [seg_embs[-1]] mixed_embs = torch.cat(mixed_embs, dim=1) return mixed_embs def chat_ask(self, state, img_list, text, image): print(1111) state.skip_next = False if len(text) <= 0 and image is None: state.skip_next = True return (state, img_list, state.to_gradio_chatbot(), "", None) + (no_change_btn, ) * 2 if image is not None: image_pt = self.vis_processor(image).unsqueeze(0).to(0) image_emb = self.llm_model.encode_img(image_pt) img_list.append(image_emb) state.append_message(state.roles[0], ["", image]) if len(state.messages) > 0 and state.messages[-1][0] == state.roles[ 0] and isinstance(state.messages[-1][1], list): #state.messages[-1][1] = ' '.join([state.messages[-1][1], text]) state.messages[-1][1][0] = ' '.join( [state.messages[-1][1][0], text]) else: state.append_message(state.roles[0], text) print(state.messages) state.append_message(state.roles[1], None) return (state, img_list, state.to_gradio_chatbot(), "", None) + (disable_btn, ) * 2 def generate_with_callback(self, callback=None, **kwargs): kwargs.setdefault("stopping_criteria", transformers.StoppingCriteriaList()) kwargs["stopping_criteria"].append(Stream(callback_func=callback)) with torch.no_grad(): with self.llm_model.maybe_autocast(): self.llm_model.internlm_model.generate(**kwargs) def generate_with_streaming(self, **kwargs): return Iteratorize(self.generate_with_callback, kwargs, callback=None) def chat_answer(self, state, img_list, max_output_tokens, repetition_penalty, num_beams, do_sample): # text = '图片中是一幅油画,描绘了红军长征的场景。画面中,一群红军战士正在穿过一片草地,他们身后的旗帜在风中飘扬。' # for i in range(len(text)): # state.messages[-1][-1] = text[:i+1] + "▌" # yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2 # state.messages[-1][-1] = text[:i + 1] # yield (state, state.to_gradio_chatbot()) + (enable_btn, ) * 2 # return if state.skip_next: return (state, state.to_gradio_chatbot()) + (no_change_btn, ) * 2 embs = self.get_context_emb(state, img_list) if chat_stream_output: generate_params = dict( inputs_embeds=embs, num_beams=num_beams, do_sample=do_sample, stopping_criteria=self.stopping_criteria, repetition_penalty=float(repetition_penalty), max_length=max_output_tokens, bos_token_id=self.llm_model.internlm_tokenizer.bos_token_id, eos_token_id=self.llm_model.internlm_tokenizer.eos_token_id, pad_token_id=self.llm_model.internlm_tokenizer.pad_token_id, ) state.messages[-1][-1] = "▌" with self.generate_with_streaming(**generate_params) as generator: for output in generator: decoded_output = self.llm_model.internlm_tokenizer.decode( output[1:]) if output[-1] in [ self.llm_model.internlm_tokenizer.eos_token_id, 333, 497 ]: break state.messages[-1][-1] = decoded_output + "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn, ) * 2 time.sleep(0.03) state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot()) + (enable_btn, ) * 2 return else: outputs = self.llm_model.internlm_model.generate( inputs_embeds=embs, max_new_tokens=max_output_tokens, stopping_criteria=self.stopping_criteria, num_beams=num_beams, #temperature=float(temperature), do_sample=do_sample, repetition_penalty=float(repetition_penalty), bos_token_id=self.llm_model.internlm_tokenizer.bos_token_id, eos_token_id=self.llm_model.internlm_tokenizer.eos_token_id, pad_token_id=self.llm_model.internlm_tokenizer.pad_token_id, ) output_token = outputs[0] if output_token[0] == 0: output_token = output_token[1:] output_text = self.llm_model.internlm_tokenizer.decode( output_token, add_special_tokens=False) print(output_text) output_text = output_text.split('')[ 0] # remove the stop sign '###' output_text = output_text.split('Assistant:')[-1].strip() output_text = output_text.replace("", "") state.messages[-1][1] = output_text return (state, state.to_gradio_chatbot()) + (enable_btn, ) * 2 def clear_answer(self, state): state.messages[-1][-1] = None return (state, state.to_gradio_chatbot()) def chat_clear_history(self): state = CONV_VISION_7132_v2.copy() return (state, [], state.to_gradio_chatbot(), "", None) + (disable_btn, ) * 2 def load_demo(): state = CONV_VISION_7132_v2.copy() return (state, [], gr.Chatbot.update(visible=True), gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Row.update(visible=True), gr.Accordion.update(visible=True)) def change_language(lang): if lang == '中文': lang_btn = gr.update(value='English') title = gr.update(label='根据给定标题写一个图文并茂的文章:') btn = gr.update(value='生成') parameter_article = gr.update(label='高级设置') beam = gr.update(label='集束大小') repetition = gr.update(label='重复惩罚') text_num = gr.update(label='最多输出字数') msi = gr.update(label='模型选图') random = gr.update(label='采样') img_num = gr.update(label='生成文章后,可选择全文配图数量') adjust_btn = gr.update(value='固定数量配图') cap_searchs, editers = [], [] for _ in range(max_section): cap_searchs.append(gr.update(value='搜索')) editers.append(gr.update(label='编辑')) save_btn = gr.update(value='文章下载') save_file = gr.update(label='文章下载') parameter_chat = gr.update(label='参数') chat_text_num = gr.update(label='最多输出字数') chat_beam = gr.update(label='集束大小') chat_repetition = gr.update(label='重复惩罚') chat_random = gr.update(label='采样') chat_textbox = gr.update(placeholder='输入聊天内容并回车') submit_btn = gr.update(value='提交') regenerate_btn = gr.update(value='🔄 重新生成') clear_btn = gr.update(value='🗑️ 清空聊天框') elif lang == 'English': lang_btn = gr.update(value='中文') title = gr.update( label='Write an illustrated article based on the given title:') btn = gr.update(value='Submit') parameter_article = gr.update(label='Advanced Settings') beam = gr.update(label='Beam Size') repetition = gr.update(label='Repetition_penalty') text_num = gr.update(label='Max output tokens') msi = gr.update(label='Model selects images') random = gr.update(label='Do_sample') img_num = gr.update( label= 'Select the number of the inserted image after article generation.' ) adjust_btn = gr.update(value='Insert a fixed number of images') cap_searchs, editers = [], [] for _ in range(max_section): cap_searchs.append(gr.update(value='Search')) editers.append(gr.update(label='edit')) save_btn = gr.update(value='Save article') save_file = gr.update(label='Save article') parameter_chat = gr.update(label='Parameters') chat_text_num = gr.update(label='Max output tokens') chat_beam = gr.update(label='Beam Size') chat_repetition = gr.update(label='Repetition_penalty') chat_random = gr.update(label='Do_sample') chat_textbox = gr.update(placeholder='Enter text and press ENTER') submit_btn = gr.update(value='Submit') regenerate_btn = gr.update(value='🔄 Regenerate') clear_btn = gr.update(value='🗑️ Clear history') return [lang_btn, title, btn, parameter_article, beam, repetition, text_num, msi, random, img_num, adjust_btn] +\ cap_searchs + editers + [save_btn, save_file] +[parameter_chat, chat_text_num, chat_beam, chat_repetition, chat_random] + \ [chat_textbox, submit_btn, regenerate_btn, clear_btn] parser = argparse.ArgumentParser() parser.add_argument("--folder", default='internlm/internlm-xcomposer-7b') parser.add_argument("--private", default=False, action='store_true') args = parser.parse_args() demo_ui = Demo_UI(args.folder) with gr.Blocks(css=custom_css, title='浦语·灵笔 (InternLM-XComposer)') as demo: with gr.Row(): with gr.Column(scale=20): #gr.HTML("""

🤗 浦语·灵笔 (InternLM-XComposer)

""") gr.HTML( """

InternLM-XComposer

""" ) with gr.Column(scale=1, min_width=100): lang_btn = gr.Button("中文") with gr.Tabs(elem_classes="tab-buttons") as tabs: with gr.TabItem("📝 创作图文并茂文章 (Write Interleaved-text-image Article)"): with gr.Row(): title = gr.Textbox( label= 'Write an illustrated article based on the given title:', scale=2) btn = gr.Button("Submit", scale=1) with gr.Row(): img_num = gr.Slider( minimum=1.0, maximum=30.0, value=5.0, step=1.0, scale=2, label= 'Select the number of the inserted image after article generation.' ) adjust_btn = gr.Button('Insert a fixed number of images', interactive=False, scale=1) with gr.Row(): with gr.Column(scale=1): with gr.Accordion("Advanced Settings", open=False, visible=True) as parameter_article: beam = gr.Slider(minimum=1.0, maximum=6.0, value=5.0, step=1.0, label='Beam Size') repetition = gr.Slider(minimum=0.0, maximum=10.0, value=5.0, step=0.1, label='Repetition_penalty') text_num = gr.Slider(minimum=100.0, maximum=2000.0, value=1000.0, step=1.0, label='Max output tokens') msi = gr.Checkbox(value=True, label='Model selects images') random = gr.Checkbox(label='Do_sample') with gr.Column(scale=1): gr.Examples( examples=[["又见敦煌"], ["星链新闻稿"], ["如何养好一只宠物"], ["Shanghai Travel Guide in English"], ["Travel guidance of London in English"], ["Advertising for Genshin Impact in English"]], inputs=[title], ) articles = [] gallerys = [] add_delete_btns = [] cap_textboxs = [] cap_searchs = [] editers = [] with gr.Column(): for i in range(max_section): with gr.Row(): visible = True if i == 0 else False with gr.Column(scale=2): article = gr.Markdown(visible=visible, elem_classes='feedback') articles.append(article) with gr.Column(scale=1): with gr.Accordion('edit', open=False, visible=False) as editer: with gr.Row(): cap_textbox = gr.Textbox(show_label=False, interactive=True, scale=6, visible=False) cap_search = gr.Button(value="Search", visible=False, scale=1) with gr.Row(): gallery = gr.Gallery(visible=False, columns=2, height='auto') add_delete_btn = gr.Button(visible=False) gallery.select(demo_ui.replace_image, [ articles[i], gr.Number(value=i, visible=False) ], articles[i]) gallerys.append(gallery) add_delete_btns.append(add_delete_btn) cap_textboxs.append(cap_textbox) cap_searchs.append(cap_search) editers.append(editer) save_btn = gr.Button("Save article") save_file = gr.File(label="Save article") for i in range(max_section): add_delete_btns[i].click(demo_ui.add_delete_image, inputs=[ articles[i], add_delete_btns[i], gr.Number(value=i, visible=False) ], outputs=[ articles[i], gallerys[i], add_delete_btns[i], cap_textboxs[i], cap_searchs[i] ]) cap_searchs[i].click(demo_ui.search_image, inputs=[ cap_textboxs[i], gr.Number(value=i, visible=False) ], outputs=gallerys[i]) btn.click( demo_ui.generate_article, inputs=[title, beam, repetition, text_num, msi, random], outputs=articles + gallerys + add_delete_btns + cap_textboxs + cap_searchs + editers + [btn, adjust_btn]) # cap_btn.click(demo_ui.add_delete_caption, inputs=None, outputs=articles) save_btn.click(demo_ui.save, inputs=None, outputs=save_file) adjust_btn.click(demo_ui.adjust_img, inputs=img_num, outputs=articles + gallerys + add_delete_btns + cap_textboxs + cap_searchs + editers) with gr.TabItem("💬 多模态对话 (Multimodal Chat)", elem_id="chat", id=0): chat_state = gr.State() img_list = gr.State() with gr.Row(): with gr.Column(scale=3): imagebox = gr.Image(type="pil") with gr.Accordion("Parameters", open=True, visible=False) as parameter_row: chat_max_output_tokens = gr.Slider( minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens", ) chat_num_beams = gr.Slider( minimum=1, maximum=5, value=3, step=1, interactive=True, label="Beam Size", ) chat_repetition_penalty = gr.Slider( minimum=1, maximum=5, value=1, step=0.1, interactive=True, label="Repetition_penalty", ) # chat_temperature = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, interactive=True, # label="Temperature", ) chat_do_sample = gr.Checkbox(interactive=True, value=True, label="Do_sample") with gr.Column(scale=6): chatbot = grChatbot(elem_id="chatbot", visible=False, height=750) with gr.Row(): with gr.Column(scale=8): chat_textbox = gr.Textbox( show_label=False, placeholder="Enter text and press ENTER", visible=False).style(container=False) with gr.Column(scale=1, min_width=60): submit_btn = gr.Button(value="Submit", visible=False) with gr.Row(visible=True) as button_row: regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) btn_list = [regenerate_btn, clear_btn] parameter_list = [ chat_max_output_tokens, chat_repetition_penalty, chat_num_beams, chat_do_sample ] chat_textbox.submit( demo_ui.chat_ask, [chat_state, img_list, chat_textbox, imagebox], [chat_state, img_list, chatbot, chat_textbox, imagebox] + btn_list).then(demo_ui.chat_answer, [chat_state, img_list] + parameter_list, [chat_state, chatbot] + btn_list) submit_btn.click( demo_ui.chat_ask, [chat_state, img_list, chat_textbox, imagebox], [chat_state, img_list, chatbot, chat_textbox, imagebox] + btn_list).then(demo_ui.chat_answer, [chat_state, img_list] + parameter_list, [chat_state, chatbot] + btn_list) regenerate_btn.click(demo_ui.clear_answer, chat_state, [chat_state, chatbot]).then( demo_ui.chat_answer, [chat_state, img_list] + parameter_list, [chat_state, chatbot] + btn_list) clear_btn.click( demo_ui.chat_clear_history, None, [chat_state, img_list, chatbot, chat_textbox, imagebox] + btn_list) demo.load(load_demo, None, [ chat_state, img_list, chatbot, chat_textbox, submit_btn, parameter_row ]) lang_btn.click(change_language, inputs=lang_btn, outputs=[lang_btn, title, btn, parameter_article] +\ [beam, repetition, text_num, msi, random, img_num, adjust_btn] + cap_searchs + editers +\ [save_btn, save_file] + [parameter_row, chat_max_output_tokens, chat_num_beams, chat_repetition_penalty, chat_do_sample] +\ [chat_textbox, submit_btn, regenerate_btn, clear_btn]) demo.queue(concurrency_count=8, status_update_rate=10, api_open=False) if __name__ == "__main__": demo.launch()