dx2102 commited on
Commit
b6316c1
·
verified ·
1 Parent(s): 83feed6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -17,7 +17,7 @@ print('\n\n\n')
17
  print('Loading model...')
18
  pipe = transformers.pipeline(
19
  "text-generation",
20
- model="dx2102/llama-midi-2",
21
  # revision="c303c108399aba837146e893375849b918f413b3",
22
  torch_dtype="bfloat16",
23
  device="cuda",
@@ -50,7 +50,7 @@ example_prefix = '''pitch duration wait velocity instrument
50
 
51
  def postprocess(txt, path):
52
  # remove prefix
53
- txt = txt.split('---\n')[-1]
54
  # track = symusic.core.TrackSecond()
55
  tracks = {}
56
 
@@ -58,7 +58,7 @@ def postprocess(txt, path):
58
  for line in txt.split('\n'):
59
  # we need to ignore the invalid output by the model
60
  try:
61
- pitch, duration, wait, velocity, instrument = line.split(',')
62
  pitch, duration, wait, velocity = [int(x) for x in [pitch, duration, wait, velocity]]
63
  if instrument not in tracks:
64
  tracks[instrument] = symusic.core.TrackSecond()
 
17
  print('Loading model...')
18
  pipe = transformers.pipeline(
19
  "text-generation",
20
+ model="dx2102/llama-midi",
21
  # revision="c303c108399aba837146e893375849b918f413b3",
22
  torch_dtype="bfloat16",
23
  device="cuda",
 
50
 
51
  def postprocess(txt, path):
52
  # remove prefix
53
+ txt = txt.split('\n\n')[-1]
54
  # track = symusic.core.TrackSecond()
55
  tracks = {}
56
 
 
58
  for line in txt.split('\n'):
59
  # we need to ignore the invalid output by the model
60
  try:
61
+ pitch, duration, wait, velocity, instrument = line.split(' ')
62
  pitch, duration, wait, velocity = [int(x) for x in [pitch, duration, wait, velocity]]
63
  if instrument not in tracks:
64
  tracks[instrument] = symusic.core.TrackSecond()