File size: 1,491 Bytes
8b414b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import traceback

import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
from transformers import logging as transformer_log

from src.data_reader import load_train_test_df
from src.utils import (get_x_columns, pretty_cfg, report_to_telegram,
                       save_experiment, seed_everything)

seed_everything()
transformer_log.set_verbosity_error()


def run(cfg):
    predictor = instantiate(cfg.predictor)
    validator = instantiate(cfg.validator)

    train_df, test_df = load_train_test_df()
    x_columns = get_x_columns()
    train_x, train_y = train_df[x_columns], train_df.drop(columns=['full_text'])

    results = validator.fit(predictor, train_x, train_y)

    cv_mean = results.iloc[len(results) - 1].mean()
    print("CV results")
    print(results)
    print(f"CV mean: {cv_mean}")

    submission_df = validator.predict(test_df)

    save_experiment(cfg, submission_df, results, validator.saving_dir)

    return cv_mean


@hydra.main(version_base=None, config_path="config/conf", config_name="config")
def main(cfg: DictConfig):
    try:
        metric = run(cfg)
        message = f"βœ… Successful run from {cfg.timestamp}!\n\n"
        message += f"Metric: {metric:.5f} MCRMSE\n\n"
        message += f"Configuration:\n{pretty_cfg(cfg)}"
    except Exception:
        message = f"🚫 Run from {cfg.timestamp} failed!\n\n"
        message += traceback.format_exc()
    report_to_telegram(message)


if __name__ == "__main__":
    main()