resolve a error
Browse files- generate.py +2 -2
generate.py
CHANGED
@@ -11,7 +11,7 @@ def generate_text(model_data, input_text, max_new_token):
|
|
11 |
model_pipeline = model_data["pipeline"]
|
12 |
generated_text = model_pipeline(
|
13 |
input_text,
|
14 |
-
max_length=max_new_token,
|
15 |
do_sample=False, # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
16 |
truncation=True # فعال کردن truncation
|
17 |
)[0]["generated_text"]
|
@@ -40,7 +40,7 @@ def generate_text(model_data, input_text, max_new_token):
|
|
40 |
outputs = model.generate(
|
41 |
input_ids=input_ids,
|
42 |
attention_mask=attention_mask,
|
43 |
-
max_new_tokens=max_new_token,
|
44 |
do_sample=False, # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
45 |
pad_token_id=tokenizer.eos_token_id,
|
46 |
repetition_penalty=1.2,
|
|
|
11 |
model_pipeline = model_data["pipeline"]
|
12 |
generated_text = model_pipeline(
|
13 |
input_text,
|
14 |
+
max_length=max_new_token + len(input_text.split()), # افزایش max_length
|
15 |
do_sample=False, # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
16 |
truncation=True # فعال کردن truncation
|
17 |
)[0]["generated_text"]
|
|
|
40 |
outputs = model.generate(
|
41 |
input_ids=input_ids,
|
42 |
attention_mask=attention_mask,
|
43 |
+
max_new_tokens=max_new_token, # استفاده از max_new_tokens
|
44 |
do_sample=False, # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
45 |
pad_token_id=tokenizer.eos_token_id,
|
46 |
repetition_penalty=1.2,
|