File size: 1,574 Bytes
8282222
 
4359eb6
2826548
 
 
4359eb6
2826548
 
 
8282222
 
 
 
 
b7ce602
8282222
 
 
 
 
 
2826548
 
 
 
 
8282222
 
 
 
2826548
 
 
8282222
 
 
2826548
 
 
 
8282222
2826548
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os

from langchain import LLMChain, PromptTemplate
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.base import Chain
from langchain.memory import ConversationBufferMemory

from app_modules.llm_inference import LLMInference


def get_llama_2_prompt_template():
    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

    instruction = "Chat History:\n\n{chat_history} \n\nUser: {question}"
    system_prompt = "You are a helpful assistant, you always only answer for the assistant then you stop. Read the chat history to get context"

    SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
    prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
    return prompt_template


class ChatChain(LLMInference):
    def __init__(self, llm_loader):
        super().__init__(llm_loader)

    def create_chain(self) -> Chain:
        template = (
            get_llama_2_prompt_template()
            if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
            else """You are a chatbot having a conversation with a human.
{chat_history}
Human: {question}
Chatbot:"""
        )

        print(f"template: {template}")

        prompt = PromptTemplate(
            input_variables=["chat_history", "question"], template=template
        )

        memory = ConversationBufferMemory(memory_key="chat_history")

        llm_chain = LLMChain(
            llm=self.llm_loader.llm,
            prompt=prompt,
            verbose=True,
            memory=memory,
        )

        return llm_chain