kcarnold commited on
Commit
38826eb
·
1 Parent(s): cb63b67

copy in the backend code

Browse files
Files changed (2) hide show
  1. custom_llm.py +170 -0
  2. custom_llm_inference.py +193 -0
custom_llm.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ from contextlib import asynccontextmanager
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional
7
+
8
+ import torch
9
+ import uvicorn
10
+ from fastapi import FastAPI, HTTPException
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from fastapi.testclient import TestClient
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+
15
+ from custom_llm_inference import get_highlights_inner, get_next_token_predictions_inner
16
+
17
+ ml_models = {}
18
+
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--gpu", action="store_true", help="Enable GPU usage")
21
+ args = parser.parse_args()
22
+
23
+ USE_GPU = args.gpu
24
+
25
+ if not USE_GPU:
26
+ print("Running without GPU. To enable GPU, run with the --gpu flag.")
27
+
28
+ @asynccontextmanager
29
+ async def models_lifespan(app: FastAPI):
30
+
31
+ #model_name = 'google/gemma-1.1-7b-it'
32
+ #model_name = 'google/gemma-1.1-2b-it'
33
+ model_name = 'google/gemma-2-9b-it'
34
+
35
+ dtype = torch.bfloat16 if USE_GPU else torch.float16
36
+
37
+ ml_models["llm"] = llm = {
38
+ 'tokenizer': AutoTokenizer.from_pretrained(model_name),
39
+ 'model': AutoModelForCausalLM.from_pretrained(model_name, device_map="auto" if USE_GPU else "cpu", torch_dtype=dtype)
40
+ }
41
+ print("Loaded llm with device map:")
42
+ print(llm['model'].hf_device_map)
43
+
44
+ # Print timing info for each endpoint
45
+ print("\nRunning endpoint tests...")
46
+
47
+ test_doc = "This is a test document that needs to be revised for clarity and conciseness."
48
+ test_prompt = "Make this more clear and concise."
49
+
50
+ client = TestClient(app)
51
+
52
+ start = time.time()
53
+ response = client.get("/api/highlights",
54
+ params={"doc": test_doc, "prompt": test_prompt})
55
+ print(f"Highlights endpoint: {time.time() - start:.2f}s")
56
+
57
+ start = time.time()
58
+ response = client.get("/api/next_token",
59
+ params={"original_doc": test_doc, "prompt": test_prompt, "doc_in_progress": "This is"})
60
+ print(f"Next token endpoint: {time.time() - start:.2f}s")
61
+
62
+ start = time.time()
63
+ response = client.get("/api/gen_revisions",
64
+ params={"doc": test_doc, "prompt": test_prompt, "n": 1})
65
+ print(f"Gen revisions endpoint: {time.time() - start:.2f}s")
66
+
67
+ yield
68
+
69
+ # Release resources on exit
70
+ ml_models.clear()
71
+
72
+ DEBUG = os.getenv("DEBUG") or False
73
+ PORT = int(os.getenv("PORT") or "19570")
74
+
75
+ app = FastAPI(lifespan=models_lifespan)
76
+
77
+ origins = [
78
+ "*",
79
+ ]
80
+
81
+ app.add_middleware(
82
+ CORSMiddleware,
83
+ allow_origins=origins,
84
+ allow_credentials=True,
85
+ allow_methods=["*"],
86
+ allow_headers=["*"],
87
+ )
88
+
89
+
90
+ @app.get("/api/highlights")
91
+ def get_highlights(doc: str, prompt: Optional[str] = None, updated_doc: Optional[str] = '', k: Optional[int] = 5):
92
+ ''' Example of using this in JavaScript:
93
+
94
+ let url = new URL('http://localhost:8000/api/highlights')
95
+ url.searchParams.append('doc', 'This is a test document. It is a test document because it is a test document.')
96
+ url.searchParams.append('prompt', 'Rewrite this document to be more concise.')
97
+ url.searchParams.append('updated_doc', 'This is a test document.')
98
+ let response = await fetch(url)
99
+ '''
100
+
101
+ llm = ml_models['llm']
102
+ model = llm['model']
103
+ tokenizer = llm['tokenizer']
104
+
105
+ if prompt is None:
106
+ prompt = "Rewrite this document to be more concise."
107
+
108
+ highlights = get_highlights_inner(model, tokenizer, doc, prompt, updated_doc, k)
109
+
110
+ return {'highlights': highlights}
111
+
112
+
113
+ @app.get('/api/next_token')
114
+ def get_next_token_predictions(original_doc: str,
115
+ prompt: str,
116
+ doc_in_progress: str,
117
+ k: Optional[int] = 5):
118
+
119
+
120
+ model = ml_models['llm']['model']
121
+ tokenizer = ml_models['llm']['tokenizer']
122
+
123
+ decoded_next_tokens, next_token_logits = get_next_token_predictions_inner(
124
+ model, tokenizer, original_doc, prompt, doc_in_progress, k)
125
+
126
+ return {
127
+ 'next_tokens': decoded_next_tokens
128
+ }
129
+
130
+
131
+ @app.get('/api/gen_revisions')
132
+ def gen_revisions(
133
+ prompt: str,
134
+ doc: str,
135
+ n: Optional[int] = 5):
136
+
137
+
138
+ model = ml_models['llm']['model']
139
+ tokenizer = ml_models['llm']['tokenizer']
140
+
141
+ messages = [
142
+ {
143
+ "role": "user",
144
+ "content": f"{prompt}\n\n{doc}",
145
+ },
146
+ ]
147
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
148
+
149
+ generations = model.generate(
150
+ tokenized_chat, num_return_sequences=n,
151
+ max_length=1024, do_sample=True, top_k=50, top_p=0.95, temperature=0.5,
152
+ return_dict_in_generate=True, output_scores=True)
153
+ generated_docs = tokenizer.batch_decode(generations.sequences, skip_special_tokens=True)
154
+ #print(generations.scores)
155
+
156
+ # Remove prompt text. see https://github.com/huggingface/transformers/blob/v4.46.2/src/transformers/pipelines/text_generation.py#L37
157
+ prompt_length = len(
158
+ tokenizer.decode(
159
+ tokenized_chat[0],
160
+ skip_special_tokens=True,
161
+ clean_up_tokenization_spaces=True,
162
+ ))
163
+
164
+ return {
165
+ 'revised_docs': [dict(doc_text=doc[prompt_length:]) for doc in generated_docs]
166
+ }
167
+
168
+
169
+ if __name__ == "__main__":
170
+ uvicorn.run(app, host="localhost", port=PORT)
custom_llm_inference.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers.cache_utils import DynamicCache
3
+
4
+
5
+ def get_tokenized_chat(tokenizer, prompt, doc):
6
+ messages = [
7
+ {
8
+ "role": "user",
9
+ "content": f"{prompt}\n\n{doc}",
10
+ },
11
+ ]
12
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
13
+ return tokenized_chat
14
+
15
+
16
+ def tokenize_doc_in_progress(tokenizer, doc_in_progress):
17
+ if len(doc_in_progress) == 0:
18
+ # Some tokenizers give tensors of the wrong dtype if the input is empty
19
+ return torch.empty(0, dtype=torch.int64)
20
+
21
+ doc_in_progress_ids = tokenizer(
22
+ doc_in_progress, return_tensors='pt')['input_ids'][0]
23
+
24
+ # strip the first token, the "beginning of document" token
25
+ # TODO: make this robust to switching models
26
+ # since some models will use different special tokens
27
+ doc_in_progress_ids = doc_in_progress_ids[1:]
28
+ return doc_in_progress_ids
29
+
30
+
31
+ def get_highlights_inner(model, tokenizer, doc, prompt, updated_doc, k):
32
+ tokenized_chat = get_tokenized_chat(tokenizer, prompt, doc)
33
+ assert len(tokenized_chat.shape) == 1
34
+
35
+ if updated_doc is None or len(updated_doc.strip()) == 0:
36
+ updated_doc = doc
37
+ updated_doc_ids = tokenize_doc_in_progress(tokenizer, updated_doc)
38
+
39
+ joined_ids = torch.cat([tokenized_chat, updated_doc_ids])
40
+ # Call the model
41
+ with torch.no_grad():
42
+ logits = model(joined_ids[None].to(model.device)).logits[0].cpu()
43
+
44
+ highlights = []
45
+ length_so_far = 0
46
+ for idx in range(len(tokenized_chat), len(joined_ids)):
47
+ probs = logits[idx - 1].softmax(dim=-1)
48
+ token_id = joined_ids[idx]
49
+ token = tokenizer.decode(token_id)
50
+ token_loss = -probs[token_id].log().item()
51
+ topk_tokens = probs.topk(k).indices.cpu().numpy().tolist()
52
+ topk_tokens_decoded = tokenizer.batch_decode(topk_tokens, skip_special_tokens=True)
53
+ highlights.append(dict(
54
+ start=length_so_far,
55
+ end=length_so_far + len(token),
56
+ token=token,
57
+ token_loss=token_loss,
58
+ most_likely_token=topk_tokens_decoded[0],
59
+ topk_tokens=topk_tokens_decoded,
60
+ ))
61
+ length_so_far += len(token)
62
+ return highlights
63
+
64
+
65
+
66
+ def get_next_token_predictions_inner(
67
+ model, tokenizer, original_doc, prompt, doc_in_progress, k):
68
+
69
+ tokenized_chat = get_tokenized_chat(tokenizer, prompt, original_doc)
70
+ doc_in_progress_ids = tokenize_doc_in_progress(tokenizer, doc_in_progress)
71
+
72
+ device = model.device
73
+
74
+ joined_ids = torch.cat([tokenized_chat, doc_in_progress_ids])
75
+ hypotheses = joined_ids[None].to(model.device)
76
+
77
+ # For each of the k next tokens, generate most-likely next tokens and append back on until we
78
+ # reach a token with a space
79
+
80
+ past_key_values = DynamicCache()
81
+
82
+ with torch.no_grad():
83
+ model_outs_onestep = model(hypotheses, output_hidden_states=True, past_key_values=past_key_values)
84
+
85
+ branch_tokens = model_outs_onestep.logits[0, -1].topk(k).indices
86
+
87
+ # split the cache into k reps. We pretend we're doing a "Beam search"...
88
+ past_key_values.reorder_cache(torch.zeros((k,), dtype=torch.long, device=device))
89
+
90
+ # Now call the model again, passing the kv cache, so we can continue generating.
91
+ # Each of the k next tokens will be considered as one sequence in a "batch".
92
+ next_tokens_as_batch = branch_tokens.unsqueeze(1)
93
+ assert next_tokens_as_batch.shape == (k, 1)
94
+
95
+ position_id_for_final_token = joined_ids.shape[0]
96
+ cache_position = torch.full((1,), position_id_for_final_token, dtype=int, device=device)
97
+ with torch.no_grad():
98
+ model_outs = model(
99
+ next_tokens_as_batch,
100
+ past_key_values=past_key_values,
101
+ output_hidden_states=True,
102
+ use_cache=True,
103
+ # the cache surprisingly doesn't know the position of the last token
104
+ cache_position=cache_position
105
+ )
106
+
107
+ # Grab the single most likely token from each of the k sequences
108
+ next_token_logits = model_outs.logits[:, -1]
109
+ vocab_size = model.config.vocab_size
110
+ assert next_token_logits.shape == (k, vocab_size), f"{next_token_logits.shape=}, {k=}, {vocab_size=}"
111
+ most_likely_token_ids = next_token_logits.argmax(dim=-1)
112
+
113
+ # Stick them at the end of the branch tokens.
114
+ assert most_likely_token_ids.shape == (k,)
115
+ lookahead_sequences = torch.cat([
116
+ branch_tokens.unsqueeze(1),
117
+ most_likely_token_ids.unsqueeze(1)
118
+ ], dim=1)
119
+ assert lookahead_sequences.shape == (k, 2)
120
+
121
+ decoded_next_tokens = tokenizer.batch_decode(lookahead_sequences, skip_special_tokens=True)
122
+ return decoded_next_tokens, next_token_logits
123
+
124
+ def get_next_token_predictions_generate(
125
+ model, tokenizer, original_doc, prompt, doc_in_progress, k):
126
+
127
+ tokenized_chat = get_tokenized_chat(tokenizer, prompt, original_doc)
128
+ doc_in_progress_ids = tokenize_doc_in_progress(tokenizer, doc_in_progress)
129
+
130
+ joined_ids = torch.cat([tokenized_chat, doc_in_progress_ids])
131
+ context_without_special_tokens = tokenizer.batch_decode(joined_ids, skip_special_tokens=True)
132
+ prefix_length = len(context_without_special_tokens)
133
+ hypotheses = joined_ids[None].to(model.device)
134
+
135
+ generation_output = model.generate(
136
+ hypotheses,
137
+ return_dict_in_generate=True,
138
+ output_scores=True,
139
+ num_beams=5, num_beam_groups=5, max_new_tokens=10, do_sample=False, diversity_penalty=1e5, top_k=None, num_return_sequences=5)#, token_healing=True, tokenizer=tokenizer)
140
+ sequences = [
141
+ decoded[prefix_length:]
142
+ for decoded in tokenizer.batch_decode(generation_output.sequences, skip_special_tokens=True)
143
+ ]
144
+ return sequences,
145
+
146
+
147
+ def get_next_token_predictions_slow(
148
+ model, tokenizer, original_doc, prompt, doc_in_progress, k):
149
+
150
+ tokenized_chat = get_tokenized_chat(tokenizer, prompt, original_doc)
151
+ doc_in_progress_ids = tokenize_doc_in_progress(tokenizer, doc_in_progress)
152
+
153
+ joined_ids = torch.cat([tokenized_chat, doc_in_progress_ids])
154
+ hypotheses = joined_ids[None].to(model.device)
155
+
156
+ # For each of the k next tokens, generate most-likely next tokens and append back on until we
157
+ # reach a token with a space
158
+
159
+ with torch.no_grad():
160
+ model_outs = model(hypotheses, output_hidden_states=True)
161
+
162
+ next_token_logits = model_outs.logits[0, -1]
163
+ branch_tokens = next_token_logits.topk(k).indices
164
+
165
+ # Slow mode: concat the branch tokens to the hypotheses.
166
+ # Then call the model on the full sequence.
167
+ # This is slow because the beginning of the sequence is re-processed each time.
168
+
169
+ hypotheses_with_next_tokens = torch.cat([
170
+ torch.repeat_interleave(hypotheses, k, dim=0),
171
+ branch_tokens.unsqueeze(1)
172
+ ], dim=1)
173
+ assert hypotheses_with_next_tokens.shape == (k, len(joined_ids) + 1)
174
+
175
+ with torch.no_grad():
176
+ model_outs = model(hypotheses_with_next_tokens)
177
+
178
+ # Grab the single most likely token from each of the k sequences
179
+ next_token_logits = model_outs.logits[:, -1]
180
+ vocab_size = model.config.vocab_size
181
+ assert next_token_logits.shape == (k, vocab_size), f"{next_token_logits.shape=}, {k=}, {vocab_size=}"
182
+ most_likely_token_ids = next_token_logits.argmax(dim=-1)
183
+
184
+ # Stick them at the end of the branch tokens.
185
+ assert most_likely_token_ids.shape == (k,)
186
+ lookahead_sequences = torch.cat([
187
+ branch_tokens.unsqueeze(1),
188
+ most_likely_token_ids.unsqueeze(1)
189
+ ], dim=1)
190
+ assert lookahead_sequences.shape == (k, 2)
191
+
192
+ decoded_next_tokens = tokenizer.batch_decode(lookahead_sequences, skip_special_tokens=True)
193
+ return decoded_next_tokens, next_token_logits