File size: 6,944 Bytes
9d54f50
4a28811
 
 
 
 
9d54f50
bfc371b
5742893
b09d1ca
 
 
 
 
 
 
 
 
 
 
 
 
4a28811
b09d1ca
 
 
 
 
 
 
 
 
 
4a28811
 
5742893
4a28811
5742893
 
4a28811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr

os.system('git lfs install')
os.system("git clone https://huggingface.co/Pluto0616/L2_InternVL")
import os
import json
import logging
from datetime import datetime

def load_json(file_name: str):
    if isinstance(file_name, str) and file_name.endswith("json"):
        with open(file_name, 'r') as file:
            data = json.load(file)
    else:
        raise ValueError("The file path you passed in is not a json file path.")
    
    return data

def init_logger(outputs_dir):
    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    os.makedirs(os.path.join(outputs_dir, "logs"), exist_ok=True)
    log_path = os.path.join(outputs_dir, "logs", "{}.txt".format(current_time))
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s || %(message)s",
        handlers=[logging.StreamHandler(), logging.FileHandler(log_path)],
    )
    
from demo import ConversationalAgent, CustomTheme

FOOD_EXAMPLES = "./L2_InternVL/demo/food_for_demo.json"
# MODEL_PATH = "/root/share/new_models/OpenGVLab/InternVL2-2B"
MODEL_PATH = "./L2_InternVL/work_dirs/internvl_v2_internlm2_2b_lora_finetune_food/lr35_ep10"
OUTPUT_PATH = "./L2_InternVL/outputs"

def setup_seeds():
    seed = 42

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True


def main():
    setup_seeds()
    # logging
    init_logger(OUTPUT_PATH)
    # food examples
    food_examples = load_json(FOOD_EXAMPLES)
    
    agent = ConversationalAgent(model_path=MODEL_PATH,
                                outputs_dir=OUTPUT_PATH)
    
    theme = CustomTheme()
    
    titles = [
        """<center><B><font face="Comic Sans MS" size=10>书生大模型实战营</font></B></center>"""  ## Kalam:wght@700
        """<center><B><font face="Courier" size=5>「进阶岛」InternVL 多模态模型部署微调实践</font></B></center>"""
    ]
    
    language = """Language: 中文 and English"""
    with gr.Blocks(theme) as demo_chatbot:
        for title in titles:
            gr.Markdown(title)
        # gr.Markdown(article)
        gr.Markdown(language)
        
        with gr.Row():
            with gr.Column(scale=3):
                start_btn = gr.Button("Start Chat", variant="primary", interactive=True)
                clear_btn = gr.Button("Clear Context", interactive=False)
                image = gr.Image(type="pil", interactive=False)
                upload_btn = gr.Button("🖼️ Upload Image", interactive=False)
                
                with gr.Accordion("Generation Settings"):                    
                    top_p = gr.Slider(minimum=0, maximum=1, step=0.1,
                                      value=0.8,
                                      interactive=True,
                                      label='top-p value',
                                      visible=True)
                    
                    temperature = gr.Slider(minimum=0, maximum=1.5, step=0.1,
                                            value=0.8,
                                            interactive=True,
                                            label='temperature',
                                            visible=True)
                    
            with gr.Column(scale=7):
                chat_state = gr.State()
                chatbot = gr.Chatbot(label='InternVL2', height=800, avatar_images=((os.path.join(os.path.dirname(__file__), 'demo/user.png')), (os.path.join(os.path.dirname(__file__), "demo/bot.png"))))
                text_input = gr.Textbox(label='User', placeholder="Please click the <Start Chat> button to start chat!", interactive=False)
                gr.Markdown("### 输入示例")
                def on_text_change(text):
                    return gr.update(interactive=True)
                text_input.change(fn=on_text_change, inputs=text_input, outputs=text_input)
                gr.Examples(
                    examples=[["图片中的食物通常属于哪个菜系?"],
                              ["如果让你简单形容一下品尝图片中的食物的滋味,你会描述它"],
                              ["去哪个地方游玩时应该品尝当地的特色美食图片中的食物?"],
                              ["食用图片中的食物时,一般它上菜或摆盘时的特点是?"]],
                    inputs=[text_input]
                )
        
        with gr.Row():
            gr.Markdown("### 食物快捷栏")
        with gr.Row():
            example_xinjiang_food = gr.Examples(examples=food_examples["新疆菜"], inputs=image, label="新疆菜")
            example_sichuan_food = gr.Examples(examples=food_examples["川菜(四川,重庆)"], inputs=image, label="川菜(四川,重庆)")
            example_xibei_food = gr.Examples(examples=food_examples["西北菜 (陕西,甘肃等地)"], inputs=image, label="西北菜 (陕西,甘肃等地)")
        with gr.Row():
            example_guizhou_food = gr.Examples(examples=food_examples["黔菜 (贵州)"], inputs=image, label="黔菜 (贵州)")
            example_jiangsu_food = gr.Examples(examples=food_examples["苏菜(江苏)"], inputs=image, label="苏菜(江苏)")
            example_guangdong_food = gr.Examples(examples=food_examples["粤菜(广东等地)"], inputs=image, label="粤菜(广东等地)")
        with gr.Row():
            example_hunan_food = gr.Examples(examples=food_examples["湘菜(湖南)"], inputs=image, label="湘菜(湖南)")
            example_fujian_food = gr.Examples(examples=food_examples["闽菜(福建)"], inputs=image, label="闽菜(福建)")
            example_zhejiang_food = gr.Examples(examples=food_examples["浙菜(浙江)"], inputs=image, label="浙菜(浙江)")
        with gr.Row():
            example_dongbei_food = gr.Examples(examples=food_examples["东北菜 (黑龙江等地)"], inputs=image, label="东北菜 (黑龙江等地)")
            
                
        start_btn.click(agent.start_chat, [chat_state], [text_input, start_btn, clear_btn, image, upload_btn, chat_state])
        clear_btn.click(agent.restart_chat, [chat_state], [chatbot, text_input, start_btn, clear_btn, image, upload_btn, chat_state], queue=False)
        upload_btn.click(agent.upload_image, [image, chatbot, chat_state], [image, chatbot, chat_state])
        text_input.submit(
            agent.respond,
            inputs=[text_input, image, chatbot, top_p, temperature, chat_state], 
            outputs=[text_input, image, chatbot, chat_state]
        )

    demo_chatbot.launch(share=True, server_name="127.0.0.1", server_port=1096, allowed_paths=['./'])
    demo_chatbot.queue()
    

if __name__ == "__main__":
    main()