Commit
·
cc0b62b
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +108 -0
- LICENSE +21 -0
- README.md +82 -0
- data/.gitkeep +1 -0
- data/Olympic/raw.pickle +0 -0
- data/PsychExp/raw.pickle +0 -0
- data/SCv1/raw.pickle +0 -0
- data/SCv2-GEN/raw.pickle +0 -0
- data/SE0714/raw.pickle +0 -0
- data/SS-Twitter/raw.pickle +0 -0
- data/SS-Youtube/raw.pickle +0 -0
- data/filtering/wanted_emojis.csv +64 -0
- data/kaggle-insults/raw.pickle +0 -0
- emoji_overview.png +0 -0
- examples/.gitkeep +1 -0
- examples/README.md +31 -0
- examples/__init__.py +0 -0
- examples/create_twitter_vocab.py +13 -0
- examples/dataset_split.py +59 -0
- examples/encode_texts.py +41 -0
- examples/example_helper.py +6 -0
- examples/finetune_insults_chain-thaw.py +44 -0
- examples/finetune_semeval_class-avg_f1.py +50 -0
- examples/finetune_youtube_last.py +35 -0
- examples/score_texts_emojis.py +76 -0
- examples/tokenize_dataset.py +26 -0
- examples/vocab_extension.py +30 -0
- model/.gitkeep +1 -0
- model/vocabulary.json +0 -0
- scripts/analyze_all_results.py +40 -0
- scripts/analyze_results.py +39 -0
- scripts/calculate_coverages.py +85 -0
- scripts/convert_all_datasets.py +105 -0
- scripts/download_weights.py +64 -0
- scripts/finetune_dataset.py +109 -0
- scripts/results/.gitkeep +1 -0
- setup.py +16 -0
- tests/test_finetuning.py +235 -0
- tests/test_helper.py +6 -0
- tests/test_sentence_tokenizer.py +113 -0
- tests/test_tokenizer.py +167 -0
- tests/test_word_generator.py +73 -0
- torchmoji/.gitkeep +1 -0
- torchmoji/__init__.py +0 -0
- torchmoji/attlayer.py +69 -0
- torchmoji/class_avg_finetuning.py +315 -0
- torchmoji/create_vocab.py +271 -0
- torchmoji/filter_input.py +36 -0
- torchmoji/filter_utils.py +191 -0
- torchmoji/finetuning.py +661 -0
.gitignore
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
env/
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
|
27 |
+
# PyInstaller
|
28 |
+
# Usually these files are written by a python script from a template
|
29 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
30 |
+
*.manifest
|
31 |
+
*.spec
|
32 |
+
|
33 |
+
# Installer logs
|
34 |
+
pip-log.txt
|
35 |
+
pip-delete-this-directory.txt
|
36 |
+
|
37 |
+
# Unit test / coverage reports
|
38 |
+
htmlcov/
|
39 |
+
.tox/
|
40 |
+
.coverage
|
41 |
+
.coverage.*
|
42 |
+
.cache
|
43 |
+
nosetests.xml
|
44 |
+
coverage.xml
|
45 |
+
*,cover
|
46 |
+
.hypothesis/
|
47 |
+
|
48 |
+
# Translations
|
49 |
+
*.mo
|
50 |
+
*.pot
|
51 |
+
|
52 |
+
# Django stuff:
|
53 |
+
*.log
|
54 |
+
local_settings.py
|
55 |
+
|
56 |
+
# Flask stuff:
|
57 |
+
instance/
|
58 |
+
.webassets-cache
|
59 |
+
|
60 |
+
# Scrapy stuff:
|
61 |
+
.scrapy
|
62 |
+
|
63 |
+
# Sphinx documentation
|
64 |
+
docs/_build/
|
65 |
+
|
66 |
+
# PyBuilder
|
67 |
+
target/
|
68 |
+
|
69 |
+
# IPython Notebook
|
70 |
+
.ipynb_checkpoints
|
71 |
+
|
72 |
+
# pyenv
|
73 |
+
.python-version
|
74 |
+
|
75 |
+
# celery beat schedule file
|
76 |
+
celerybeat-schedule
|
77 |
+
|
78 |
+
# dotenv
|
79 |
+
.env
|
80 |
+
|
81 |
+
# virtualenv
|
82 |
+
venv/
|
83 |
+
ENV/
|
84 |
+
|
85 |
+
# Spyder project settings
|
86 |
+
.spyderproject
|
87 |
+
|
88 |
+
# Rope project settings
|
89 |
+
.ropeproject
|
90 |
+
|
91 |
+
# Local data
|
92 |
+
/data/local
|
93 |
+
|
94 |
+
# Vim swapfiles
|
95 |
+
*.swp
|
96 |
+
*.swo
|
97 |
+
|
98 |
+
# nosetests
|
99 |
+
.noseids
|
100 |
+
|
101 |
+
# pyTorch model
|
102 |
+
pytorch_model.bin
|
103 |
+
|
104 |
+
# VSCODE
|
105 |
+
.vscode/*
|
106 |
+
|
107 |
+
# data
|
108 |
+
*.csv
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2017 Bjarke Felbo, Han Thi Nguyen, Thomas Wolf
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# torchMoji
|
2 |
+
|
3 |
+
TorchMoji is a [pyTorch](http://pytorch.org/) implementation of the [DeepMoji](https://github.com/bfelbo/DeepMoji) model developped by Bjarke Felbo, Alan Mislove, Anders Søgaard, Iyad Rahwan and Sune Lehmann.
|
4 |
+
|
5 |
+
This model trained on 1.2 billion tweets with emojis to understand how language is used to express emotions. Through transfer learning the model can obtain state-of-the-art performance on many emotion-related text modeling tasks.
|
6 |
+
|
7 |
+
Try the online demo of DeepMoji [http://deepmoji.mit.edu](http://deepmoji.mit.edu/)! See the [paper](https://arxiv.org/abs/1708.00524), [blog post](https://medium.com/@bjarkefelbo/what-can-we-learn-from-emojis-6beb165a5ea0) or [FAQ](https://www.media.mit.edu/projects/deepmoji/overview/) for more details.
|
8 |
+
|
9 |
+
## Overview
|
10 |
+
* [torchmoji/](torchmoji) contains all the underlying code needed to convert a dataset to the vocabulary and use the model.
|
11 |
+
* [examples/](examples) contains short code snippets showing how to convert a dataset to the vocabulary, load up the model and run it on that dataset.
|
12 |
+
* [scripts/](scripts) contains code for processing and analysing datasets to reproduce results in the paper.
|
13 |
+
* [model/](model) contains the pretrained model and vocabulary.
|
14 |
+
* [data/](data) contains raw and processed datasets that we include in this repository for testing.
|
15 |
+
* [tests/](tests) contains unit tests for the codebase.
|
16 |
+
|
17 |
+
To start out with, have a look inside the [examples/](examples) directory. See [score_texts_emojis.py](examples/score_texts_emojis.py) for how to use DeepMoji to extract emoji predictions, [encode_texts.py](examples/encode_texts.py) for how to convert text into 2304-dimensional emotional feature vectors or [finetune_youtube_last.py](examples/finetune_youtube_last.py) for how to use the model for transfer learning on a new dataset.
|
18 |
+
|
19 |
+
Please consider citing the [paper](https://arxiv.org/abs/1708.00524) of DeepMoji if you use the model or code (see below for citation).
|
20 |
+
|
21 |
+
## Installation
|
22 |
+
|
23 |
+
We assume that you're using [Python 2.7-3.5](https://www.python.org/downloads/) with [pip](https://pip.pypa.io/en/stable/installing/) installed.
|
24 |
+
|
25 |
+
First you need to install [pyTorch (version 0.2+)](http://pytorch.org/), currently by:
|
26 |
+
```bash
|
27 |
+
conda install pytorch -c soumith
|
28 |
+
```
|
29 |
+
At the present stage the model can't make efficient use of CUDA. See details in the HuggingFace blog post.
|
30 |
+
|
31 |
+
When pyTorch is installed, run the following in the root directory to install the remaining dependencies:
|
32 |
+
|
33 |
+
```bash
|
34 |
+
pip install -e .
|
35 |
+
```
|
36 |
+
This will install the following dependencies:
|
37 |
+
* [scikit-learn](https://github.com/scikit-learn/scikit-learn)
|
38 |
+
* [text-unidecode](https://github.com/kmike/text-unidecode)
|
39 |
+
* [emoji](https://github.com/carpedm20/emoji)
|
40 |
+
|
41 |
+
Then, run the download script to downloads the pretrained torchMoji weights (~85MB) from [here](https://www.dropbox.com/s/q8lax9ary32c7t9/pytorch_model.bin?dl=0) and put them in the model/ directory:
|
42 |
+
|
43 |
+
```bash
|
44 |
+
python scripts/download_weights.py
|
45 |
+
```
|
46 |
+
|
47 |
+
## Testing
|
48 |
+
To run the tests, install [nose](http://nose.readthedocs.io/en/latest/). After installing, navigate to the [tests/](tests) directory and run:
|
49 |
+
|
50 |
+
```bash
|
51 |
+
cd tests
|
52 |
+
nosetests -v
|
53 |
+
```
|
54 |
+
|
55 |
+
By default, this will also run finetuning tests. These tests train the model for one epoch and then check the resulting accuracy, which may take several minutes to finish. If you'd prefer to exclude those, run the following instead:
|
56 |
+
|
57 |
+
```bash
|
58 |
+
cd tests
|
59 |
+
nosetests -v -a '!slow'
|
60 |
+
```
|
61 |
+
|
62 |
+
## Disclaimer
|
63 |
+
This code has been tested to work with Python 2.7 and 3.5 on Ubuntu 16.04 and macOS Sierra machines. It has not been optimized for efficiency, but should be fast enough for most purposes. We do not give any guarantees that there are no bugs - use the code on your own responsibility!
|
64 |
+
|
65 |
+
## Contributions
|
66 |
+
We welcome pull requests if you feel like something could be improved. You can also greatly help us by telling us how you felt when writing your most recent tweets. Just click [here](http://deepmoji.mit.edu/contribute/) to contribute.
|
67 |
+
|
68 |
+
## License
|
69 |
+
This code and the pretrained model is licensed under the MIT license.
|
70 |
+
|
71 |
+
## Benchmark datasets
|
72 |
+
The benchmark datasets are uploaded to this repository for convenience purposes only. They were not released by us and we do not claim any rights on them. Use the datasets at your responsibility and make sure you fulfill the licenses that they were released with. If you use any of the benchmark datasets please consider citing the original authors.
|
73 |
+
|
74 |
+
## Citation
|
75 |
+
```
|
76 |
+
@inproceedings{felbo2017,
|
77 |
+
title={Using millions of emoji occurrences to learn any-domain representations for detecting sentiment, emotion and sarcasm},
|
78 |
+
author={Felbo, Bjarke and Mislove, Alan and S{\o}gaard, Anders and Rahwan, Iyad and Lehmann, Sune},
|
79 |
+
booktitle={Conference on Empirical Methods in Natural Language Processing (EMNLP)},
|
80 |
+
year={2017}
|
81 |
+
}
|
82 |
+
```
|
data/.gitkeep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
data/Olympic/raw.pickle
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/PsychExp/raw.pickle
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/SCv1/raw.pickle
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/SCv2-GEN/raw.pickle
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/SE0714/raw.pickle
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/SS-Twitter/raw.pickle
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/SS-Youtube/raw.pickle
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/filtering/wanted_emojis.csv
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
\U0001f602
|
2 |
+
\U0001f612
|
3 |
+
\U0001f629
|
4 |
+
\U0001f62d
|
5 |
+
\U0001f60d
|
6 |
+
\U0001f614
|
7 |
+
\U0001f44c
|
8 |
+
\U0001f60a
|
9 |
+
\u2764
|
10 |
+
\U0001f60f
|
11 |
+
\U0001f601
|
12 |
+
\U0001f3b6
|
13 |
+
\U0001f633
|
14 |
+
\U0001f4af
|
15 |
+
\U0001f634
|
16 |
+
\U0001f60c
|
17 |
+
\u263a
|
18 |
+
\U0001f64c
|
19 |
+
\U0001f495
|
20 |
+
\U0001f611
|
21 |
+
\U0001f605
|
22 |
+
\U0001f64f
|
23 |
+
\U0001f615
|
24 |
+
\U0001f618
|
25 |
+
\u2665
|
26 |
+
\U0001f610
|
27 |
+
\U0001f481
|
28 |
+
\U0001f61e
|
29 |
+
\U0001f648
|
30 |
+
\U0001f62b
|
31 |
+
\u270c
|
32 |
+
\U0001f60e
|
33 |
+
\U0001f621
|
34 |
+
\U0001f44d
|
35 |
+
\U0001f622
|
36 |
+
\U0001f62a
|
37 |
+
\U0001f60b
|
38 |
+
\U0001f624
|
39 |
+
\u270b
|
40 |
+
\U0001f637
|
41 |
+
\U0001f44f
|
42 |
+
\U0001f440
|
43 |
+
\U0001f52b
|
44 |
+
\U0001f623
|
45 |
+
\U0001f608
|
46 |
+
\U0001f613
|
47 |
+
\U0001f494
|
48 |
+
\u2661
|
49 |
+
\U0001f3a7
|
50 |
+
\U0001f64a
|
51 |
+
\U0001f609
|
52 |
+
\U0001f480
|
53 |
+
\U0001f616
|
54 |
+
\U0001f604
|
55 |
+
\U0001f61c
|
56 |
+
\U0001f620
|
57 |
+
\U0001f645
|
58 |
+
\U0001f4aa
|
59 |
+
\U0001f44a
|
60 |
+
\U0001f49c
|
61 |
+
\U0001f496
|
62 |
+
\U0001f499
|
63 |
+
\U0001f62c
|
64 |
+
\u2728
|
data/kaggle-insults/raw.pickle
ADDED
The diff for this file is too large to render.
See raw diff
|
|
emoji_overview.png
ADDED
![]() |
examples/.gitkeep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
examples/README.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# torchMoji examples
|
2 |
+
|
3 |
+
## Initialization
|
4 |
+
[create_twitter_vocab.py](create_twitter_vocab.py)
|
5 |
+
Create a new vocabulary from a tsv file.
|
6 |
+
|
7 |
+
[tokenize_dataset.py](tokenize_dataset.py)
|
8 |
+
Tokenize a given dataset using the prebuilt vocabulary.
|
9 |
+
|
10 |
+
[vocab_extension.py](vocab_extension.py)
|
11 |
+
Extend the given vocabulary using dataset-specific words.
|
12 |
+
|
13 |
+
[dataset_split.py](dataset_split.py)
|
14 |
+
Split a given dataset into training, validation and testing.
|
15 |
+
|
16 |
+
## Use pretrained model/architecture
|
17 |
+
[score_texts_emojis.py](score_texts_emojis.py)
|
18 |
+
Use torchMoji to score texts for emoji distribution.
|
19 |
+
|
20 |
+
[encode_texts.py](encode_texts.py)
|
21 |
+
Use torchMoji to encode the text into 2304-dimensional feature vectors for further modeling/analysis.
|
22 |
+
|
23 |
+
## Transfer learning
|
24 |
+
[finetune_youtube_last.py](finetune_youtube_last.py)
|
25 |
+
Finetune the model on the SS-Youtube dataset using the 'last' method.
|
26 |
+
|
27 |
+
[finetune_insults_chain-thaw.py](finetune_insults_chain-thaw.py)
|
28 |
+
Finetune the model on the Kaggle insults dataset (from blog post) using the 'chain-thaw' method.
|
29 |
+
|
30 |
+
[finetune_semeval_class-avg_f1.py](finetune_semeval_class-avg_f1.py)
|
31 |
+
Finetune the model on the SemeEval emotion dataset using the 'full' method and evaluate using the class average F1 metric.
|
examples/__init__.py
ADDED
File without changes
|
examples/create_twitter_vocab.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Creates a vocabulary from a tsv file.
|
2 |
+
"""
|
3 |
+
|
4 |
+
import codecs
|
5 |
+
import example_helper
|
6 |
+
from torchmoji.create_vocab import VocabBuilder
|
7 |
+
from torchmoji.word_generator import TweetWordGenerator
|
8 |
+
|
9 |
+
with codecs.open('../../twitterdata/tweets.2016-09-01', 'rU', 'utf-8') as stream:
|
10 |
+
wg = TweetWordGenerator(stream)
|
11 |
+
vb = VocabBuilder(wg)
|
12 |
+
vb.count_all_words()
|
13 |
+
vb.save_vocab()
|
examples/dataset_split.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Split a given dataset into three different datasets: training, validation and
|
3 |
+
testing.
|
4 |
+
|
5 |
+
This is achieved by splitting the given list of sentences into three separate
|
6 |
+
lists according to either a given ratio (e.g. [0.7, 0.1, 0.2]) or by an
|
7 |
+
explicit enumeration. The sentences are also tokenised using the given
|
8 |
+
vocabulary.
|
9 |
+
|
10 |
+
Also splits a given list of dictionaries containing information about
|
11 |
+
each sentence.
|
12 |
+
|
13 |
+
An additional parameter can be set 'extend_with', which will extend the given
|
14 |
+
vocabulary with up to 'extend_with' tokens, taken from the training dataset.
|
15 |
+
'''
|
16 |
+
from __future__ import print_function, unicode_literals
|
17 |
+
import example_helper
|
18 |
+
import json
|
19 |
+
|
20 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
21 |
+
|
22 |
+
DATASET = [
|
23 |
+
'I am sentence 0',
|
24 |
+
'I am sentence 1',
|
25 |
+
'I am sentence 2',
|
26 |
+
'I am sentence 3',
|
27 |
+
'I am sentence 4',
|
28 |
+
'I am sentence 5',
|
29 |
+
'I am sentence 6',
|
30 |
+
'I am sentence 7',
|
31 |
+
'I am sentence 8',
|
32 |
+
'I am sentence 9 newword',
|
33 |
+
]
|
34 |
+
|
35 |
+
INFO_DICTS = [
|
36 |
+
{'label': 'sentence 0'},
|
37 |
+
{'label': 'sentence 1'},
|
38 |
+
{'label': 'sentence 2'},
|
39 |
+
{'label': 'sentence 3'},
|
40 |
+
{'label': 'sentence 4'},
|
41 |
+
{'label': 'sentence 5'},
|
42 |
+
{'label': 'sentence 6'},
|
43 |
+
{'label': 'sentence 7'},
|
44 |
+
{'label': 'sentence 8'},
|
45 |
+
{'label': 'sentence 9'},
|
46 |
+
]
|
47 |
+
|
48 |
+
with open('../model/vocabulary.json', 'r') as f:
|
49 |
+
vocab = json.load(f)
|
50 |
+
st = SentenceTokenizer(vocab, 30)
|
51 |
+
|
52 |
+
# Split using the default split ratio
|
53 |
+
print(st.split_train_val_test(DATASET, INFO_DICTS))
|
54 |
+
|
55 |
+
# Split explicitly
|
56 |
+
print(st.split_train_val_test(DATASET,
|
57 |
+
INFO_DICTS,
|
58 |
+
[[0, 1, 2, 4, 9], [5, 6], [7, 8, 3]],
|
59 |
+
extend_with=1))
|
examples/encode_texts.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
""" Use torchMoji to encode texts into emotional feature vectors.
|
4 |
+
"""
|
5 |
+
from __future__ import print_function, division, unicode_literals
|
6 |
+
import json
|
7 |
+
|
8 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
9 |
+
from torchmoji.model_def import torchmoji_feature_encoding
|
10 |
+
from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH
|
11 |
+
|
12 |
+
TEST_SENTENCES = ['I love mom\'s cooking',
|
13 |
+
'I love how you never reply back..',
|
14 |
+
'I love cruising with my homies',
|
15 |
+
'I love messing with yo mind!!',
|
16 |
+
'I love you and now you\'re just gone..',
|
17 |
+
'This is shit',
|
18 |
+
'This is the shit']
|
19 |
+
|
20 |
+
maxlen = 30
|
21 |
+
batch_size = 32
|
22 |
+
|
23 |
+
print('Tokenizing using dictionary from {}'.format(VOCAB_PATH))
|
24 |
+
with open(VOCAB_PATH, 'r') as f:
|
25 |
+
vocabulary = json.load(f)
|
26 |
+
st = SentenceTokenizer(vocabulary, maxlen)
|
27 |
+
tokenized, _, _ = st.tokenize_sentences(TEST_SENTENCES)
|
28 |
+
|
29 |
+
print('Loading model from {}.'.format(PRETRAINED_PATH))
|
30 |
+
model = torchmoji_feature_encoding(PRETRAINED_PATH)
|
31 |
+
print(model)
|
32 |
+
|
33 |
+
print('Encoding texts..')
|
34 |
+
encoding = model(tokenized)
|
35 |
+
|
36 |
+
print('First 5 dimensions for sentence: {}'.format(TEST_SENTENCES[0]))
|
37 |
+
print(encoding[0,:5])
|
38 |
+
|
39 |
+
# Now you could visualize the encodings to see differences,
|
40 |
+
# run a logistic regression classifier on top,
|
41 |
+
# or basically anything you'd like to do.
|
examples/example_helper.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Module import helper.
|
2 |
+
Modifies PATH in order to allow us to import the torchmoji directory.
|
3 |
+
"""
|
4 |
+
import sys
|
5 |
+
from os.path import abspath, dirname
|
6 |
+
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
examples/finetune_insults_chain-thaw.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Finetuning example.
|
2 |
+
|
3 |
+
Trains the torchMoji model on the kaggle insults dataset, using the 'chain-thaw'
|
4 |
+
finetuning method and the accuracy metric. See the blog post at
|
5 |
+
https://medium.com/@bjarkefelbo/what-can-we-learn-from-emojis-6beb165a5ea0
|
6 |
+
for more information. Note that results may differ a bit due to slight
|
7 |
+
changes in preprocessing and train/val/test split.
|
8 |
+
|
9 |
+
The 'chain-thaw' method does the following:
|
10 |
+
0) Load all weights except for the softmax layer. Extend the embedding layer if
|
11 |
+
necessary, initialising the new weights with random values.
|
12 |
+
1) Freeze every layer except the last (softmax) layer and train it.
|
13 |
+
2) Freeze every layer except the first layer and train it.
|
14 |
+
3) Freeze every layer except the second etc., until the second last layer.
|
15 |
+
4) Unfreeze all layers and train entire model.
|
16 |
+
"""
|
17 |
+
|
18 |
+
from __future__ import print_function
|
19 |
+
import example_helper
|
20 |
+
import json
|
21 |
+
from torchmoji.model_def import torchmoji_transfer
|
22 |
+
from torchmoji.global_variables import PRETRAINED_PATH
|
23 |
+
from torchmoji.finetuning import (
|
24 |
+
load_benchmark,
|
25 |
+
finetune)
|
26 |
+
|
27 |
+
|
28 |
+
DATASET_PATH = '../data/kaggle-insults/raw.pickle'
|
29 |
+
nb_classes = 2
|
30 |
+
|
31 |
+
with open('../model/vocabulary.json', 'r') as f:
|
32 |
+
vocab = json.load(f)
|
33 |
+
|
34 |
+
# Load dataset. Extend the existing vocabulary with up to 10000 tokens from
|
35 |
+
# the training dataset.
|
36 |
+
data = load_benchmark(DATASET_PATH, vocab, extend_with=10000)
|
37 |
+
|
38 |
+
# Set up model and finetune. Note that we have to extend the embedding layer
|
39 |
+
# with the number of tokens added to the vocabulary.
|
40 |
+
model = torchmoji_transfer(nb_classes, PRETRAINED_PATH, extend_embedding=data['added'])
|
41 |
+
print(model)
|
42 |
+
model, acc = finetune(model, data['texts'], data['labels'], nb_classes,
|
43 |
+
data['batch_size'], method='chain-thaw')
|
44 |
+
print('Acc: {}'.format(acc))
|
examples/finetune_semeval_class-avg_f1.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Finetuning example.
|
2 |
+
|
3 |
+
Trains the torchMoji model on the SemEval emotion dataset, using the 'last'
|
4 |
+
finetuning method and the class average F1 metric.
|
5 |
+
|
6 |
+
The 'last' method does the following:
|
7 |
+
0) Load all weights except for the softmax layer. Do not add tokens to the
|
8 |
+
vocabulary and do not extend the embedding layer.
|
9 |
+
1) Freeze all layers except for the softmax layer.
|
10 |
+
2) Train.
|
11 |
+
|
12 |
+
The class average F1 metric does the following:
|
13 |
+
1) For each class, relabel the dataset into binary classification
|
14 |
+
(belongs to/does not belong to this class).
|
15 |
+
2) Calculate F1 score for each class.
|
16 |
+
3) Compute the average of all F1 scores.
|
17 |
+
"""
|
18 |
+
|
19 |
+
from __future__ import print_function
|
20 |
+
import example_helper
|
21 |
+
import json
|
22 |
+
from torchmoji.finetuning import load_benchmark
|
23 |
+
from torchmoji.class_avg_finetuning import class_avg_finetune
|
24 |
+
from torchmoji.model_def import torchmoji_transfer
|
25 |
+
from torchmoji.global_variables import PRETRAINED_PATH
|
26 |
+
|
27 |
+
DATASET_PATH = '../data/SE0714/raw.pickle'
|
28 |
+
nb_classes = 3
|
29 |
+
|
30 |
+
with open('../model/vocabulary.json', 'r') as f:
|
31 |
+
vocab = json.load(f)
|
32 |
+
|
33 |
+
|
34 |
+
# Load dataset. Extend the existing vocabulary with up to 10000 tokens from
|
35 |
+
# the training dataset.
|
36 |
+
data = load_benchmark(DATASET_PATH, vocab, extend_with=10000)
|
37 |
+
|
38 |
+
# Set up model and finetune. Note that we have to extend the embedding layer
|
39 |
+
# with the number of tokens added to the vocabulary.
|
40 |
+
#
|
41 |
+
# Also note that when using class average F1 to evaluate, the model has to be
|
42 |
+
# defined with two classes, since the model will be trained for each class
|
43 |
+
# separately.
|
44 |
+
model = torchmoji_transfer(2, PRETRAINED_PATH, extend_embedding=data['added'])
|
45 |
+
print(model)
|
46 |
+
|
47 |
+
# For finetuning however, pass in the actual number of classes.
|
48 |
+
model, f1 = class_avg_finetune(model, data['texts'], data['labels'],
|
49 |
+
nb_classes, data['batch_size'], method='last')
|
50 |
+
print('F1: {}'.format(f1))
|
examples/finetune_youtube_last.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Finetuning example.
|
2 |
+
|
3 |
+
Trains the torchMoji model on the SS-Youtube dataset, using the 'last'
|
4 |
+
finetuning method and the accuracy metric.
|
5 |
+
|
6 |
+
The 'last' method does the following:
|
7 |
+
0) Load all weights except for the softmax layer. Do not add tokens to the
|
8 |
+
vocabulary and do not extend the embedding layer.
|
9 |
+
1) Freeze all layers except for the softmax layer.
|
10 |
+
2) Train.
|
11 |
+
"""
|
12 |
+
|
13 |
+
from __future__ import print_function
|
14 |
+
import example_helper
|
15 |
+
import json
|
16 |
+
from torchmoji.model_def import torchmoji_transfer
|
17 |
+
from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH, ROOT_PATH
|
18 |
+
from torchmoji.finetuning import (
|
19 |
+
load_benchmark,
|
20 |
+
finetune)
|
21 |
+
|
22 |
+
DATASET_PATH = '{}/data/SS-Youtube/raw.pickle'.format(ROOT_PATH)
|
23 |
+
nb_classes = 2
|
24 |
+
|
25 |
+
with open(VOCAB_PATH, 'r') as f:
|
26 |
+
vocab = json.load(f)
|
27 |
+
|
28 |
+
# Load dataset.
|
29 |
+
data = load_benchmark(DATASET_PATH, vocab)
|
30 |
+
|
31 |
+
# Set up model and finetune
|
32 |
+
model = torchmoji_transfer(nb_classes, PRETRAINED_PATH)
|
33 |
+
print(model)
|
34 |
+
model, acc = finetune(model, data['texts'], data['labels'], nb_classes, data['batch_size'], method='last')
|
35 |
+
print('Acc: {}'.format(acc))
|
examples/score_texts_emojis.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
""" Use torchMoji to score texts for emoji distribution.
|
4 |
+
|
5 |
+
The resulting emoji ids (0-63) correspond to the mapping
|
6 |
+
in emoji_overview.png file at the root of the torchMoji repo.
|
7 |
+
|
8 |
+
Writes the result to a csv file.
|
9 |
+
"""
|
10 |
+
from __future__ import print_function, division, unicode_literals
|
11 |
+
import example_helper
|
12 |
+
import json
|
13 |
+
import csv
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
17 |
+
from torchmoji.model_def import torchmoji_emojis
|
18 |
+
from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH
|
19 |
+
|
20 |
+
OUTPUT_PATH = 'test_sentences.csv'
|
21 |
+
|
22 |
+
TEST_SENTENCES = ['I love mom\'s cooking',
|
23 |
+
'I love how you never reply back..',
|
24 |
+
'I love cruising with my homies',
|
25 |
+
'I love messing with yo mind!!',
|
26 |
+
'I love you and now you\'re just gone..',
|
27 |
+
'This is shit',
|
28 |
+
'This is the shit']
|
29 |
+
|
30 |
+
|
31 |
+
def top_elements(array, k):
|
32 |
+
ind = np.argpartition(array, -k)[-k:]
|
33 |
+
return ind[np.argsort(array[ind])][::-1]
|
34 |
+
|
35 |
+
maxlen = 30
|
36 |
+
|
37 |
+
print('Tokenizing using dictionary from {}'.format(VOCAB_PATH))
|
38 |
+
with open(VOCAB_PATH, 'r') as f:
|
39 |
+
vocabulary = json.load(f)
|
40 |
+
|
41 |
+
st = SentenceTokenizer(vocabulary, maxlen)
|
42 |
+
|
43 |
+
print('Loading model from {}.'.format(PRETRAINED_PATH))
|
44 |
+
model = torchmoji_emojis(PRETRAINED_PATH)
|
45 |
+
print(model)
|
46 |
+
print('Running predictions.')
|
47 |
+
tokenized, _, _ = st.tokenize_sentences(TEST_SENTENCES)
|
48 |
+
prob = model(tokenized)
|
49 |
+
|
50 |
+
for prob in [prob]:
|
51 |
+
# Find top emojis for each sentence. Emoji ids (0-63)
|
52 |
+
# correspond to the mapping in emoji_overview.png
|
53 |
+
# at the root of the torchMoji repo.
|
54 |
+
print('Writing results to {}'.format(OUTPUT_PATH))
|
55 |
+
scores = []
|
56 |
+
for i, t in enumerate(TEST_SENTENCES):
|
57 |
+
t_tokens = tokenized[i]
|
58 |
+
t_score = [t]
|
59 |
+
t_prob = prob[i]
|
60 |
+
ind_top = top_elements(t_prob, 5)
|
61 |
+
t_score.append(sum(t_prob[ind_top]))
|
62 |
+
t_score.extend(ind_top)
|
63 |
+
t_score.extend([t_prob[ind] for ind in ind_top])
|
64 |
+
scores.append(t_score)
|
65 |
+
print(t_score)
|
66 |
+
|
67 |
+
with open(OUTPUT_PATH, 'wb') as csvfile:
|
68 |
+
writer = csv.writer(csvfile, delimiter=',', lineterminator='\n')
|
69 |
+
writer.writerow(['Text', 'Top5%',
|
70 |
+
'Emoji_1', 'Emoji_2', 'Emoji_3', 'Emoji_4', 'Emoji_5',
|
71 |
+
'Pct_1', 'Pct_2', 'Pct_3', 'Pct_4', 'Pct_5'])
|
72 |
+
for i, row in enumerate(scores):
|
73 |
+
try:
|
74 |
+
writer.writerow(row)
|
75 |
+
except:
|
76 |
+
print("Exception at row {}!".format(i))
|
examples/tokenize_dataset.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Take a given list of sentences and turn it into a numpy array, where each
|
3 |
+
number corresponds to a word. Padding is used (number 0) to ensure fixed length
|
4 |
+
of sentences.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from __future__ import print_function, unicode_literals
|
8 |
+
import example_helper
|
9 |
+
import json
|
10 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
11 |
+
|
12 |
+
with open('../model/vocabulary.json', 'r') as f:
|
13 |
+
vocabulary = json.load(f)
|
14 |
+
|
15 |
+
st = SentenceTokenizer(vocabulary, 30)
|
16 |
+
test_sentences = [
|
17 |
+
'\u2014 -- \u203c !!\U0001F602',
|
18 |
+
'Hello world!',
|
19 |
+
'This is a sample tweet #example',
|
20 |
+
]
|
21 |
+
|
22 |
+
tokens, infos, stats = st.tokenize_sentences(test_sentences)
|
23 |
+
|
24 |
+
print(tokens)
|
25 |
+
print(infos)
|
26 |
+
print(stats)
|
examples/vocab_extension.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Extend the given vocabulary using dataset-specific words.
|
3 |
+
|
4 |
+
1. First create a vocabulary for the specific dataset.
|
5 |
+
2. Find all words not in our vocabulary, but in the dataset vocabulary.
|
6 |
+
3. Take top X (default=1000) of these words and add them to the vocabulary.
|
7 |
+
4. Save this combined vocabulary and embedding matrix, which can now be used.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from __future__ import print_function, unicode_literals
|
11 |
+
import example_helper
|
12 |
+
import json
|
13 |
+
from torchmoji.create_vocab import extend_vocab, VocabBuilder
|
14 |
+
from torchmoji.word_generator import WordGenerator
|
15 |
+
|
16 |
+
new_words = ['#zzzzaaazzz', 'newword', 'newword']
|
17 |
+
word_gen = WordGenerator(new_words)
|
18 |
+
vb = VocabBuilder(word_gen)
|
19 |
+
vb.count_all_words()
|
20 |
+
|
21 |
+
with open('../model/vocabulary.json') as f:
|
22 |
+
vocab = json.load(f)
|
23 |
+
|
24 |
+
print(len(vocab))
|
25 |
+
print(vb.word_counts)
|
26 |
+
extend_vocab(vocab, vb, max_tokens=1)
|
27 |
+
|
28 |
+
# 'newword' should be added because it's more frequent in the given vocab
|
29 |
+
print(vocab['newword'])
|
30 |
+
print(len(vocab))
|
model/.gitkeep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
model/vocabulary.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
scripts/analyze_all_results.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
|
3 |
+
# allow us to import the codebase directory
|
4 |
+
import sys
|
5 |
+
import glob
|
6 |
+
import numpy as np
|
7 |
+
from os.path import dirname, abspath
|
8 |
+
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
9 |
+
|
10 |
+
DATASETS = ['SE0714', 'Olympic', 'PsychExp', 'SS-Twitter', 'SS-Youtube',
|
11 |
+
'SCv1', 'SV2-GEN'] # 'SE1604' excluded due to Twitter's ToS
|
12 |
+
|
13 |
+
def get_results(dset):
|
14 |
+
METHOD = 'last'
|
15 |
+
RESULTS_DIR = 'results/'
|
16 |
+
RESULT_PATHS = glob.glob('{}/{}_{}_*_results.txt'.format(RESULTS_DIR, dset, METHOD))
|
17 |
+
assert len(RESULT_PATHS)
|
18 |
+
|
19 |
+
scores = []
|
20 |
+
for path in RESULT_PATHS:
|
21 |
+
with open(path) as f:
|
22 |
+
score = f.readline().split(':')[1]
|
23 |
+
scores.append(float(score))
|
24 |
+
|
25 |
+
average = np.mean(scores)
|
26 |
+
maximum = max(scores)
|
27 |
+
minimum = min(scores)
|
28 |
+
std = np.std(scores)
|
29 |
+
|
30 |
+
print('Dataset: {}'.format(dset))
|
31 |
+
print('Method: {}'.format(METHOD))
|
32 |
+
print('Number of results: {}'.format(len(scores)))
|
33 |
+
print('--------------------------')
|
34 |
+
print('Average: {}'.format(average))
|
35 |
+
print('Maximum: {}'.format(maximum))
|
36 |
+
print('Minimum: {}'.format(minimum))
|
37 |
+
print('Standard deviaton: {}'.format(std))
|
38 |
+
|
39 |
+
for dset in DATASETS:
|
40 |
+
get_results(dset)
|
scripts/analyze_results.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import glob
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
DATASET = 'SS-Twitter' # 'SE1604' excluded due to Twitter's ToS
|
8 |
+
METHOD = 'new'
|
9 |
+
|
10 |
+
# Optional usage: analyze_results.py <dataset> <method>
|
11 |
+
if len(sys.argv) == 3:
|
12 |
+
DATASET = sys.argv[1]
|
13 |
+
METHOD = sys.argv[2]
|
14 |
+
|
15 |
+
RESULTS_DIR = 'results/'
|
16 |
+
RESULT_PATHS = glob.glob('{}/{}_{}_*_results.txt'.format(RESULTS_DIR, DATASET, METHOD))
|
17 |
+
|
18 |
+
if not RESULT_PATHS:
|
19 |
+
print('Could not find results for \'{}\' using \'{}\' in directory \'{}\'.'.format(DATASET, METHOD, RESULTS_DIR))
|
20 |
+
else:
|
21 |
+
scores = []
|
22 |
+
for path in RESULT_PATHS:
|
23 |
+
with open(path) as f:
|
24 |
+
score = f.readline().split(':')[1]
|
25 |
+
scores.append(float(score))
|
26 |
+
|
27 |
+
average = np.mean(scores)
|
28 |
+
maximum = max(scores)
|
29 |
+
minimum = min(scores)
|
30 |
+
std = np.std(scores)
|
31 |
+
|
32 |
+
print('Dataset: {}'.format(DATASET))
|
33 |
+
print('Method: {}'.format(METHOD))
|
34 |
+
print('Number of results: {}'.format(len(scores)))
|
35 |
+
print('--------------------------')
|
36 |
+
print('Average: {}'.format(average))
|
37 |
+
print('Maximum: {}'.format(maximum))
|
38 |
+
print('Minimum: {}'.format(minimum))
|
39 |
+
print('Standard deviaton: {}'.format(std))
|
scripts/calculate_coverages.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import pickle
|
3 |
+
import json
|
4 |
+
import csv
|
5 |
+
import sys
|
6 |
+
from io import open
|
7 |
+
|
8 |
+
# Allow us to import the torchmoji directory
|
9 |
+
from os.path import dirname, abspath
|
10 |
+
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
11 |
+
|
12 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer, coverage
|
13 |
+
|
14 |
+
IS_PYTHON2 = int(sys.version[0]) == 2
|
15 |
+
|
16 |
+
OUTPUT_PATH = 'coverage.csv'
|
17 |
+
DATASET_PATHS = [
|
18 |
+
'../data/Olympic/raw.pickle',
|
19 |
+
'../data/PsychExp/raw.pickle',
|
20 |
+
'../data/SCv1/raw.pickle',
|
21 |
+
'../data/SCv2-GEN/raw.pickle',
|
22 |
+
'../data/SE0714/raw.pickle',
|
23 |
+
#'../data/SE1604/raw.pickle', # Excluded due to Twitter's ToS
|
24 |
+
'../data/SS-Twitter/raw.pickle',
|
25 |
+
'../data/SS-Youtube/raw.pickle',
|
26 |
+
]
|
27 |
+
|
28 |
+
with open('../model/vocabulary.json', 'r') as f:
|
29 |
+
vocab = json.load(f)
|
30 |
+
|
31 |
+
results = []
|
32 |
+
for p in DATASET_PATHS:
|
33 |
+
coverage_result = [p]
|
34 |
+
print('Calculating coverage for {}'.format(p))
|
35 |
+
with open(p, 'rb') as f:
|
36 |
+
if IS_PYTHON2:
|
37 |
+
s = pickle.load(f)
|
38 |
+
else:
|
39 |
+
s = pickle.load(f, fix_imports=True)
|
40 |
+
|
41 |
+
# Decode data
|
42 |
+
try:
|
43 |
+
s['texts'] = [unicode(x) for x in s['texts']]
|
44 |
+
except UnicodeDecodeError:
|
45 |
+
s['texts'] = [x.decode('utf-8') for x in s['texts']]
|
46 |
+
|
47 |
+
# Own
|
48 |
+
st = SentenceTokenizer({}, 30)
|
49 |
+
tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'],
|
50 |
+
[s['train_ind'],
|
51 |
+
s['val_ind'],
|
52 |
+
s['test_ind']],
|
53 |
+
extend_with=10000)
|
54 |
+
coverage_result.append(coverage(tests[2]))
|
55 |
+
|
56 |
+
# Last
|
57 |
+
st = SentenceTokenizer(vocab, 30)
|
58 |
+
tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'],
|
59 |
+
[s['train_ind'],
|
60 |
+
s['val_ind'],
|
61 |
+
s['test_ind']],
|
62 |
+
extend_with=0)
|
63 |
+
coverage_result.append(coverage(tests[2]))
|
64 |
+
|
65 |
+
# Full
|
66 |
+
st = SentenceTokenizer(vocab, 30)
|
67 |
+
tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'],
|
68 |
+
[s['train_ind'],
|
69 |
+
s['val_ind'],
|
70 |
+
s['test_ind']],
|
71 |
+
extend_with=10000)
|
72 |
+
coverage_result.append(coverage(tests[2]))
|
73 |
+
|
74 |
+
results.append(coverage_result)
|
75 |
+
|
76 |
+
with open(OUTPUT_PATH, 'wb') as csvfile:
|
77 |
+
writer = csv.writer(csvfile, delimiter='\t', lineterminator='\n')
|
78 |
+
writer.writerow(['Dataset', 'Own', 'Last', 'Full'])
|
79 |
+
for i, row in enumerate(results):
|
80 |
+
try:
|
81 |
+
writer.writerow(row)
|
82 |
+
except:
|
83 |
+
print("Exception at row {}!".format(i))
|
84 |
+
|
85 |
+
print('Saved to {}'.format(OUTPUT_PATH))
|
scripts/convert_all_datasets.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
import pickle
|
6 |
+
import sys
|
7 |
+
from io import open
|
8 |
+
import numpy as np
|
9 |
+
from os.path import abspath, dirname
|
10 |
+
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
11 |
+
|
12 |
+
from torchmoji.word_generator import WordGenerator
|
13 |
+
from torchmoji.create_vocab import VocabBuilder
|
14 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer, extend_vocab, coverage
|
15 |
+
from torchmoji.tokenizer import tokenize
|
16 |
+
|
17 |
+
IS_PYTHON2 = int(sys.version[0]) == 2
|
18 |
+
|
19 |
+
DATASETS = [
|
20 |
+
'Olympic',
|
21 |
+
'PsychExp',
|
22 |
+
'SCv1',
|
23 |
+
'SCv2-GEN',
|
24 |
+
'SE0714',
|
25 |
+
#'SE1604', # Excluded due to Twitter's ToS
|
26 |
+
'SS-Twitter',
|
27 |
+
'SS-Youtube',
|
28 |
+
]
|
29 |
+
|
30 |
+
DIR = '../data'
|
31 |
+
FILENAME_RAW = 'raw.pickle'
|
32 |
+
FILENAME_OWN = 'own_vocab.pickle'
|
33 |
+
FILENAME_OUR = 'twitter_vocab.pickle'
|
34 |
+
FILENAME_COMBINED = 'combined_vocab.pickle'
|
35 |
+
|
36 |
+
|
37 |
+
def roundup(x):
|
38 |
+
return int(math.ceil(x / 10.0)) * 10
|
39 |
+
|
40 |
+
|
41 |
+
def format_pickle(dset, train_texts, val_texts, test_texts, train_labels, val_labels, test_labels):
|
42 |
+
return {'dataset': dset,
|
43 |
+
'train_texts': train_texts,
|
44 |
+
'val_texts': val_texts,
|
45 |
+
'test_texts': test_texts,
|
46 |
+
'train_labels': train_labels,
|
47 |
+
'val_labels': val_labels,
|
48 |
+
'test_labels': test_labels}
|
49 |
+
|
50 |
+
def convert_dataset(filepath, extend_with, vocab):
|
51 |
+
print('-- Generating {} '.format(filepath))
|
52 |
+
sys.stdout.flush()
|
53 |
+
st = SentenceTokenizer(vocab, maxlen)
|
54 |
+
tokenized, dicts, _ = st.split_train_val_test(texts,
|
55 |
+
labels,
|
56 |
+
[data['train_ind'],
|
57 |
+
data['val_ind'],
|
58 |
+
data['test_ind']],
|
59 |
+
extend_with=extend_with)
|
60 |
+
pick = format_pickle(dset, tokenized[0], tokenized[1], tokenized[2],
|
61 |
+
dicts[0], dicts[1], dicts[2])
|
62 |
+
with open(filepath, 'w') as f:
|
63 |
+
pickle.dump(pick, f)
|
64 |
+
cover = coverage(tokenized[2])
|
65 |
+
|
66 |
+
print(' done. Coverage: {}'.format(cover))
|
67 |
+
|
68 |
+
with open('../model/vocabulary.json', 'r') as f:
|
69 |
+
vocab = json.load(f)
|
70 |
+
|
71 |
+
for dset in DATASETS:
|
72 |
+
print('Converting {}'.format(dset))
|
73 |
+
|
74 |
+
PATH_RAW = '{}/{}/{}'.format(DIR, dset, FILENAME_RAW)
|
75 |
+
PATH_OWN = '{}/{}/{}'.format(DIR, dset, FILENAME_OWN)
|
76 |
+
PATH_OUR = '{}/{}/{}'.format(DIR, dset, FILENAME_OUR)
|
77 |
+
PATH_COMBINED = '{}/{}/{}'.format(DIR, dset, FILENAME_COMBINED)
|
78 |
+
|
79 |
+
with open(PATH_RAW, 'rb') as dataset:
|
80 |
+
if IS_PYTHON2:
|
81 |
+
data = pickle.load(dataset)
|
82 |
+
else:
|
83 |
+
data = pickle.load(dataset, fix_imports=True)
|
84 |
+
|
85 |
+
# Decode data
|
86 |
+
try:
|
87 |
+
texts = [unicode(x) for x in data['texts']]
|
88 |
+
except UnicodeDecodeError:
|
89 |
+
texts = [x.decode('utf-8') for x in data['texts']]
|
90 |
+
|
91 |
+
wg = WordGenerator(texts)
|
92 |
+
vb = VocabBuilder(wg)
|
93 |
+
vb.count_all_words()
|
94 |
+
|
95 |
+
# Calculate max length of sequences considered
|
96 |
+
# Adjust batch_size accordingly to prevent GPU overflow
|
97 |
+
lengths = [len(tokenize(t)) for t in texts]
|
98 |
+
maxlen = roundup(np.percentile(lengths, 80.0))
|
99 |
+
|
100 |
+
# Extract labels
|
101 |
+
labels = [x['label'] for x in data['info']]
|
102 |
+
|
103 |
+
convert_dataset(PATH_OWN, 50000, {})
|
104 |
+
convert_dataset(PATH_OUR, 0, vocab)
|
105 |
+
convert_dataset(PATH_COMBINED, 10000, vocab)
|
scripts/download_weights.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
from subprocess import call
|
4 |
+
|
5 |
+
curr_folder = os.path.basename(os.path.normpath(os.getcwd()))
|
6 |
+
|
7 |
+
weights_filename = 'pytorch_model.bin'
|
8 |
+
weights_folder = 'model'
|
9 |
+
weights_path = '{}/{}'.format(weights_folder, weights_filename)
|
10 |
+
if curr_folder == 'scripts':
|
11 |
+
weights_path = '../' + weights_path
|
12 |
+
weights_download_link = 'https://www.dropbox.com/s/q8lax9ary32c7t9/pytorch_model.bin?dl=0#'
|
13 |
+
|
14 |
+
|
15 |
+
MB_FACTOR = float(1<<20)
|
16 |
+
|
17 |
+
def prompt():
|
18 |
+
while True:
|
19 |
+
valid = {
|
20 |
+
'y': True,
|
21 |
+
'ye': True,
|
22 |
+
'yes': True,
|
23 |
+
'n': False,
|
24 |
+
'no': False,
|
25 |
+
}
|
26 |
+
choice = raw_input().lower()
|
27 |
+
if choice in valid:
|
28 |
+
return valid[choice]
|
29 |
+
else:
|
30 |
+
print('Please respond with \'y\' or \'n\' (or \'yes\' or \'no\')')
|
31 |
+
|
32 |
+
download = True
|
33 |
+
if os.path.exists(weights_path):
|
34 |
+
print('Weight file already exists at {}. Would you like to redownload it anyway? [y/n]'.format(weights_path))
|
35 |
+
download = prompt()
|
36 |
+
already_exists = True
|
37 |
+
else:
|
38 |
+
already_exists = False
|
39 |
+
|
40 |
+
if download:
|
41 |
+
print('About to download the pretrained weights file from {}'.format(weights_download_link))
|
42 |
+
if already_exists == False:
|
43 |
+
print('The size of the file is roughly 85MB. Continue? [y/n]')
|
44 |
+
else:
|
45 |
+
os.unlink(weights_path)
|
46 |
+
|
47 |
+
if already_exists or prompt():
|
48 |
+
print('Downloading...')
|
49 |
+
|
50 |
+
#urllib.urlretrieve(weights_download_link, weights_path)
|
51 |
+
#with open(weights_path,'wb') as f:
|
52 |
+
# f.write(requests.get(weights_download_link).content)
|
53 |
+
|
54 |
+
# downloading using wget due to issues with urlretrieve and requests
|
55 |
+
sys_call = 'wget {} -O {}'.format(weights_download_link, os.path.abspath(weights_path))
|
56 |
+
print("Running system call: {}".format(sys_call))
|
57 |
+
call(sys_call, shell=True)
|
58 |
+
|
59 |
+
if os.path.getsize(weights_path) / MB_FACTOR < 80:
|
60 |
+
raise ValueError("Download finished, but the resulting file is too small! " +
|
61 |
+
"It\'s only {} bytes.".format(os.path.getsize(weights_path)))
|
62 |
+
print('Downloaded weights to {}'.format(weights_path))
|
63 |
+
else:
|
64 |
+
print('Exiting.')
|
scripts/finetune_dataset.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Finetuning example.
|
2 |
+
"""
|
3 |
+
from __future__ import print_function
|
4 |
+
import sys
|
5 |
+
import numpy as np
|
6 |
+
from os.path import abspath, dirname
|
7 |
+
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
8 |
+
|
9 |
+
import json
|
10 |
+
import math
|
11 |
+
from torchmoji.model_def import torchmoji_transfer
|
12 |
+
from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH
|
13 |
+
from torchmoji.finetuning import (
|
14 |
+
load_benchmark,
|
15 |
+
finetune)
|
16 |
+
from torchmoji.class_avg_finetuning import class_avg_finetune
|
17 |
+
|
18 |
+
def roundup(x):
|
19 |
+
return int(math.ceil(x / 10.0)) * 10
|
20 |
+
|
21 |
+
|
22 |
+
# Format: (dataset_name,
|
23 |
+
# path_to_dataset,
|
24 |
+
# nb_classes,
|
25 |
+
# use_f1_score)
|
26 |
+
DATASETS = [
|
27 |
+
#('SE0714', '../data/SE0714/raw.pickle', 3, True),
|
28 |
+
#('Olympic', '../data/Olympic/raw.pickle', 4, True),
|
29 |
+
#('PsychExp', '../data/PsychExp/raw.pickle', 7, True),
|
30 |
+
#('SS-Twitter', '../data/SS-Twitter/raw.pickle', 2, False),
|
31 |
+
('SS-Youtube', '../data/SS-Youtube/raw.pickle', 2, False),
|
32 |
+
#('SE1604', '../data/SE1604/raw.pickle', 3, False), # Excluded due to Twitter's ToS
|
33 |
+
#('SCv1', '../data/SCv1/raw.pickle', 2, True),
|
34 |
+
#('SCv2-GEN', '../data/SCv2-GEN/raw.pickle', 2, True)
|
35 |
+
]
|
36 |
+
|
37 |
+
RESULTS_DIR = 'results'
|
38 |
+
|
39 |
+
# 'new' | 'last' | 'full' | 'chain-thaw'
|
40 |
+
FINETUNE_METHOD = 'last'
|
41 |
+
VERBOSE = 1
|
42 |
+
|
43 |
+
nb_tokens = 50000
|
44 |
+
nb_epochs = 1000
|
45 |
+
epoch_size = 1000
|
46 |
+
|
47 |
+
with open(VOCAB_PATH, 'r') as f:
|
48 |
+
vocab = json.load(f)
|
49 |
+
|
50 |
+
for rerun_iter in range(5):
|
51 |
+
for p in DATASETS:
|
52 |
+
|
53 |
+
# debugging
|
54 |
+
assert len(vocab) == nb_tokens
|
55 |
+
|
56 |
+
dset = p[0]
|
57 |
+
path = p[1]
|
58 |
+
nb_classes = p[2]
|
59 |
+
use_f1_score = p[3]
|
60 |
+
|
61 |
+
if FINETUNE_METHOD == 'last':
|
62 |
+
extend_with = 0
|
63 |
+
elif FINETUNE_METHOD in ['new', 'full', 'chain-thaw']:
|
64 |
+
extend_with = 10000
|
65 |
+
else:
|
66 |
+
raise ValueError('Finetuning method not recognised!')
|
67 |
+
|
68 |
+
# Load dataset.
|
69 |
+
data = load_benchmark(path, vocab, extend_with=extend_with)
|
70 |
+
|
71 |
+
(X_train, y_train) = (data['texts'][0], data['labels'][0])
|
72 |
+
(X_val, y_val) = (data['texts'][1], data['labels'][1])
|
73 |
+
(X_test, y_test) = (data['texts'][2], data['labels'][2])
|
74 |
+
|
75 |
+
weight_path = PRETRAINED_PATH if FINETUNE_METHOD != 'new' else None
|
76 |
+
nb_model_classes = 2 if use_f1_score else nb_classes
|
77 |
+
model = torchmoji_transfer(
|
78 |
+
nb_model_classes,
|
79 |
+
data['maxlen'], weight_path,
|
80 |
+
extend_embedding=data['added'])
|
81 |
+
model.summary()
|
82 |
+
|
83 |
+
# Training
|
84 |
+
print('Training: {}'.format(path))
|
85 |
+
if use_f1_score:
|
86 |
+
model, result = class_avg_finetune(model, data['texts'],
|
87 |
+
data['labels'],
|
88 |
+
nb_classes, data['batch_size'],
|
89 |
+
FINETUNE_METHOD,
|
90 |
+
verbose=VERBOSE)
|
91 |
+
else:
|
92 |
+
model, result = finetune(model, data['texts'], data['labels'],
|
93 |
+
nb_classes, data['batch_size'],
|
94 |
+
FINETUNE_METHOD, metric='acc',
|
95 |
+
verbose=VERBOSE)
|
96 |
+
|
97 |
+
# Write results
|
98 |
+
if use_f1_score:
|
99 |
+
print('Overall F1 score (dset = {}): {}'.format(dset, result))
|
100 |
+
with open('{}/{}_{}_{}_results.txt'.
|
101 |
+
format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter),
|
102 |
+
"w") as f:
|
103 |
+
f.write("F1: {}\n".format(result))
|
104 |
+
else:
|
105 |
+
print('Test accuracy (dset = {}): {}'.format(dset, result))
|
106 |
+
with open('{}/{}_{}_{}_results.txt'.
|
107 |
+
format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter),
|
108 |
+
"w") as f:
|
109 |
+
f.write("Acc: {}\n".format(result))
|
scripts/results/.gitkeep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
setup.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name='torchmoji',
|
5 |
+
version='1.0',
|
6 |
+
packages=['torchmoji'],
|
7 |
+
description='torchMoji',
|
8 |
+
include_package_data=True,
|
9 |
+
install_requires=[
|
10 |
+
'emoji==0.4.5',
|
11 |
+
'numpy==1.13.1',
|
12 |
+
'scipy==0.19.1',
|
13 |
+
'scikit-learn==0.19.0',
|
14 |
+
'text-unidecode==1.0',
|
15 |
+
],
|
16 |
+
)
|
tests/test_finetuning.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, print_function, division, unicode_literals
|
2 |
+
|
3 |
+
import test_helper
|
4 |
+
|
5 |
+
from nose.plugins.attrib import attr
|
6 |
+
import json
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from torchmoji.class_avg_finetuning import relabel
|
10 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
11 |
+
|
12 |
+
from torchmoji.finetuning import (
|
13 |
+
calculate_batchsize_maxlen,
|
14 |
+
freeze_layers,
|
15 |
+
change_trainable,
|
16 |
+
finetune,
|
17 |
+
load_benchmark
|
18 |
+
)
|
19 |
+
from torchmoji.model_def import (
|
20 |
+
torchmoji_transfer,
|
21 |
+
torchmoji_feature_encoding,
|
22 |
+
torchmoji_emojis
|
23 |
+
)
|
24 |
+
from torchmoji.global_variables import (
|
25 |
+
PRETRAINED_PATH,
|
26 |
+
NB_TOKENS,
|
27 |
+
VOCAB_PATH,
|
28 |
+
ROOT_PATH
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def test_calculate_batchsize_maxlen():
|
33 |
+
""" Batch size and max length are calculated properly.
|
34 |
+
"""
|
35 |
+
texts = ['a b c d',
|
36 |
+
'e f g h i']
|
37 |
+
batch_size, maxlen = calculate_batchsize_maxlen(texts)
|
38 |
+
|
39 |
+
assert batch_size == 250
|
40 |
+
assert maxlen == 10, maxlen
|
41 |
+
|
42 |
+
|
43 |
+
def test_freeze_layers():
|
44 |
+
""" Correct layers are frozen.
|
45 |
+
"""
|
46 |
+
model = torchmoji_transfer(5)
|
47 |
+
keyword = 'output_layer'
|
48 |
+
|
49 |
+
model = freeze_layers(model, unfrozen_keyword=keyword)
|
50 |
+
|
51 |
+
for name, module in model.named_children():
|
52 |
+
trainable = keyword.lower() in name.lower()
|
53 |
+
assert all(p.requires_grad == trainable for p in module.parameters())
|
54 |
+
|
55 |
+
|
56 |
+
def test_change_trainable():
|
57 |
+
""" change_trainable() changes trainability of layers.
|
58 |
+
"""
|
59 |
+
model = torchmoji_transfer(5)
|
60 |
+
change_trainable(model.embed, False)
|
61 |
+
assert not any(p.requires_grad for p in model.embed.parameters())
|
62 |
+
change_trainable(model.embed, True)
|
63 |
+
assert all(p.requires_grad for p in model.embed.parameters())
|
64 |
+
|
65 |
+
|
66 |
+
def test_torchmoji_transfer_extend_embedding():
|
67 |
+
""" Defining torchmoji with extension.
|
68 |
+
"""
|
69 |
+
extend_with = 50
|
70 |
+
model = torchmoji_transfer(5, weight_path=PRETRAINED_PATH,
|
71 |
+
extend_embedding=extend_with)
|
72 |
+
embedding_layer = model.embed
|
73 |
+
assert embedding_layer.weight.size()[0] == NB_TOKENS + extend_with
|
74 |
+
|
75 |
+
|
76 |
+
def test_torchmoji_return_attention():
|
77 |
+
seq_tensor = np.array([[1]])
|
78 |
+
# test the output of the normal model
|
79 |
+
model = torchmoji_emojis(weight_path=PRETRAINED_PATH)
|
80 |
+
# check correct number of outputs
|
81 |
+
assert len(model(seq_tensor)) == 1
|
82 |
+
# repeat above described tests when returning attention weights
|
83 |
+
model = torchmoji_emojis(weight_path=PRETRAINED_PATH, return_attention=True)
|
84 |
+
assert len(model(seq_tensor)) == 2
|
85 |
+
|
86 |
+
|
87 |
+
def test_relabel():
|
88 |
+
""" relabel() works with multi-class labels.
|
89 |
+
"""
|
90 |
+
nb_classes = 3
|
91 |
+
inputs = np.array([
|
92 |
+
[True, False, False],
|
93 |
+
[False, True, False],
|
94 |
+
[True, False, True],
|
95 |
+
])
|
96 |
+
expected_0 = np.array([True, False, True])
|
97 |
+
expected_1 = np.array([False, True, False])
|
98 |
+
expected_2 = np.array([False, False, True])
|
99 |
+
|
100 |
+
assert np.array_equal(relabel(inputs, 0, nb_classes), expected_0)
|
101 |
+
assert np.array_equal(relabel(inputs, 1, nb_classes), expected_1)
|
102 |
+
assert np.array_equal(relabel(inputs, 2, nb_classes), expected_2)
|
103 |
+
|
104 |
+
|
105 |
+
def test_relabel_binary():
|
106 |
+
""" relabel() works with binary classification (no changes to labels)
|
107 |
+
"""
|
108 |
+
nb_classes = 2
|
109 |
+
inputs = np.array([True, False, False])
|
110 |
+
|
111 |
+
assert np.array_equal(relabel(inputs, 0, nb_classes), inputs)
|
112 |
+
|
113 |
+
|
114 |
+
@attr('slow')
|
115 |
+
def test_finetune_full():
|
116 |
+
""" finetuning using 'full'.
|
117 |
+
"""
|
118 |
+
DATASET_PATH = ROOT_PATH+'/data/SS-Youtube/raw.pickle'
|
119 |
+
nb_classes = 2
|
120 |
+
# Keras and pyTorch implementation of the Adam optimizer are slightly different and change a bit the results
|
121 |
+
# We reduce the min accuracy needed here to pass the test
|
122 |
+
# See e.g. https://discuss.pytorch.org/t/suboptimal-convergence-when-compared-with-tensorflow-model/5099/11
|
123 |
+
min_acc = 0.68
|
124 |
+
|
125 |
+
with open(VOCAB_PATH, 'r') as f:
|
126 |
+
vocab = json.load(f)
|
127 |
+
|
128 |
+
data = load_benchmark(DATASET_PATH, vocab, extend_with=10000)
|
129 |
+
print('Loading pyTorch model from {}.'.format(PRETRAINED_PATH))
|
130 |
+
model = torchmoji_transfer(nb_classes, PRETRAINED_PATH, extend_embedding=data['added'])
|
131 |
+
print(model)
|
132 |
+
model, acc = finetune(model, data['texts'], data['labels'], nb_classes,
|
133 |
+
data['batch_size'], method='full', nb_epochs=1)
|
134 |
+
|
135 |
+
print("Finetune full SS-Youtube 1 epoch acc: {}".format(acc))
|
136 |
+
assert acc >= min_acc
|
137 |
+
|
138 |
+
|
139 |
+
@attr('slow')
|
140 |
+
def test_finetune_last():
|
141 |
+
""" finetuning using 'last'.
|
142 |
+
"""
|
143 |
+
dataset_path = ROOT_PATH + '/data/SS-Youtube/raw.pickle'
|
144 |
+
nb_classes = 2
|
145 |
+
min_acc = 0.68
|
146 |
+
|
147 |
+
with open(VOCAB_PATH, 'r') as f:
|
148 |
+
vocab = json.load(f)
|
149 |
+
|
150 |
+
data = load_benchmark(dataset_path, vocab)
|
151 |
+
print('Loading model from {}.'.format(PRETRAINED_PATH))
|
152 |
+
model = torchmoji_transfer(nb_classes, PRETRAINED_PATH)
|
153 |
+
print(model)
|
154 |
+
model, acc = finetune(model, data['texts'], data['labels'], nb_classes,
|
155 |
+
data['batch_size'], method='last', nb_epochs=1)
|
156 |
+
|
157 |
+
print("Finetune last SS-Youtube 1 epoch acc: {}".format(acc))
|
158 |
+
|
159 |
+
assert acc >= min_acc
|
160 |
+
|
161 |
+
|
162 |
+
def test_score_emoji():
|
163 |
+
""" Emoji predictions make sense.
|
164 |
+
"""
|
165 |
+
test_sentences = [
|
166 |
+
'I love mom\'s cooking',
|
167 |
+
'I love how you never reply back..',
|
168 |
+
'I love cruising with my homies',
|
169 |
+
'I love messing with yo mind!!',
|
170 |
+
'I love you and now you\'re just gone..',
|
171 |
+
'This is shit',
|
172 |
+
'This is the shit'
|
173 |
+
]
|
174 |
+
|
175 |
+
expected = [
|
176 |
+
np.array([36, 4, 8, 16, 47]),
|
177 |
+
np.array([1, 19, 55, 25, 46]),
|
178 |
+
np.array([31, 6, 30, 15, 13]),
|
179 |
+
np.array([54, 44, 9, 50, 49]),
|
180 |
+
np.array([46, 5, 27, 35, 34]),
|
181 |
+
np.array([55, 32, 27, 1, 37]),
|
182 |
+
np.array([48, 11, 6, 31, 9])
|
183 |
+
]
|
184 |
+
|
185 |
+
def top_elements(array, k):
|
186 |
+
ind = np.argpartition(array, -k)[-k:]
|
187 |
+
return ind[np.argsort(array[ind])][::-1]
|
188 |
+
|
189 |
+
# Initialize by loading dictionary and tokenize texts
|
190 |
+
with open(VOCAB_PATH, 'r') as f:
|
191 |
+
vocabulary = json.load(f)
|
192 |
+
|
193 |
+
st = SentenceTokenizer(vocabulary, 30)
|
194 |
+
tokens, _, _ = st.tokenize_sentences(test_sentences)
|
195 |
+
|
196 |
+
# Load model and run
|
197 |
+
model = torchmoji_emojis(weight_path=PRETRAINED_PATH)
|
198 |
+
prob = model(tokens)
|
199 |
+
|
200 |
+
# Find top emojis for each sentence
|
201 |
+
for i, t_prob in enumerate(list(prob)):
|
202 |
+
assert np.array_equal(top_elements(t_prob, 5), expected[i])
|
203 |
+
|
204 |
+
|
205 |
+
def test_encode_texts():
|
206 |
+
""" Text encoding is stable.
|
207 |
+
"""
|
208 |
+
|
209 |
+
TEST_SENTENCES = ['I love mom\'s cooking',
|
210 |
+
'I love how you never reply back..',
|
211 |
+
'I love cruising with my homies',
|
212 |
+
'I love messing with yo mind!!',
|
213 |
+
'I love you and now you\'re just gone..',
|
214 |
+
'This is shit',
|
215 |
+
'This is the shit']
|
216 |
+
|
217 |
+
|
218 |
+
maxlen = 30
|
219 |
+
batch_size = 32
|
220 |
+
|
221 |
+
with open(VOCAB_PATH, 'r') as f:
|
222 |
+
vocabulary = json.load(f)
|
223 |
+
|
224 |
+
st = SentenceTokenizer(vocabulary, maxlen)
|
225 |
+
|
226 |
+
print('Loading model from {}.'.format(PRETRAINED_PATH))
|
227 |
+
model = torchmoji_feature_encoding(PRETRAINED_PATH)
|
228 |
+
print(model)
|
229 |
+
tokenized, _, _ = st.tokenize_sentences(TEST_SENTENCES)
|
230 |
+
encoding = model(tokenized)
|
231 |
+
|
232 |
+
avg_across_sentences = np.around(np.mean(encoding, axis=0)[:5], 3)
|
233 |
+
assert np.allclose(avg_across_sentences, np.array([-0.023, 0.021, -0.037, -0.001, -0.005]))
|
234 |
+
|
235 |
+
test_encode_texts()
|
tests/test_helper.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Module import helper.
|
2 |
+
Modifies PATH in order to allow us to import the torchmoji directory.
|
3 |
+
"""
|
4 |
+
import sys
|
5 |
+
from os.path import abspath, dirname
|
6 |
+
sys.path.insert(0, dirname(dirname(abspath(__file__))))
|
tests/test_sentence_tokenizer.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, print_function, division, unicode_literals
|
2 |
+
import test_helper
|
3 |
+
import json
|
4 |
+
|
5 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
6 |
+
|
7 |
+
sentences = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
|
8 |
+
|
9 |
+
dicts = [
|
10 |
+
{'label': 0},
|
11 |
+
{'label': 1},
|
12 |
+
{'label': 2},
|
13 |
+
{'label': 3},
|
14 |
+
{'label': 4},
|
15 |
+
{'label': 5},
|
16 |
+
{'label': 6},
|
17 |
+
{'label': 7},
|
18 |
+
{'label': 8},
|
19 |
+
{'label': 9},
|
20 |
+
]
|
21 |
+
|
22 |
+
train_ind = [0, 5, 3, 6, 8]
|
23 |
+
val_ind = [9, 2, 1]
|
24 |
+
test_ind = [4, 7]
|
25 |
+
|
26 |
+
with open('../model/vocabulary.json', 'r') as f:
|
27 |
+
vocab = json.load(f)
|
28 |
+
|
29 |
+
def test_dataset_split_parameter():
|
30 |
+
""" Dataset is split in the desired ratios
|
31 |
+
"""
|
32 |
+
split_parameter = [0.7, 0.1, 0.2]
|
33 |
+
st = SentenceTokenizer(vocab, 30)
|
34 |
+
|
35 |
+
result, result_dicts, _ = st.split_train_val_test(sentences, dicts,
|
36 |
+
split_parameter, extend_with=0)
|
37 |
+
train = result[0]
|
38 |
+
val = result[1]
|
39 |
+
test = result[2]
|
40 |
+
|
41 |
+
train_dicts = result_dicts[0]
|
42 |
+
val_dicts = result_dicts[1]
|
43 |
+
test_dicts = result_dicts[2]
|
44 |
+
|
45 |
+
assert len(train) == len(sentences) * split_parameter[0]
|
46 |
+
assert len(val) == len(sentences) * split_parameter[1]
|
47 |
+
assert len(test) == len(sentences) * split_parameter[2]
|
48 |
+
|
49 |
+
assert len(train_dicts) == len(dicts) * split_parameter[0]
|
50 |
+
assert len(val_dicts) == len(dicts) * split_parameter[1]
|
51 |
+
assert len(test_dicts) == len(dicts) * split_parameter[2]
|
52 |
+
|
53 |
+
def test_dataset_split_explicit():
|
54 |
+
""" Dataset is split according to given indices
|
55 |
+
"""
|
56 |
+
split_parameter = [train_ind, val_ind, test_ind]
|
57 |
+
st = SentenceTokenizer(vocab, 30)
|
58 |
+
tokenized, _, _ = st.tokenize_sentences(sentences)
|
59 |
+
|
60 |
+
result, result_dicts, added = st.split_train_val_test(sentences, dicts, split_parameter, extend_with=0)
|
61 |
+
train = result[0]
|
62 |
+
val = result[1]
|
63 |
+
test = result[2]
|
64 |
+
|
65 |
+
train_dicts = result_dicts[0]
|
66 |
+
val_dicts = result_dicts[1]
|
67 |
+
test_dicts = result_dicts[2]
|
68 |
+
|
69 |
+
tokenized = tokenized
|
70 |
+
|
71 |
+
for i, sentence in enumerate(sentences):
|
72 |
+
if i in train_ind:
|
73 |
+
assert tokenized[i] in train
|
74 |
+
assert dicts[i] in train_dicts
|
75 |
+
elif i in val_ind:
|
76 |
+
assert tokenized[i] in val
|
77 |
+
assert dicts[i] in val_dicts
|
78 |
+
elif i in test_ind:
|
79 |
+
assert tokenized[i] in test
|
80 |
+
assert dicts[i] in test_dicts
|
81 |
+
|
82 |
+
assert len(train) == len(train_ind)
|
83 |
+
assert len(val) == len(val_ind)
|
84 |
+
assert len(test) == len(test_ind)
|
85 |
+
assert len(train_dicts) == len(train_ind)
|
86 |
+
assert len(val_dicts) == len(val_ind)
|
87 |
+
assert len(test_dicts) == len(test_ind)
|
88 |
+
|
89 |
+
def test_id_to_sentence():
|
90 |
+
"""Tokenizing and converting back preserves the input.
|
91 |
+
"""
|
92 |
+
vb = {'CUSTOM_MASK': 0,
|
93 |
+
'aasdf': 1000,
|
94 |
+
'basdf': 2000}
|
95 |
+
|
96 |
+
sentence = 'aasdf basdf basdf basdf'
|
97 |
+
st = SentenceTokenizer(vb, 30)
|
98 |
+
token, _, _ = st.tokenize_sentences([sentence])
|
99 |
+
assert st.to_sentence(token[0]) == sentence
|
100 |
+
|
101 |
+
def test_id_to_sentence_with_unknown():
|
102 |
+
"""Tokenizing and converting back preserves the input, except for unknowns.
|
103 |
+
"""
|
104 |
+
vb = {'CUSTOM_MASK': 0,
|
105 |
+
'CUSTOM_UNKNOWN': 1,
|
106 |
+
'aasdf': 1000,
|
107 |
+
'basdf': 2000}
|
108 |
+
|
109 |
+
sentence = 'aasdf basdf ccc'
|
110 |
+
expected = 'aasdf basdf CUSTOM_UNKNOWN'
|
111 |
+
st = SentenceTokenizer(vb, 30)
|
112 |
+
token, _, _ = st.tokenize_sentences([sentence])
|
113 |
+
assert st.to_sentence(token[0]) == expected
|
tests/test_tokenizer.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
""" Tokenization tests.
|
3 |
+
"""
|
4 |
+
from __future__ import absolute_import, print_function, division, unicode_literals
|
5 |
+
|
6 |
+
import sys
|
7 |
+
from nose.tools import nottest
|
8 |
+
from os.path import dirname, abspath
|
9 |
+
sys.path.append(dirname(dirname(abspath(__file__))))
|
10 |
+
from torchmoji.tokenizer import tokenize
|
11 |
+
|
12 |
+
TESTS_NORMAL = [
|
13 |
+
('200K words!', ['200', 'K', 'words', '!']),
|
14 |
+
]
|
15 |
+
|
16 |
+
TESTS_EMOJIS = [
|
17 |
+
('i \U0001f496 you to the moon and back',
|
18 |
+
['i', '\U0001f496', 'you', 'to', 'the', 'moon', 'and', 'back']),
|
19 |
+
("i\U0001f496you to the \u2605's and back",
|
20 |
+
['i', '\U0001f496', 'you', 'to', 'the',
|
21 |
+
'\u2605', "'", 's', 'and', 'back']),
|
22 |
+
('~<3~', ['~', '<3', '~']),
|
23 |
+
('<333', ['<333']),
|
24 |
+
(':-)', [':-)']),
|
25 |
+
('>:-(', ['>:-(']),
|
26 |
+
('\u266b\u266a\u2605\u2606\u2665\u2764\u2661',
|
27 |
+
['\u266b', '\u266a', '\u2605', '\u2606',
|
28 |
+
'\u2665', '\u2764', '\u2661']),
|
29 |
+
]
|
30 |
+
|
31 |
+
TESTS_URLS = [
|
32 |
+
('www.sample.com', ['www.sample.com']),
|
33 |
+
('http://endless.horse', ['http://endless.horse']),
|
34 |
+
('https://github.mit.ed', ['https://github.mit.ed']),
|
35 |
+
]
|
36 |
+
|
37 |
+
TESTS_TWITTER = [
|
38 |
+
('#blacklivesmatter', ['#blacklivesmatter']),
|
39 |
+
('#99_percent.', ['#99_percent', '.']),
|
40 |
+
('the#99%', ['the', '#99', '%']),
|
41 |
+
('@golden_zenith', ['@golden_zenith']),
|
42 |
+
('@99_percent', ['@99_percent']),
|
43 |
+
('[email protected]', ['[email protected]']),
|
44 |
+
]
|
45 |
+
|
46 |
+
TESTS_PHONE_NUMS = [
|
47 |
+
('518)528-0252', ['518', ')', '528', '-', '0252']),
|
48 |
+
('1200-0221-0234', ['1200', '-', '0221', '-', '0234']),
|
49 |
+
('1200.0221.0234', ['1200', '.', '0221', '.', '0234']),
|
50 |
+
]
|
51 |
+
|
52 |
+
TESTS_DATETIME = [
|
53 |
+
('15:00', ['15', ':', '00']),
|
54 |
+
('2:00pm', ['2', ':', '00', 'pm']),
|
55 |
+
('9/14/16', ['9', '/', '14', '/', '16']),
|
56 |
+
]
|
57 |
+
|
58 |
+
TESTS_CURRENCIES = [
|
59 |
+
('517.933\xa3', ['517', '.', '933', '\xa3']),
|
60 |
+
('$517.87', ['$', '517', '.', '87']),
|
61 |
+
('1201.6598', ['1201', '.', '6598']),
|
62 |
+
('120,6', ['120', ',', '6']),
|
63 |
+
('10,00\u20ac', ['10', ',', '00', '\u20ac']),
|
64 |
+
('1,000', ['1', ',', '000']),
|
65 |
+
('1200pesos', ['1200', 'pesos']),
|
66 |
+
]
|
67 |
+
|
68 |
+
TESTS_NUM_SYM = [
|
69 |
+
('5162f', ['5162', 'f']),
|
70 |
+
('f5162', ['f', '5162']),
|
71 |
+
('1203(', ['1203', '(']),
|
72 |
+
('(1203)', ['(', '1203', ')']),
|
73 |
+
('1200/', ['1200', '/']),
|
74 |
+
('1200+', ['1200', '+']),
|
75 |
+
('1202o-east', ['1202', 'o-east']),
|
76 |
+
('1200r', ['1200', 'r']),
|
77 |
+
('1200-1400', ['1200', '-', '1400']),
|
78 |
+
('120/today', ['120', '/', 'today']),
|
79 |
+
('today/120', ['today', '/', '120']),
|
80 |
+
('120/5', ['120', '/', '5']),
|
81 |
+
("120'/5", ['120', "'", '/', '5']),
|
82 |
+
('120/5pro', ['120', '/', '5', 'pro']),
|
83 |
+
("1200's,)", ['1200', "'", 's', ',', ')']),
|
84 |
+
('120.76.218.207', ['120', '.', '76', '.', '218', '.', '207']),
|
85 |
+
]
|
86 |
+
|
87 |
+
TESTS_PUNCTUATION = [
|
88 |
+
("don''t", ['don', "''", 't']),
|
89 |
+
("don'tcha", ["don'tcha"]),
|
90 |
+
('no?!?!;', ['no', '?', '!', '?', '!', ';']),
|
91 |
+
('no??!!..', ['no', '??', '!!', '..']),
|
92 |
+
('a.m.', ['a.m.']),
|
93 |
+
('.s.u', ['.', 's', '.', 'u']),
|
94 |
+
('!!i..n__', ['!!', 'i', '..', 'n', '__']),
|
95 |
+
('lv(<3)w(3>)u Mr.!', ['lv', '(', '<3', ')', 'w', '(', '3',
|
96 |
+
'>', ')', 'u', 'Mr.', '!']),
|
97 |
+
('-->', ['--', '>']),
|
98 |
+
('->', ['-', '>']),
|
99 |
+
('<-', ['<', '-']),
|
100 |
+
('<--', ['<', '--']),
|
101 |
+
('hello (@person)', ['hello', '(', '@person', ')']),
|
102 |
+
]
|
103 |
+
|
104 |
+
|
105 |
+
def test_normal():
|
106 |
+
""" Normal/combined usage.
|
107 |
+
"""
|
108 |
+
test_base(TESTS_NORMAL)
|
109 |
+
|
110 |
+
|
111 |
+
def test_emojis():
|
112 |
+
""" Tokenizing emojis/emoticons/decorations.
|
113 |
+
"""
|
114 |
+
test_base(TESTS_EMOJIS)
|
115 |
+
|
116 |
+
|
117 |
+
def test_urls():
|
118 |
+
""" Tokenizing URLs.
|
119 |
+
"""
|
120 |
+
test_base(TESTS_URLS)
|
121 |
+
|
122 |
+
|
123 |
+
def test_twitter():
|
124 |
+
""" Tokenizing hashtags, mentions and emails.
|
125 |
+
"""
|
126 |
+
test_base(TESTS_TWITTER)
|
127 |
+
|
128 |
+
|
129 |
+
def test_phone_nums():
|
130 |
+
""" Tokenizing phone numbers.
|
131 |
+
"""
|
132 |
+
test_base(TESTS_PHONE_NUMS)
|
133 |
+
|
134 |
+
|
135 |
+
def test_datetime():
|
136 |
+
""" Tokenizing dates and times.
|
137 |
+
"""
|
138 |
+
test_base(TESTS_DATETIME)
|
139 |
+
|
140 |
+
|
141 |
+
def test_currencies():
|
142 |
+
""" Tokenizing currencies.
|
143 |
+
"""
|
144 |
+
test_base(TESTS_CURRENCIES)
|
145 |
+
|
146 |
+
|
147 |
+
def test_num_sym():
|
148 |
+
""" Tokenizing combinations of numbers and symbols.
|
149 |
+
"""
|
150 |
+
test_base(TESTS_NUM_SYM)
|
151 |
+
|
152 |
+
|
153 |
+
def test_punctuation():
|
154 |
+
""" Tokenizing punctuation and contractions.
|
155 |
+
"""
|
156 |
+
test_base(TESTS_PUNCTUATION)
|
157 |
+
|
158 |
+
|
159 |
+
@nottest
|
160 |
+
def test_base(tests):
|
161 |
+
""" Base function for running tests.
|
162 |
+
"""
|
163 |
+
for (test, expected) in tests:
|
164 |
+
actual = tokenize(test)
|
165 |
+
assert actual == expected, \
|
166 |
+
"Tokenization of \'{}\' failed, expected: {}, actual: {}"\
|
167 |
+
.format(test, expected, actual)
|
tests/test_word_generator.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import sys
|
3 |
+
from os.path import dirname, abspath
|
4 |
+
sys.path.append(dirname(dirname(abspath(__file__))))
|
5 |
+
from nose.tools import raises
|
6 |
+
from torchmoji.word_generator import WordGenerator
|
7 |
+
|
8 |
+
IS_PYTHON2 = int(sys.version[0]) == 2
|
9 |
+
|
10 |
+
@raises(ValueError)
|
11 |
+
def test_only_unicode_accepted():
|
12 |
+
""" Non-Unicode strings raise a ValueError.
|
13 |
+
In Python 3 all string are Unicode
|
14 |
+
"""
|
15 |
+
if not IS_PYTHON2:
|
16 |
+
raise ValueError("You are using python 3 so this test should always pass")
|
17 |
+
|
18 |
+
sentences = [
|
19 |
+
u'Hello world',
|
20 |
+
u'I am unicode',
|
21 |
+
'I am not unicode',
|
22 |
+
]
|
23 |
+
|
24 |
+
wg = WordGenerator(sentences)
|
25 |
+
for w in wg:
|
26 |
+
pass
|
27 |
+
|
28 |
+
|
29 |
+
def test_unicode_sentences_ignored_if_set():
|
30 |
+
""" Strings with Unicode characters tokenize to empty array if they're not allowed.
|
31 |
+
"""
|
32 |
+
sentence = [u'Dobrý den, jak se máš?']
|
33 |
+
wg = WordGenerator(sentence, allow_unicode_text=False)
|
34 |
+
assert wg.get_words(sentence[0]) == []
|
35 |
+
|
36 |
+
|
37 |
+
def test_check_ascii():
|
38 |
+
""" check_ascii recognises ASCII words properly.
|
39 |
+
In Python 3 all string are Unicode
|
40 |
+
"""
|
41 |
+
if not IS_PYTHON2:
|
42 |
+
return
|
43 |
+
|
44 |
+
wg = WordGenerator([])
|
45 |
+
assert wg.check_ascii('ASCII')
|
46 |
+
assert not wg.check_ascii('ščřžýá')
|
47 |
+
assert not wg.check_ascii('❤ ☀ ☆ ☂ ☻ ♞ ☯ ☭ ☢')
|
48 |
+
|
49 |
+
|
50 |
+
def test_convert_unicode_word():
|
51 |
+
""" convert_unicode_word converts Unicode words correctly.
|
52 |
+
"""
|
53 |
+
wg = WordGenerator([], allow_unicode_text=True)
|
54 |
+
|
55 |
+
result = wg.convert_unicode_word(u'č')
|
56 |
+
assert result == (True, u'\u010d'), '{}'.format(result)
|
57 |
+
|
58 |
+
|
59 |
+
def test_convert_unicode_word_ignores_if_set():
|
60 |
+
""" convert_unicode_word ignores Unicode words if set.
|
61 |
+
"""
|
62 |
+
wg = WordGenerator([], allow_unicode_text=False)
|
63 |
+
|
64 |
+
result = wg.convert_unicode_word(u'č')
|
65 |
+
assert result == (False, ''), '{}'.format(result)
|
66 |
+
|
67 |
+
|
68 |
+
def test_convert_unicode_chars():
|
69 |
+
""" convert_unicode_word correctly converts accented characters.
|
70 |
+
"""
|
71 |
+
wg = WordGenerator([], allow_unicode_text=True)
|
72 |
+
result = wg.convert_unicode_word(u'ěščřžýáíé')
|
73 |
+
assert result == (True, u'\u011b\u0161\u010d\u0159\u017e\xfd\xe1\xed\xe9'), '{}'.format(result)
|
torchmoji/.gitkeep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
torchmoji/__init__.py
ADDED
File without changes
|
torchmoji/attlayer.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
""" Define the Attention Layer of the model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from __future__ import print_function, division
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from torch.autograd import Variable
|
10 |
+
from torch.nn import Module
|
11 |
+
from torch.nn.parameter import Parameter
|
12 |
+
|
13 |
+
class Attention(Module):
|
14 |
+
"""
|
15 |
+
Computes a weighted average of the different channels across timesteps.
|
16 |
+
Uses 1 parameter pr. channel to compute the attention value for a single timestep.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, attention_size, return_attention=False):
|
20 |
+
""" Initialize the attention layer
|
21 |
+
|
22 |
+
# Arguments:
|
23 |
+
attention_size: Size of the attention vector.
|
24 |
+
return_attention: If true, output will include the weight for each input token
|
25 |
+
used for the prediction
|
26 |
+
|
27 |
+
"""
|
28 |
+
super(Attention, self).__init__()
|
29 |
+
self.return_attention = return_attention
|
30 |
+
self.attention_size = attention_size
|
31 |
+
self.attention_vector = Parameter(torch.FloatTensor(attention_size))
|
32 |
+
|
33 |
+
def __repr__(self):
|
34 |
+
s = '{name}({attention_size}, return attention={return_attention})'
|
35 |
+
return s.format(name=self.__class__.__name__, **self.__dict__)
|
36 |
+
|
37 |
+
def forward(self, inputs, input_lengths):
|
38 |
+
""" Forward pass.
|
39 |
+
|
40 |
+
# Arguments:
|
41 |
+
inputs (Torch.Variable): Tensor of input sequences
|
42 |
+
input_lengths (torch.LongTensor): Lengths of the sequences
|
43 |
+
|
44 |
+
# Return:
|
45 |
+
Tuple with (representations and attentions if self.return_attention else None).
|
46 |
+
"""
|
47 |
+
logits = inputs.matmul(self.attention_vector)
|
48 |
+
unnorm_ai = (logits - logits.max()).exp()
|
49 |
+
|
50 |
+
# Compute a mask for the attention on the padded sequences
|
51 |
+
# See e.g. https://discuss.pytorch.org/t/self-attention-on-words-and-masking/5671/5
|
52 |
+
max_len = unnorm_ai.size(1)
|
53 |
+
idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)).unsqueeze(0)
|
54 |
+
if torch.cuda.is_available():
|
55 |
+
idxes = idxes.cuda()
|
56 |
+
mask = Variable((idxes < input_lengths.unsqueeze(1)).float())
|
57 |
+
|
58 |
+
# apply mask and renormalize attention scores (weights)
|
59 |
+
masked_weights = unnorm_ai * mask
|
60 |
+
att_sums = masked_weights.sum(dim=1, keepdim=True) # sums per sequence
|
61 |
+
attentions = masked_weights.div(att_sums)
|
62 |
+
|
63 |
+
# apply attention weights
|
64 |
+
weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs))
|
65 |
+
|
66 |
+
# get the final fixed vector representations of the sentences
|
67 |
+
representations = weighted.sum(dim=1)
|
68 |
+
|
69 |
+
return (representations, attentions if self.return_attention else None)
|
torchmoji/class_avg_finetuning.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
""" Class average finetuning functions. Before using any of these finetuning
|
3 |
+
functions, ensure that the model is set up with nb_classes=2.
|
4 |
+
"""
|
5 |
+
from __future__ import print_function
|
6 |
+
|
7 |
+
import uuid
|
8 |
+
from time import sleep
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.optim as optim
|
14 |
+
|
15 |
+
from torchmoji.global_variables import (
|
16 |
+
FINETUNING_METHODS,
|
17 |
+
WEIGHTS_DIR)
|
18 |
+
from torchmoji.finetuning import (
|
19 |
+
freeze_layers,
|
20 |
+
get_data_loader,
|
21 |
+
fit_model,
|
22 |
+
train_by_chain_thaw,
|
23 |
+
find_f1_threshold)
|
24 |
+
|
25 |
+
def relabel(y, current_label_nr, nb_classes):
|
26 |
+
""" Makes a binary classification for a specific class in a
|
27 |
+
multi-class dataset.
|
28 |
+
|
29 |
+
# Arguments:
|
30 |
+
y: Outputs to be relabelled.
|
31 |
+
current_label_nr: Current label number.
|
32 |
+
nb_classes: Total number of classes.
|
33 |
+
|
34 |
+
# Returns:
|
35 |
+
Relabelled outputs of a given multi-class dataset into a binary
|
36 |
+
classification dataset.
|
37 |
+
"""
|
38 |
+
|
39 |
+
# Handling binary classification
|
40 |
+
if nb_classes == 2 and len(y.shape) == 1:
|
41 |
+
return y
|
42 |
+
|
43 |
+
y_new = np.zeros(len(y))
|
44 |
+
y_cut = y[:, current_label_nr]
|
45 |
+
label_pos = np.where(y_cut == 1)[0]
|
46 |
+
y_new[label_pos] = 1
|
47 |
+
return y_new
|
48 |
+
|
49 |
+
|
50 |
+
def class_avg_finetune(model, texts, labels, nb_classes, batch_size,
|
51 |
+
method, epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
|
52 |
+
verbose=True):
|
53 |
+
""" Compiles and finetunes the given model.
|
54 |
+
|
55 |
+
# Arguments:
|
56 |
+
model: Model to be finetuned
|
57 |
+
texts: List of three lists, containing tokenized inputs for training,
|
58 |
+
validation and testing (in that order).
|
59 |
+
labels: List of three lists, containing labels for training,
|
60 |
+
validation and testing (in that order).
|
61 |
+
nb_classes: Number of classes in the dataset.
|
62 |
+
batch_size: Batch size.
|
63 |
+
method: Finetuning method to be used. For available methods, see
|
64 |
+
FINETUNING_METHODS in global_variables.py. Note that the model
|
65 |
+
should be defined accordingly (see docstring for torchmoji_transfer())
|
66 |
+
epoch_size: Number of samples in an epoch.
|
67 |
+
nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
|
68 |
+
embed_l2: L2 regularization for the embedding layer.
|
69 |
+
verbose: Verbosity flag.
|
70 |
+
|
71 |
+
# Returns:
|
72 |
+
Model after finetuning,
|
73 |
+
score after finetuning using the class average F1 metric.
|
74 |
+
"""
|
75 |
+
|
76 |
+
if method not in FINETUNING_METHODS:
|
77 |
+
raise ValueError('ERROR (class_avg_tune_trainable): '
|
78 |
+
'Invalid method parameter. '
|
79 |
+
'Available options: {}'.format(FINETUNING_METHODS))
|
80 |
+
|
81 |
+
(X_train, y_train) = (texts[0], labels[0])
|
82 |
+
(X_val, y_val) = (texts[1], labels[1])
|
83 |
+
(X_test, y_test) = (texts[2], labels[2])
|
84 |
+
|
85 |
+
checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
|
86 |
+
.format(WEIGHTS_DIR, str(uuid.uuid4()))
|
87 |
+
|
88 |
+
f1_init_path = '{}/torchmoji-f1-init-{}.bin' \
|
89 |
+
.format(WEIGHTS_DIR, str(uuid.uuid4()))
|
90 |
+
|
91 |
+
if method in ['last', 'new']:
|
92 |
+
lr = 0.001
|
93 |
+
elif method in ['full', 'chain-thaw']:
|
94 |
+
lr = 0.0001
|
95 |
+
|
96 |
+
loss_op = nn.BCEWithLogitsLoss()
|
97 |
+
|
98 |
+
# Freeze layers if using last
|
99 |
+
if method == 'last':
|
100 |
+
model = freeze_layers(model, unfrozen_keyword='output_layer')
|
101 |
+
|
102 |
+
# Define optimizer, for chain-thaw we define it later (after freezing)
|
103 |
+
if method == 'last':
|
104 |
+
adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
|
105 |
+
elif method in ['full', 'new']:
|
106 |
+
# Add L2 regulation on embeddings only
|
107 |
+
special_params = [id(p) for p in model.embed.parameters()]
|
108 |
+
base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
|
109 |
+
embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
|
110 |
+
adam = optim.Adam([
|
111 |
+
{'params': base_params},
|
112 |
+
{'params': embed_parameters, 'weight_decay': embed_l2},
|
113 |
+
], lr=lr)
|
114 |
+
|
115 |
+
# Training
|
116 |
+
if verbose:
|
117 |
+
print('Method: {}'.format(method))
|
118 |
+
print('Classes: {}'.format(nb_classes))
|
119 |
+
|
120 |
+
if method == 'chain-thaw':
|
121 |
+
result = class_avg_chainthaw(model, nb_classes=nb_classes,
|
122 |
+
loss_op=loss_op,
|
123 |
+
train=(X_train, y_train),
|
124 |
+
val=(X_val, y_val),
|
125 |
+
test=(X_test, y_test),
|
126 |
+
batch_size=batch_size,
|
127 |
+
epoch_size=epoch_size,
|
128 |
+
nb_epochs=nb_epochs,
|
129 |
+
checkpoint_weight_path=checkpoint_path,
|
130 |
+
f1_init_weight_path=f1_init_path,
|
131 |
+
verbose=verbose)
|
132 |
+
else:
|
133 |
+
result = class_avg_tune_trainable(model, nb_classes=nb_classes,
|
134 |
+
loss_op=loss_op,
|
135 |
+
optim_op=adam,
|
136 |
+
train=(X_train, y_train),
|
137 |
+
val=(X_val, y_val),
|
138 |
+
test=(X_test, y_test),
|
139 |
+
epoch_size=epoch_size,
|
140 |
+
nb_epochs=nb_epochs,
|
141 |
+
batch_size=batch_size,
|
142 |
+
init_weight_path=f1_init_path,
|
143 |
+
checkpoint_weight_path=checkpoint_path,
|
144 |
+
verbose=verbose)
|
145 |
+
return model, result
|
146 |
+
|
147 |
+
|
148 |
+
def prepare_labels(y_train, y_val, y_test, iter_i, nb_classes):
|
149 |
+
# Relabel into binary classification
|
150 |
+
y_train_new = relabel(y_train, iter_i, nb_classes)
|
151 |
+
y_val_new = relabel(y_val, iter_i, nb_classes)
|
152 |
+
y_test_new = relabel(y_test, iter_i, nb_classes)
|
153 |
+
return y_train_new, y_val_new, y_test_new
|
154 |
+
|
155 |
+
def prepare_generators(X_train, y_train_new, X_val, y_val_new, batch_size, epoch_size):
|
156 |
+
# Create sample generators
|
157 |
+
# Make a fixed validation set to avoid fluctuations in validation
|
158 |
+
train_gen = get_data_loader(X_train, y_train_new, batch_size,
|
159 |
+
extended_batch_sampler=True)
|
160 |
+
val_gen = get_data_loader(X_val, y_val_new, epoch_size,
|
161 |
+
extended_batch_sampler=True)
|
162 |
+
X_val_resamp, y_val_resamp = next(iter(val_gen))
|
163 |
+
return train_gen, X_val_resamp, y_val_resamp
|
164 |
+
|
165 |
+
|
166 |
+
def class_avg_tune_trainable(model, nb_classes, loss_op, optim_op, train, val, test,
|
167 |
+
epoch_size, nb_epochs, batch_size,
|
168 |
+
init_weight_path, checkpoint_weight_path, patience=5,
|
169 |
+
verbose=True):
|
170 |
+
""" Finetunes the given model using the F1 measure.
|
171 |
+
|
172 |
+
# Arguments:
|
173 |
+
model: Model to be finetuned.
|
174 |
+
nb_classes: Number of classes in the given dataset.
|
175 |
+
train: Training data, given as a tuple of (inputs, outputs)
|
176 |
+
val: Validation data, given as a tuple of (inputs, outputs)
|
177 |
+
test: Testing data, given as a tuple of (inputs, outputs)
|
178 |
+
epoch_size: Number of samples in an epoch.
|
179 |
+
nb_epochs: Number of epochs.
|
180 |
+
batch_size: Batch size.
|
181 |
+
init_weight_path: Filepath where weights will be initially saved before
|
182 |
+
training each class. This file will be rewritten by the function.
|
183 |
+
checkpoint_weight_path: Filepath where weights will be checkpointed to
|
184 |
+
during training. This file will be rewritten by the function.
|
185 |
+
verbose: Verbosity flag.
|
186 |
+
|
187 |
+
# Returns:
|
188 |
+
F1 score of the trained model
|
189 |
+
"""
|
190 |
+
total_f1 = 0
|
191 |
+
nb_iter = nb_classes if nb_classes > 2 else 1
|
192 |
+
|
193 |
+
# Unpack args
|
194 |
+
X_train, y_train = train
|
195 |
+
X_val, y_val = val
|
196 |
+
X_test, y_test = test
|
197 |
+
|
198 |
+
# Save and reload initial weights after running for
|
199 |
+
# each class to avoid learning across classes
|
200 |
+
torch.save(model.state_dict(), init_weight_path)
|
201 |
+
for i in range(nb_iter):
|
202 |
+
if verbose:
|
203 |
+
print('Iteration number {}/{}'.format(i+1, nb_iter))
|
204 |
+
|
205 |
+
model.load_state_dict(torch.load(init_weight_path))
|
206 |
+
y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
|
207 |
+
y_test, i, nb_classes)
|
208 |
+
train_gen, X_val_resamp, y_val_resamp = \
|
209 |
+
prepare_generators(X_train, y_train_new, X_val, y_val_new,
|
210 |
+
batch_size, epoch_size)
|
211 |
+
|
212 |
+
if verbose:
|
213 |
+
print("Training..")
|
214 |
+
fit_model(model, loss_op, optim_op, train_gen, [(X_val_resamp, y_val_resamp)],
|
215 |
+
nb_epochs, checkpoint_weight_path, patience, verbose=0)
|
216 |
+
|
217 |
+
# Reload the best weights found to avoid overfitting
|
218 |
+
# Wait a bit to allow proper closing of weights file
|
219 |
+
sleep(1)
|
220 |
+
model.load_state_dict(torch.load(checkpoint_weight_path))
|
221 |
+
|
222 |
+
# Evaluate
|
223 |
+
y_pred_val = model(X_val).cpu().numpy()
|
224 |
+
y_pred_test = model(X_test).cpu().numpy()
|
225 |
+
|
226 |
+
f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
|
227 |
+
y_test_new, y_pred_test)
|
228 |
+
if verbose:
|
229 |
+
print('f1_test: {}'.format(f1_test))
|
230 |
+
print('best_t: {}'.format(best_t))
|
231 |
+
total_f1 += f1_test
|
232 |
+
|
233 |
+
return total_f1 / nb_iter
|
234 |
+
|
235 |
+
|
236 |
+
def class_avg_chainthaw(model, nb_classes, loss_op, train, val, test, batch_size,
|
237 |
+
epoch_size, nb_epochs, checkpoint_weight_path,
|
238 |
+
f1_init_weight_path, patience=5,
|
239 |
+
initial_lr=0.001, next_lr=0.0001, verbose=True):
|
240 |
+
""" Finetunes given model using chain-thaw and evaluates using F1.
|
241 |
+
For a dataset with multiple classes, the model is trained once for
|
242 |
+
each class, relabeling those classes into a binary classification task.
|
243 |
+
The result is an average of all F1 scores for each class.
|
244 |
+
|
245 |
+
# Arguments:
|
246 |
+
model: Model to be finetuned.
|
247 |
+
nb_classes: Number of classes in the given dataset.
|
248 |
+
train: Training data, given as a tuple of (inputs, outputs)
|
249 |
+
val: Validation data, given as a tuple of (inputs, outputs)
|
250 |
+
test: Testing data, given as a tuple of (inputs, outputs)
|
251 |
+
batch_size: Batch size.
|
252 |
+
loss: Loss function to be used during training.
|
253 |
+
epoch_size: Number of samples in an epoch.
|
254 |
+
nb_epochs: Number of epochs.
|
255 |
+
checkpoint_weight_path: Filepath where weights will be checkpointed to
|
256 |
+
during training. This file will be rewritten by the function.
|
257 |
+
f1_init_weight_path: Filepath where weights will be saved to and
|
258 |
+
reloaded from before training each class. This ensures that
|
259 |
+
each class is trained independently. This file will be rewritten.
|
260 |
+
initial_lr: Initial learning rate. Will only be used for the first
|
261 |
+
training step (i.e. the softmax layer)
|
262 |
+
next_lr: Learning rate for every subsequent step.
|
263 |
+
seed: Random number generator seed.
|
264 |
+
verbose: Verbosity flag.
|
265 |
+
|
266 |
+
# Returns:
|
267 |
+
Averaged F1 score.
|
268 |
+
"""
|
269 |
+
|
270 |
+
# Unpack args
|
271 |
+
X_train, y_train = train
|
272 |
+
X_val, y_val = val
|
273 |
+
X_test, y_test = test
|
274 |
+
|
275 |
+
total_f1 = 0
|
276 |
+
nb_iter = nb_classes if nb_classes > 2 else 1
|
277 |
+
|
278 |
+
torch.save(model.state_dict(), f1_init_weight_path)
|
279 |
+
|
280 |
+
for i in range(nb_iter):
|
281 |
+
if verbose:
|
282 |
+
print('Iteration number {}/{}'.format(i+1, nb_iter))
|
283 |
+
|
284 |
+
model.load_state_dict(torch.load(f1_init_weight_path))
|
285 |
+
y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
|
286 |
+
y_test, i, nb_classes)
|
287 |
+
train_gen, X_val_resamp, y_val_resamp = \
|
288 |
+
prepare_generators(X_train, y_train_new, X_val, y_val_new,
|
289 |
+
batch_size, epoch_size)
|
290 |
+
|
291 |
+
if verbose:
|
292 |
+
print("Training..")
|
293 |
+
|
294 |
+
# Train using chain-thaw
|
295 |
+
train_by_chain_thaw(model=model, train_gen=train_gen,
|
296 |
+
val_gen=[(X_val_resamp, y_val_resamp)],
|
297 |
+
loss_op=loss_op, patience=patience,
|
298 |
+
nb_epochs=nb_epochs,
|
299 |
+
checkpoint_path=checkpoint_weight_path,
|
300 |
+
initial_lr=initial_lr, next_lr=next_lr,
|
301 |
+
verbose=verbose)
|
302 |
+
|
303 |
+
# Evaluate
|
304 |
+
y_pred_val = model(X_val).cpu().numpy()
|
305 |
+
y_pred_test = model(X_test).cpu().numpy()
|
306 |
+
|
307 |
+
f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
|
308 |
+
y_test_new, y_pred_test)
|
309 |
+
|
310 |
+
if verbose:
|
311 |
+
print('f1_test: {}'.format(f1_test))
|
312 |
+
print('best_t: {}'.format(best_t))
|
313 |
+
total_f1 += f1_test
|
314 |
+
|
315 |
+
return total_f1 / nb_iter
|
torchmoji/create_vocab.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from __future__ import print_function, division
|
3 |
+
|
4 |
+
import glob
|
5 |
+
import json
|
6 |
+
import uuid
|
7 |
+
from copy import deepcopy
|
8 |
+
from collections import defaultdict, OrderedDict
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torchmoji.filter_utils import is_special_token
|
12 |
+
from torchmoji.word_generator import WordGenerator
|
13 |
+
from torchmoji.global_variables import SPECIAL_TOKENS, VOCAB_PATH
|
14 |
+
|
15 |
+
class VocabBuilder():
|
16 |
+
""" Create vocabulary with words extracted from sentences as fed from a
|
17 |
+
word generator.
|
18 |
+
"""
|
19 |
+
def __init__(self, word_gen):
|
20 |
+
# initialize any new key with value of 0
|
21 |
+
self.word_counts = defaultdict(lambda: 0, {})
|
22 |
+
self.word_length_limit=30
|
23 |
+
|
24 |
+
for token in SPECIAL_TOKENS:
|
25 |
+
assert len(token) < self.word_length_limit
|
26 |
+
self.word_counts[token] = 0
|
27 |
+
self.word_gen = word_gen
|
28 |
+
|
29 |
+
def count_words_in_sentence(self, words):
|
30 |
+
""" Generates word counts for all tokens in the given sentence.
|
31 |
+
|
32 |
+
# Arguments:
|
33 |
+
words: Tokenized sentence whose words should be counted.
|
34 |
+
"""
|
35 |
+
for word in words:
|
36 |
+
if 0 < len(word) and len(word) <= self.word_length_limit:
|
37 |
+
try:
|
38 |
+
self.word_counts[word] += 1
|
39 |
+
except KeyError:
|
40 |
+
self.word_counts[word] = 1
|
41 |
+
|
42 |
+
def save_vocab(self, path=None):
|
43 |
+
""" Saves the vocabulary into a file.
|
44 |
+
|
45 |
+
# Arguments:
|
46 |
+
path: Where the vocabulary should be saved. If not specified, a
|
47 |
+
randomly generated filename is used instead.
|
48 |
+
"""
|
49 |
+
dtype = ([('word','|S{}'.format(self.word_length_limit)),('count','int')])
|
50 |
+
np_dict = np.array(self.word_counts.items(), dtype=dtype)
|
51 |
+
|
52 |
+
# sort from highest to lowest frequency
|
53 |
+
np_dict[::-1].sort(order='count')
|
54 |
+
data = np_dict
|
55 |
+
|
56 |
+
if path is None:
|
57 |
+
path = str(uuid.uuid4())
|
58 |
+
|
59 |
+
np.savez_compressed(path, data=data)
|
60 |
+
print("Saved dict to {}".format(path))
|
61 |
+
|
62 |
+
def get_next_word(self):
|
63 |
+
""" Returns next tokenized sentence from the word geneerator.
|
64 |
+
|
65 |
+
# Returns:
|
66 |
+
List of strings, representing the next tokenized sentence.
|
67 |
+
"""
|
68 |
+
return self.word_gen.__iter__().next()
|
69 |
+
|
70 |
+
def count_all_words(self):
|
71 |
+
""" Generates word counts for all words in all sentences of the word
|
72 |
+
generator.
|
73 |
+
"""
|
74 |
+
for words, _ in self.word_gen:
|
75 |
+
self.count_words_in_sentence(words)
|
76 |
+
|
77 |
+
class MasterVocab():
|
78 |
+
""" Combines vocabularies.
|
79 |
+
"""
|
80 |
+
def __init__(self):
|
81 |
+
|
82 |
+
# initialize custom tokens
|
83 |
+
self.master_vocab = {}
|
84 |
+
|
85 |
+
def populate_master_vocab(self, vocab_path, min_words=1, force_appearance=None):
|
86 |
+
""" Populates the master vocabulary using all vocabularies found in the
|
87 |
+
given path. Vocabularies should be named *.npz. Expects the
|
88 |
+
vocabularies to be numpy arrays with counts. Normalizes the counts
|
89 |
+
and combines them.
|
90 |
+
|
91 |
+
# Arguments:
|
92 |
+
vocab_path: Path containing vocabularies to be combined.
|
93 |
+
min_words: Minimum amount of occurences a word must have in order
|
94 |
+
to be included in the master vocabulary.
|
95 |
+
force_appearance: Optional vocabulary filename that will be added
|
96 |
+
to the master vocabulary no matter what. This vocabulary must
|
97 |
+
be present in vocab_path.
|
98 |
+
"""
|
99 |
+
|
100 |
+
paths = glob.glob(vocab_path + '*.npz')
|
101 |
+
sizes = {path: 0 for path in paths}
|
102 |
+
dicts = {path: {} for path in paths}
|
103 |
+
|
104 |
+
# set up and get sizes of individual dictionaries
|
105 |
+
for path in paths:
|
106 |
+
np_data = np.load(path)['data']
|
107 |
+
|
108 |
+
for entry in np_data:
|
109 |
+
word, count = entry
|
110 |
+
if count < min_words:
|
111 |
+
continue
|
112 |
+
if is_special_token(word):
|
113 |
+
continue
|
114 |
+
dicts[path][word] = count
|
115 |
+
|
116 |
+
sizes[path] = sum(dicts[path].values())
|
117 |
+
print('Overall word count for {} -> {}'.format(path, sizes[path]))
|
118 |
+
print('Overall word number for {} -> {}'.format(path, len(dicts[path])))
|
119 |
+
|
120 |
+
vocab_of_max_size = max(sizes, key=sizes.get)
|
121 |
+
max_size = sizes[vocab_of_max_size]
|
122 |
+
print('Min: {}, {}, {}'.format(sizes, vocab_of_max_size, max_size))
|
123 |
+
|
124 |
+
# can force one vocabulary to always be present
|
125 |
+
if force_appearance is not None:
|
126 |
+
force_appearance_path = [p for p in paths if force_appearance in p][0]
|
127 |
+
force_appearance_vocab = deepcopy(dicts[force_appearance_path])
|
128 |
+
print(force_appearance_path)
|
129 |
+
else:
|
130 |
+
force_appearance_path, force_appearance_vocab = None, None
|
131 |
+
|
132 |
+
# normalize word counts before inserting into master dict
|
133 |
+
for path in paths:
|
134 |
+
normalization_factor = max_size / sizes[path]
|
135 |
+
print('Norm factor for path {} -> {}'.format(path, normalization_factor))
|
136 |
+
|
137 |
+
for word in dicts[path]:
|
138 |
+
if is_special_token(word):
|
139 |
+
print("SPECIAL - ", word)
|
140 |
+
continue
|
141 |
+
normalized_count = dicts[path][word] * normalization_factor
|
142 |
+
|
143 |
+
# can force one vocabulary to always be present
|
144 |
+
if force_appearance_vocab is not None:
|
145 |
+
try:
|
146 |
+
force_word_count = force_appearance_vocab[word]
|
147 |
+
except KeyError:
|
148 |
+
continue
|
149 |
+
#if force_word_count < 5:
|
150 |
+
#continue
|
151 |
+
|
152 |
+
if word in self.master_vocab:
|
153 |
+
self.master_vocab[word] += normalized_count
|
154 |
+
else:
|
155 |
+
self.master_vocab[word] = normalized_count
|
156 |
+
|
157 |
+
print('Size of master_dict {}'.format(len(self.master_vocab)))
|
158 |
+
print("Hashes for master dict: {}".format(
|
159 |
+
len([w for w in self.master_vocab if '#' in w[0]])))
|
160 |
+
|
161 |
+
def save_vocab(self, path_count, path_vocab, word_limit=100000):
|
162 |
+
""" Saves the master vocabulary into a file.
|
163 |
+
"""
|
164 |
+
|
165 |
+
# reserve space for 10 special tokens
|
166 |
+
words = OrderedDict()
|
167 |
+
for token in SPECIAL_TOKENS:
|
168 |
+
# store -1 instead of np.inf, which can overflow
|
169 |
+
words[token] = -1
|
170 |
+
|
171 |
+
# sort words by frequency
|
172 |
+
desc_order = OrderedDict(sorted(self.master_vocab.items(),
|
173 |
+
key=lambda kv: kv[1], reverse=True))
|
174 |
+
words.update(desc_order)
|
175 |
+
|
176 |
+
# use encoding of up to 30 characters (no token conversions)
|
177 |
+
# use float to store large numbers (we don't care about precision loss)
|
178 |
+
np_vocab = np.array(words.items(),
|
179 |
+
dtype=([('word','|S30'),('count','float')]))
|
180 |
+
|
181 |
+
# output count for debugging
|
182 |
+
counts = np_vocab[:word_limit]
|
183 |
+
np.savez_compressed(path_count, counts=counts)
|
184 |
+
|
185 |
+
# output the index of each word for easy lookup
|
186 |
+
final_words = OrderedDict()
|
187 |
+
for i, w in enumerate(words.keys()[:word_limit]):
|
188 |
+
final_words.update({w:i})
|
189 |
+
with open(path_vocab, 'w') as f:
|
190 |
+
f.write(json.dumps(final_words, indent=4, separators=(',', ': ')))
|
191 |
+
|
192 |
+
|
193 |
+
def all_words_in_sentences(sentences):
|
194 |
+
""" Extracts all unique words from a given list of sentences.
|
195 |
+
|
196 |
+
# Arguments:
|
197 |
+
sentences: List or word generator of sentences to be processed.
|
198 |
+
|
199 |
+
# Returns:
|
200 |
+
List of all unique words contained in the given sentences.
|
201 |
+
"""
|
202 |
+
vocab = []
|
203 |
+
if isinstance(sentences, WordGenerator):
|
204 |
+
sentences = [s for s, _ in sentences]
|
205 |
+
|
206 |
+
for sentence in sentences:
|
207 |
+
for word in sentence:
|
208 |
+
if word not in vocab:
|
209 |
+
vocab.append(word)
|
210 |
+
|
211 |
+
return vocab
|
212 |
+
|
213 |
+
|
214 |
+
def extend_vocab_in_file(vocab, max_tokens=10000, vocab_path=VOCAB_PATH):
|
215 |
+
""" Extends JSON-formatted vocabulary with words from vocab that are not
|
216 |
+
present in the current vocabulary. Adds up to max_tokens words.
|
217 |
+
Overwrites file in vocab_path.
|
218 |
+
|
219 |
+
# Arguments:
|
220 |
+
new_vocab: Vocabulary to be added. MUST have word_counts populated, i.e.
|
221 |
+
must have run count_all_words() previously.
|
222 |
+
max_tokens: Maximum number of words to be added.
|
223 |
+
vocab_path: Path to the vocabulary json which is to be extended.
|
224 |
+
"""
|
225 |
+
try:
|
226 |
+
with open(vocab_path, 'r') as f:
|
227 |
+
current_vocab = json.load(f)
|
228 |
+
except IOError:
|
229 |
+
print('Vocabulary file not found, expected at ' + vocab_path)
|
230 |
+
return
|
231 |
+
|
232 |
+
extend_vocab(current_vocab, vocab, max_tokens)
|
233 |
+
|
234 |
+
# Save back to file
|
235 |
+
with open(vocab_path, 'w') as f:
|
236 |
+
json.dump(current_vocab, f, sort_keys=True, indent=4, separators=(',',': '))
|
237 |
+
|
238 |
+
|
239 |
+
def extend_vocab(current_vocab, new_vocab, max_tokens=10000):
|
240 |
+
""" Extends current vocabulary with words from vocab that are not
|
241 |
+
present in the current vocabulary. Adds up to max_tokens words.
|
242 |
+
|
243 |
+
# Arguments:
|
244 |
+
current_vocab: Current dictionary of tokens.
|
245 |
+
new_vocab: Vocabulary to be added. MUST have word_counts populated, i.e.
|
246 |
+
must have run count_all_words() previously.
|
247 |
+
max_tokens: Maximum number of words to be added.
|
248 |
+
|
249 |
+
# Returns:
|
250 |
+
How many new tokens have been added.
|
251 |
+
"""
|
252 |
+
if max_tokens < 0:
|
253 |
+
max_tokens = 10000
|
254 |
+
|
255 |
+
words = OrderedDict()
|
256 |
+
|
257 |
+
# sort words by frequency
|
258 |
+
desc_order = OrderedDict(sorted(new_vocab.word_counts.items(),
|
259 |
+
key=lambda kv: kv[1], reverse=True))
|
260 |
+
words.update(desc_order)
|
261 |
+
|
262 |
+
base_index = len(current_vocab.keys())
|
263 |
+
added = 0
|
264 |
+
for word in words:
|
265 |
+
if added >= max_tokens:
|
266 |
+
break
|
267 |
+
if word not in current_vocab.keys():
|
268 |
+
current_vocab[word] = base_index + added
|
269 |
+
added += 1
|
270 |
+
|
271 |
+
return added
|
torchmoji/filter_input.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from __future__ import print_function, division
|
3 |
+
import codecs
|
4 |
+
import csv
|
5 |
+
import numpy as np
|
6 |
+
from emoji import UNICODE_EMOJI
|
7 |
+
|
8 |
+
def read_english(path="english_words.txt", add_emojis=True):
|
9 |
+
# read english words for filtering (includes emojis as part of set)
|
10 |
+
english = set()
|
11 |
+
with codecs.open(path, "r", "utf-8") as f:
|
12 |
+
for line in f:
|
13 |
+
line = line.strip().lower().replace('\n', '')
|
14 |
+
if len(line):
|
15 |
+
english.add(line)
|
16 |
+
if add_emojis:
|
17 |
+
for e in UNICODE_EMOJI:
|
18 |
+
english.add(e)
|
19 |
+
return english
|
20 |
+
|
21 |
+
def read_wanted_emojis(path="wanted_emojis.csv"):
|
22 |
+
emojis = []
|
23 |
+
with open(path, 'rb') as f:
|
24 |
+
reader = csv.reader(f)
|
25 |
+
for line in reader:
|
26 |
+
line = line[0].strip().replace('\n', '')
|
27 |
+
line = line.decode('unicode-escape')
|
28 |
+
emojis.append(line)
|
29 |
+
return emojis
|
30 |
+
|
31 |
+
def read_non_english_users(path="unwanted_users.npz"):
|
32 |
+
try:
|
33 |
+
neu_set = set(np.load(path)['userids'])
|
34 |
+
except IOError:
|
35 |
+
neu_set = set()
|
36 |
+
return neu_set
|
torchmoji/filter_utils.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from __future__ import print_function, division, unicode_literals
|
4 |
+
import sys
|
5 |
+
import re
|
6 |
+
import string
|
7 |
+
import emoji
|
8 |
+
from itertools import groupby
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from torchmoji.tokenizer import RE_MENTION, RE_URL
|
12 |
+
from torchmoji.global_variables import SPECIAL_TOKENS
|
13 |
+
|
14 |
+
IS_PYTHON2 = int(sys.version[0]) == 2
|
15 |
+
chr_ = unichr if IS_PYTHON2 else chr
|
16 |
+
|
17 |
+
AtMentionRegex = re.compile(RE_MENTION)
|
18 |
+
urlRegex = re.compile(RE_URL)
|
19 |
+
|
20 |
+
# from http://bit.ly/2rdjgjE (UTF-8 encodings and Unicode chars)
|
21 |
+
VARIATION_SELECTORS = [ '\ufe00',
|
22 |
+
'\ufe01',
|
23 |
+
'\ufe02',
|
24 |
+
'\ufe03',
|
25 |
+
'\ufe04',
|
26 |
+
'\ufe05',
|
27 |
+
'\ufe06',
|
28 |
+
'\ufe07',
|
29 |
+
'\ufe08',
|
30 |
+
'\ufe09',
|
31 |
+
'\ufe0a',
|
32 |
+
'\ufe0b',
|
33 |
+
'\ufe0c',
|
34 |
+
'\ufe0d',
|
35 |
+
'\ufe0e',
|
36 |
+
'\ufe0f']
|
37 |
+
|
38 |
+
# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python
|
39 |
+
ALL_CHARS = (chr_(i) for i in range(sys.maxunicode))
|
40 |
+
CONTROL_CHARS = ''.join(map(chr_, list(range(0,32)) + list(range(127,160))))
|
41 |
+
CONTROL_CHAR_REGEX = re.compile('[%s]' % re.escape(CONTROL_CHARS))
|
42 |
+
|
43 |
+
def is_special_token(word):
|
44 |
+
equal = False
|
45 |
+
for spec in SPECIAL_TOKENS:
|
46 |
+
if word == spec:
|
47 |
+
equal = True
|
48 |
+
break
|
49 |
+
return equal
|
50 |
+
|
51 |
+
def mostly_english(words, english, pct_eng_short=0.5, pct_eng_long=0.6, ignore_special_tokens=True, min_length=2):
|
52 |
+
""" Ensure text meets threshold for containing English words """
|
53 |
+
|
54 |
+
n_words = 0
|
55 |
+
n_english = 0
|
56 |
+
|
57 |
+
if english is None:
|
58 |
+
return True, 0, 0
|
59 |
+
|
60 |
+
for w in words:
|
61 |
+
if len(w) < min_length:
|
62 |
+
continue
|
63 |
+
if punct_word(w):
|
64 |
+
continue
|
65 |
+
if ignore_special_tokens and is_special_token(w):
|
66 |
+
continue
|
67 |
+
n_words += 1
|
68 |
+
if w in english:
|
69 |
+
n_english += 1
|
70 |
+
|
71 |
+
if n_words < 2:
|
72 |
+
return True, n_words, n_english
|
73 |
+
if n_words < 5:
|
74 |
+
valid_english = n_english >= n_words * pct_eng_short
|
75 |
+
else:
|
76 |
+
valid_english = n_english >= n_words * pct_eng_long
|
77 |
+
return valid_english, n_words, n_english
|
78 |
+
|
79 |
+
def correct_length(words, min_words, max_words, ignore_special_tokens=True):
|
80 |
+
""" Ensure text meets threshold for containing English words
|
81 |
+
and that it's within the min and max words limits. """
|
82 |
+
|
83 |
+
if min_words is None:
|
84 |
+
min_words = 0
|
85 |
+
|
86 |
+
if max_words is None:
|
87 |
+
max_words = 99999
|
88 |
+
|
89 |
+
n_words = 0
|
90 |
+
for w in words:
|
91 |
+
if punct_word(w):
|
92 |
+
continue
|
93 |
+
if ignore_special_tokens and is_special_token(w):
|
94 |
+
continue
|
95 |
+
n_words += 1
|
96 |
+
valid = min_words <= n_words and n_words <= max_words
|
97 |
+
return valid
|
98 |
+
|
99 |
+
def punct_word(word, punctuation=string.punctuation):
|
100 |
+
return all([True if c in punctuation else False for c in word])
|
101 |
+
|
102 |
+
def load_non_english_user_set():
|
103 |
+
non_english_user_set = set(np.load('uids.npz')['data'])
|
104 |
+
return non_english_user_set
|
105 |
+
|
106 |
+
def non_english_user(userid, non_english_user_set):
|
107 |
+
neu_found = int(userid) in non_english_user_set
|
108 |
+
return neu_found
|
109 |
+
|
110 |
+
def separate_emojis_and_text(text):
|
111 |
+
emoji_chars = []
|
112 |
+
non_emoji_chars = []
|
113 |
+
for c in text:
|
114 |
+
if c in emoji.UNICODE_EMOJI:
|
115 |
+
emoji_chars.append(c)
|
116 |
+
else:
|
117 |
+
non_emoji_chars.append(c)
|
118 |
+
return ''.join(emoji_chars), ''.join(non_emoji_chars)
|
119 |
+
|
120 |
+
def extract_emojis(text, wanted_emojis):
|
121 |
+
text = remove_variation_selectors(text)
|
122 |
+
return [c for c in text if c in wanted_emojis]
|
123 |
+
|
124 |
+
def remove_variation_selectors(text):
|
125 |
+
""" Remove styling glyph variants for Unicode characters.
|
126 |
+
For instance, remove skin color from emojis.
|
127 |
+
"""
|
128 |
+
for var in VARIATION_SELECTORS:
|
129 |
+
text = text.replace(var, '')
|
130 |
+
return text
|
131 |
+
|
132 |
+
def shorten_word(word):
|
133 |
+
""" Shorten groupings of 3+ identical consecutive chars to 2, e.g. '!!!!' --> '!!'
|
134 |
+
"""
|
135 |
+
|
136 |
+
# only shorten ASCII words
|
137 |
+
try:
|
138 |
+
word.decode('ascii')
|
139 |
+
except (UnicodeDecodeError, UnicodeEncodeError, AttributeError) as e:
|
140 |
+
return word
|
141 |
+
|
142 |
+
# must have at least 3 char to be shortened
|
143 |
+
if len(word) < 3:
|
144 |
+
return word
|
145 |
+
|
146 |
+
# find groups of 3+ consecutive letters
|
147 |
+
letter_groups = [list(g) for k, g in groupby(word)]
|
148 |
+
triple_or_more = [''.join(g) for g in letter_groups if len(g) >= 3]
|
149 |
+
if len(triple_or_more) == 0:
|
150 |
+
return word
|
151 |
+
|
152 |
+
# replace letters to find the short word
|
153 |
+
short_word = word
|
154 |
+
for trip in triple_or_more:
|
155 |
+
short_word = short_word.replace(trip, trip[0]*2)
|
156 |
+
|
157 |
+
return short_word
|
158 |
+
|
159 |
+
def detect_special_tokens(word):
|
160 |
+
try:
|
161 |
+
int(word)
|
162 |
+
word = SPECIAL_TOKENS[4]
|
163 |
+
except ValueError:
|
164 |
+
if AtMentionRegex.findall(word):
|
165 |
+
word = SPECIAL_TOKENS[2]
|
166 |
+
elif urlRegex.findall(word):
|
167 |
+
word = SPECIAL_TOKENS[3]
|
168 |
+
return word
|
169 |
+
|
170 |
+
def process_word(word):
|
171 |
+
""" Shortening and converting the word to a special token if relevant.
|
172 |
+
"""
|
173 |
+
word = shorten_word(word)
|
174 |
+
word = detect_special_tokens(word)
|
175 |
+
return word
|
176 |
+
|
177 |
+
def remove_control_chars(text):
|
178 |
+
return CONTROL_CHAR_REGEX.sub('', text)
|
179 |
+
|
180 |
+
def convert_nonbreaking_space(text):
|
181 |
+
# ugly hack handling non-breaking space no matter how badly it's been encoded in the input
|
182 |
+
for r in ['\\\\xc2', '\\xc2', '\xc2', '\\\\xa0', '\\xa0', '\xa0']:
|
183 |
+
text = text.replace(r, ' ')
|
184 |
+
return text
|
185 |
+
|
186 |
+
def convert_linebreaks(text):
|
187 |
+
# ugly hack handling non-breaking space no matter how badly it's been encoded in the input
|
188 |
+
# space around to ensure proper tokenization
|
189 |
+
for r in ['\\\\n', '\\n', '\n', '\\\\r', '\\r', '\r', '<br>']:
|
190 |
+
text = text.replace(r, ' ' + SPECIAL_TOKENS[5] + ' ')
|
191 |
+
return text
|
torchmoji/finetuning.py
ADDED
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
""" Finetuning functions for doing transfer learning to new datasets.
|
3 |
+
"""
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import sys
|
7 |
+
import uuid
|
8 |
+
from time import sleep
|
9 |
+
from io import open
|
10 |
+
|
11 |
+
import math
|
12 |
+
import pickle
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.optim as optim
|
18 |
+
from torch.autograd import Variable
|
19 |
+
from torch.utils.data import Dataset, DataLoader
|
20 |
+
from torch.utils.data.sampler import BatchSampler, SequentialSampler
|
21 |
+
from torch.nn.utils import clip_grad_norm
|
22 |
+
|
23 |
+
from sklearn.metrics import f1_score
|
24 |
+
|
25 |
+
from torchmoji.global_variables import (FINETUNING_METHODS,
|
26 |
+
FINETUNING_METRICS,
|
27 |
+
WEIGHTS_DIR)
|
28 |
+
from torchmoji.tokenizer import tokenize
|
29 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
30 |
+
|
31 |
+
IS_PYTHON2 = int(sys.version[0]) == 2
|
32 |
+
unicode_ = unicode if IS_PYTHON2 else str
|
33 |
+
|
34 |
+
def load_benchmark(path, vocab, extend_with=0):
|
35 |
+
""" Loads the given benchmark dataset.
|
36 |
+
|
37 |
+
Tokenizes the texts using the provided vocabulary, extending it with
|
38 |
+
words from the training dataset if extend_with > 0. Splits them into
|
39 |
+
three lists: training, validation and testing (in that order).
|
40 |
+
|
41 |
+
Also calculates the maximum length of the texts and the
|
42 |
+
suggested batch_size.
|
43 |
+
|
44 |
+
# Arguments:
|
45 |
+
path: Path to the dataset to be loaded.
|
46 |
+
vocab: Vocabulary to be used for tokenizing texts.
|
47 |
+
extend_with: If > 0, the vocabulary will be extended with up to
|
48 |
+
extend_with tokens from the training set before tokenizing.
|
49 |
+
|
50 |
+
# Returns:
|
51 |
+
A dictionary with the following fields:
|
52 |
+
texts: List of three lists, containing tokenized inputs for
|
53 |
+
training, validation and testing (in that order).
|
54 |
+
labels: List of three lists, containing labels for training,
|
55 |
+
validation and testing (in that order).
|
56 |
+
added: Number of tokens added to the vocabulary.
|
57 |
+
batch_size: Batch size.
|
58 |
+
maxlen: Maximum length of an input.
|
59 |
+
"""
|
60 |
+
# Pre-processing dataset
|
61 |
+
with open(path, 'rb') as dataset:
|
62 |
+
if IS_PYTHON2:
|
63 |
+
data = pickle.load(dataset)
|
64 |
+
else:
|
65 |
+
data = pickle.load(dataset, fix_imports=True)
|
66 |
+
|
67 |
+
# Decode data
|
68 |
+
try:
|
69 |
+
texts = [unicode_(x) for x in data['texts']]
|
70 |
+
except UnicodeDecodeError:
|
71 |
+
texts = [x.decode('utf-8') for x in data['texts']]
|
72 |
+
|
73 |
+
# Extract labels
|
74 |
+
labels = [x['label'] for x in data['info']]
|
75 |
+
|
76 |
+
batch_size, maxlen = calculate_batchsize_maxlen(texts)
|
77 |
+
|
78 |
+
st = SentenceTokenizer(vocab, maxlen)
|
79 |
+
|
80 |
+
# Split up dataset. Extend the existing vocabulary with up to extend_with
|
81 |
+
# tokens from the training dataset.
|
82 |
+
texts, labels, added = st.split_train_val_test(texts,
|
83 |
+
labels,
|
84 |
+
[data['train_ind'],
|
85 |
+
data['val_ind'],
|
86 |
+
data['test_ind']],
|
87 |
+
extend_with=extend_with)
|
88 |
+
return {'texts': texts,
|
89 |
+
'labels': labels,
|
90 |
+
'added': added,
|
91 |
+
'batch_size': batch_size,
|
92 |
+
'maxlen': maxlen}
|
93 |
+
|
94 |
+
|
95 |
+
def calculate_batchsize_maxlen(texts):
|
96 |
+
""" Calculates the maximum length in the provided texts and a suitable
|
97 |
+
batch size. Rounds up maxlen to the nearest multiple of ten.
|
98 |
+
|
99 |
+
# Arguments:
|
100 |
+
texts: List of inputs.
|
101 |
+
|
102 |
+
# Returns:
|
103 |
+
Batch size,
|
104 |
+
max length
|
105 |
+
"""
|
106 |
+
def roundup(x):
|
107 |
+
return int(math.ceil(x / 10.0)) * 10
|
108 |
+
|
109 |
+
# Calculate max length of sequences considered
|
110 |
+
# Adjust batch_size accordingly to prevent GPU overflow
|
111 |
+
lengths = [len(tokenize(t)) for t in texts]
|
112 |
+
maxlen = roundup(np.percentile(lengths, 80.0))
|
113 |
+
batch_size = 250 if maxlen <= 100 else 50
|
114 |
+
return batch_size, maxlen
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
def freeze_layers(model, unfrozen_types=[], unfrozen_keyword=None):
|
119 |
+
""" Freezes all layers in the given model, except for ones that are
|
120 |
+
explicitly specified to not be frozen.
|
121 |
+
|
122 |
+
# Arguments:
|
123 |
+
model: Model whose layers should be modified.
|
124 |
+
unfrozen_types: List of layer types which shouldn't be frozen.
|
125 |
+
unfrozen_keyword: Name keywords of layers that shouldn't be frozen.
|
126 |
+
|
127 |
+
# Returns:
|
128 |
+
Model with the selected layers frozen.
|
129 |
+
"""
|
130 |
+
# Get trainable modules
|
131 |
+
trainable_modules = [(n, m) for n, m in model.named_children() if len([id(p) for p in m.parameters()]) != 0]
|
132 |
+
for name, module in trainable_modules:
|
133 |
+
trainable = (any(typ in str(module) for typ in unfrozen_types) or
|
134 |
+
(unfrozen_keyword is not None and unfrozen_keyword.lower() in name.lower()))
|
135 |
+
change_trainable(module, trainable, verbose=False)
|
136 |
+
return model
|
137 |
+
|
138 |
+
|
139 |
+
def change_trainable(module, trainable, verbose=False):
|
140 |
+
""" Helper method that freezes or unfreezes a given layer.
|
141 |
+
|
142 |
+
# Arguments:
|
143 |
+
module: Module to be modified.
|
144 |
+
trainable: Whether the layer should be frozen or unfrozen.
|
145 |
+
verbose: Verbosity flag.
|
146 |
+
"""
|
147 |
+
|
148 |
+
if verbose: print('Changing MODULE', module, 'to trainable =', trainable)
|
149 |
+
for name, param in module.named_parameters():
|
150 |
+
if verbose: print('Setting weight', name, 'to trainable =', trainable)
|
151 |
+
param.requires_grad = trainable
|
152 |
+
|
153 |
+
if verbose:
|
154 |
+
action = 'Unfroze' if trainable else 'Froze'
|
155 |
+
if verbose: print("{} {}".format(action, module))
|
156 |
+
|
157 |
+
|
158 |
+
def find_f1_threshold(model, val_gen, test_gen, average='binary'):
|
159 |
+
""" Choose a threshold for F1 based on the validation dataset
|
160 |
+
(see https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4442797/
|
161 |
+
for details on why to find another threshold than simply 0.5)
|
162 |
+
|
163 |
+
# Arguments:
|
164 |
+
model: pyTorch model
|
165 |
+
val_gen: Validation set dataloader.
|
166 |
+
test_gen: Testing set dataloader.
|
167 |
+
|
168 |
+
# Returns:
|
169 |
+
F1 score for the given data and
|
170 |
+
the corresponding F1 threshold
|
171 |
+
"""
|
172 |
+
thresholds = np.arange(0.01, 0.5, step=0.01)
|
173 |
+
f1_scores = []
|
174 |
+
|
175 |
+
model.eval()
|
176 |
+
val_out = [(y, model(X)) for X, y in val_gen]
|
177 |
+
y_val, y_pred_val = (list(t) for t in zip(*val_out))
|
178 |
+
|
179 |
+
test_out = [(y, model(X)) for X, y in test_gen]
|
180 |
+
y_test, y_pred_test = (list(t) for t in zip(*val_out))
|
181 |
+
|
182 |
+
for t in thresholds:
|
183 |
+
y_pred_val_ind = (y_pred_val > t)
|
184 |
+
f1_val = f1_score(y_val, y_pred_val_ind, average=average)
|
185 |
+
f1_scores.append(f1_val)
|
186 |
+
|
187 |
+
best_t = thresholds[np.argmax(f1_scores)]
|
188 |
+
y_pred_ind = (y_pred_test > best_t)
|
189 |
+
f1_test = f1_score(y_test, y_pred_ind, average=average)
|
190 |
+
return f1_test, best_t
|
191 |
+
|
192 |
+
|
193 |
+
def finetune(model, texts, labels, nb_classes, batch_size, method,
|
194 |
+
metric='acc', epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
|
195 |
+
verbose=1):
|
196 |
+
""" Compiles and finetunes the given pytorch model.
|
197 |
+
|
198 |
+
# Arguments:
|
199 |
+
model: Model to be finetuned
|
200 |
+
texts: List of three lists, containing tokenized inputs for training,
|
201 |
+
validation and testing (in that order).
|
202 |
+
labels: List of three lists, containing labels for training,
|
203 |
+
validation and testing (in that order).
|
204 |
+
nb_classes: Number of classes in the dataset.
|
205 |
+
batch_size: Batch size.
|
206 |
+
method: Finetuning method to be used. For available methods, see
|
207 |
+
FINETUNING_METHODS in global_variables.py.
|
208 |
+
metric: Evaluation metric to be used. For available metrics, see
|
209 |
+
FINETUNING_METRICS in global_variables.py.
|
210 |
+
epoch_size: Number of samples in an epoch.
|
211 |
+
nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
|
212 |
+
embed_l2: L2 regularization for the embedding layer.
|
213 |
+
verbose: Verbosity flag.
|
214 |
+
|
215 |
+
# Returns:
|
216 |
+
Model after finetuning,
|
217 |
+
score after finetuning using the provided metric.
|
218 |
+
"""
|
219 |
+
|
220 |
+
if method not in FINETUNING_METHODS:
|
221 |
+
raise ValueError('ERROR (finetune): Invalid method parameter. '
|
222 |
+
'Available options: {}'.format(FINETUNING_METHODS))
|
223 |
+
if metric not in FINETUNING_METRICS:
|
224 |
+
raise ValueError('ERROR (finetune): Invalid metric parameter. '
|
225 |
+
'Available options: {}'.format(FINETUNING_METRICS))
|
226 |
+
|
227 |
+
train_gen = get_data_loader(texts[0], labels[0], batch_size,
|
228 |
+
extended_batch_sampler=True, epoch_size=epoch_size)
|
229 |
+
val_gen = get_data_loader(texts[1], labels[1], batch_size,
|
230 |
+
extended_batch_sampler=False)
|
231 |
+
test_gen = get_data_loader(texts[2], labels[2], batch_size,
|
232 |
+
extended_batch_sampler=False)
|
233 |
+
|
234 |
+
checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
|
235 |
+
.format(WEIGHTS_DIR, str(uuid.uuid4()))
|
236 |
+
|
237 |
+
if method in ['last', 'new']:
|
238 |
+
lr = 0.001
|
239 |
+
elif method in ['full', 'chain-thaw']:
|
240 |
+
lr = 0.0001
|
241 |
+
|
242 |
+
loss_op = nn.BCEWithLogitsLoss() if nb_classes <= 2 \
|
243 |
+
else nn.CrossEntropyLoss()
|
244 |
+
|
245 |
+
# Freeze layers if using last
|
246 |
+
if method == 'last':
|
247 |
+
model = freeze_layers(model, unfrozen_keyword='output_layer')
|
248 |
+
|
249 |
+
# Define optimizer, for chain-thaw we define it later (after freezing)
|
250 |
+
if method == 'last':
|
251 |
+
adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
|
252 |
+
elif method in ['full', 'new']:
|
253 |
+
# Add L2 regulation on embeddings only
|
254 |
+
embed_params_id = [id(p) for p in model.embed.parameters()]
|
255 |
+
output_layer_params_id = [id(p) for p in model.output_layer.parameters()]
|
256 |
+
base_params = [p for p in model.parameters()
|
257 |
+
if id(p) not in embed_params_id and id(p) not in output_layer_params_id and p.requires_grad]
|
258 |
+
embed_params = [p for p in model.parameters() if id(p) in embed_params_id and p.requires_grad]
|
259 |
+
output_layer_params = [p for p in model.parameters() if id(p) in output_layer_params_id and p.requires_grad]
|
260 |
+
adam = optim.Adam([
|
261 |
+
{'params': base_params},
|
262 |
+
{'params': embed_params, 'weight_decay': embed_l2},
|
263 |
+
{'params': output_layer_params, 'lr': 0.001},
|
264 |
+
], lr=lr)
|
265 |
+
|
266 |
+
# Training
|
267 |
+
if verbose:
|
268 |
+
print('Method: {}'.format(method))
|
269 |
+
print('Metric: {}'.format(metric))
|
270 |
+
print('Classes: {}'.format(nb_classes))
|
271 |
+
|
272 |
+
if method == 'chain-thaw':
|
273 |
+
result = chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, embed_l2=embed_l2,
|
274 |
+
evaluate=metric, verbose=verbose)
|
275 |
+
else:
|
276 |
+
result = tune_trainable(model, loss_op, adam, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path,
|
277 |
+
evaluate=metric, verbose=verbose)
|
278 |
+
return model, result
|
279 |
+
|
280 |
+
|
281 |
+
def tune_trainable(model, loss_op, optim_op, train_gen, val_gen, test_gen,
|
282 |
+
nb_epochs, checkpoint_path, patience=5, evaluate='acc',
|
283 |
+
verbose=2):
|
284 |
+
""" Finetunes the given model using the accuracy measure.
|
285 |
+
|
286 |
+
# Arguments:
|
287 |
+
model: Model to be finetuned.
|
288 |
+
nb_classes: Number of classes in the given dataset.
|
289 |
+
train: Training data, given as a tuple of (inputs, outputs)
|
290 |
+
val: Validation data, given as a tuple of (inputs, outputs)
|
291 |
+
test: Testing data, given as a tuple of (inputs, outputs)
|
292 |
+
epoch_size: Number of samples in an epoch.
|
293 |
+
nb_epochs: Number of epochs.
|
294 |
+
batch_size: Batch size.
|
295 |
+
checkpoint_weight_path: Filepath where weights will be checkpointed to
|
296 |
+
during training. This file will be rewritten by the function.
|
297 |
+
patience: Patience for callback methods.
|
298 |
+
evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'.
|
299 |
+
verbose: Verbosity flag.
|
300 |
+
|
301 |
+
# Returns:
|
302 |
+
Accuracy of the trained model, ONLY if 'evaluate' is set.
|
303 |
+
"""
|
304 |
+
if verbose:
|
305 |
+
print("Trainable weights: {}".format([n for n, p in model.named_parameters() if p.requires_grad]))
|
306 |
+
print("Training...")
|
307 |
+
if evaluate == 'acc':
|
308 |
+
print("Evaluation on test set prior training:", evaluate_using_acc(model, test_gen))
|
309 |
+
elif evaluate == 'weighted_f1':
|
310 |
+
print("Evaluation on test set prior training:", evaluate_using_weighted_f1(model, test_gen, val_gen))
|
311 |
+
|
312 |
+
fit_model(model, loss_op, optim_op, train_gen, val_gen, nb_epochs, checkpoint_path, patience)
|
313 |
+
|
314 |
+
# Reload the best weights found to avoid overfitting
|
315 |
+
# Wait a bit to allow proper closing of weights file
|
316 |
+
sleep(1)
|
317 |
+
model.load_state_dict(torch.load(checkpoint_path))
|
318 |
+
if verbose >= 2:
|
319 |
+
print("Loaded weights from {}".format(checkpoint_path))
|
320 |
+
|
321 |
+
if evaluate == 'acc':
|
322 |
+
return evaluate_using_acc(model, test_gen)
|
323 |
+
elif evaluate == 'weighted_f1':
|
324 |
+
return evaluate_using_weighted_f1(model, test_gen, val_gen)
|
325 |
+
|
326 |
+
|
327 |
+
def evaluate_using_weighted_f1(model, test_gen, val_gen):
|
328 |
+
""" Evaluation function using macro weighted F1 score.
|
329 |
+
|
330 |
+
# Arguments:
|
331 |
+
model: Model to be evaluated.
|
332 |
+
X_test: Inputs of the testing set.
|
333 |
+
y_test: Outputs of the testing set.
|
334 |
+
X_val: Inputs of the validation set.
|
335 |
+
y_val: Outputs of the validation set.
|
336 |
+
batch_size: Batch size.
|
337 |
+
|
338 |
+
# Returns:
|
339 |
+
Weighted F1 score of the given model.
|
340 |
+
"""
|
341 |
+
# Evaluate on test and val data
|
342 |
+
f1_test, _ = find_f1_threshold(model, test_gen, val_gen, average='weighted_f1')
|
343 |
+
return f1_test
|
344 |
+
|
345 |
+
|
346 |
+
def evaluate_using_acc(model, test_gen):
|
347 |
+
""" Evaluation function using accuracy.
|
348 |
+
|
349 |
+
# Arguments:
|
350 |
+
model: Model to be evaluated.
|
351 |
+
test_gen: Testing data iterator (DataLoader)
|
352 |
+
|
353 |
+
# Returns:
|
354 |
+
Accuracy of the given model.
|
355 |
+
"""
|
356 |
+
|
357 |
+
# Validate on test_data
|
358 |
+
model.eval()
|
359 |
+
correct_count = 0.0
|
360 |
+
total_y = sum(len(y) for _, y in test_gen)
|
361 |
+
for i, data in enumerate(test_gen):
|
362 |
+
x, y = data
|
363 |
+
outs = model(x)
|
364 |
+
pred = (outs >= 0).long()
|
365 |
+
added_counts = (pred == y).double().sum()
|
366 |
+
correct_count += added_counts
|
367 |
+
return correct_count/total_y
|
368 |
+
|
369 |
+
|
370 |
+
def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
|
371 |
+
patience=5, initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, evaluate='acc', verbose=1):
|
372 |
+
""" Finetunes given model using chain-thaw and evaluates using accuracy.
|
373 |
+
|
374 |
+
# Arguments:
|
375 |
+
model: Model to be finetuned.
|
376 |
+
train: Training data, given as a tuple of (inputs, outputs)
|
377 |
+
val: Validation data, given as a tuple of (inputs, outputs)
|
378 |
+
test: Testing data, given as a tuple of (inputs, outputs)
|
379 |
+
batch_size: Batch size.
|
380 |
+
loss: Loss function to be used during training.
|
381 |
+
epoch_size: Number of samples in an epoch.
|
382 |
+
nb_epochs: Number of epochs.
|
383 |
+
checkpoint_weight_path: Filepath where weights will be checkpointed to
|
384 |
+
during training. This file will be rewritten by the function.
|
385 |
+
initial_lr: Initial learning rate. Will only be used for the first
|
386 |
+
training step (i.e. the output_layer layer)
|
387 |
+
next_lr: Learning rate for every subsequent step.
|
388 |
+
seed: Random number generator seed.
|
389 |
+
verbose: Verbosity flag.
|
390 |
+
evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'.
|
391 |
+
|
392 |
+
# Returns:
|
393 |
+
Accuracy of the finetuned model.
|
394 |
+
"""
|
395 |
+
if verbose:
|
396 |
+
print('Training..')
|
397 |
+
|
398 |
+
# Train using chain-thaw
|
399 |
+
train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path,
|
400 |
+
initial_lr, next_lr, embed_l2, verbose)
|
401 |
+
|
402 |
+
if evaluate == 'acc':
|
403 |
+
return evaluate_using_acc(model, test_gen)
|
404 |
+
elif evaluate == 'weighted_f1':
|
405 |
+
return evaluate_using_weighted_f1(model, test_gen, val_gen)
|
406 |
+
|
407 |
+
|
408 |
+
def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path,
|
409 |
+
initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, verbose=1):
|
410 |
+
""" Finetunes model using the chain-thaw method.
|
411 |
+
|
412 |
+
This is done as follows:
|
413 |
+
1) Freeze every layer except the last (output_layer) layer and train it.
|
414 |
+
2) Freeze every layer except the first layer and train it.
|
415 |
+
3) Freeze every layer except the second etc., until the second last layer.
|
416 |
+
4) Unfreeze all layers and train entire model.
|
417 |
+
|
418 |
+
# Arguments:
|
419 |
+
model: Model to be trained.
|
420 |
+
train_gen: Training sample generator.
|
421 |
+
val_data: Validation data.
|
422 |
+
loss: Loss function to be used.
|
423 |
+
finetuning_args: Training early stopping and checkpoint saving parameters
|
424 |
+
epoch_size: Number of samples in an epoch.
|
425 |
+
nb_epochs: Number of epochs.
|
426 |
+
checkpoint_weight_path: Where weight checkpoints should be saved.
|
427 |
+
batch_size: Batch size.
|
428 |
+
initial_lr: Initial learning rate. Will only be used for the first
|
429 |
+
training step (i.e. the output_layer layer)
|
430 |
+
next_lr: Learning rate for every subsequent step.
|
431 |
+
verbose: Verbosity flag.
|
432 |
+
"""
|
433 |
+
# Get trainable layers
|
434 |
+
layers = [m for m in model.children() if len([id(p) for p in m.parameters()]) != 0]
|
435 |
+
|
436 |
+
# Bring last layer to front
|
437 |
+
layers.insert(0, layers.pop(len(layers) - 1))
|
438 |
+
|
439 |
+
# Add None to the end to signify finetuning all layers
|
440 |
+
layers.append(None)
|
441 |
+
|
442 |
+
lr = None
|
443 |
+
# Finetune each layer one by one and finetune all of them at once
|
444 |
+
# at the end
|
445 |
+
for layer in layers:
|
446 |
+
if lr is None:
|
447 |
+
lr = initial_lr
|
448 |
+
elif lr == initial_lr:
|
449 |
+
lr = next_lr
|
450 |
+
|
451 |
+
# Freeze all except current layer
|
452 |
+
for _layer in layers:
|
453 |
+
if _layer is not None:
|
454 |
+
trainable = _layer == layer or layer is None
|
455 |
+
change_trainable(_layer, trainable=trainable, verbose=False)
|
456 |
+
|
457 |
+
# Verify we froze the right layers
|
458 |
+
for _layer in model.children():
|
459 |
+
assert all(p.requires_grad == (_layer == layer) for p in _layer.parameters()) or layer is None
|
460 |
+
|
461 |
+
if verbose:
|
462 |
+
if layer is None:
|
463 |
+
print('Finetuning all layers')
|
464 |
+
else:
|
465 |
+
print('Finetuning {}'.format(layer))
|
466 |
+
|
467 |
+
special_params = [id(p) for p in model.embed.parameters()]
|
468 |
+
base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
|
469 |
+
embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
|
470 |
+
adam = optim.Adam([
|
471 |
+
{'params': base_params},
|
472 |
+
{'params': embed_parameters, 'weight_decay': embed_l2},
|
473 |
+
], lr=lr)
|
474 |
+
|
475 |
+
fit_model(model, loss_op, adam, train_gen, val_gen, nb_epochs,
|
476 |
+
checkpoint_path, patience)
|
477 |
+
|
478 |
+
# Reload the best weights found to avoid overfitting
|
479 |
+
# Wait a bit to allow proper closing of weights file
|
480 |
+
sleep(1)
|
481 |
+
model.load_state_dict(torch.load(checkpoint_path))
|
482 |
+
if verbose >= 2:
|
483 |
+
print("Loaded weights from {}".format(checkpoint_path))
|
484 |
+
|
485 |
+
def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
|
486 |
+
checkpoint_path, patience):
|
487 |
+
""" Analog to Keras fit_generator function.
|
488 |
+
|
489 |
+
# Arguments:
|
490 |
+
model: Model to be finetuned.
|
491 |
+
loss_op: loss operation (BCEWithLogitsLoss or CrossEntropy for e.g.)
|
492 |
+
optim_op: optimization operation (Adam e.g.)
|
493 |
+
train_gen: Training data iterator (DataLoader)
|
494 |
+
val_gen: Validation data iterator (DataLoader)
|
495 |
+
epochs: Number of epochs.
|
496 |
+
checkpoint_path: Filepath where weights will be checkpointed to
|
497 |
+
during training. This file will be rewritten by the function.
|
498 |
+
patience: Patience for callback methods.
|
499 |
+
verbose: Verbosity flag.
|
500 |
+
|
501 |
+
# Returns:
|
502 |
+
Accuracy of the trained model, ONLY if 'evaluate' is set.
|
503 |
+
"""
|
504 |
+
# Save original checkpoint
|
505 |
+
torch.save(model.state_dict(), checkpoint_path)
|
506 |
+
|
507 |
+
model.eval()
|
508 |
+
best_loss = np.mean([loss_op(model(Variable(xv)).squeeze(), Variable(yv.float()).squeeze()).data.cpu().numpy()[0] for xv, yv in val_gen])
|
509 |
+
print("original val loss", best_loss)
|
510 |
+
|
511 |
+
epoch_without_impr = 0
|
512 |
+
for epoch in range(epochs):
|
513 |
+
for i, data in enumerate(train_gen):
|
514 |
+
X_train, y_train = data
|
515 |
+
X_train = Variable(X_train, requires_grad=False)
|
516 |
+
y_train = Variable(y_train, requires_grad=False)
|
517 |
+
if torch.cuda.is_available():
|
518 |
+
X_train = X_train.cuda()
|
519 |
+
y_train = y_train.cuda()
|
520 |
+
model.train()
|
521 |
+
optim_op.zero_grad()
|
522 |
+
output = model(X_train)
|
523 |
+
loss = loss_op(output, y_train.float())
|
524 |
+
loss.backward()
|
525 |
+
clip_grad_norm(model.parameters(), 1)
|
526 |
+
optim_op.step()
|
527 |
+
|
528 |
+
acc = evaluate_using_acc(model, [(X_train.data, y_train.data)])
|
529 |
+
print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy()[0], "train acc", acc)
|
530 |
+
|
531 |
+
model.eval()
|
532 |
+
acc = evaluate_using_acc(model, val_gen)
|
533 |
+
print("val acc", acc)
|
534 |
+
|
535 |
+
val_loss = np.mean([loss_op(model(Variable(xv)).squeeze(), Variable(yv.float()).squeeze()).data.cpu().numpy()[0] for xv, yv in val_gen])
|
536 |
+
print("val loss", val_loss)
|
537 |
+
if best_loss is not None and val_loss >= best_loss:
|
538 |
+
epoch_without_impr += 1
|
539 |
+
print('No improvement over previous best loss: ', best_loss)
|
540 |
+
|
541 |
+
# Save checkpoint
|
542 |
+
if best_loss is None or val_loss < best_loss:
|
543 |
+
best_loss = val_loss
|
544 |
+
torch.save(model.state_dict(), checkpoint_path)
|
545 |
+
print('Saving model at', checkpoint_path)
|
546 |
+
|
547 |
+
# Early stopping
|
548 |
+
if epoch_without_impr >= patience:
|
549 |
+
break
|
550 |
+
|
551 |
+
def get_data_loader(X_in, y_in, batch_size, extended_batch_sampler=True, epoch_size=25000, upsample=False, seed=42):
|
552 |
+
""" Returns a dataloader that enables larger epochs on small datasets and
|
553 |
+
has upsampling functionality.
|
554 |
+
|
555 |
+
# Arguments:
|
556 |
+
X_in: Inputs of the given dataset.
|
557 |
+
y_in: Outputs of the given dataset.
|
558 |
+
batch_size: Batch size.
|
559 |
+
epoch_size: Number of samples in an epoch.
|
560 |
+
upsample: Whether upsampling should be done. This flag should only be
|
561 |
+
set on binary class problems.
|
562 |
+
|
563 |
+
# Returns:
|
564 |
+
DataLoader.
|
565 |
+
"""
|
566 |
+
dataset = DeepMojiDataset(X_in, y_in)
|
567 |
+
|
568 |
+
if extended_batch_sampler:
|
569 |
+
batch_sampler = DeepMojiBatchSampler(y_in, batch_size, epoch_size=epoch_size, upsample=upsample, seed=seed)
|
570 |
+
else:
|
571 |
+
batch_sampler = BatchSampler(SequentialSampler(y_in), batch_size, drop_last=False)
|
572 |
+
|
573 |
+
return DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0)
|
574 |
+
|
575 |
+
class DeepMojiDataset(Dataset):
|
576 |
+
""" A simple Dataset class.
|
577 |
+
|
578 |
+
# Arguments:
|
579 |
+
X_in: Inputs of the given dataset.
|
580 |
+
y_in: Outputs of the given dataset.
|
581 |
+
|
582 |
+
# __getitem__ output:
|
583 |
+
(torch.LongTensor, torch.LongTensor)
|
584 |
+
"""
|
585 |
+
def __init__(self, X_in, y_in):
|
586 |
+
# Check if we have Torch.LongTensor inputs (assume Numpy array otherwise)
|
587 |
+
if not isinstance(X_in, torch.LongTensor):
|
588 |
+
X_in = torch.from_numpy(X_in.astype('int64')).long()
|
589 |
+
if not isinstance(y_in, torch.LongTensor):
|
590 |
+
y_in = torch.from_numpy(y_in.astype('int64')).long()
|
591 |
+
|
592 |
+
self.X_in = torch.split(X_in, 1, dim=0)
|
593 |
+
self.y_in = torch.split(y_in, 1, dim=0)
|
594 |
+
|
595 |
+
def __len__(self):
|
596 |
+
return len(self.X_in)
|
597 |
+
|
598 |
+
def __getitem__(self, idx):
|
599 |
+
return self.X_in[idx].squeeze(), self.y_in[idx].squeeze()
|
600 |
+
|
601 |
+
class DeepMojiBatchSampler(object):
|
602 |
+
"""A Batch sampler that enables larger epochs on small datasets and
|
603 |
+
has upsampling functionality.
|
604 |
+
|
605 |
+
# Arguments:
|
606 |
+
y_in: Labels of the dataset.
|
607 |
+
batch_size: Batch size.
|
608 |
+
epoch_size: Number of samples in an epoch.
|
609 |
+
upsample: Whether upsampling should be done. This flag should only be
|
610 |
+
set on binary class problems.
|
611 |
+
seed: Random number generator seed.
|
612 |
+
|
613 |
+
# __iter__ output:
|
614 |
+
iterator of lists (batches) of indices in the dataset
|
615 |
+
"""
|
616 |
+
|
617 |
+
def __init__(self, y_in, batch_size, epoch_size, upsample, seed):
|
618 |
+
self.batch_size = batch_size
|
619 |
+
self.epoch_size = epoch_size
|
620 |
+
self.upsample = upsample
|
621 |
+
|
622 |
+
np.random.seed(seed)
|
623 |
+
|
624 |
+
if upsample:
|
625 |
+
# Should only be used on binary class problems
|
626 |
+
assert len(y_in.shape) == 1
|
627 |
+
neg = np.where(y_in.numpy() == 0)[0]
|
628 |
+
pos = np.where(y_in.numpy() == 1)[0]
|
629 |
+
assert epoch_size % 2 == 0
|
630 |
+
samples_pr_class = int(epoch_size / 2)
|
631 |
+
else:
|
632 |
+
ind = range(len(y_in))
|
633 |
+
|
634 |
+
if not upsample:
|
635 |
+
# Randomly sample observations in a balanced way
|
636 |
+
self.sample_ind = np.random.choice(ind, epoch_size, replace=True)
|
637 |
+
else:
|
638 |
+
# Randomly sample observations in a balanced way
|
639 |
+
sample_neg = np.random.choice(neg, samples_pr_class, replace=True)
|
640 |
+
sample_pos = np.random.choice(pos, samples_pr_class, replace=True)
|
641 |
+
concat_ind = np.concatenate((sample_neg, sample_pos), axis=0)
|
642 |
+
|
643 |
+
# Shuffle to avoid labels being in specific order
|
644 |
+
# (all negative then positive)
|
645 |
+
p = np.random.permutation(len(concat_ind))
|
646 |
+
self.sample_ind = concat_ind[p]
|
647 |
+
|
648 |
+
label_dist = np.mean(y_in.numpy()[self.sample_ind])
|
649 |
+
assert(label_dist > 0.45)
|
650 |
+
assert(label_dist < 0.55)
|
651 |
+
|
652 |
+
def __iter__(self):
|
653 |
+
# Hand-off data using batch_size
|
654 |
+
for i in range(int(self.epoch_size/self.batch_size)):
|
655 |
+
start = i * self.batch_size
|
656 |
+
end = min(start + self.batch_size, self.epoch_size)
|
657 |
+
yield self.sample_ind[start:end]
|
658 |
+
|
659 |
+
def __len__(self):
|
660 |
+
# Take care of the last (maybe incomplete) batch
|
661 |
+
return (self.epoch_size + self.batch_size - 1) // self.batch_size
|