File size: 1,686 Bytes
7f9d16c
815128e
7f9d16c
815128e
7f9d16c
 
815128e
7f9d16c
 
815128e
7f9d16c
815128e
 
 
 
 
 
 
 
 
 
 
 
7f9d16c
815128e
 
 
 
 
 
7f9d16c
 
 
 
 
 
 
815128e
7f9d16c
 
 
 
 
 
815128e
7f9d16c
 
815128e
7f9d16c
 
815128e
7f9d16c
 
815128e
7f9d16c
 
815128e
 
7f9d16c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# project/test.py

import unittest

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import HumanMessage

from app_modules.llm_loader import LLMLoader
from timeit import default_timer as timer

USER_QUESTION = "What's the capital city of Malaysia?"


class MyCustomHandler(BaseCallbackHandler):
    def __init__(self):
        self.reset()

    def reset(self):
        self.texts = []

    def get_standalone_question(self) -> str:
        return self.texts[0].strip() if len(self.texts) > 0 else None

    def on_llm_end(self, response, **kwargs) -> None:
        """Run when chain ends running."""
        print("\non_llm_end - response:")
        print(response)
        self.texts.append(response.generations[0][0].text)


class TestLLMLoader(unittest.TestCase):
    def run_test_case(self, llm_model_type, query):
        llm_loader = LLMLoader(llm_model_type)
        start = timer()
        llm_loader.init(n_threds=8, hf_pipeline_device_type="cpu")
        end = timer()
        print(f"Model loaded in {end - start:.3f}s")

        result = llm_loader.llm(
            [HumanMessage(content=query)] if llm_model_type == "openai" else query
        )
        end2 = timer()
        print(f"Inference completed in {end2 - end:.3f}s")
        print(result)

    def xtest_openai(self):
        self.run_test_case("openai", USER_QUESTION)

    def xtest_llamacpp(self):
        self.run_test_case("llamacpp", USER_QUESTION)

    def xtest_gpt4all_j(self):
        self.run_test_case("gpt4all-j", USER_QUESTION)

    def test_huggingface(self):
        self.run_test_case("huggingface", USER_QUESTION)


if __name__ == "__main__":
    unittest.main()