zhengr commited on
Commit
1502095
·
1 Parent(s): 5063e55
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -394,12 +394,14 @@ def ask(symbol, weeks_before, withbasic):
394
  info, pt = get_all_prompts_online(symbol=symbol, weeks_before=weeks_before, with_basics=withbasic)
395
  # print(info)
396
 
397
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
398
  model=model.to(device)
399
 
400
  print(model.device)
401
  inputs = tokenizer(pt, return_tensors='pt')
402
- inputs = {key: value.to(model.device) for key, value in inputs.items()}
 
 
403
  #inputs = {key: value.to('cuda:0') for key, value in inputs.items()}
404
  print("Inputs loaded onto devices.")
405
 
 
394
  info, pt = get_all_prompts_online(symbol=symbol, weeks_before=weeks_before, with_basics=withbasic)
395
  # print(info)
396
 
397
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
398
  model=model.to(device)
399
 
400
  print(model.device)
401
  inputs = tokenizer(pt, return_tensors='pt')
402
+
403
+ inputs = {key: value.to(model.device('cuda:0')) for key, value in inputs.items()}
404
+ #inputs = {key: value.to(model.device) for key, value in inputs.items()}
405
  #inputs = {key: value.to('cuda:0') for key, value in inputs.items()}
406
  print("Inputs loaded onto devices.")
407