File size: 3,168 Bytes
328b268
 
 
 
 
 
 
 
 
 
e182c41
 
328b268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95d2e5f
328b268
 
 
 
 
 
 
 
2826548
 
 
 
 
 
 
 
 
 
328b268
 
 
95d2e5f
 
 
 
 
 
 
 
 
328b268
 
 
95d2e5f
 
328b268
 
 
 
 
 
ef1ef76
 
 
 
 
328b268
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import abc
import os
import time
import urllib
from queue import Queue
from threading import Thread

from langchain.callbacks.tracers import LangChainTracer
from langchain.chains.base import Chain

from app_modules.llm_loader import LLMLoader, TextIteratorStreamer
from app_modules.utils import remove_extra_spaces


class LLMInference(metaclass=abc.ABCMeta):
    llm_loader: LLMLoader
    chain: Chain

    def __init__(self, llm_loader):
        self.llm_loader = llm_loader
        self.chain = None

    @abc.abstractmethod
    def create_chain(self) -> Chain:
        pass

    def get_chain(self, tracing: bool = False) -> Chain:
        if self.chain is None:
            if tracing:
                tracer = LangChainTracer()
                tracer.load_default_session()

            self.chain = self.create_chain()

        return self.chain

    def call_chain(
        self, inputs, streaming_handler, q: Queue = None, tracing: bool = False
    ):
        print(inputs)

        if self.llm_loader.streamer is not None and isinstance(
            self.llm_loader.streamer, TextIteratorStreamer
        ):
            self.llm_loader.streamer.reset(q)

        chain = self.get_chain(tracing)
        result = (
            self._run_chain(
                chain,
                inputs,
                streaming_handler,
            )
            if streaming_handler is not None
            else chain(inputs)
        )

        if "answer" in result:
            result["answer"] = remove_extra_spaces(result["answer"])

            base_url = os.environ.get("PDF_FILE_BASE_URL")
            if base_url is not None and len(base_url) > 0:
                documents = result["source_documents"]
                for doc in documents:
                    source = doc.metadata["source"]
                    title = source.split("/")[-1]
                    doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"

        return result

    def _execute_chain(self, chain, inputs, q, sh):
        self.llm_loader.lock.acquire()
        try:
            q.put(chain(inputs, callbacks=[sh]))
        finally:
            # Release the lock
            self.llm_loader.lock.release()

    def _run_chain(self, chain, inputs, streaming_handler):
        que = Queue()

        t = Thread(
            target=self._execute_chain,
            args=(chain, inputs, que, streaming_handler),
        )
        t.start()

        if self.llm_loader.streamer is not None and isinstance(
            self.llm_loader.streamer, TextIteratorStreamer
        ):
            count = (
                2
                if "chat_history" in inputs and len(inputs.get("chat_history")) > 0
                else 1
            )

            while count > 0:
                try:
                    for token in self.llm_loader.streamer:
                        streaming_handler.on_llm_new_token(token)

                    self.llm_loader.streamer.reset()
                    count -= 1
                except Exception:
                    print("nothing generated yet - retry in 0.5s")
                    time.sleep(0.5)

        t.join()
        return que.get()