Why does the model "explode" in memory

#8
by bharatcoder - opened

When I try to run the model locally (Mac Mini 24GB unified), the model explodes in the memory, consuming all of 24GB and triggering swap.
It does not happen with the original Parler-TTS model.

I see the following in the console:

Some weights of the model checkpoint at ai4bharat/indic-parler-tts were not used when initializing ParlerTTSForConditionalGeneration: ['audio_encoder.decoder.block.0.conv_t1.bias', 'audio_encoder.decoder.block.0.conv_t1.weight', 'audio_encoder.decoder.block.0.res_unit1.conv1.bias', 'audio_encoder.decoder.block.0.res_unit1.conv1.weight', 'audio_encoder.decoder.block.0.res_unit1.conv2.bias', 'audio_encoder.decoder.block.0.res_unit1.conv2.weight', 'audio_encoder.decoder.block.0.res_unit1.snake1.alpha', 'audio_encoder.decoder.block.0.res_unit1.snake2.alpha', 'audio_encoder.decoder.block.0.res_unit2.conv1.bias', 'audio_encoder.decoder.block.0.res_unit2.conv1.weight', 'audio_encoder.decoder.block.0.res_unit2.conv2.bias', 'audio_encoder.decoder.block.0.res_unit2.conv2.weight', 'audio_encoder.decoder.block.0.res_unit2.snake1.alpha', 'audio_encoder.decoder.block.0.res_unit2.snake2.alpha', 'audio_encoder.decoder.block.0.res_unit3.conv1.bias', 'audio_encoder.decoder.block.0.res_unit3.conv1.weight', 'audio_encoder.decoder.block.0.res_unit3.conv2.bias', 'audio_encoder.decoder.block.0.res_unit3.conv2.weight', 'audio_encoder.decoder.block.0.res_unit3.snake1.alpha', 'audio_encoder.decoder.block.0.res_unit3.snake2.alpha', 'audio_encoder.decoder.block.0.snake1.alpha', 'audio_encoder.decoder.block.1.conv_t1.bias', 'audio_encoder.decoder.block.1.conv_t1.weight', 'audio_encoder.decoder.block.1.res_unit1.conv1.bias', 'audio_encoder.decoder.block.1.res_unit1.conv1.weight', 'audio_encoder.decoder.block.1.res_unit1.conv2.bias', 'audio_encoder.decoder.block.1.res_unit1.conv2.weight', 'audio_encoder.decoder.block.1.res_unit1.snake1.alpha', 'audio_encoder.decoder.block.1.res_unit1.snake2.alpha', 'audio_encoder.decoder.block.1.res_unit2.conv1.bias', 'audio_encoder.decoder.block.1.res_unit2.conv1.weight', 'audio_encoder.decoder.block.1.res_unit2.conv2.bias', 'audio_encoder.decoder.block.1.res_unit2.conv2.weight', 'audio_encoder.decoder.block.1.res_unit2.snake1.alpha', 'audio_encoder.decoder.block.1.res_unit2.snake2.alpha', 'audio_encoder.decoder.block.1.res_unit3.conv1.bias', 'audio_encoder.decoder.block.1.res_unit3.conv1.weight', 'audio_encoder.decoder.block.1.res_unit3.conv2.bias', 'audio_encoder.decoder.block.1.res_unit3.conv2.weight', 'audio_encoder.decoder.block.1.res_unit3.snake1.alpha', 'audio_encoder.decoder.block.1.res_unit3.snake2.alpha', 'audio_encoder.decoder.block.1.snake1.alpha', 'audio_encoder.decoder.block.2.conv_t1.bias', 'audio_encoder.decoder.block.2.conv_t1.weight', 'audio_encoder.decoder.block.2.res_unit1.conv1.bias', 'audio_encoder.decoder.block.2.res_unit1.conv1.weight', 'audio_encoder.decoder.block.2.res_unit1.conv2.bias', 'audio_encoder.decoder.block.2.res_unit1.conv2.weight', 'audio_encoder.decoder.block.2.res_unit1.snake1.alpha', 'audio_encoder.decoder.block.2.res_unit1.snake2.alpha', 'audio_encoder.decoder.block.2.res_unit2.conv1.bias', 'audio_encoder.decoder.block.2.res_unit2.conv1.weight', 'audio_encoder.decoder.block.2.res_unit2.conv2.bias', 'audio_encoder.decoder.block.2.res_unit2.conv2.weight', 'audio_encoder.decoder.block.2.res_unit2.snake1.alpha', 'audio_encoder.decoder.block.2.res_unit2.snake2.alpha', 'audio_encoder.decoder.block.2.res_unit3.conv1.bias', 'audio_encoder.decoder.block.2.res_unit3.conv1.weight', 'audio_encoder.decoder.block.2.res_unit3.conv2.bias', 'audio_encoder.decoder.block.2.res_unit3.conv2.weight', 'audio_encoder.decoder.block.2.res_unit3.snake1.alpha', 'audio_encoder.decoder.block.2.res_unit3.snake2.alpha', 'audio_encoder.decoder.block.2.snake1.alpha', 'audio_encoder.decoder.block.3.conv_t1.bias', 'audio_encoder.decoder.block.3.conv_t1.weight', 'audio_encoder.decoder.block.3.res_unit1.conv1.bias', 'audio_encoder.decoder.block.3.res_unit1.conv1.weight', 'audio_encoder.decoder.block.3.res_unit1.conv2.bias', 'audio_encoder.decoder.block.3.res_unit1.conv2.weight', 'audio_encoder.decoder.block.3.res_unit1.snake1.alpha', 'audio_encoder.decoder.block.3.res_unit1.snake2.alpha', 'audio_encoder.decoder.block.3.res_unit2.conv1.bias', 'audio_encoder.decoder.block.3.res_unit2.conv1.weight', 'audio_encoder.decoder.block.3.res_unit2.conv2.bias', 'audio_encoder.decoder.block.3.res_unit2.conv2.weight', 'audio_encoder.decoder.block.3.res_unit2.snake1.alpha', 'audio_encoder.decoder.block.3.res_unit2.snake2.alpha', 'audio_encoder.decoder.block.3.res_unit3.conv1.bias', 'audio_encoder.decoder.block.3.res_unit3.conv1.weight', 'audio_encoder.decoder.block.3.res_unit3.conv2.bias', 'audio_encoder.decoder.block.3.res_unit3.conv2.weight', 'audio_encoder.decoder.block.3.res_unit3.snake1.alpha', 'audio_encoder.decoder.block.3.res_unit3.snake2.alpha', 'audio_encoder.decoder.block.3.snake1.alpha', 'audio_encoder.decoder.conv1.bias', 'audio_encoder.decoder.conv1.weight', 'audio_encoder.decoder.conv2.bias', 'audio_encoder.decoder.conv2.weight', 'audio_encoder.decoder.snake1.alpha', 'audio_encoder.encoder.block.0.conv1.bias', 'audio_encoder.encoder.block.0.conv1.weight', 'audio_encoder.encoder.block.0.res_unit1.conv1.bias', 'audio_encoder.encoder.block.0.res_unit1.conv1.weight', 'audio_encoder.encoder.block.0.res_unit1.conv2.bias', 'audio_encoder.encoder.block.0.res_unit1.conv2.weight', 'audio_encoder.encoder.block.0.res_unit1.snake1.alpha', 'audio_encoder.encoder.block.0.res_unit1.snake2.alpha', 'audio_encoder.encoder.block.0.res_unit2.conv1.bias', 'audio_encoder.encoder.block.0.res_unit2.conv1.weight', 'audio_encoder.encoder.block.0.res_unit2.conv2.bias', 'audio_encoder.encoder.block.0.res_unit2.conv2.weight', 'audio_encoder.encoder.block.0.res_unit2.snake1.alpha', 'audio_encoder.encoder.block.0.res_unit2.snake2.alpha', 'audio_encoder.encoder.block.0.res_unit3.conv1.bias', 'audio_encoder.encoder.block.0.res_unit3.conv1.weight', 'audio_encoder.encoder.block.0.res_unit3.conv2.bias', 'audio_encoder.encoder.block.0.res_unit3.conv2.weight', 'audio_encoder.encoder.block.0.res_unit3.snake1.alpha', 'audio_encoder.encoder.block.0.res_unit3.snake2.alpha', 'audio_encoder.encoder.block.0.snake1.alpha', 'audio_encoder.encoder.block.1.conv1.bias', 'audio_encoder.encoder.block.1.conv1.weight', 'audio_encoder.encoder.block.1.res_unit1.conv1.bias', 'audio_encoder.encoder.block.1.res_unit1.conv1.weight', 'audio_encoder.encoder.block.1.res_unit1.conv2.bias', 'audio_encoder.encoder.block.1.res_unit1.conv2.weight', 'audio_encoder.encoder.block.1.res_unit1.snake1.alpha', 'audio_encoder.encoder.block.1.res_unit1.snake2.alpha', 'audio_encoder.encoder.block.1.res_unit2.conv1.bias', 'audio_encoder.encoder.block.1.res_unit2.conv1.weight', 'audio_encoder.encoder.block.1.res_unit2.conv2.bias', 'audio_encoder.encoder.block.1.res_unit2.conv2.weight', 'audio_encoder.encoder.block.1.res_unit2.snake1.alpha', 'audio_encoder.encoder.block.1.res_unit2.snake2.alpha', 'audio_encoder.encoder.block.1.res_unit3.conv1.bias', 'audio_encoder.encoder.block.1.res_unit3.conv1.weight', 'audio_encoder.encoder.block.1.res_unit3.conv2.bias', 'audio_encoder.encoder.block.1.res_unit3.conv2.weight', 'audio_encoder.encoder.block.1.res_unit3.snake1.alpha', 'audio_encoder.encoder.block.1.res_unit3.snake2.alpha', 'audio_encoder.encoder.block.1.snake1.alpha', 'audio_encoder.encoder.block.2.conv1.bias', 'audio_encoder.encoder.block.2.conv1.weight', 'audio_encoder.encoder.block.2.res_unit1.conv1.bias', 'audio_encoder.encoder.block.2.res_unit1.conv1.weight', 'audio_encoder.encoder.block.2.res_unit1.conv2.bias', 'audio_encoder.encoder.block.2.res_unit1.conv2.weight', 'audio_encoder.encoder.block.2.res_unit1.snake1.alpha', 'audio_encoder.encoder.block.2.res_unit1.snake2.alpha', 'audio_encoder.encoder.block.2.res_unit2.conv1.bias', 'audio_encoder.encoder.block.2.res_unit2.conv1.weight', 'audio_encoder.encoder.block.2.res_unit2.conv2.bias', 'audio_encoder.encoder.block.2.res_unit2.conv2.weight', 'audio_encoder.encoder.block.2.res_unit2.snake1.alpha', 'audio_encoder.encoder.block.2.res_unit2.snake2.alpha', 'audio_encoder.encoder.block.2.res_unit3.conv1.bias', 'audio_encoder.encoder.block.2.res_unit3.conv1.weight', 'audio_encoder.encoder.block.2.res_unit3.conv2.bias', 'audio_encoder.encoder.block.2.res_unit3.conv2.weight', 'audio_encoder.encoder.block.2.res_unit3.snake1.alpha', 'audio_encoder.encoder.block.2.res_unit3.snake2.alpha', 'audio_encoder.encoder.block.2.snake1.alpha', 'audio_encoder.encoder.block.3.conv1.bias', 'audio_encoder.encoder.block.3.conv1.weight', 'audio_encoder.encoder.block.3.res_unit1.conv1.bias', 'audio_encoder.encoder.block.3.res_unit1.conv1.weight', 'audio_encoder.encoder.block.3.res_unit1.conv2.bias', 'audio_encoder.encoder.block.3.res_unit1.conv2.weight', 'audio_encoder.encoder.block.3.res_unit1.snake1.alpha', 'audio_encoder.encoder.block.3.res_unit1.snake2.alpha', 'audio_encoder.encoder.block.3.res_unit2.conv1.bias', 'audio_encoder.encoder.block.3.res_unit2.conv1.weight', 'audio_encoder.encoder.block.3.res_unit2.conv2.bias', 'audio_encoder.encoder.block.3.res_unit2.conv2.weight', 'audio_encoder.encoder.block.3.res_unit2.snake1.alpha', 'audio_encoder.encoder.block.3.res_unit2.snake2.alpha', 'audio_encoder.encoder.block.3.res_unit3.conv1.bias', 'audio_encoder.encoder.block.3.res_unit3.conv1.weight', 'audio_encoder.encoder.block.3.res_unit3.conv2.bias', 'audio_encoder.encoder.block.3.res_unit3.conv2.weight', 'audio_encoder.encoder.block.3.res_unit3.snake1.alpha', 'audio_encoder.encoder.block.3.res_unit3.snake2.alpha', 'audio_encoder.encoder.block.3.snake1.alpha', 'audio_encoder.encoder.conv1.bias', 'audio_encoder.encoder.conv1.weight', 'audio_encoder.encoder.conv2.bias', 'audio_encoder.encoder.conv2.weight', 'audio_encoder.encoder.snake1.alpha', 'audio_encoder.quantizer.quantizers.0.codebook.weight', 'audio_encoder.quantizer.quantizers.0.in_proj.bias', 'audio_encoder.quantizer.quantizers.0.in_proj.weight', 'audio_encoder.quantizer.quantizers.0.out_proj.bias', 'audio_encoder.quantizer.quantizers.0.out_proj.weight', 'audio_encoder.quantizer.quantizers.1.codebook.weight', 'audio_encoder.quantizer.quantizers.1.in_proj.bias', 'audio_encoder.quantizer.quantizers.1.in_proj.weight', 'audio_encoder.quantizer.quantizers.1.out_proj.bias', 'audio_encoder.quantizer.quantizers.1.out_proj.weight', 'audio_encoder.quantizer.quantizers.2.codebook.weight', 'audio_encoder.quantizer.quantizers.2.in_proj.bias', 'audio_encoder.quantizer.quantizers.2.in_proj.weight', 'audio_encoder.quantizer.quantizers.2.out_proj.bias', 'audio_encoder.quantizer.quantizers.2.out_proj.weight', 'audio_encoder.quantizer.quantizers.3.codebook.weight', 'audio_encoder.quantizer.quantizers.3.in_proj.bias', 'audio_encoder.quantizer.quantizers.3.in_proj.weight', 'audio_encoder.quantizer.quantizers.3.out_proj.bias', 'audio_encoder.quantizer.quantizers.3.out_proj.weight', 'audio_encoder.quantizer.quantizers.4.codebook.weight', 'audio_encoder.quantizer.quantizers.4.in_proj.bias', 'audio_encoder.quantizer.quantizers.4.in_proj.weight', 'audio_encoder.quantizer.quantizers.4.out_proj.bias', 'audio_encoder.quantizer.quantizers.4.out_proj.weight', 'audio_encoder.quantizer.quantizers.5.codebook.weight', 'audio_encoder.quantizer.quantizers.5.in_proj.bias', 'audio_encoder.quantizer.quantizers.5.in_proj.weight', 'audio_encoder.quantizer.quantizers.5.out_proj.bias', 'audio_encoder.quantizer.quantizers.5.out_proj.weight', 'audio_encoder.quantizer.quantizers.6.codebook.weight', 'audio_encoder.quantizer.quantizers.6.in_proj.bias', 'audio_encoder.quantizer.quantizers.6.in_proj.weight', 'audio_encoder.quantizer.quantizers.6.out_proj.bias', 'audio_encoder.quantizer.quantizers.6.out_proj.weight', 'audio_encoder.quantizer.quantizers.7.codebook.weight', 'audio_encoder.quantizer.quantizers.7.in_proj.bias', 'audio_encoder.quantizer.quantizers.7.in_proj.weight', 'audio_encoder.quantizer.quantizers.7.out_proj.bias', 'audio_encoder.quantizer.quantizers.7.out_proj.weight', 'audio_encoder.quantizer.quantizers.8.codebook.weight', 'audio_encoder.quantizer.quantizers.8.in_proj.bias', 'audio_encoder.quantizer.quantizers.8.in_proj.weight', 'audio_encoder.quantizer.quantizers.8.out_proj.bias', 'audio_encoder.quantizer.quantizers.8.out_proj.weight', 'decoder.lm_heads.weight']
- This IS expected if you are initializing ParlerTTSForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ParlerTTSForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ParlerTTSForConditionalGeneration were not initialized from the model checkpoint at ai4bharat/indic-parler-tts and are newly initialized: ['audio_encoder.model.decoder.model.0.bias', 'audio_encoder.model.decoder.model.0.weight_g', 'audio_encoder.model.decoder.model.0.weight_v', 'audio_encoder.model.decoder.model.1.block.0.alpha', 'audio_encoder.model.deco

My code is as follows:

import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf

device = "mps"

my_model = "ai4bharat/indic-parler-tts"
#my_model = "parler-tts/parler-tts-mini-v1"

model = ParlerTTSForConditionalGeneration.from_pretrained(my_model).to(device)
tokenizer = AutoTokenizer.from_pretrained(my_model)
#description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)
description_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")

prompt = "अरे, तुम आज कैसे हो?"
description = "Divya's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."


inputs = description_tokenizer(description, return_tensors="pt").to(device)
#inputs = tokenizer(description, return_tensors="pt").to(device)
prompts = tokenizer(prompt, return_tensors="pt").to(device)


generation = model.generate(input_ids=inputs.input_ids, 
                            prompt_input_ids=prompts.input_ids, 
                            attention_mask=inputs.attention_mask,
                            decoder_attention_mask=prompts.attention_mask,
                            prompt_attention_mask=prompts.attention_mask)
'''
generation = model.generate(input_ids=inputs.input_ids, 
                            prompt_input_ids=prompts.input_ids, 
                            attention_mask=inputs.attention_mask)
'''
audio_arr = generation.cpu().numpy().squeeze()
sf.write("indic_tts_out.wav", audio_arr, model.config.sampling_rate)

Screenshot 2025-03-16 at 10.18.00 AM.png

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment