zhengr commited on
Commit
b54cfaa
·
verified ·
1 Parent(s): fdf2c1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -394,9 +394,11 @@ def ask(symbol, weeks_before, withbasic):
394
  info, pt = get_all_prompts_online(symbol=symbol, weeks_before=weeks_before, with_basics=withbasic)
395
  # print(info)
396
 
397
- inputs = tokenizer(pt, return_tensors='pt')
398
- #print(model.device('cuda:0' if torch.cuda.is_available() else 'cpu'))
 
399
  print(model.device)
 
400
  inputs = {key: value.to(model.device) for key, value in inputs.items()}
401
  #inputs = {key: value.to('cuda:0') for key, value in inputs.items()}
402
  print("Inputs loaded onto devices.")
 
394
  info, pt = get_all_prompts_online(symbol=symbol, weeks_before=weeks_before, with_basics=withbasic)
395
  # print(info)
396
 
397
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
398
+ model=model.to(device)
399
+
400
  print(model.device)
401
+ inputs = tokenizer(pt, return_tensors='pt')
402
  inputs = {key: value.to(model.device) for key, value in inputs.items()}
403
  #inputs = {key: value.to('cuda:0') for key, value in inputs.items()}
404
  print("Inputs loaded onto devices.")