File size: 4,681 Bytes
328b268 6011708 a28a4f8 328b268 e182c41 328b268 bf1e59b 328b268 a28a4f8 6011708 a28a4f8 6011708 328b268 6b469d2 328b268 d5af465 328b268 3ca5bd8 328b268 bf1e59b 3ca5bd8 6011708 3ca5bd8 6011708 328b268 3ca5bd8 2826548 a28a4f8 3ca5bd8 a28a4f8 328b268 3ca5bd8 d5af465 328b268 95d2e5f 6011708 95d2e5f 6011708 3ca5bd8 328b268 3ca5bd8 4cae0a4 bf1e59b 734948a 4cae0a4 734948a bf1e59b 4cae0a4 734948a bf1e59b 734948a 3ca5bd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import abc
import os
import time
import urllib
from queue import Queue
from threading import Thread
from typing import List, Optional
from urllib.parse import quote, urlparse, urlunparse
from langchain.chains.base import Chain
from app_modules.llm_loader import LLMLoader, TextIteratorStreamer
from app_modules.utils import remove_extra_spaces
class LLMInference(metaclass=abc.ABCMeta):
def __init__(self, llm_loader):
self.llm_loader = llm_loader
self.chain = None
@abc.abstractmethod
def create_chain(self) -> Chain:
pass
def get_chain(self) -> Chain:
if self.chain is None:
self.chain = self.create_chain()
return self.chain
def reset(self) -> None:
self.chain = None
def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
result = chain.invoke(inputs, {"callbacks": callbacks})
if "text" in result:
result["response"] = result["text"]
del result["text"]
return result
def call_chain(
self,
inputs,
streaming_handler,
q: Queue = None,
testing: bool = False,
):
print(inputs)
if self.llm_loader.streamer.for_huggingface:
self.llm_loader.lock.acquire()
try:
self.llm_loader.streamer.reset(q)
chain = self.get_chain()
result = (
self._run_chain_with_streaming_handler(
chain, inputs, streaming_handler, testing
)
if streaming_handler is not None
else self.run_chain(chain, inputs)
)
if "answer" in result:
result["answer"] = remove_extra_spaces(result["answer"])
source_path = os.environ.get("SOURCE_PATH")
base_url = os.environ.get("PDF_FILE_BASE_URL")
if base_url is not None and len(base_url) > 0:
documents = result["source_documents"]
for doc in documents:
source = doc.metadata["source"]
title = source.split("/")[-1]
doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"
elif source_path is not None and len(source_path) > 0:
documents = result["source_documents"]
for doc in documents:
source = doc.metadata["source"]
url = source.replace(source_path, "https://")
url = url.replace(".html", "")
parsed_url = urlparse(url)
# Encode path, query, and fragment
encoded_path = quote(parsed_url.path)
encoded_query = quote(parsed_url.query)
encoded_fragment = quote(parsed_url.fragment)
# Construct the encoded URL
doc.metadata["url"] = urlunparse(
(
parsed_url.scheme,
parsed_url.netloc,
encoded_path,
parsed_url.params,
encoded_query,
encoded_fragment,
)
)
return result
finally:
if self.llm_loader.streamer.for_huggingface:
self.llm_loader.lock.release()
def _execute_chain(self, chain, inputs, q, sh):
q.put(self.run_chain(chain, inputs, callbacks=[sh]))
def _run_chain_with_streaming_handler(
self, chain, inputs, streaming_handler, testing
):
que = Queue()
t = Thread(
target=self._execute_chain,
args=(chain, inputs, que, streaming_handler),
)
t.start()
if self.llm_loader.streamer.for_huggingface:
count = (
2
if "chat_history" in inputs and len(inputs.get("chat_history")) > 0
else 1
)
while count > 0:
try:
for token in self.llm_loader.streamer:
if not testing:
streaming_handler.on_llm_new_token(token)
self.llm_loader.streamer.reset()
count -= 1
except Exception:
if not testing:
print("nothing generated yet - retry in 0.5s")
time.sleep(0.5)
t.join()
return que.get()
|