File size: 3,552 Bytes
8a8d449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import logging
from unimernet.common.registry import registry
from unimernet.datasets.builders.base_dataset_builder import BaseDatasetBuilder
from unimernet.datasets.datasets.formula import Im2LatexDataset
from unimernet.datasets.datasets.formula_multi_scale import MultiScaleIm2LatexDataset


@registry.register_builder("formula_rec_train")
class FormulaRecTrainBuilder(BaseDatasetBuilder):
    train_dataset_cls = Im2LatexDataset
    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/formula/formula_train.yaml"
    }
    LOG_INFO = "Formula Recgnition Train"

    def build_datasets(self):
        logging.info(f"Building {self.LOG_INFO} datasets ...")
        self.build_processors()

        build_info = self.config.build_info
        anno_path = build_info.annotation,
        vis_root = build_info.images
        anno_path = [anno_path] if isinstance(anno_path, str) else anno_path
        vis_root = [vis_root] if isinstance(vis_root, str) else vis_root
        datasets = dict()

        # create datasets
        dataset_cls = self.train_dataset_cls
        datasets['train'] = dataset_cls(
            vis_processor=self.vis_processors["train"],
            text_processor=self.text_processors["train"],
            vis_root=vis_root,
            anno_path=anno_path,
        )
        print(datasets['train'][0])

        return datasets


@registry.register_builder("multi_scale_formula_rec_train")
class MultiScaleFormulaRecTrainBuilder(BaseDatasetBuilder):
    train_dataset_cls = MultiScaleIm2LatexDataset
    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/formula/multi_scale_formula_train.yaml"
    }
    LOG_INFO = "Multi Scale Formula Recgnition Train"

    def build_datasets(self):
        logging.info(f"Building {self.LOG_INFO} datasets ...")
        self.build_processors()

        build_info = self.config.build_info
        anno_path = build_info.annotation,
        vis_root = build_info.images

        anno_path = [anno_path] if isinstance(anno_path, str) else anno_path
        vis_root = [vis_root] if isinstance(vis_root, str) else vis_root

        datasets = dict()

        # create datasets
        dataset_cls = self.train_dataset_cls
        datasets['train'] = dataset_cls(
            vis_processor=self.vis_processors["train"],
            text_processor=self.text_processors["train"],
            vis_root=vis_root,
            anno_path=anno_path,
        )
        print(datasets['train'][0])

        return datasets


@registry.register_builder("formula_rec_eval")
class FormulaRecEvalBuilder(BaseDatasetBuilder):
    eval_dataset_cls = Im2LatexDataset
    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/formula/formula_eval.yaml"
    }
    LOG_INFO = "Formula Recgnition Eval"

    def build_datasets(self):
        logging.info(f"Building {self.LOG_INFO} datasets ...")
        self.build_processors()

        build_info = self.config.build_info
        anno_path = build_info.annotation,
        vis_root = build_info.images

        anno_path = [anno_path] if isinstance(anno_path, str) else anno_path
        vis_root = [vis_root] if isinstance(vis_root, str) else vis_root

        datasets = dict()

        # create datasets
        dataset_cls = self.eval_dataset_cls
        datasets['eval'] = dataset_cls(
            vis_processor=self.vis_processors["eval"],
            text_processor=self.text_processors["eval"],
            vis_root=vis_root,
            anno_path=anno_path,
        )
        print(datasets['eval'][0])

        return datasets