philipp-zettl commited on
Commit
c51abf0
·
1 Parent(s): 3479f48

move to cpu runtime

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  import gradio as gr
3
  from model import DecoderTransformer, Tokenizer
4
  from huggingface_hub import hf_hub_download
@@ -12,7 +11,7 @@ n_layer=6
12
  n_head=6
13
  dropout=0.2
14
 
15
- device = 'cuda'
16
 
17
  model_id = "philipp-zettl/chessPT"
18
 
@@ -24,7 +23,6 @@ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
24
  model.to(device)
25
  tokenizer = Tokenizer.from_pretrained(tokenizer_path)
26
 
27
- @spaces.GPU
28
  def greet(prompt):
29
  model_input = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).view((1, len(prompt)))
30
  return tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size)[0].tolist())
@@ -32,9 +30,10 @@ def greet(prompt):
32
 
33
  with gr.Blocks() as demo:
34
  gr.Markdown("""
 
35
  Welcome to ChessPT.
36
 
37
- The Chess-Pre-trained-Transformer.
38
 
39
  The rules are simple: provide a PGN string of your current game, the engine will predict the next token!
40
  """)
 
 
1
  import gradio as gr
2
  from model import DecoderTransformer, Tokenizer
3
  from huggingface_hub import hf_hub_download
 
11
  n_head=6
12
  dropout=0.2
13
 
14
+ device = 'cpu'
15
 
16
  model_id = "philipp-zettl/chessPT"
17
 
 
23
  model.to(device)
24
  tokenizer = Tokenizer.from_pretrained(tokenizer_path)
25
 
 
26
  def greet(prompt):
27
  model_input = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).view((1, len(prompt)))
28
  return tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size)[0].tolist())
 
30
 
31
  with gr.Blocks() as demo:
32
  gr.Markdown("""
33
+ # ChessPT
34
  Welcome to ChessPT.
35
 
36
+ The **C**hess-**P**re-trained-**T**ransformer.
37
 
38
  The rules are simple: provide a PGN string of your current game, the engine will predict the next token!
39
  """)