|
import numpy as np
|
|
import json
|
|
|
|
class LabelEncoder(object):
|
|
"""Label encoder for tag labels."""
|
|
def __init__(self, class_to_index={}):
|
|
self.class_to_index = class_to_index
|
|
self.index_to_class = {v: k for k, v in self.class_to_index.items()}
|
|
self.classes = list(self.class_to_index.keys())
|
|
|
|
def __len__(self):
|
|
return len(self.class_to_index)
|
|
|
|
def __str__(self):
|
|
return f"<LabelEncoder(num_classes={len(self)})>"
|
|
|
|
def fit(self, y):
|
|
classes = np.unique(y)
|
|
for i, class_ in enumerate(classes):
|
|
self.class_to_index[class_] = i
|
|
self.index_to_class = {v: k for k, v in self.class_to_index.items()}
|
|
self.classes = list(self.class_to_index.keys())
|
|
return self
|
|
|
|
def encode(self, y):
|
|
encoded = np.zeros((len(y)), dtype=int)
|
|
for i, item in enumerate(y):
|
|
encoded[i] = self.class_to_index[item]
|
|
return encoded
|
|
|
|
def decode(self, y):
|
|
classes = []
|
|
for i, item in enumerate(y):
|
|
classes.append(self.index_to_class[item])
|
|
return classes
|
|
|
|
def save(self, fp):
|
|
with open(fp, "w") as fp:
|
|
contents = {'class_to_index': self.class_to_index}
|
|
json.dump(contents, fp, indent=4, sort_keys=False)
|
|
|
|
@classmethod
|
|
def load(cls, fp):
|
|
with open(fp, "r") as fp:
|
|
kwargs = json.load(fp=fp)
|
|
return cls(**kwargs) |