abancp commited on
Commit
ec30812
·
verified ·
1 Parent(s): 5647494

Update inference_fine_tune.py

Browse files
Files changed (1) hide show
  1. inference_fine_tune.py +19 -17
inference_fine_tune.py CHANGED
@@ -1,49 +1,51 @@
1
  import torch
2
-
3
  from tokenizers import Tokenizer
4
-
5
-
6
  from pathlib import Path
7
  from config import get_config, get_weights_file_path
8
  from train import get_model
9
 
10
- def get_tokenizer(config)->Tokenizer:
 
11
  tokenizers_path = Path(config['tokenizer_file'])
12
  if Path.exists(tokenizers_path):
13
- print("Loading tokenizer from ", tokenizers_path)
14
  tokenizer = Tokenizer.from_file(str(tokenizers_path))
15
  return tokenizer
16
  else:
17
- raise FileNotFoundError("Cant find tokenizer file : ",tokenizers_path)
18
-
19
 
 
20
  config = get_config("./openweb.config.json")
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  tokenizer = get_tokenizer(config)
 
 
23
  pad_token_id = tokenizer.token_to_id("<pad>")
24
  eos_token_id = tokenizer.token_to_id("</s>")
25
- user_token_id = tokenizer.token_to_id("<user>")
26
  ai_token_id = tokenizer.token_to_id("<ai>")
27
 
28
- model = get_model(config, tokenizer.get_vocab_size()).to(device)
29
- model_path = get_weights_file_path(config,config['preload'])
30
- model.eval()
31
- state = torch.load(model_path,map_location=torch.device('cpu'))
 
32
  model.load_state_dict(state['model_state_dict'])
33
 
 
34
  def generate_response(prompt: str):
35
  input_tokens = tokenizer.encode(prompt).ids
36
  input_tokens = [user_token_id] + input_tokens + [ai_token_id]
37
 
38
  if len(input_tokens) > config['seq_len']:
39
- yield gr.Textbox.update(value="Prompt too long.")
40
  return
41
 
42
  input_tokens = torch.tensor(input_tokens).unsqueeze(0).to(device)
43
  temperature = 0.7
44
  top_k = 50
45
- i = 0
46
  generated_text = ""
 
47
 
48
  while input_tokens.shape[1] < 2000:
49
  out = model.decode(input_tokens)
@@ -56,13 +58,13 @@ def generate_response(prompt: str):
56
 
57
  word = tokenizer.decode([next_token.item()])
58
  generated_text += word
59
-
60
- yield gr.Textbox.update(value=generated_text)
61
 
62
  input_tokens = torch.cat([input_tokens, next_token], dim=1)
 
63
  if input_tokens.shape[1] > config['seq_len']:
64
  input_tokens = input_tokens[:, -config['seq_len']:]
65
 
66
  if next_token.item() == eos_token_id or i >= 1024:
67
  break
68
- i += 1
 
1
  import torch
 
2
  from tokenizers import Tokenizer
 
 
3
  from pathlib import Path
4
  from config import get_config, get_weights_file_path
5
  from train import get_model
6
 
7
+ # Load tokenizer
8
+ def get_tokenizer(config) -> Tokenizer:
9
  tokenizers_path = Path(config['tokenizer_file'])
10
  if Path.exists(tokenizers_path):
11
+ print("Loading tokenizer from", tokenizers_path)
12
  tokenizer = Tokenizer.from_file(str(tokenizers_path))
13
  return tokenizer
14
  else:
15
+ raise FileNotFoundError("Can't find tokenizer file:", tokenizers_path)
 
16
 
17
+ # Setup config
18
  config = get_config("./openweb.config.json")
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  tokenizer = get_tokenizer(config)
21
+
22
+ # Token IDs
23
  pad_token_id = tokenizer.token_to_id("<pad>")
24
  eos_token_id = tokenizer.token_to_id("</s>")
25
+ user_token_id = tokenizer.token_to_id("<user>")
26
  ai_token_id = tokenizer.token_to_id("<ai>")
27
 
28
+ # Load model
29
+ model = get_model(config, tokenizer.get_vocab_size()).to(device)
30
+ model_path = get_weights_file_path(config, config['preload'])
31
+ model.eval()
32
+ state = torch.load(model_path, map_location=torch.device('cpu'))
33
  model.load_state_dict(state['model_state_dict'])
34
 
35
+ # Streaming text generation
36
  def generate_response(prompt: str):
37
  input_tokens = tokenizer.encode(prompt).ids
38
  input_tokens = [user_token_id] + input_tokens + [ai_token_id]
39
 
40
  if len(input_tokens) > config['seq_len']:
41
+ yield "Prompt too long."
42
  return
43
 
44
  input_tokens = torch.tensor(input_tokens).unsqueeze(0).to(device)
45
  temperature = 0.7
46
  top_k = 50
 
47
  generated_text = ""
48
+ i = 0
49
 
50
  while input_tokens.shape[1] < 2000:
51
  out = model.decode(input_tokens)
 
58
 
59
  word = tokenizer.decode([next_token.item()])
60
  generated_text += word
61
+ yield generated_text # ✅ plain string for ChatInterface
 
62
 
63
  input_tokens = torch.cat([input_tokens, next_token], dim=1)
64
+
65
  if input_tokens.shape[1] > config['seq_len']:
66
  input_tokens = input_tokens[:, -config['seq_len']:]
67
 
68
  if next_token.item() == eos_token_id or i >= 1024:
69
  break
70
+ i += 1