Makhinur commited on
Commit
790a199
·
verified ·
1 Parent(s): 4e8290f

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +21 -39
model.py CHANGED
@@ -1,25 +1,19 @@
1
- from threading import Thread
2
  from typing import Iterator
3
 
4
- import torch
5
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
 
7
- model_id = 'codellama/CodeLlama-13b-Instruct-hf'
8
 
9
- if torch.cuda.is_available():
10
- config = AutoConfig.from_pretrained(model_id)
11
- config.pretraining_tp = 1
12
- model = AutoModelForCausalLM.from_pretrained(
13
- model_id,
14
- config=config,
15
- torch_dtype=torch.float16,
16
- load_in_4bit=True,
17
- device_map='auto',
18
- use_safetensors=False,
19
- )
20
- else:
21
- model = None
22
- tokenizer = AutoTokenizer.from_pretrained(model_id)
23
 
24
 
25
  def get_prompt(message: str, chat_history: list[tuple[str, str]],
@@ -36,12 +30,6 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
36
  return ''.join(texts)
37
 
38
 
39
- def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
40
- prompt = get_prompt(message, chat_history, system_prompt)
41
- input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
42
- return input_ids.shape[-1]
43
-
44
-
45
  def run(message: str,
46
  chat_history: list[tuple[str, str]],
47
  system_prompt: str,
@@ -50,26 +38,20 @@ def run(message: str,
50
  top_p: float = 0.9,
51
  top_k: int = 50) -> Iterator[str]:
52
  prompt = get_prompt(message, chat_history, system_prompt)
53
- inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
54
 
55
- streamer = TextIteratorStreamer(tokenizer,
56
- timeout=10.,
57
- skip_prompt=True,
58
- skip_special_tokens=True)
59
  generate_kwargs = dict(
60
- inputs,
61
- streamer=streamer,
62
  max_new_tokens=max_new_tokens,
63
  do_sample=True,
64
  top_p=top_p,
65
  top_k=top_k,
66
  temperature=temperature,
67
- num_beams=1,
68
  )
69
- t = Thread(target=model.generate, kwargs=generate_kwargs)
70
- t.start()
71
-
72
- outputs = []
73
- for text in streamer:
74
- outputs.append(text)
75
- yield ''.join(outputs)
 
 
 
1
+ import os
2
  from typing import Iterator
3
 
4
+ from text_generation import Client
 
5
 
6
+ model_id = 'codellama/CodeLlama-34b-Instruct-hf'
7
 
8
+ API_URL = "https://api-inference.huggingface.co/models/" + model_id
9
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
+
11
+ client = Client(
12
+ API_URL,
13
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
14
+ )
15
+ EOS_STRING = "</s>"
16
+ EOT_STRING = "<EOT>"
 
 
 
 
 
17
 
18
 
19
  def get_prompt(message: str, chat_history: list[tuple[str, str]],
 
30
  return ''.join(texts)
31
 
32
 
 
 
 
 
 
 
33
  def run(message: str,
34
  chat_history: list[tuple[str, str]],
35
  system_prompt: str,
 
38
  top_p: float = 0.9,
39
  top_k: int = 50) -> Iterator[str]:
40
  prompt = get_prompt(message, chat_history, system_prompt)
 
41
 
 
 
 
 
42
  generate_kwargs = dict(
 
 
43
  max_new_tokens=max_new_tokens,
44
  do_sample=True,
45
  top_p=top_p,
46
  top_k=top_k,
47
  temperature=temperature,
 
48
  )
49
+ stream = client.generate_stream(prompt, **generate_kwargs)
50
+ output = ""
51
+ for response in stream:
52
+ if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
53
+ return output
54
+ else:
55
+ output += response.token.text
56
+ yield output
57
+ return output