File size: 28,575 Bytes
38f0a43
477195e
 
38f0a43
 
d8edfa5
477195e
 
c7362aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477195e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf02fb0
477195e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf02fb0
477195e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7362aa
477195e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7362aa
 
 
477195e
 
 
cf02fb0
477195e
 
 
 
 
c7362aa
 
 
 
 
 
477195e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7362aa
 
477195e
cf02fb0
477195e
cf02fb0
 
 
 
 
 
 
 
 
477195e
 
 
c7362aa
477195e
 
 
c7362aa
 
 
477195e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf02fb0
477195e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38f0a43
 
477195e
 
 
 
 
 
c7362aa
 
477195e
38f0a43
477195e
 
 
 
 
 
 
 
 
 
 
 
 
cf02fb0
477195e
 
 
 
 
 
c7362aa
 
 
 
 
 
 
 
 
 
477195e
 
 
cf02fb0
477195e
 
 
 
 
c7362aa
 
 
 
 
 
477195e
 
 
 
 
 
 
c7362aa
477195e
cf02fb0
477195e
 
 
cf02fb0
477195e
 
 
 
 
 
c7362aa
477195e
 
 
 
 
cf02fb0
477195e
 
 
 
d8edfa5
c7362aa
 
d8edfa5
c7362aa
477195e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf02fb0
 
477195e
 
 
 
 
 
 
 
 
d8edfa5
477195e
 
 
 
 
 
 
cf02fb0
 
477195e
 
 
 
 
 
 
 
c7362aa
477195e
 
 
 
 
 
 
 
 
 
 
 
 
cf02fb0
477195e
 
 
 
 
 
 
 
c7362aa
 
477195e
 
 
 
c7362aa
477195e
 
 
cf02fb0
477195e
 
 
c7362aa
 
477195e
 
 
 
 
 
 
 
cf02fb0
477195e
 
 
 
 
 
 
 
c7362aa
 
 
 
cf02fb0
c7362aa
 
d8edfa5
477195e
 
c7362aa
 
477195e
 
 
 
 
c7362aa
477195e
 
cf02fb0
477195e
 
 
d8edfa5
477195e
 
c7362aa
 
 
 
 
477195e
c7362aa
477195e
 
 
 
 
c7362aa
 
 
477195e
 
 
 
 
 
 
 
 
 
 
 
cf02fb0
477195e
 
d8edfa5
477195e
 
 
 
 
 
 
d8edfa5
 
 
 
 
 
 
c7362aa
 
d8edfa5
477195e
 
 
 
 
c7362aa
 
477195e
c7362aa
 
 
 
 
 
 
 
 
 
477195e
c7362aa
 
 
 
 
477195e
 
 
 
 
 
 
 
 
 
 
 
c7362aa
477195e
 
c7362aa
 
 
 
477195e
c7362aa
477195e
 
 
c7362aa
477195e
c7362aa
 
477195e
c7362aa
477195e
c7362aa
477195e
 
 
c7362aa
477195e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7362aa
 
477195e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7362aa
 
 
 
 
477195e
 
 
 
 
 
 
 
 
c7362aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
import math
from dataclasses import dataclass
import numpy as np
import torch
from torch import nn
from transformers.modeling_outputs import BaseModelOutput, ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
import json
import os
import re
from typing import Any, Dict, List, Optional, Tuple
from transformers.tokenization_utils import PreTrainedTokenizer
import phonemizer
import uroman as ur
import torch.nn.functional as F


def has_non_roman_characters(input_string):
    # Find any character outside the ASCII range
    non_roman_pattern = re.compile(r"[^\x00-\x7F]")

    # Search the input string for non-Roman characters
    match = non_roman_pattern.search(input_string)
    has_non_roman = match is not None
    return has_non_roman


class VitsConfig(PretrainedConfig):

    model_type = "vits"

    def __init__(
        self,
        vocab_size=38,
        hidden_size=192,
        num_hidden_layers=6,
        num_attention_heads=2,
        window_size=4,
        use_bias=True,
        ffn_dim=768,
        layerdrop=0.1,
        ffn_kernel_size=3,
        flow_size=192,
        spectrogram_bins=513,
        # hidden_act="relu",
        hidden_dropout=0.1,
        attention_dropout=0.1,
        activation_dropout=0.1,
        initializer_range=0.02,
        layer_norm_eps=1e-5,
        use_stochastic_duration_prediction=True,
        num_speakers=1,
        speaker_embedding_size=0,
        upsample_initial_channel=512,
        upsample_rates=[8, 8, 2, 2],
        upsample_kernel_sizes=[16, 16, 4, 4],
        resblock_kernel_sizes=[3, 7, 11],
        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
        leaky_relu_slope=0.1,
        depth_separable_channels=2,
        depth_separable_num_layers=3,
        duration_predictor_flow_bins=10,
        duration_predictor_tail_bound=5.0,
        duration_predictor_kernel_size=3,
        duration_predictor_dropout=0.5,
        duration_predictor_num_flows=4,
        duration_predictor_filter_channels=256,
        prior_encoder_num_flows=4,
        prior_encoder_num_wavenet_layers=4,
        posterior_encoder_num_wavenet_layers=16,
        wavenet_kernel_size=5,
        wavenet_dilation_rate=1,
        wavenet_dropout=0.0,
        speaking_rate=1.0,  # unused
        noise_scale=0.667,
        noise_scale_duration=0.8,
        sampling_rate=16_000,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.window_size = window_size
        self.use_bias = use_bias
        self.ffn_dim = ffn_dim
        self.layerdrop = layerdrop
        self.ffn_kernel_size = ffn_kernel_size
        self.flow_size = flow_size
        self.spectrogram_bins = spectrogram_bins
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        # self.use_stochastic_duration_prediction = use_stochastic_duration_prediction
        self.num_speakers = num_speakers
        self.speaker_embedding_size = speaker_embedding_size
        self.upsample_initial_channel = upsample_initial_channel
        self.upsample_rates = upsample_rates
        self.upsample_kernel_sizes = upsample_kernel_sizes
        self.resblock_kernel_sizes = resblock_kernel_sizes
        self.resblock_dilation_sizes = resblock_dilation_sizes
        self.leaky_relu_slope = leaky_relu_slope
        self.depth_separable_channels = depth_separable_channels
        self.depth_separable_num_layers = depth_separable_num_layers
        self.duration_predictor_flow_bins = duration_predictor_flow_bins
        self.duration_predictor_tail_bound = duration_predictor_tail_bound
        self.duration_predictor_kernel_size = duration_predictor_kernel_size
        self.duration_predictor_num_flows = duration_predictor_num_flows
        self.duration_predictor_filter_channels = duration_predictor_filter_channels
        self.prior_encoder_num_flows = prior_encoder_num_flows
        self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers
        self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers
        self.wavenet_kernel_size = wavenet_kernel_size
        self.wavenet_dilation_rate = wavenet_dilation_rate
        self.noise_scale = noise_scale
        self.noise_scale_duration = noise_scale_duration
        self.sampling_rate = sampling_rate

        if len(upsample_kernel_sizes) != len(upsample_rates):
            raise ValueError(
                f"The length of `upsample_kernel_sizes` ({len(upsample_kernel_sizes)}) must match the length of "
                f"`upsample_rates` ({len(upsample_rates)})"
            )

        super().__init__(**kwargs)



@dataclass
class VitsTextEncoderOutput(ModelOutput):
    last_hidden_state: torch.FloatTensor = None
    prior_means: torch.FloatTensor = None
    prior_log_variances: torch.FloatTensor = None
    hidden_states: torch.FloatTensor = None
    attentions: torch.FloatTensor = None



class VitsWaveNet(torch.nn.Module):
    def __init__(self, config, num_layers):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_layers = num_layers
        self.in_layers = torch.nn.ModuleList()
        self.res_skip_layers = torch.nn.ModuleList()
        # if hasattr(nn.utils.parametrizations, "weight_norm"):
        #     # raise ValueError
        weight_norm = nn.utils.parametrizations.weight_norm
        # else:
        #     raise ValueError
        #     # weight_norm = nn.utils.weight_norm
        for i in range(num_layers):
            dilation = config.wavenet_dilation_rate**i
            padding = (config.wavenet_kernel_size * dilation - dilation) // 2
            in_layer = torch.nn.Conv1d(
                in_channels=config.hidden_size,
                out_channels=2 * config.hidden_size,
                kernel_size=config.wavenet_kernel_size,
                dilation=dilation,
                padding=padding,
            )
            in_layer = weight_norm(in_layer, name="weight")
            self.in_layers.append(in_layer)

            # last one is not necessary
            if i < num_layers - 1:
                res_skip_channels = 2 * config.hidden_size
            else:
                res_skip_channels = config.hidden_size
            res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
            res_skip_layer = weight_norm(res_skip_layer, name="weight")
            self.res_skip_layers.append(res_skip_layer)

    def forward(self,
                inputs):
        outputs = torch.zeros_like(inputs)
        num_channels = torch.IntTensor([self.hidden_size])[0]
        for i in range(self.num_layers):
            in_act = self.in_layers[i](inputs)
            # global_states = torch.zeros_like(hidden_states)  # style ?
            # acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
            # --
            # def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
            # in_act = input_a #  + input_b
            t_act = torch.tanh(in_act[:, :num_channels, :])
            s_act = torch.sigmoid(in_act[:, num_channels:, :])
            acts = t_act * s_act
            res_skip_acts = self.res_skip_layers[i](acts)
            if i < self.num_layers - 1:
                res_acts = res_skip_acts[:, : self.hidden_size, :]
                inputs = inputs + res_acts
                outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
            else:
                outputs = outputs + res_skip_acts
        return outputs









# Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
class HifiGanResidualBlock(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
        super().__init__()
        self.leaky_relu_slope = leaky_relu_slope

        self.convs1 = nn.ModuleList(
            [
                nn.Conv1d(
                    channels,
                    channels,
                    kernel_size,
                    stride=1,
                    dilation=dilation[i],
                    padding=self.get_padding(kernel_size, dilation[i]),
                )
                for i in range(len(dilation))
            ]
        )
        self.convs2 = nn.ModuleList(
            [
                nn.Conv1d(
                    channels,
                    channels,
                    kernel_size,
                    stride=1,
                    dilation=1,
                    padding=self.get_padding(kernel_size, 1),
                )
                for _ in range(len(dilation))
            ]
        )

    def get_padding(self, kernel_size, dilation=1):
        return (kernel_size * dilation - dilation) // 2

    def forward(self, hidden_states):
        for conv1, conv2 in zip(self.convs1, self.convs2):
            residual = hidden_states
            hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
            hidden_states = conv1(hidden_states)
            hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
            hidden_states = conv2(hidden_states)
            hidden_states = hidden_states + residual
        return hidden_states


class VitsHifiGan(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_kernels = len(config.resblock_kernel_sizes)
        self.num_upsamples = len(config.upsample_rates)
        self.conv_pre = nn.Conv1d(
            config.flow_size,
            config.upsample_initial_channel,
            kernel_size=7,
            stride=1,
            padding=3,
        )

        self.upsampler = nn.ModuleList()
        for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
            self.upsampler.append(
                nn.ConvTranspose1d(
                    config.upsample_initial_channel // (2**i),
                    config.upsample_initial_channel // (2 ** (i + 1)),
                    kernel_size=kernel_size,
                    stride=upsample_rate,
                    padding=(kernel_size - upsample_rate) // 2,
                )
            )

        self.resblocks = nn.ModuleList()
        for i in range(len(self.upsampler)):
            channels = config.upsample_initial_channel // (2 ** (i + 1))
            for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
                self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
        self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)

    def forward(self,
                spectrogram):
        hidden_states = self.conv_pre(spectrogram)
        for i in range(self.num_upsamples):
            hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
            hidden_states = self.upsampler[i](hidden_states)
            res_state = self.resblocks[i * self.num_kernels](hidden_states)
            for j in range(1, self.num_kernels):
                res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
            hidden_states = res_state / self.num_kernels
        hidden_states = nn.functional.leaky_relu(hidden_states)
        hidden_states = self.conv_post(hidden_states)
        waveform = torch.tanh(hidden_states)
        return waveform


class VitsResidualCouplingLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.half_channels = config.flow_size // 2
        self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
        self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
        self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)

    def forward(self,
                x,
                reverse=False):
        first_half, second_half = torch.split(x, [self.half_channels] * 2, dim=1)
        hidden_states = self.conv_pre(first_half)
        hidden_states = self.wavenet(hidden_states)
        mean = self.conv_post(hidden_states)
        second_half = (second_half - mean)
        outputs = torch.cat([first_half, second_half], dim=1)
        return outputs


class VitsResidualCouplingBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.flows = nn.ModuleList()
        for _ in range(config.prior_encoder_num_flows):
            self.flows.append(VitsResidualCouplingLayer(config))

    def forward(self, x, reverse=False):
        # x L [1, 192, 481]
        for flow in reversed(self.flows):
            x = torch.flip(x, [1])  # flipud CHANNELs
            x = flow(x, reverse=True)
        return x






class VitsAttention(nn.Module):
    """has no positional info"""

    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        
        self.window_size = config.window_size

        self.head_dim = self.embed_dim // self.num_heads
        self.scaling = self.head_dim**-0.5

        if (self.head_dim * self.num_heads) != self.embed_dim:
            raise ValueError
        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)

    def _shape(self, tensor, seq_len, bsz):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states,
        layer_head_mask = None,
        output_attentions = False,
    ):
        

        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scaling

        # self_attention
        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.view(*proj_shape)
        value_states = value_states.view(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
        attn_output = torch.bmm(attn_weights, 
                                value_states)
        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)

        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
        # partitioned aross GPUs when using tensor-parallelism.
        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, None #attn_weights_reshaped


class VitsFeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size)
        self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size)
        self.act_fn = nn.ReLU()
        

        if config.ffn_kernel_size > 1:
            pad_left = (config.ffn_kernel_size - 1) // 2
            pad_right = config.ffn_kernel_size // 2
            self.padding = [pad_left, pad_right, 0, 0, 0, 0]
        else:
            self.padding = None

    def forward(self, hidden_states):
        hidden_states = hidden_states.permute(0, 2, 1)
        if self.padding is not None:
            hidden_states = nn.functional.pad(hidden_states, self.padding)
        hidden_states = self.conv_1(hidden_states)
        hidden_states = self.act_fn(hidden_states)
        if self.padding is not None:
            hidden_states = nn.functional.pad(hidden_states, self.padding)
        hidden_states = self.conv_2(hidden_states)
        hidden_states = hidden_states.permute(0, 2, 1)
        return hidden_states


class VitsEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = VitsAttention(config)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.feed_forward = VitsFeedForward(config)
        self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states,
        output_attentions = False,
    ):
        residual = hidden_states
        hidden_states, attn_weights = self.attention(
            hidden_states=hidden_states,
            # attention_mask=attention_mask,
            output_attentions=output_attentions,
        )

        
        hidden_states = self.layer_norm(residual + hidden_states)

        residual = hidden_states
        hidden_states = self.feed_forward(hidden_states)

        hidden_states = self.final_layer_norm(residual + hidden_states)

        outputs = (hidden_states,)

        return outputs


class VitsEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False
        self.layerdrop = config.layerdrop

    def forward(
        self,
        hidden_states,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
    ):
        for _layer in self.layers:
            layer_outputs = _layer(hidden_states)
            hidden_states = layer_outputs[0]
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            # hidden_states=all_hidden_states,
            # attentions=all_self_attentions,
        )


class VitsTextEncoder(nn.Module):
    """
    Has VitsEncoder
    """

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
        self.encoder = VitsEncoder(config)  # 6 Layers of VitsAttention
        self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)

    def forward(self, 
                input_ids
                ):
        hidden_states = self.embed_tokens(input_ids)   * math.sqrt(self.config.hidden_size)
        last_hidden_state = self.encoder(hidden_states=hidden_states).last_hidden_state

        stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2)
        prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2)

        return VitsTextEncoderOutput(
            last_hidden_state=last_hidden_state,
            prior_means=prior_means,
            # prior_log_variances=prior_log_variances,
            # hidden_states=encoder_outputs.hidden_states,
            # attentions=encoder_outputs.attentions,
        )


class VitsPreTrainedModel(PreTrainedModel):
    config_class = VitsConfig
    base_model_prefix = "vits"
    main_input_name = "input_ids"
    supports_gradient_checkpointing = True



class VitsModel(VitsPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.text_encoder = VitsTextEncoder(config)  # has VitsEncoder that includes 6L of VitsAttention
        self.flow = VitsResidualCouplingBlock(config)
        self.decoder = VitsHifiGan(config)
        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids = None,
        attention_mask = None,
        speaker_id = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
        labels = None,
        speed = None,
        lang_code = 'deu',  # speed oscillation pattern per voice/lang
    ):
        mask_dtype = self.text_encoder.embed_tokens.weight.dtype
        if attention_mask is not None:
            input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
        else:
            input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
        out = self.text_encoder(input_ids=input_ids)
        hidden_states = out.last_hidden_state.transpose(1, 2)
        input_padding_mask = input_padding_mask.transpose(1, 2)
        prior_means = out.prior_means
        bs, _, in_len = hidden_states.shape
        # VITS Duration Oscillation
        if lang_code == 'deu':
            pattern = [1, 2, 1]  # each voice (lang_code) sounds cooler with different pattern
        elif lang_code == 'rmc-script_latin':
            pattern = [2, 2, 1, 2, 2]   # [2, 2, 2, 1, 2]
        elif lang_code == 'hun':
            # pattern = [1, 2, 2, 1, 1, 1] #sounds cool / has valley-pause
            pattern = [1, 2, 1, 1, 1]
        else:
            pattern = [1, 2, 1]
        duration = torch.tensor(pattern, device=hidden_states.device).repeat(int(in_len / len(pattern)) + 2)[None, None, :in_len]   # perhaps define [1, 2, 1] per voice or language
        duration[:, :, 0] = 4
        duration[:, :, -1] = 3
        # ATTN
        predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
        indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
        output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
        output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
        attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
        batch_size, _, output_length, input_length = attn_mask.shape
        cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
        indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
        valid_indices = indices.unsqueeze(0) < cum_duration
        valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
        padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
        attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
        attn = attn[:, 0, :, :]


        attn = attn + 1e-4 * torch.rand_like(attn)
        attn /= attn.sum(2, keepdims=True)
        #print(attn)
        prior_means = torch.matmul(attn, prior_means)  # try attn to contain .5/.5 instead of 1/0 so it smoothly interpolates repeated prior_means

        #prior_means = F.interpolate(prior_means.transpose(1,2),   int(1.74 * prior_means.shape[1]), mode='linear').transpose(1,2)  # extend for slow speed



        # prior means have now been replicated x duration of each prior mean

        latents = self.flow(prior_means.transpose(1, 2), # + torch.randn_like(prior_means) * .94,
                            reverse=True)

        waveform = self.decoder(latents)  # [bs, 1, 16000]

        return waveform[:, 0, :]


class VitsTokenizer(PreTrainedTokenizer):
    vocab_files_names = {"vocab_file": "vocab.json"}
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        vocab_file,
        pad_token="<pad>",
        unk_token="<unk>",
        language=None,
        add_blank=True,
        normalize=True,
        phonemize=True,
        is_uroman=False,
        **kwargs,
    ) -> None:
        with open(vocab_file, encoding="utf-8") as vocab_handle:
            self.encoder = json.load(vocab_handle)

        self.decoder = {v: k for k, v in self.encoder.items()}
        self.language = language
        self.add_blank = add_blank
        self.normalize = normalize
        self.phonemize = phonemize

        self.is_uroman = is_uroman

        super().__init__(
            pad_token=pad_token,
            unk_token=unk_token,
            language=language,
            add_blank=add_blank,
            normalize=normalize,
            phonemize=phonemize,
            is_uroman=is_uroman,
            **kwargs,
        )

    @property
    def vocab_size(self):
        return len(self.encoder)

    def get_vocab(self):
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        vocab.update(self.added_tokens_encoder)
        return vocab

    def normalize_text(self, input_string):
        """Lowercase the input string, respecting any special token ids that may be part or entirely upper-cased."""
        all_vocabulary = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys())
        filtered_text = ""

        i = 0
        while i < len(input_string):
            found_match = False
            for word in all_vocabulary:
                if input_string[i : i + len(word)] == word:
                    filtered_text += word
                    i += len(word)
                    found_match = True
                    break

            if not found_match:
                filtered_text += input_string[i].lower()
                i += 1

        return filtered_text

    def _preprocess_char(self, text):
        """Special treatment of characters in certain languages"""
        if self.language == "ron":
            text = text.replace("ț", "ţ")
        return text

    def prepare_for_tokenization(
        self, text: str, is_split_into_words: bool = False, normalize = None, **kwargs):

        normalize = normalize if normalize is not None else self.normalize

        if normalize:
            # normalise for casing
            text = self.normalize_text(text)

        filtered_text = self._preprocess_char(text)

        if has_non_roman_characters(filtered_text) and self.is_uroman:
            if not is_uroman_available():
                print(
                    "Text to the tokenizer contains non-Roman characters. To apply the `uroman` pre-processing "
                    "step automatically, ensure the `uroman` Romanizer is installed with: `pip install uroman` "
                    "Note `uroman` requires python version >= 3.10"
                    "Otherwise, apply the Romanizer manually as per the instructions: https://github.com/isi-nlp/uroman"
                )
            else:
                uroman = ur.Uroman()
                filtered_text = uroman.romanize_string(filtered_text)

        if self.phonemize:
            if not is_phonemizer_available():
                raise ImportError("Please install the `phonemizer` Python package to use this tokenizer.")

            filtered_text = phonemizer.phonemize(
                filtered_text,
                language="en-us",
                backend="espeak",
                strip=True,
                preserve_punctuation=True,
                with_stress=True,
            )
            filtered_text = re.sub(r"\s+", " ", filtered_text)
        elif normalize:
            # strip any chars outside of the vocab (punctuation)
            filtered_text = "".join(list(filter(lambda char: char in self.encoder, filtered_text))).strip()

        return filtered_text, kwargs

    def _tokenize(self, text: str) -> List[str]:
        """Tokenize a string by inserting the `<pad>` token at the boundary between adjacent characters."""
        tokens = list(text)

        if self.add_blank:
            # sounds dyslexi if no space between letters
            # sounds disconnected if >2 spaces between letters
            interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2) # + 1)  # +1 rises slice index error if tokens odd
            interspersed[::2] = tokens
            tokens = interspersed + [self._convert_id_to_token(0)]  # append one last space (it has indexing error ::2 mismatch if tokens is odd)

        return tokens

    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        return self.encoder.get(token, self.encoder.get(self.unk_token))

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        return self.decoder.get(index)