ak0327 commited on
Commit
4e9335a
·
verified ·
1 Parent(s): 0f3cf73

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +46 -0
README.md CHANGED
@@ -53,6 +53,52 @@ def load_model(model_name):
53
  return model, tokenizer
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  model_name = "ak0327/llm-jp-3-13b-ft-5"
57
 
58
  model, tokenizer = load_model(model_name)
 
53
  return model, tokenizer
54
 
55
 
56
+ def inference(datasets, model, tokenizer):
57
+ _results = []
58
+ for data in tqdm(datasets):
59
+ input = data["input"]
60
+
61
+ prompt = f"""### 指示
62
+ {input}
63
+ ### 回答:
64
+ """
65
+
66
+ # 修正箇所: encode_plus を使用して attention_mask を取得
67
+ encoded_input = tokenizer.encode_plus(
68
+ prompt,
69
+ add_special_tokens=False,
70
+ return_tensors="pt",
71
+ padding=True,
72
+ truncation=True,
73
+ ).to(model.device)
74
+
75
+ tokenized_input = encoded_input["input_ids"]
76
+ attention_mask = encoded_input["attention_mask"]
77
+
78
+ # 修正箇所: attention_mask と pad_token_id を model.generate に渡す
79
+ with torch.no_grad():
80
+ outputs = model.generate(
81
+ tokenized_input,
82
+ attention_mask=attention_mask,
83
+ max_new_tokens=100,
84
+ do_sample=False,
85
+ repetition_penalty=1.2,
86
+ pad_token_id=tokenizer.pad_token_id # 明示的に設定
87
+ )[0]
88
+
89
+ output = tokenizer.decode(
90
+ outputs[tokenized_input.size(1):],
91
+ skip_special_tokens=True
92
+ )
93
+
94
+ _results.append({
95
+ "task_id": data["task_id"],
96
+ "input": input,
97
+ "output": output
98
+ })
99
+ return _results
100
+
101
+
102
  model_name = "ak0327/llm-jp-3-13b-ft-5"
103
 
104
  model, tokenizer = load_model(model_name)