AliMc2021 commited on
Commit
cdf4596
·
1 Parent(s): 3dc7882

resolve a error

Browse files
Files changed (1) hide show
  1. generate.py +2 -2
generate.py CHANGED
@@ -11,7 +11,7 @@ def generate_text(model_data, input_text, max_new_token):
11
  model_pipeline = model_data["pipeline"]
12
  generated_text = model_pipeline(
13
  input_text,
14
- max_length=max_new_token,
15
  do_sample=False, # غیرفعال کردن نمونه‌گیری (حالت حریصانه)
16
  truncation=True # فعال کردن truncation
17
  )[0]["generated_text"]
@@ -40,7 +40,7 @@ def generate_text(model_data, input_text, max_new_token):
40
  outputs = model.generate(
41
  input_ids=input_ids,
42
  attention_mask=attention_mask,
43
- max_new_tokens=max_new_token,
44
  do_sample=False, # غیرفعال کردن نمونه‌گیری (حالت حریصانه)
45
  pad_token_id=tokenizer.eos_token_id,
46
  repetition_penalty=1.2,
 
11
  model_pipeline = model_data["pipeline"]
12
  generated_text = model_pipeline(
13
  input_text,
14
+ max_length=max_new_token + len(input_text.split()), # افزایش max_length
15
  do_sample=False, # غیرفعال کردن نمونه‌گیری (حالت حریصانه)
16
  truncation=True # فعال کردن truncation
17
  )[0]["generated_text"]
 
40
  outputs = model.generate(
41
  input_ids=input_ids,
42
  attention_mask=attention_mask,
43
+ max_new_tokens=max_new_token, # استفاده از max_new_tokens
44
  do_sample=False, # غیرفعال کردن نمونه‌گیری (حالت حریصانه)
45
  pad_token_id=tokenizer.eos_token_id,
46
  repetition_penalty=1.2,