Makhinur commited on
Commit
265cb8e
·
verified ·
1 Parent(s): 3f2eae3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -8
app.py CHANGED
@@ -1,13 +1,25 @@
1
  import os
2
  from typing import Iterator
3
  import gradio as gr
4
- from model import run
5
 
6
  # Ensure the HF_TOKEN environment variable is set
7
  HF_TOKEN = os.environ.get("HF_TOKEN")
8
  if HF_TOKEN is None:
9
  raise ValueError("Please set the HF_TOKEN environment variable.")
10
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  HF_PUBLIC = os.environ.get("HF_PUBLIC", False)
12
 
13
  DEFAULT_SYSTEM_PROMPT = """\
@@ -36,17 +48,52 @@ As a derivate work of Code Llama by Meta,
36
  this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/codellama-2-34b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/codellama-2-34b-chat/blob/main/USE_POLICY.md).
37
  """
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
41
  return '', message
42
 
43
-
44
  def display_input(message: str,
45
  history: list[tuple[str, str]]) -> list[tuple[str, str]]:
46
  history.append((message, ''))
47
  return history
48
 
49
-
50
  def delete_prev_fn(
51
  history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
52
  try:
@@ -77,20 +124,17 @@ def generate(
77
  for response in generator:
78
  yield history + [(message, response)]
79
 
80
-
81
  def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
82
  generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
83
  for x in generator:
84
  pass
85
  return '', x
86
 
87
-
88
  def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
89
  input_token_length = len(message) + len(chat_history)
90
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
91
  raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
92
 
93
-
94
  with gr.Blocks(css='style.css') as demo:
95
  gr.Markdown(DESCRIPTION)
96
  gr.DuplicateButton(value='Duplicate Space for private use',
@@ -148,8 +192,6 @@ with gr.Blocks(css='style.css') as demo:
148
  step=1,
149
  value=10,
150
  )
151
-
152
-
153
 
154
  gr.Markdown(LICENSE)
155
 
 
1
  import os
2
  from typing import Iterator
3
  import gradio as gr
4
+ from text_generation import Client
5
 
6
  # Ensure the HF_TOKEN environment variable is set
7
  HF_TOKEN = os.environ.get("HF_TOKEN")
8
  if HF_TOKEN is None:
9
  raise ValueError("Please set the HF_TOKEN environment variable.")
10
 
11
+ # Model and API setup
12
+ model_id = 'codellama/CodeLlama-34b-Instruct-hf'
13
+ API_URL = "https://api-inference.huggingface.co/models/" + model_id
14
+
15
+ client = Client(
16
+ API_URL,
17
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
18
+ )
19
+
20
+ EOS_STRING = "</s>"
21
+ EOT_STRING = "<EOT>"
22
+
23
  HF_PUBLIC = os.environ.get("HF_PUBLIC", False)
24
 
25
  DEFAULT_SYSTEM_PROMPT = """\
 
48
  this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/codellama-2-34b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/codellama-2-34b-chat/blob/main/USE_POLICY.md).
49
  """
50
 
51
+ def get_prompt(message: str, chat_history: list[tuple[str, str]],
52
+ system_prompt: str) -> str:
53
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
54
+ do_strip = False
55
+ for user_input, response in chat_history:
56
+ user_input = user_input.strip() if do_strip else user_input
57
+ do_strip = True
58
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
59
+ message = message.strip() if do_strip else message
60
+ texts.append(f'{message} [/INST]')
61
+ return ''.join(texts)
62
+
63
+ def run(message: str,
64
+ chat_history: list[tuple[str, str]],
65
+ system_prompt: str,
66
+ max_new_tokens: int = 1024,
67
+ temperature: float = 0.1,
68
+ top_p: float = 0.9,
69
+ top_k: int = 50) -> Iterator[str]:
70
+ prompt = get_prompt(message, chat_history, system_prompt)
71
+
72
+ generate_kwargs = dict(
73
+ max_new_tokens=max_new_tokens,
74
+ do_sample=True,
75
+ top_p=top_p,
76
+ top_k=top_k,
77
+ temperature=temperature,
78
+ )
79
+ stream = client.generate_stream(prompt, **generate_kwargs)
80
+ output = ""
81
+ for response in stream:
82
+ if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
83
+ return output
84
+ else:
85
+ output += response.token.text
86
+ yield output
87
+ return output
88
 
89
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
90
  return '', message
91
 
 
92
  def display_input(message: str,
93
  history: list[tuple[str, str]]) -> list[tuple[str, str]]:
94
  history.append((message, ''))
95
  return history
96
 
 
97
  def delete_prev_fn(
98
  history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
99
  try:
 
124
  for response in generator:
125
  yield history + [(message, response)]
126
 
 
127
  def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
128
  generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
129
  for x in generator:
130
  pass
131
  return '', x
132
 
 
133
  def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
134
  input_token_length = len(message) + len(chat_history)
135
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
136
  raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
137
 
 
138
  with gr.Blocks(css='style.css') as demo:
139
  gr.Markdown(DESCRIPTION)
140
  gr.DuplicateButton(value='Duplicate Space for private use',
 
192
  step=1,
193
  value=10,
194
  )
 
 
195
 
196
  gr.Markdown(LICENSE)
197