File size: 1,868 Bytes
79ec61a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from textgen import ChatGlmModel, LlamaModel
from loguru import logger


class LLM(object):
    def __init__(
            self,
            gen_model_type: str = "chatglm",
            gen_model_name_or_path: str = "THUDM/chatglm-6b-int4",
            lora_model_name_or_path: str = None,

    ):

        self.model_type = gen_model_type

        if gen_model_type == "chatglm":
            self.gen_model = ChatGlmModel(
                gen_model_type,
                gen_model_name_or_path,
                lora_name=lora_model_name_or_path,
            )
        elif gen_model_type == "llama":

            self.gen_model = LlamaModel(
                gen_model_type,
                gen_model_name_or_path,
                lora_name=lora_model_name_or_path,
            )

        else:
            raise ValueError('gen_model_type must be chatglm or llama.')
        self.history = None

    def generate_answer(self, query_str, context_str, history=None, max_length=1024, prompt_template=None):
        """Generate answer from query and context."""
        if self.model_type == "t5":
            response = self.gen_model(query_str, max_length=max_length, do_sample=True)[0]['generated_text']
            return response, history
        prompt = prompt_template.format(context_str=context_str, query_str=query_str)
        response, out_history = self.gen_model.chat(prompt, history, max_length=max_length)
        return response, out_history

    def chat(self, query_str, history=None, max_length=1024):
        if self.model_type == "t5":
            response = self.gen_model(query_str, max_length=max_length, do_sample=True)[0]['generated_text']
            logger.debug(response)
            return response, history

        response, out_history = self.gen_model.chat(query_str, history, max_length=max_length)
        return response, out_history