Kevin Hu commited on
Commit
8ce7a30
·
1 Parent(s): 196c662

rm unused file (#3205)

Browse files

### What problem does this PR solve?


### Type of change

- [x] Refactoring

Files changed (1) hide show
  1. rag/llm/rpc_server.py +0 -171
rag/llm/rpc_server.py DELETED
@@ -1,171 +0,0 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
-
17
- import argparse
18
- import pickle
19
- import random
20
- import time
21
- from copy import deepcopy
22
- from multiprocessing.connection import Listener
23
- from threading import Thread
24
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
25
-
26
-
27
- def torch_gc():
28
- try:
29
- import torch
30
- if torch.cuda.is_available():
31
- # with torch.cuda.device(DEVICE):
32
- torch.cuda.empty_cache()
33
- torch.cuda.ipc_collect()
34
- elif torch.backends.mps.is_available():
35
- try:
36
- from torch.mps import empty_cache
37
- empty_cache()
38
- except Exception as e:
39
- pass
40
- except Exception:
41
- pass
42
-
43
-
44
- class RPCHandler:
45
- def __init__(self):
46
- self._functions = {}
47
-
48
- def register_function(self, func):
49
- self._functions[func.__name__] = func
50
-
51
- def handle_connection(self, connection):
52
- try:
53
- while True:
54
- # Receive a message
55
- func_name, args, kwargs = pickle.loads(connection.recv())
56
- # Run the RPC and send a response
57
- try:
58
- r = self._functions[func_name](*args, **kwargs)
59
- connection.send(pickle.dumps(r))
60
- except Exception as e:
61
- connection.send(pickle.dumps(e))
62
- except EOFError:
63
- pass
64
-
65
-
66
- def rpc_server(hdlr, address, authkey):
67
- sock = Listener(address, authkey=authkey)
68
- while True:
69
- try:
70
- client = sock.accept()
71
- t = Thread(target=hdlr.handle_connection, args=(client,))
72
- t.daemon = True
73
- t.start()
74
- except Exception as e:
75
- print("【EXCEPTION】:", str(e))
76
-
77
-
78
- models = []
79
- tokenizer = None
80
-
81
-
82
- def chat(messages, gen_conf):
83
- global tokenizer
84
- model = Model()
85
- try:
86
- torch_gc()
87
- conf = {
88
- "max_new_tokens": int(
89
- gen_conf.get(
90
- "max_tokens", 256)), "temperature": float(
91
- gen_conf.get(
92
- "temperature", 0.1))}
93
- print(messages, conf)
94
- text = tokenizer.apply_chat_template(
95
- messages,
96
- tokenize=False,
97
- add_generation_prompt=True
98
- )
99
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
100
-
101
- generated_ids = model.generate(
102
- model_inputs.input_ids,
103
- **conf
104
- )
105
- generated_ids = [
106
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
107
- ]
108
-
109
- return tokenizer.batch_decode(
110
- generated_ids, skip_special_tokens=True)[0]
111
- except Exception as e:
112
- return str(e)
113
-
114
-
115
- def chat_streamly(messages, gen_conf):
116
- global tokenizer
117
- model = Model()
118
- try:
119
- torch_gc()
120
- conf = deepcopy(gen_conf)
121
- print(messages, conf)
122
- text = tokenizer.apply_chat_template(
123
- messages,
124
- tokenize=False,
125
- add_generation_prompt=True
126
- )
127
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
128
- streamer = TextStreamer(tokenizer)
129
- conf["inputs"] = model_inputs.input_ids
130
- conf["streamer"] = streamer
131
- conf["max_new_tokens"] = conf["max_tokens"]
132
- del conf["max_tokens"]
133
- thread = Thread(target=model.generate, kwargs=conf)
134
- thread.start()
135
- for _, new_text in enumerate(streamer):
136
- yield new_text
137
- except Exception as e:
138
- yield "**ERROR**: " + str(e)
139
-
140
-
141
- def Model():
142
- global models
143
- random.seed(time.time())
144
- return random.choice(models)
145
-
146
-
147
- if __name__ == "__main__":
148
- parser = argparse.ArgumentParser()
149
- parser.add_argument("--model_name", type=str, help="Model name")
150
- parser.add_argument(
151
- "--port",
152
- default=7860,
153
- type=int,
154
- help="RPC serving port")
155
- args = parser.parse_args()
156
-
157
- handler = RPCHandler()
158
- handler.register_function(chat)
159
- handler.register_function(chat_streamly)
160
-
161
- models = []
162
- for _ in range(1):
163
- m = AutoModelForCausalLM.from_pretrained(args.model_name,
164
- device_map="auto",
165
- torch_dtype='auto')
166
- models.append(m)
167
- tokenizer = AutoTokenizer.from_pretrained(args.model_name)
168
-
169
- # Run the server
170
- rpc_server(handler, ('0.0.0.0', args.port),
171
- authkey=b'infiniflow-token4kevinhu')