ehristoforu commited on
Commit
bda8280
·
1 Parent(s): f236837

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -57
model.py DELETED
@@ -1,57 +0,0 @@
1
- import os
2
- from typing import Iterator
3
-
4
- from text_generation import Client
5
-
6
- model_id = 'mistralai/Mistral-7B-Instruct-v0.1'
7
-
8
- API_URL = "https://api-inference.huggingface.co/models/" + model_id
9
- HF_TOKEN = os.environ.get("HF_READ_TOKEN", None)
10
-
11
- client = Client(
12
- API_URL,
13
- headers={"Authorization": f"Bearer {HF_TOKEN}"},
14
- )
15
- EOS_STRING = "</s>"
16
- EOT_STRING = "<EOT>"
17
-
18
-
19
- def get_prompt(message: str, chat_history: list[tuple[str, str]],
20
- system_prompt: str) -> str:
21
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
22
- # The first user input is _not_ stripped
23
- do_strip = False
24
- for user_input, response in chat_history:
25
- user_input = user_input.strip() if do_strip else user_input
26
- do_strip = True
27
- texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
28
- message = message.strip() if do_strip else message
29
- texts.append(f'{message} [/INST]')
30
- return ''.join(texts)
31
-
32
-
33
- def run(message: str,
34
- chat_history: list[tuple[str, str]],
35
- system_prompt: str,
36
- max_new_tokens: int = 1024,
37
- temperature: float = 0.1,
38
- top_p: float = 0.9,
39
- top_k: int = 50) -> Iterator[str]:
40
- prompt = get_prompt(message, chat_history, system_prompt)
41
-
42
- generate_kwargs = dict(
43
- max_new_tokens=max_new_tokens,
44
- do_sample=True,
45
- top_p=top_p,
46
- top_k=top_k,
47
- temperature=temperature,
48
- )
49
- stream = client.generate_stream(prompt, **generate_kwargs)
50
- output = ""
51
- for response in stream:
52
- if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
53
- return output
54
- else:
55
- output += response.token.text
56
- yield output
57
- return output