add support for LocalLLM (#1744)
Browse files### What problem does this PR solve?
add support for LocalLLM
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
---------
Co-authored-by: Zhedong Cen <[email protected]>
- rag/llm/chat_model.py +36 -23
- rag/svr/jina_server.py +93 -0
    	
        rag/llm/chat_model.py
    CHANGED
    
    | @@ -27,6 +27,8 @@ from groq import Groq | |
| 27 | 
             
            import os 
         | 
| 28 | 
             
            import json
         | 
| 29 | 
             
            import requests
         | 
|  | |
|  | |
| 30 |  | 
| 31 | 
             
            class Base(ABC):
         | 
| 32 | 
             
                def __init__(self, key, model_name, base_url):
         | 
| @@ -381,8 +383,10 @@ class LocalLLM(Base): | |
| 381 |  | 
| 382 | 
             
                    def __conn(self):
         | 
| 383 | 
             
                        from multiprocessing.connection import Client
         | 
|  | |
| 384 | 
             
                        self._connection = Client(
         | 
| 385 | 
            -
                            (self.host, self.port), authkey=b | 
|  | |
| 386 |  | 
| 387 | 
             
                    def __getattr__(self, name):
         | 
| 388 | 
             
                        import pickle
         | 
| @@ -390,8 +394,7 @@ class LocalLLM(Base): | |
| 390 | 
             
                        def do_rpc(*args, **kwargs):
         | 
| 391 | 
             
                            for _ in range(3):
         | 
| 392 | 
             
                                try:
         | 
| 393 | 
            -
                                    self._connection.send(
         | 
| 394 | 
            -
                                        pickle.dumps((name, args, kwargs)))
         | 
| 395 | 
             
                                    return pickle.loads(self._connection.recv())
         | 
| 396 | 
             
                                except Exception as e:
         | 
| 397 | 
             
                                    self.__conn()
         | 
| @@ -399,35 +402,45 @@ class LocalLLM(Base): | |
| 399 |  | 
| 400 | 
             
                        return do_rpc
         | 
| 401 |  | 
| 402 | 
            -
                def __init__(self, key, model_name | 
| 403 | 
            -
                     | 
| 404 |  | 
| 405 | 
            -
             | 
| 406 | 
            -
                    if system:
         | 
| 407 | 
            -
                        history.insert(0, {"role": "system", "content": system})
         | 
| 408 | 
            -
                    try:
         | 
| 409 | 
            -
                        ans = self.client.chat(
         | 
| 410 | 
            -
                            history,
         | 
| 411 | 
            -
                            gen_conf
         | 
| 412 | 
            -
                        )
         | 
| 413 | 
            -
                        return ans, num_tokens_from_string(ans)
         | 
| 414 | 
            -
                    except Exception as e:
         | 
| 415 | 
            -
                        return "**ERROR**: " + str(e), 0
         | 
| 416 |  | 
| 417 | 
            -
                def  | 
| 418 | 
             
                    if system:
         | 
| 419 | 
             
                        history.insert(0, {"role": "system", "content": system})
         | 
| 420 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
| 421 | 
             
                    answer = ""
         | 
| 422 | 
             
                    try:
         | 
| 423 | 
            -
                         | 
| 424 | 
            -
                             | 
| 425 | 
            -
             | 
| 426 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 427 | 
             
                    except Exception as e:
         | 
| 428 | 
             
                        yield answer + "\n**ERROR**: " + str(e)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 429 |  | 
| 430 | 
            -
             | 
|  | |
|  | |
| 431 |  | 
| 432 |  | 
| 433 | 
             
            class VolcEngineChat(Base):
         | 
|  | |
| 27 | 
             
            import os 
         | 
| 28 | 
             
            import json
         | 
| 29 | 
             
            import requests
         | 
| 30 | 
            +
            import asyncio
         | 
| 31 | 
            +
            from rag.svr.jina_server import Prompt,Generation
         | 
| 32 |  | 
| 33 | 
             
            class Base(ABC):
         | 
| 34 | 
             
                def __init__(self, key, model_name, base_url):
         | 
|  | |
| 383 |  | 
| 384 | 
             
                    def __conn(self):
         | 
| 385 | 
             
                        from multiprocessing.connection import Client
         | 
| 386 | 
            +
             | 
| 387 | 
             
                        self._connection = Client(
         | 
| 388 | 
            +
                            (self.host, self.port), authkey=b"infiniflow-token4kevinhu"
         | 
| 389 | 
            +
                        )
         | 
| 390 |  | 
| 391 | 
             
                    def __getattr__(self, name):
         | 
| 392 | 
             
                        import pickle
         | 
|  | |
| 394 | 
             
                        def do_rpc(*args, **kwargs):
         | 
| 395 | 
             
                            for _ in range(3):
         | 
| 396 | 
             
                                try:
         | 
| 397 | 
            +
                                    self._connection.send(pickle.dumps((name, args, kwargs)))
         | 
|  | |
| 398 | 
             
                                    return pickle.loads(self._connection.recv())
         | 
| 399 | 
             
                                except Exception as e:
         | 
| 400 | 
             
                                    self.__conn()
         | 
|  | |
| 402 |  | 
| 403 | 
             
                        return do_rpc
         | 
| 404 |  | 
| 405 | 
            +
                def __init__(self, key, model_name):
         | 
| 406 | 
            +
                    from jina import Client
         | 
| 407 |  | 
| 408 | 
            +
                    self.client = Client(port=12345, protocol="grpc", asyncio=True)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 409 |  | 
| 410 | 
            +
                def _prepare_prompt(self, system, history, gen_conf):
         | 
| 411 | 
             
                    if system:
         | 
| 412 | 
             
                        history.insert(0, {"role": "system", "content": system})
         | 
| 413 | 
            +
                    if "max_tokens" in gen_conf:
         | 
| 414 | 
            +
                        gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
         | 
| 415 | 
            +
                    return Prompt(message=history, gen_conf=gen_conf)
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                def _stream_response(self, endpoint, prompt):
         | 
| 418 | 
             
                    answer = ""
         | 
| 419 | 
             
                    try:
         | 
| 420 | 
            +
                        res = self.client.stream_doc(
         | 
| 421 | 
            +
                            on=endpoint, inputs=prompt, return_type=Generation
         | 
| 422 | 
            +
                        )
         | 
| 423 | 
            +
                        loop = asyncio.get_event_loop()
         | 
| 424 | 
            +
                        try:
         | 
| 425 | 
            +
                            while True:
         | 
| 426 | 
            +
                                answer = loop.run_until_complete(res.__anext__()).text
         | 
| 427 | 
            +
                                yield answer
         | 
| 428 | 
            +
                        except StopAsyncIteration:
         | 
| 429 | 
            +
                            pass
         | 
| 430 | 
             
                    except Exception as e:
         | 
| 431 | 
             
                        yield answer + "\n**ERROR**: " + str(e)
         | 
| 432 | 
            +
                    yield num_tokens_from_string(answer)
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                def chat(self, system, history, gen_conf):
         | 
| 435 | 
            +
                    prompt = self._prepare_prompt(system, history, gen_conf)
         | 
| 436 | 
            +
                    chat_gen = self._stream_response("/chat", prompt)
         | 
| 437 | 
            +
                    ans = next(chat_gen)
         | 
| 438 | 
            +
                    total_tokens = next(chat_gen)
         | 
| 439 | 
            +
                    return ans, total_tokens
         | 
| 440 |  | 
| 441 | 
            +
                def chat_streamly(self, system, history, gen_conf):
         | 
| 442 | 
            +
                    prompt = self._prepare_prompt(system, history, gen_conf)
         | 
| 443 | 
            +
                    return self._stream_response("/stream", prompt)
         | 
| 444 |  | 
| 445 |  | 
| 446 | 
             
            class VolcEngineChat(Base):
         | 
    	
        rag/svr/jina_server.py
    ADDED
    
    | @@ -0,0 +1,93 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from jina import Deployment
         | 
| 2 | 
            +
            from docarray import BaseDoc
         | 
| 3 | 
            +
            from jina import Executor, requests
         | 
| 4 | 
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
         | 
| 5 | 
            +
            import argparse
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class Prompt(BaseDoc):
         | 
| 10 | 
            +
                message: list[dict]
         | 
| 11 | 
            +
                gen_conf: dict
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class Generation(BaseDoc):
         | 
| 15 | 
            +
                text: str
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            tokenizer = None
         | 
| 19 | 
            +
            model_name = ""
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class TokenStreamingExecutor(Executor):
         | 
| 23 | 
            +
                def __init__(self, **kwargs):
         | 
| 24 | 
            +
                    super().__init__(**kwargs)
         | 
| 25 | 
            +
                    self.model = AutoModelForCausalLM.from_pretrained(
         | 
| 26 | 
            +
                        model_name, device_map="auto", torch_dtype="auto"
         | 
| 27 | 
            +
                    )
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                @requests(on="/chat")
         | 
| 30 | 
            +
                async def generate(self, doc: Prompt, **kwargs) -> Generation:
         | 
| 31 | 
            +
                    text = tokenizer.apply_chat_template(
         | 
| 32 | 
            +
                        doc.message,
         | 
| 33 | 
            +
                        tokenize=False,
         | 
| 34 | 
            +
                    )
         | 
| 35 | 
            +
                    inputs = tokenizer([text], return_tensors="pt")
         | 
| 36 | 
            +
                    generation_config = GenerationConfig(
         | 
| 37 | 
            +
                        **doc.gen_conf,
         | 
| 38 | 
            +
                        eos_token_id=tokenizer.eos_token_id,
         | 
| 39 | 
            +
                        pad_token_id=tokenizer.eos_token_id
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
                    generated_ids = self.model.generate(
         | 
| 42 | 
            +
                        inputs.input_ids, generation_config=generation_config
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
                    generated_ids = [
         | 
| 45 | 
            +
                        output_ids[len(input_ids) :]
         | 
| 46 | 
            +
                        for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
         | 
| 47 | 
            +
                    ]
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
         | 
| 50 | 
            +
                    yield Generation(text=response)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                @requests(on="/stream")
         | 
| 53 | 
            +
                async def task(self, doc: Prompt, **kwargs) -> Generation:
         | 
| 54 | 
            +
                    text = tokenizer.apply_chat_template(
         | 
| 55 | 
            +
                        doc.message,
         | 
| 56 | 
            +
                        tokenize=False,
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
                    input = tokenizer([text], return_tensors="pt")
         | 
| 59 | 
            +
                    input_len = input["input_ids"].shape[1]
         | 
| 60 | 
            +
                    max_new_tokens = 512
         | 
| 61 | 
            +
                    if "max_new_tokens" in doc.gen_conf:
         | 
| 62 | 
            +
                        max_new_tokens = doc.gen_conf.pop("max_new_tokens")
         | 
| 63 | 
            +
                    generation_config = GenerationConfig(
         | 
| 64 | 
            +
                        **doc.gen_conf,
         | 
| 65 | 
            +
                        eos_token_id=tokenizer.eos_token_id,
         | 
| 66 | 
            +
                        pad_token_id=tokenizer.eos_token_id
         | 
| 67 | 
            +
                    )
         | 
| 68 | 
            +
                    for _ in range(max_new_tokens):
         | 
| 69 | 
            +
                        output = self.model.generate(
         | 
| 70 | 
            +
                            **input, max_new_tokens=1, generation_config=generation_config
         | 
| 71 | 
            +
                        )
         | 
| 72 | 
            +
                        if output[0][-1] == tokenizer.eos_token_id:
         | 
| 73 | 
            +
                            break
         | 
| 74 | 
            +
                        yield Generation(
         | 
| 75 | 
            +
                            text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
         | 
| 76 | 
            +
                        )
         | 
| 77 | 
            +
                        input = {
         | 
| 78 | 
            +
                            "input_ids": output,
         | 
| 79 | 
            +
                            "attention_mask": torch.ones(1, len(output[0])),
         | 
| 80 | 
            +
                        }
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            if __name__ == "__main__":
         | 
| 84 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 85 | 
            +
                parser.add_argument("--model_name", type=str, help="Model name or path")
         | 
| 86 | 
            +
                parser.add_argument("--port", default=12345, type=int, help="Jina serving port")
         | 
| 87 | 
            +
                args = parser.parse_args()
         | 
| 88 | 
            +
                model_name = args.model_name
         | 
| 89 | 
            +
                tokenizer = AutoTokenizer.from_pretrained(args.model_name)
         | 
| 90 | 
            +
                with Deployment(
         | 
| 91 | 
            +
                    uses=TokenStreamingExecutor, port=args.port, protocol="grpc"
         | 
| 92 | 
            +
                ) as dep:
         | 
| 93 | 
            +
                    dep.block()
         | 
