masharpe commited on
Commit
66fa394
·
1 Parent(s): 39e4f73

Time generation. Align responses.

Browse files
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -22,13 +22,17 @@ print(f'transformers=={transformers.__version__}')
22
  print(f'accelerate=={accelerate.__version__}')
23
  print()
24
 
 
 
 
 
25
  # gr.DataFrame is currently bugged for updating values,
26
  # so we must use raw HTML.
27
  # https://github.com/gradio-app/gradio/issues/8160
28
  def make_html_table(headers, data):
29
  rows = ['<tr>' + ''.join(f'<th style="width: 50%">{h}</th>' for h in headers) + '</tr>\n']
30
  for row in data:
31
- rows.append('<tr>' + ''.join(f'<td style="width: 50%; font-family: monospace; white-space: pre-wrap;">{v}</td>' for v in row) + '</tr>\n')
32
  return '<table style="width: 100%; table-layout: fixed">\n' + ''.join(rows) + '</table>\n'
33
 
34
  def highlight_prefix(tokens, prefix_len):
@@ -56,7 +60,7 @@ def format_response_pair(tokens_a, tokens_b):
56
  HEADERS = ['Response (Left)', 'Response (Right)']
57
  repo_id = "Qwen/Qwen2-0.5B-Instruct"
58
 
59
- DRY_RUN = False
60
 
61
  if DRY_RUN:
62
  from load import load_tokenizer
@@ -64,13 +68,16 @@ if DRY_RUN:
64
  tokenizer = load_tokenizer(repo_id)
65
 
66
  def fn(max_tokens, num_responses, prompt_x, prompt_y):
 
 
 
67
  rows = [['']*2 for i in range(num_responses)]
68
 
69
  yield make_html_table(HEADERS, rows)
70
 
71
  for j in range(num_responses):
72
  response_raw_a = f'Sure!\n\n1 2 3 4 & 5.'
73
- response_raw_b = f'Sure!\n\n1 2 3 4 5 & 6.'
74
 
75
  response_tok_a = tokenizer.encode(response_raw_a, add_special_tokens=False, return_tensors='np')[0]
76
  response_tok_b = tokenizer.encode(response_raw_b, add_special_tokens=False, return_tensors='np')[0]
@@ -78,7 +85,7 @@ if DRY_RUN:
78
  steps = 1 + max(len(response_tok_a), len(response_tok_b))
79
 
80
  for i in range(steps):
81
- time.sleep(0.1)
82
  prefix_tok_a = response_tok_a[:i]
83
  prefix_tok_b = response_tok_b[:i]
84
 
@@ -88,12 +95,13 @@ if DRY_RUN:
88
  rows[j][1] = content_b
89
 
90
  yield make_html_table(HEADERS, rows)
 
 
 
91
  else:
92
  from load import load_model
93
  import algorithms
94
-
95
- logging.basicConfig(format='%(levelname)s:%(name)s: %(message)s')
96
- algorithms.logger.setLevel(logging.INFO)
97
 
98
  model, tokenizer = load_model(repo_id)
99
 
@@ -112,6 +120,9 @@ else:
112
 
113
  @spaces.GPU
114
  def fn(max_tokens, num_responses, prompt_x, prompt_y):
 
 
 
115
  # Is this necessary with ZeroGPU?
116
  torch.use_deterministic_algorithms(True)
117
 
@@ -151,6 +162,9 @@ else:
151
 
152
  yield make_html_table(HEADERS, rows)
153
 
 
 
 
154
  demo = gr.Interface(
155
  fn=fn,
156
  inputs=[
 
22
  print(f'accelerate=={accelerate.__version__}')
23
  print()
24
 
25
+ # Initialize logging.
26
+ logging.basicConfig(format='%(levelname)s:%(name)s: %(message)s', level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
  # gr.DataFrame is currently bugged for updating values,
30
  # so we must use raw HTML.
31
  # https://github.com/gradio-app/gradio/issues/8160
32
  def make_html_table(headers, data):
33
  rows = ['<tr>' + ''.join(f'<th style="width: 50%">{h}</th>' for h in headers) + '</tr>\n']
34
  for row in data:
35
+ rows.append('<tr>' + ''.join(f'<td style="width: 50%; font-family: monospace; white-space: pre-wrap; text-align: left; vertical-align: top;">{v}</td>' for v in row) + '</tr>\n')
36
  return '<table style="width: 100%; table-layout: fixed">\n' + ''.join(rows) + '</table>\n'
37
 
38
  def highlight_prefix(tokens, prefix_len):
 
60
  HEADERS = ['Response (Left)', 'Response (Right)']
61
  repo_id = "Qwen/Qwen2-0.5B-Instruct"
62
 
63
+ DRY_RUN = True
64
 
65
  if DRY_RUN:
66
  from load import load_tokenizer
 
68
  tokenizer = load_tokenizer(repo_id)
69
 
70
  def fn(max_tokens, num_responses, prompt_x, prompt_y):
71
+ logger.info('Starting generation...')
72
+ generation_start = time.perf_counter()
73
+
74
  rows = [['']*2 for i in range(num_responses)]
75
 
76
  yield make_html_table(HEADERS, rows)
77
 
78
  for j in range(num_responses):
79
  response_raw_a = f'Sure!\n\n1 2 3 4 & 5.'
80
+ response_raw_b = f'Sure!\n\n1 2 3 4 5 &\n\n\n\n6.'
81
 
82
  response_tok_a = tokenizer.encode(response_raw_a, add_special_tokens=False, return_tensors='np')[0]
83
  response_tok_b = tokenizer.encode(response_raw_b, add_special_tokens=False, return_tensors='np')[0]
 
85
  steps = 1 + max(len(response_tok_a), len(response_tok_b))
86
 
87
  for i in range(steps):
88
+ time.sleep(0.01)
89
  prefix_tok_a = response_tok_a[:i]
90
  prefix_tok_b = response_tok_b[:i]
91
 
 
95
  rows[j][1] = content_b
96
 
97
  yield make_html_table(HEADERS, rows)
98
+
99
+ generation_end = time.perf_counter()
100
+ logger.info(f'Generation took {(generation_end - generation_start):.3f} s')
101
  else:
102
  from load import load_model
103
  import algorithms
104
+ #algorithms.logger.setLevel(logging.DEBUG)
 
 
105
 
106
  model, tokenizer = load_model(repo_id)
107
 
 
120
 
121
  @spaces.GPU
122
  def fn(max_tokens, num_responses, prompt_x, prompt_y):
123
+ logger.info('Starting generation...')
124
+ generation_start = time.perf_counter()
125
+
126
  # Is this necessary with ZeroGPU?
127
  torch.use_deterministic_algorithms(True)
128
 
 
162
 
163
  yield make_html_table(HEADERS, rows)
164
 
165
+ generation_end = time.perf_counter()
166
+ logger.info(f'Generation took {(generation_end - generation_start):.3f} s')
167
+
168
  demo = gr.Interface(
169
  fn=fn,
170
  inputs=[