File size: 2,083 Bytes
eac1a45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# -*- coding:utf-8 -*-
# @FileName  :ctt-punctuator.py
# @Time      :2023/4/13 15:03
# @Author    :lovemefan
# @Email     :[email protected]


__author__ = "lovemefan"
__copyright__ = "Copyright (C) 2023 lovemefan"
__license__ = "MIT"
__version__ = "v0.0.1"

import logging
import threading

from cttpunctuator.src.punctuator import CT_Transformer, CT_Transformer_VadRealtime

logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s",
)

lock = threading.RLock()


class CttPunctuator:
    _offline_model = None
    _online_model = None

    def __init__(self, online: bool = False):
        """

        punctuator with singleton pattern

        :param online:

        """
        self.online = online

        if online:
            if CttPunctuator._online_model is None:
                with lock:
                    if CttPunctuator._online_model is None:
                        logging.info("Initializing punctuator model with online mode.")
                        CttPunctuator._online_model = CT_Transformer_VadRealtime()
                        self.param_dict = {"cache": []}
                        logging.info("Online model initialized.")
            self.model = CttPunctuator._online_model

        else:
            if CttPunctuator._offline_model is None:
                with lock:
                    if CttPunctuator._offline_model is None:
                        logging.info("Initializing punctuator model with offline mode.")
                        CttPunctuator._offline_model = CT_Transformer()
                        logging.info("Offline model initialized.")
            self.model = CttPunctuator._offline_model

        logging.info("Model initialized.")

    def punctuate(self, text: str, param_dict=None):
        if self.online:
            param_dict = param_dict or self.param_dict
            return self.model(text, self.param_dict)
        else:
            return self.model(text)