|
|
|
"""基本方法 |
|
创建中文数字系统 方法 |
|
中文字符串 <=> 数字串 方法 |
|
数字串 <=> 中文字符串 方法 |
|
""" |
|
|
|
__author__ = "Zhiyang Zhou <[email protected]>" |
|
__data__ = "2019-05-02" |
|
|
|
from fish_speech.text.chn_text_norm.basic_class import * |
|
from fish_speech.text.chn_text_norm.basic_constant import * |
|
|
|
|
|
def create_system(numbering_type=NUMBERING_TYPES[1]): |
|
""" |
|
根据数字系统类型返回创建相应的数字系统,默认为 mid |
|
NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 |
|
low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. |
|
mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. |
|
high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. |
|
返回对应的数字系统 |
|
""" |
|
|
|
|
|
all_larger_units = zip( |
|
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, |
|
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL, |
|
) |
|
larger_units = [ |
|
CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units) |
|
] |
|
|
|
all_smaller_units = zip( |
|
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, |
|
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL, |
|
) |
|
smaller_units = [ |
|
CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units) |
|
] |
|
|
|
chinese_digis = zip( |
|
CHINESE_DIGIS, |
|
CHINESE_DIGIS, |
|
BIG_CHINESE_DIGIS_SIMPLIFIED, |
|
BIG_CHINESE_DIGIS_TRADITIONAL, |
|
) |
|
digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] |
|
digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT |
|
digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT |
|
digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] |
|
|
|
|
|
positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) |
|
negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) |
|
point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) |
|
|
|
system = NumberSystem() |
|
system.units = smaller_units + larger_units |
|
system.digits = digits |
|
system.math = MathSymbol(positive_cn, negative_cn, point_cn) |
|
|
|
return system |
|
|
|
|
|
def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): |
|
|
|
def get_symbol(char, system): |
|
for u in system.units: |
|
if char in [u.traditional, u.simplified, u.big_s, u.big_t]: |
|
return u |
|
for d in system.digits: |
|
if char in [ |
|
d.traditional, |
|
d.simplified, |
|
d.big_s, |
|
d.big_t, |
|
d.alt_s, |
|
d.alt_t, |
|
]: |
|
return d |
|
for m in system.math: |
|
if char in [m.traditional, m.simplified]: |
|
return m |
|
|
|
def string2symbols(chinese_string, system): |
|
int_string, dec_string = chinese_string, "" |
|
for p in [system.math.point.simplified, system.math.point.traditional]: |
|
if p in chinese_string: |
|
int_string, dec_string = chinese_string.split(p) |
|
break |
|
return [get_symbol(c, system) for c in int_string], [ |
|
get_symbol(c, system) for c in dec_string |
|
] |
|
|
|
def correct_symbols(integer_symbols, system): |
|
""" |
|
一百八 to 一百八十 |
|
一亿一千三百万 to 一亿 一千万 三百万 |
|
""" |
|
|
|
if integer_symbols and isinstance(integer_symbols[0], CNU): |
|
if integer_symbols[0].power == 1: |
|
integer_symbols = [system.digits[1]] + integer_symbols |
|
|
|
if len(integer_symbols) > 1: |
|
if isinstance(integer_symbols[-1], CND) and isinstance( |
|
integer_symbols[-2], CNU |
|
): |
|
integer_symbols.append( |
|
CNU(integer_symbols[-2].power - 1, None, None, None, None) |
|
) |
|
|
|
result = [] |
|
unit_count = 0 |
|
for s in integer_symbols: |
|
if isinstance(s, CND): |
|
result.append(s) |
|
unit_count = 0 |
|
elif isinstance(s, CNU): |
|
current_unit = CNU(s.power, None, None, None, None) |
|
unit_count += 1 |
|
|
|
if unit_count == 1: |
|
result.append(current_unit) |
|
elif unit_count > 1: |
|
for i in range(len(result)): |
|
if ( |
|
isinstance(result[-i - 1], CNU) |
|
and result[-i - 1].power < current_unit.power |
|
): |
|
result[-i - 1] = CNU( |
|
result[-i - 1].power + current_unit.power, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
return result |
|
|
|
def compute_value(integer_symbols): |
|
""" |
|
Compute the value. |
|
When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. |
|
e.g. '两千万' = 2000 * 10000 not 2000 + 10000 |
|
""" |
|
value = [0] |
|
last_power = 0 |
|
for s in integer_symbols: |
|
if isinstance(s, CND): |
|
value[-1] = s.value |
|
elif isinstance(s, CNU): |
|
value[-1] *= pow(10, s.power) |
|
if s.power > last_power: |
|
value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) |
|
last_power = s.power |
|
value.append(0) |
|
return sum(value) |
|
|
|
system = create_system(numbering_type) |
|
int_part, dec_part = string2symbols(chinese_string, system) |
|
int_part = correct_symbols(int_part, system) |
|
int_str = str(compute_value(int_part)) |
|
dec_str = "".join([str(d.value) for d in dec_part]) |
|
if dec_part: |
|
return "{0}.{1}".format(int_str, dec_str) |
|
else: |
|
return int_str |
|
|
|
|
|
def num2chn( |
|
number_string, |
|
numbering_type=NUMBERING_TYPES[1], |
|
big=False, |
|
traditional=False, |
|
alt_zero=False, |
|
alt_one=False, |
|
alt_two=True, |
|
use_zeros=True, |
|
use_units=True, |
|
): |
|
|
|
def get_value(value_string, use_zeros=True): |
|
|
|
striped_string = value_string.lstrip("0") |
|
|
|
|
|
if not striped_string: |
|
return [] |
|
|
|
|
|
elif len(striped_string) == 1: |
|
if use_zeros and len(value_string) != len(striped_string): |
|
return [system.digits[0], system.digits[int(striped_string)]] |
|
else: |
|
return [system.digits[int(striped_string)]] |
|
|
|
|
|
else: |
|
result_unit = next( |
|
u for u in reversed(system.units) if u.power < len(striped_string) |
|
) |
|
result_string = value_string[: -result_unit.power] |
|
return ( |
|
get_value(result_string) |
|
+ [result_unit] |
|
+ get_value(striped_string[-result_unit.power :]) |
|
) |
|
|
|
system = create_system(numbering_type) |
|
|
|
int_dec = number_string.split(".") |
|
if len(int_dec) == 1: |
|
int_string = int_dec[0] |
|
dec_string = "" |
|
elif len(int_dec) == 2: |
|
int_string = int_dec[0] |
|
dec_string = int_dec[1] |
|
else: |
|
raise ValueError( |
|
"invalid input num string with more than one dot: {}".format(number_string) |
|
) |
|
|
|
if use_units and len(int_string) > 1: |
|
result_symbols = get_value(int_string) |
|
else: |
|
result_symbols = [system.digits[int(c)] for c in int_string] |
|
dec_symbols = [system.digits[int(c)] for c in dec_string] |
|
if dec_string: |
|
result_symbols += [system.math.point] + dec_symbols |
|
|
|
if alt_two: |
|
liang = CND( |
|
2, |
|
system.digits[2].alt_s, |
|
system.digits[2].alt_t, |
|
system.digits[2].big_s, |
|
system.digits[2].big_t, |
|
) |
|
for i, v in enumerate(result_symbols): |
|
if isinstance(v, CND) and v.value == 2: |
|
next_symbol = ( |
|
result_symbols[i + 1] if i < len(result_symbols) - 1 else None |
|
) |
|
previous_symbol = result_symbols[i - 1] if i > 0 else None |
|
if isinstance(next_symbol, CNU) and isinstance( |
|
previous_symbol, (CNU, type(None)) |
|
): |
|
if next_symbol.power != 1 and ( |
|
(previous_symbol is None) or (previous_symbol.power != 1) |
|
): |
|
result_symbols[i] = liang |
|
|
|
|
|
if big: |
|
attr_name = "big_" |
|
if traditional: |
|
attr_name += "t" |
|
else: |
|
attr_name += "s" |
|
else: |
|
if traditional: |
|
attr_name = "traditional" |
|
else: |
|
attr_name = "simplified" |
|
|
|
result = "".join([getattr(s, attr_name) for s in result_symbols]) |
|
|
|
|
|
|
|
|
|
if alt_zero: |
|
result = result.replace( |
|
getattr(system.digits[0], attr_name), system.digits[0].alt_s |
|
) |
|
|
|
if alt_one: |
|
result = result.replace( |
|
getattr(system.digits[1], attr_name), system.digits[1].alt_s |
|
) |
|
|
|
for i, p in enumerate(POINT): |
|
if result.startswith(p): |
|
return CHINESE_DIGIS[0] + result |
|
|
|
|
|
if ( |
|
len(result) >= 2 |
|
and result[1] |
|
in [ |
|
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], |
|
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0], |
|
] |
|
and result[0] |
|
in [ |
|
CHINESE_DIGIS[1], |
|
BIG_CHINESE_DIGIS_SIMPLIFIED[1], |
|
BIG_CHINESE_DIGIS_TRADITIONAL[1], |
|
] |
|
): |
|
result = result[1:] |
|
|
|
return result |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
all_chinese_number_string = ( |
|
CHINESE_DIGIS |
|
+ BIG_CHINESE_DIGIS_SIMPLIFIED |
|
+ BIG_CHINESE_DIGIS_TRADITIONAL |
|
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED |
|
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL |
|
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED |
|
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL |
|
+ ZERO_ALT |
|
+ ONE_ALT |
|
+ "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT) |
|
) |
|
|
|
print("num:", chn2num("一万零四百零三点八零五")) |
|
print("num:", chn2num("一亿六点三")) |
|
print("num:", chn2num("一亿零六点三")) |
|
print("num:", chn2num("两千零一亿六点三")) |
|
|
|
print("txt:", num2chn("10260.03", alt_zero=True)) |
|
print("txt:", num2chn("20037.090", numbering_type="low", traditional=True)) |
|
print("txt:", num2chn("100860001.77", numbering_type="high", big=True)) |
|
print( |
|
"txt:", |
|
num2chn( |
|
"059523810880", |
|
alt_one=True, |
|
alt_two=False, |
|
use_lzeros=True, |
|
use_rzeros=True, |
|
use_units=False, |
|
), |
|
) |
|
|
|
print(all_chinese_number_string) |
|
|