Update run_eval.py
Browse files- run_eval.py +14 -1
run_eval.py
CHANGED
|
@@ -22,6 +22,7 @@ from logging import getLogger
|
|
| 22 |
from pathlib import Path
|
| 23 |
from typing import Dict, List
|
| 24 |
|
|
|
|
| 25 |
from tqdm import tqdm
|
| 26 |
|
| 27 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
@@ -33,12 +34,13 @@ from utils import (
|
|
| 33 |
use_task_specific_params,
|
| 34 |
)
|
| 35 |
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
logger = getLogger(__name__)
|
| 39 |
|
| 40 |
|
| 41 |
-
DEFAULT_DEVICE = "cpu"
|
| 42 |
|
| 43 |
|
| 44 |
def generate_summaries_or_translations(
|
|
@@ -206,6 +208,17 @@ def run_generate(
|
|
| 206 |
if scor_path:
|
| 207 |
args.score_path = scor_path
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
if parsed_args and verbose:
|
| 210 |
print(f"parsed the following generate kwargs: {parsed_args}")
|
| 211 |
examples = [
|
|
|
|
| 22 |
from pathlib import Path
|
| 23 |
from typing import Dict, List
|
| 24 |
|
| 25 |
+
import torch
|
| 26 |
from tqdm import tqdm
|
| 27 |
|
| 28 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
|
| 34 |
use_task_specific_params,
|
| 35 |
)
|
| 36 |
|
| 37 |
+
from evaluate_gpt import gpt_eval
|
| 38 |
|
| 39 |
|
| 40 |
logger = getLogger(__name__)
|
| 41 |
|
| 42 |
|
| 43 |
+
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 44 |
|
| 45 |
|
| 46 |
def generate_summaries_or_translations(
|
|
|
|
| 208 |
if scor_path:
|
| 209 |
args.score_path = scor_path
|
| 210 |
|
| 211 |
+
if args.model_name[-3:] == 'gpt':
|
| 212 |
+
gpt_eval(
|
| 213 |
+
model_name_path=args.model_name,
|
| 214 |
+
src_txt=args.input_path,
|
| 215 |
+
tar_txt=args.reference_path,
|
| 216 |
+
gen_path=args.save_path,
|
| 217 |
+
scor_path=args.score_path,
|
| 218 |
+
batch_size=args.bs
|
| 219 |
+
)
|
| 220 |
+
return None
|
| 221 |
+
|
| 222 |
if parsed_args and verbose:
|
| 223 |
print(f"parsed the following generate kwargs: {parsed_args}")
|
| 224 |
examples = [
|