zamroni111 commited on
Commit
cb930ac
·
verified ·
1 Parent(s): bc8214f

Update onnxgenairun.py

Browse files
Files changed (1) hide show
  1. onnxgenairun.py +31 -23
onnxgenairun.py CHANGED
@@ -3,16 +3,23 @@ import argparse
3
  import time
4
  import re
5
 
6
-
7
  def main(args):
8
  if args.verbose: print("Loading model...")
9
  if args.timings:
10
  started_timestamp = 0
11
  first_token_timestamp = 0
12
 
13
- model = og.Model(f'{args.model}')
14
- ##########model = og.Model(".\\")
 
 
 
 
 
 
 
15
  if args.verbose: print("Model loaded")
 
16
  tokenizer = og.Tokenizer(model)
17
  tokenizer_stream = tokenizer.create_stream()
18
  if args.verbose: print("Tokenizer created")
@@ -26,6 +33,10 @@ def main(args):
26
 
27
  chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
28
 
 
 
 
 
29
  # Keep asking for input prompts in a loop
30
  while True:
31
  text = input("Input: ")
@@ -40,10 +51,8 @@ def main(args):
40
 
41
  input_tokens = tokenizer.encode(prompt)
42
 
43
- params = og.GeneratorParams(model)
44
- params.set_search_options(**search_options)
45
- # params.input_ids = input_tokens
46
  generator = og.Generator(model, params)
 
47
  if args.verbose: print("Generator created")
48
 
49
  if args.verbose: print("Running generation loop ...")
@@ -52,14 +61,13 @@ def main(args):
52
  new_tokens = []
53
 
54
  print()
55
- print("Output:\n", end='', flush=True)
 
 
 
56
 
57
  try:
58
- vPreviousDecoded = ""
59
- vNewDecoded = ""
60
- generator.append_tokens(input_tokens)
61
  while not generator.is_done():
62
- # generator.compute_logits()
63
  generator.generate_next_token()
64
  if args.timings:
65
  if first:
@@ -67,26 +75,25 @@ def main(args):
67
  first = False
68
 
69
  new_token = generator.get_next_tokens()[0]
70
-
71
- ###print(tokenizer_stream.decode(new_token), end='', flush=True)
72
-
73
 
74
  vNewDecoded = tokenizer_stream.decode(new_token)
75
- if re.findall("^[\x2E\x3A\x3B]$", vPreviousDecoded) and vNewDecoded.startswith(" ") and (not vNewDecoded.startswith(" *")) :
76
- vNewDecoded = "\n" + vNewDecoded.replace(" ", "", 1)
77
 
78
- print(vNewDecoded, end='', flush=True)
 
 
 
 
79
  vPreviousDecoded = vNewDecoded
80
-
81
  if args.timings: new_tokens.append(new_token)
82
  except KeyboardInterrupt:
83
  print(" --control+c pressed, aborting generation--")
84
  print()
85
  print()
86
 
87
- # Delete the generator to free the captured graph for the next generator, if graph capture is enabled
88
- del generator
89
-
90
  if args.timings:
91
  prompt_time = first_token_timestamp - started_timestamp
92
  run_time = time.time() - first_token_timestamp
@@ -95,7 +102,8 @@ def main(args):
95
 
96
  if __name__ == "__main__":
97
  parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai")
98
- parser.add_argument('-m', '--model', type=str, required=True, help='Onnx model folder path (must contain config.json and model.onnx)')
 
99
  parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt')
100
  parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt')
101
  parser.add_argument('-ds', '--do_sample', action='store_true', default=False, help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false')
@@ -106,4 +114,4 @@ if __name__ == "__main__":
106
  parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false')
107
  parser.add_argument('-g', '--timings', action='store_true', default=False, help='Print timing information for each generation step. Defaults to false')
108
  args = parser.parse_args()
109
- main(args)
 
3
  import time
4
  import re
5
 
 
6
  def main(args):
7
  if args.verbose: print("Loading model...")
8
  if args.timings:
9
  started_timestamp = 0
10
  first_token_timestamp = 0
11
 
12
+ config = og.Config(args.model_path)
13
+ config.clear_providers()
14
+ # if args.execution_provider != "cpu":
15
+ # if args.verbose: print(f"Setting model to {args.execution_provider}")
16
+ # config.append_provider(args.execution_provider)
17
+
18
+ config.append_provider("dml")
19
+ model = og.Model(config)
20
+
21
  if args.verbose: print("Model loaded")
22
+
23
  tokenizer = og.Tokenizer(model)
24
  tokenizer_stream = tokenizer.create_stream()
25
  if args.verbose: print("Tokenizer created")
 
33
 
34
  chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
35
 
36
+ params = og.GeneratorParams(model)
37
+ params.set_search_options(**search_options)
38
+ # generator = og.Generator(model, params)
39
+
40
  # Keep asking for input prompts in a loop
41
  while True:
42
  text = input("Input: ")
 
51
 
52
  input_tokens = tokenizer.encode(prompt)
53
 
 
 
 
54
  generator = og.Generator(model, params)
55
+ generator.append_tokens(input_tokens)
56
  if args.verbose: print("Generator created")
57
 
58
  if args.verbose: print("Running generation loop ...")
 
61
  new_tokens = []
62
 
63
  print()
64
+ print("Output: ", end='', flush=True)
65
+
66
+ vPreviousDecoded = ""
67
+ vNewDecoded = ""
68
 
69
  try:
 
 
 
70
  while not generator.is_done():
 
71
  generator.generate_next_token()
72
  if args.timings:
73
  if first:
 
75
  first = False
76
 
77
  new_token = generator.get_next_tokens()[0]
78
+ #print(tokenizer_stream.decode(new_token), end='', flush=True)
 
 
79
 
80
  vNewDecoded = tokenizer_stream.decode(new_token)
81
+ #if re.findall("^[\x2E\x3A\x3B]$", vPreviousDecoded) and vNewDecoded.startswith(" ") and (not vNewDecoded.startswith(" *")) :
82
+ if re.fullmatch("^[\x2E\x3A\x3B]$", vPreviousDecoded) and vNewDecoded.startswith(" ") and (not vNewDecoded.startswith(" *")) :
83
 
84
+ # vNewDecoded = "\n" + vNewDecoded.replace(" ", "", 1)
85
+ print("\n" + vNewDecoded.replace(" ", "", 1), end='', flush=True)
86
+ else :
87
+ print(vNewDecoded, end='', flush=True)
88
+
89
  vPreviousDecoded = vNewDecoded
90
+
91
  if args.timings: new_tokens.append(new_token)
92
  except KeyboardInterrupt:
93
  print(" --control+c pressed, aborting generation--")
94
  print()
95
  print()
96
 
 
 
 
97
  if args.timings:
98
  prompt_time = first_token_timestamp - started_timestamp
99
  run_time = time.time() - first_token_timestamp
 
102
 
103
  if __name__ == "__main__":
104
  parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai")
105
+ parser.add_argument('-m', '--model_path', type=str, required=True, help='Onnx model folder path (must contain genai_config.json and model.onnx)')
106
+ # parser.add_argument('-e', '--execution_provider', type=str, required=True, choices=["cpu", "cuda", "dml"], help="Execution provider to run ONNX model with")
107
  parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt')
108
  parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt')
109
  parser.add_argument('-ds', '--do_sample', action='store_true', default=False, help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false')
 
114
  parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false')
115
  parser.add_argument('-g', '--timings', action='store_true', default=False, help='Print timing information for each generation step. Defaults to false')
116
  args = parser.parse_args()
117
+ main(args)