#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import os from pathlib import Path import sys pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import pandas as pd from toolbox.torch.utils.data.vocabulary import Vocabulary def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--vocabulary_dir", default="vocabulary", type=str) parser.add_argument("--train_dataset", default="train.xlsx", type=str) parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) args = parser.parse_args() return args def main(): args = get_args() train_dataset = pd.read_excel(args.train_dataset) valid_dataset = pd.read_excel(args.valid_dataset) # non_padded_namespaces category_set = set() for i, row in train_dataset.iterrows(): category = row["category"] category_set.add(category) for i, row in valid_dataset.iterrows(): category = row["category"] category_set.add(category) vocabulary = Vocabulary(non_padded_namespaces=["global_labels", *list(category_set)]) # train for i, row in train_dataset.iterrows(): global_labels = row["global_labels"] country_labels = row["country_labels"] category = row["category"] vocabulary.add_token_to_namespace(global_labels, "global_labels") vocabulary.add_token_to_namespace(country_labels, category) # valid for i, row in valid_dataset.iterrows(): global_labels = row["global_labels"] country_labels = row["country_labels"] category = row["category"] vocabulary.add_token_to_namespace(global_labels, "global_labels") vocabulary.add_token_to_namespace(country_labels, category) vocabulary.save_to_files(args.vocabulary_dir) return if __name__ == "__main__": main()