tomxxie commited on
Commit
7580011
·
1 Parent(s): 3660ae8

适配zeroGPU

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -53,7 +53,7 @@ TASK_PROMPT_MAPPING = {
53
  "STTC (Speech to Text Chat)": "首先将语音转录为文字,然后对语音内容进行回复,转录和文字之间使用<开始回答>分割。"
54
  }
55
 
56
- @spaces.GPU
57
  def init_model_my():
58
  logging.basicConfig(level=logging.DEBUG,
59
  format='%(asctime)s %(levelname)s %(message)s')
@@ -71,7 +71,7 @@ def init_model_my():
71
  print(model)
72
  return model, tokenizer
73
 
74
- model, tokenizer = init_model_my()
75
  print("model init success")
76
  def do_resample(input_wav_path, output_wav_path):
77
  """"""
@@ -87,6 +87,7 @@ def do_resample(input_wav_path, output_wav_path):
87
  makedir_for_file(output_wav_path)
88
  torchaudio.save(output_wav_path, waveform, 16000)
89
 
 
90
  def true_decode_fuc(input_wav_path, input_prompt):
91
  # input_prompt = TASK_PROMPT_MAPPING.get(input_prompt, "未知任务类型")
92
  print(f"wav_path: {input_wav_path}, prompt:{input_prompt}")
@@ -120,11 +121,12 @@ def true_decode_fuc(input_wav_path, input_prompt):
120
  feat = feat.unsqueeze(0).cuda()
121
  # feat = feat.half()
122
  # feat_lens = feat_lens.half()
123
- model = None
 
124
  res_text = model.generate(wavs=feat, wavs_len=feat_lens, prompt=input_prompt)[0]
125
  print("耿雪龙哈哈:", res_text)
126
  return res_text
127
- @spaces.GPU
128
  def do_decode(input_wav_path, input_prompt):
129
  print(f'input_wav_path= {input_wav_path}, input_prompt= {input_prompt}')
130
  # 省略处理逻辑
 
53
  "STTC (Speech to Text Chat)": "首先将语音转录为文字,然后对语音内容进行回复,转录和文字之间使用<开始回答>分割。"
54
  }
55
 
56
+
57
  def init_model_my():
58
  logging.basicConfig(level=logging.DEBUG,
59
  format='%(asctime)s %(levelname)s %(message)s')
 
71
  print(model)
72
  return model, tokenizer
73
 
74
+ global_model, tokenizer = init_model_my()
75
  print("model init success")
76
  def do_resample(input_wav_path, output_wav_path):
77
  """"""
 
87
  makedir_for_file(output_wav_path)
88
  torchaudio.save(output_wav_path, waveform, 16000)
89
 
90
+ @spaces.GPU
91
  def true_decode_fuc(input_wav_path, input_prompt):
92
  # input_prompt = TASK_PROMPT_MAPPING.get(input_prompt, "未知任务类型")
93
  print(f"wav_path: {input_wav_path}, prompt:{input_prompt}")
 
121
  feat = feat.unsqueeze(0).cuda()
122
  # feat = feat.half()
123
  # feat_lens = feat_lens.half()
124
+ model = global_model.cuda()
125
+ model.eval()
126
  res_text = model.generate(wavs=feat, wavs_len=feat_lens, prompt=input_prompt)[0]
127
  print("耿雪龙哈哈:", res_text)
128
  return res_text
129
+
130
  def do_decode(input_wav_path, input_prompt):
131
  print(f'input_wav_path= {input_wav_path}, input_prompt= {input_prompt}')
132
  # 省略处理逻辑