Spaces:
Runtime error
Runtime error
File size: 1,868 Bytes
79ec61a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from textgen import ChatGlmModel, LlamaModel
from loguru import logger
class LLM(object):
def __init__(
self,
gen_model_type: str = "chatglm",
gen_model_name_or_path: str = "THUDM/chatglm-6b-int4",
lora_model_name_or_path: str = None,
):
self.model_type = gen_model_type
if gen_model_type == "chatglm":
self.gen_model = ChatGlmModel(
gen_model_type,
gen_model_name_or_path,
lora_name=lora_model_name_or_path,
)
elif gen_model_type == "llama":
self.gen_model = LlamaModel(
gen_model_type,
gen_model_name_or_path,
lora_name=lora_model_name_or_path,
)
else:
raise ValueError('gen_model_type must be chatglm or llama.')
self.history = None
def generate_answer(self, query_str, context_str, history=None, max_length=1024, prompt_template=None):
"""Generate answer from query and context."""
if self.model_type == "t5":
response = self.gen_model(query_str, max_length=max_length, do_sample=True)[0]['generated_text']
return response, history
prompt = prompt_template.format(context_str=context_str, query_str=query_str)
response, out_history = self.gen_model.chat(prompt, history, max_length=max_length)
return response, out_history
def chat(self, query_str, history=None, max_length=1024):
if self.model_type == "t5":
response = self.gen_model(query_str, max_length=max_length, do_sample=True)[0]['generated_text']
logger.debug(response)
return response, history
response, out_history = self.gen_model.chat(query_str, history, max_length=max_length)
return response, out_history
|