Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	尝试加入jittor本地模型
Browse files- .gitignore +1 -0
 - request_llm/bridge_jittorllms.py +153 -0
 - request_llm/requirements_jittorllms.txt +4 -0
 - request_llm/test_llms.py +26 -0
 
    	
        .gitignore
    CHANGED
    
    | 
         @@ -146,3 +146,4 @@ debug* 
     | 
|
| 146 | 
         
             
            private*
         
     | 
| 147 | 
         
             
            crazy_functions/test_project/pdf_and_word
         
     | 
| 148 | 
         
             
            crazy_functions/test_samples
         
     | 
| 
         | 
| 
         | 
|
| 146 | 
         
             
            private*
         
     | 
| 147 | 
         
             
            crazy_functions/test_project/pdf_and_word
         
     | 
| 148 | 
         
             
            crazy_functions/test_samples
         
     | 
| 149 | 
         
            +
            request_llm/jittorllms
         
     | 
    	
        request_llm/bridge_jittorllms.py
    ADDED
    
    | 
         @@ -0,0 +1,153 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
             
     | 
| 2 | 
         
            +
            from transformers import AutoModel, AutoTokenizer
         
     | 
| 3 | 
         
            +
            import time
         
     | 
| 4 | 
         
            +
            import threading
         
     | 
| 5 | 
         
            +
            import importlib
         
     | 
| 6 | 
         
            +
            from toolbox import update_ui, get_conf
         
     | 
| 7 | 
         
            +
            from multiprocessing import Process, Pipe
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            load_message = "jittorllms尚未加载,加载需要一段时间。注意,取决于`config.py`的配置,jittorllms消耗大量的内存(CPU)或显存(GPU),也许会导致低配计算机卡死 ……"
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            #################################################################################
         
     | 
| 12 | 
         
            +
            class GetGLMHandle(Process):
         
     | 
| 13 | 
         
            +
                def __init__(self):
         
     | 
| 14 | 
         
            +
                    super().__init__(daemon=True)
         
     | 
| 15 | 
         
            +
                    self.parent, self.child = Pipe()
         
     | 
| 16 | 
         
            +
                    self.jittorllms_model = None
         
     | 
| 17 | 
         
            +
                    self.info = ""
         
     | 
| 18 | 
         
            +
                    self.success = True
         
     | 
| 19 | 
         
            +
                    self.check_dependency()
         
     | 
| 20 | 
         
            +
                    self.start()
         
     | 
| 21 | 
         
            +
                    self.threadLock = threading.Lock()
         
     | 
| 22 | 
         
            +
                    
         
     | 
| 23 | 
         
            +
                def check_dependency(self):
         
     | 
| 24 | 
         
            +
                    try:
         
     | 
| 25 | 
         
            +
                        import jittor
         
     | 
| 26 | 
         
            +
                        from .jittorllms.models import get_model
         
     | 
| 27 | 
         
            +
                        self.info = "依赖检测通过"
         
     | 
| 28 | 
         
            +
                        self.success = True
         
     | 
| 29 | 
         
            +
                    except:
         
     | 
| 30 | 
         
            +
                        self.info = r"缺少jittorllms的依赖,如果要使用jittorllms,除了基础的pip依赖以外,您还需要运行`pip install -r request_llm/requirements_jittorllms.txt`"+\
         
     | 
| 31 | 
         
            +
                                    r"和`git clone https://gitlink.org.cn/jittor/JittorLLMs.git --depth 1 request_llm/jittorllms`两个指令来安装jittorllms的依赖(在项目根目录运行这两个指令)。"
         
     | 
| 32 | 
         
            +
                        self.success = False
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                def ready(self):
         
     | 
| 35 | 
         
            +
                    return self.jittorllms_model is not None
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                def run(self):
         
     | 
| 38 | 
         
            +
                    # 子进程执行
         
     | 
| 39 | 
         
            +
                    # 第一次运行,加载参数
         
     | 
| 40 | 
         
            +
                    def load_model():
         
     | 
| 41 | 
         
            +
                        import types
         
     | 
| 42 | 
         
            +
                        try:
         
     | 
| 43 | 
         
            +
                            if self.jittorllms_model is None:
         
     | 
| 44 | 
         
            +
                                device, = get_conf('LOCAL_MODEL_DEVICE')
         
     | 
| 45 | 
         
            +
                                from .jittorllms.models import get_model
         
     | 
| 46 | 
         
            +
                                # availabel_models = ["chatglm", "pangualpha", "llama", "chatrwkv"]
         
     | 
| 47 | 
         
            +
                                args_dict = {'model': 'chatglm', 'RUN_DEVICE':'cpu'}
         
     | 
| 48 | 
         
            +
                                self.jittorllms_model = get_model(types.SimpleNamespace(**args_dict))
         
     | 
| 49 | 
         
            +
                        except:
         
     | 
| 50 | 
         
            +
                            self.child.send('[Local Message] Call jittorllms fail 不能正常加载jittorllms的参数。')
         
     | 
| 51 | 
         
            +
                            raise RuntimeError("不能正常加载jittorllms的参数!")
         
     | 
| 52 | 
         
            +
                    
         
     | 
| 53 | 
         
            +
                    load_model()
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    # 进入任务等待状态
         
     | 
| 56 | 
         
            +
                    while True:
         
     | 
| 57 | 
         
            +
                        # 进入任务等待状态
         
     | 
| 58 | 
         
            +
                        kwargs = self.child.recv()
         
     | 
| 59 | 
         
            +
                        # 收到消息,开始请求
         
     | 
| 60 | 
         
            +
                        try:
         
     | 
| 61 | 
         
            +
                            for response, history in self.jittorllms_model.run_web_demo(kwargs['query'], kwargs['history']):
         
     | 
| 62 | 
         
            +
                                self.child.send(response)
         
     | 
| 63 | 
         
            +
                        except:
         
     | 
| 64 | 
         
            +
                            self.child.send('[Local Message] Call jittorllms fail.')
         
     | 
| 65 | 
         
            +
                        # 请求处理结束,开始下一个循环
         
     | 
| 66 | 
         
            +
                        self.child.send('[Finish]')
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                def stream_chat(self, **kwargs):
         
     | 
| 69 | 
         
            +
                    # 主进程执行
         
     | 
| 70 | 
         
            +
                    self.threadLock.acquire()
         
     | 
| 71 | 
         
            +
                    self.parent.send(kwargs)
         
     | 
| 72 | 
         
            +
                    while True:
         
     | 
| 73 | 
         
            +
                        res = self.parent.recv()
         
     | 
| 74 | 
         
            +
                        if res != '[Finish]':
         
     | 
| 75 | 
         
            +
                            yield res
         
     | 
| 76 | 
         
            +
                        else:
         
     | 
| 77 | 
         
            +
                            break
         
     | 
| 78 | 
         
            +
                    self.threadLock.release()
         
     | 
| 79 | 
         
            +
                
         
     | 
| 80 | 
         
            +
            global glm_handle
         
     | 
| 81 | 
         
            +
            glm_handle = None
         
     | 
| 82 | 
         
            +
            #################################################################################
         
     | 
| 83 | 
         
            +
            def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False):
         
     | 
| 84 | 
         
            +
                """
         
     | 
| 85 | 
         
            +
                    多线程方法
         
     | 
| 86 | 
         
            +
                    函数的说明请见 request_llm/bridge_all.py
         
     | 
| 87 | 
         
            +
                """
         
     | 
| 88 | 
         
            +
                global glm_handle
         
     | 
| 89 | 
         
            +
                if glm_handle is None:
         
     | 
| 90 | 
         
            +
                    glm_handle = GetGLMHandle()
         
     | 
| 91 | 
         
            +
                    if len(observe_window) >= 1: observe_window[0] = load_message + "\n\n" + glm_handle.info
         
     | 
| 92 | 
         
            +
                    if not glm_handle.success: 
         
     | 
| 93 | 
         
            +
                        error = glm_handle.info
         
     | 
| 94 | 
         
            +
                        glm_handle = None
         
     | 
| 95 | 
         
            +
                        raise RuntimeError(error)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                # jittorllms 没有 sys_prompt 接口,因此把prompt加入 history
         
     | 
| 98 | 
         
            +
                history_feedin = []
         
     | 
| 99 | 
         
            +
                history_feedin.append(["What can I do?", sys_prompt])
         
     | 
| 100 | 
         
            +
                for i in range(len(history)//2):
         
     | 
| 101 | 
         
            +
                    history_feedin.append([history[2*i], history[2*i+1]] )
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可
         
     | 
| 104 | 
         
            +
                response = ""
         
     | 
| 105 | 
         
            +
                for response in glm_handle.stream_chat(query=inputs, history=history_feedin, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
         
     | 
| 106 | 
         
            +
                    if len(observe_window) >= 1:  observe_window[0] = response
         
     | 
| 107 | 
         
            +
                    if len(observe_window) >= 2:  
         
     | 
| 108 | 
         
            +
                        if (time.time()-observe_window[1]) > watch_dog_patience:
         
     | 
| 109 | 
         
            +
                            raise RuntimeError("程序终止。")
         
     | 
| 110 | 
         
            +
                return response
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
         
     | 
| 115 | 
         
            +
                """
         
     | 
| 116 | 
         
            +
                    单线程方法
         
     | 
| 117 | 
         
            +
                    函数的说明请见 request_llm/bridge_all.py
         
     | 
| 118 | 
         
            +
                """
         
     | 
| 119 | 
         
            +
                chatbot.append((inputs, ""))
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                global glm_handle
         
     | 
| 122 | 
         
            +
                if glm_handle is None:
         
     | 
| 123 | 
         
            +
                    glm_handle = GetGLMHandle()
         
     | 
| 124 | 
         
            +
                    chatbot[-1] = (inputs, load_message + "\n\n" + glm_handle.info)
         
     | 
| 125 | 
         
            +
                    yield from update_ui(chatbot=chatbot, history=[])
         
     | 
| 126 | 
         
            +
                    if not glm_handle.success: 
         
     | 
| 127 | 
         
            +
                        glm_handle = None
         
     | 
| 128 | 
         
            +
                        return
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                if additional_fn is not None:
         
     | 
| 131 | 
         
            +
                    import core_functional
         
     | 
| 132 | 
         
            +
                    importlib.reload(core_functional)    # 热更新prompt
         
     | 
| 133 | 
         
            +
                    core_functional = core_functional.get_core_functions()
         
     | 
| 134 | 
         
            +
                    if "PreProcess" in core_functional[additional_fn]: inputs = core_functional[additional_fn]["PreProcess"](inputs)  # 获取预处理函数(如果有的话)
         
     | 
| 135 | 
         
            +
                    inputs = core_functional[additional_fn]["Prefix"] + inputs + core_functional[additional_fn]["Suffix"]
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                # 处理历史信息
         
     | 
| 138 | 
         
            +
                history_feedin = []
         
     | 
| 139 | 
         
            +
                history_feedin.append(["What can I do?", system_prompt] )
         
     | 
| 140 | 
         
            +
                for i in range(len(history)//2):
         
     | 
| 141 | 
         
            +
                    history_feedin.append([history[2*i], history[2*i+1]] )
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                # 开始接收jittorllms的回复
         
     | 
| 144 | 
         
            +
                response = "[Local Message]: 等待jittorllms响应中 ..."
         
     | 
| 145 | 
         
            +
                for response in glm_handle.stream_chat(query=inputs, history=history_feedin, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
         
     | 
| 146 | 
         
            +
                    chatbot[-1] = (inputs, response)
         
     | 
| 147 | 
         
            +
                    yield from update_ui(chatbot=chatbot, history=history)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                # 总结输出
         
     | 
| 150 | 
         
            +
                if response == "[Local Message]: 等待jittorllms响应中 ...":
         
     | 
| 151 | 
         
            +
                    response = "[Local Message]: jittorllms响应异常 ..."
         
     | 
| 152 | 
         
            +
                history.extend([inputs, response])
         
     | 
| 153 | 
         
            +
                yield from update_ui(chatbot=chatbot, history=history)
         
     | 
    	
        request_llm/requirements_jittorllms.txt
    ADDED
    
    | 
         @@ -0,0 +1,4 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            jittor >= 1.3.7.9
         
     | 
| 2 | 
         
            +
            jtorch >= 0.1.3
         
     | 
| 3 | 
         
            +
            torch
         
     | 
| 4 | 
         
            +
            torchvision
         
     | 
    	
        request_llm/test_llms.py
    ADDED
    
    | 
         @@ -0,0 +1,26 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            对各个llm模型进行单元测试
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
            def validate_path():
         
     | 
| 5 | 
         
            +
                import os, sys
         
     | 
| 6 | 
         
            +
                dir_name = os.path.dirname(__file__)
         
     | 
| 7 | 
         
            +
                root_dir_assume = os.path.abspath(os.path.dirname(__file__) +  '/..')
         
     | 
| 8 | 
         
            +
                os.chdir(root_dir_assume)
         
     | 
| 9 | 
         
            +
                sys.path.append(root_dir_assume)
         
     | 
| 10 | 
         
            +
                
         
     | 
| 11 | 
         
            +
            validate_path() # validate path so you can run from base directory
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from request_llm.bridge_jittorllms import predict_no_ui_long_connection
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            llm_kwargs = {
         
     | 
| 16 | 
         
            +
                'max_length': 512,
         
     | 
| 17 | 
         
            +
                'top_p': 1,
         
     | 
| 18 | 
         
            +
                'temperature': 1,
         
     | 
| 19 | 
         
            +
            }
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            result = predict_no_ui_long_connection(inputs="你好", 
         
     | 
| 22 | 
         
            +
                                                   llm_kwargs=llm_kwargs,
         
     | 
| 23 | 
         
            +
                                                   history=[],
         
     | 
| 24 | 
         
            +
                                                   sys_prompt="")
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            print('result')
         
     |