About IndexTTS

IndexTTS arXiv:2502.05512是B站 Index 团队开源的一款语音合成模型TTS,支持中文、英文的零样本语音克隆。特色是参数量小还可以用拼音声调来控制中文多音字发音。其基本结构基于 Tortoise TTSXTTS,声码器(Vocoder) 则采用 BigVGAN。虽然官方报告中提到了支持合成可控情绪音频,但实际目前并未开放相关能力的代码和使用方式

本文为作者记录学习和微调 IndexTTS 以生成带有可控情绪的语音音频的过程。

微调实验结果

以下使用 NVIDIA GeForce RTX 4070 大约半小时微调后的 IndexTTS 所生成的中英文语音音频样例:

参考音频文本合成的语音试听
Elise-1Hey there my name is Elise,and I’m a speech generation model that can sound like a person.
Elise-1你 好 , 我 是 ELISE, 一 个 语 音 生 成 模 型 ,我 的 声 音 听 起 来 跟 真 人 一 样 .
Female-1Seriously? <giggles> That’s the cutest thing I’ve ever heard!
Female-1真的吗? <giggles> 这也太可爱了吧!
Male-1Wha—? Cute? <giggles> You think I’m cute?! Well, uh, thanks, I guess?
Male-1哎呀! 忘了他还在那等我们呢! <giggles> 我们两个动作得快点了!

完整的实验 Jupyter Notebook 见仓库 yrom/finetune-index-tts


模型结构与微调思路

IndexTTS 基于 Tortoise TTS,采用了多个模块协同,其主要流程如下:

图为作者绘制

要达到微调目标,主要涉及两个核心部分:

  • BPE 分词器: 基于 sentencepiece,将文本和新增的情绪标签(如 <GIGGLES>)编码为词表序列ID。
  • GPT2 自回归模块:通过微调学习根据目标情绪标签ID生成合适的音频 latent 表示。

微调实验流程总览

  1. 微调数据集准备:收集并整理带有情绪标签的文本-音频对齐数据。
  2. 扩充 BPE 词表:将自定义情绪标签(如 <LAUGHS>, <GIGGLES>, <SIGHS>, <CHUCKLES>)加入文本分词器词表。
  3. 模型参数调整:同步调整 GPT2 文本的 embedding 层与输出 head 层的通道数到新词表大小。
  4. 数据预处理:裁剪音频数据中的长静音,重采样到24KHZ,获得log-Mel Spectrogram,并使用 DiscreteVAE 对其进行离散编码,获得目标样本音频对应的 mel 编码ID,也就是用于训练 GPT2 自回归生成的目标 mel 编码ID序列的 Ground Truth
  5. LoRA 微调:使用 Hugging Face的 PEFT 对 GPT2 自注意力 (Self Attention)和前馈网络(MLP)部分进行微调。

数据集准备与音频处理

直接使用 MrDragonFox/Elise 数据集,其包含多情绪标注的文本与音频对齐样本。

数据加载示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from datasets import load_dataset
import re
ds = load_dataset("MrDragonFox/Elise", split="train")
emotion_tag_pattern = re.compile(r"<[a-z]+>", re.IGNORECASE)
texts = ds["text"]
emotion_tags = dict()
for text in texts:
m = re.search(emotion_tag_pattern, text)
if m:
tag = m.group(0)
tag = tag.upper()
if tag not in emotion_tags:
emotion_tags[tag] = 0
print(f"Found emotion tag: {tag}, text: {text}")
emotion_tags[tag] += 1

print("Emotion tags:")
for tag, count in emotion_tags.items():
print(f"{tag}: {count}")

输出:

Found emotion tag: <LAUGHS>, text: That's so sweet. And I hadn't even promoted it, I just like put in the descriptions of stuff and whatnot, and it just <laughs>. I was so surprised.
Found emotion tag: <GIGGLES>, text: Oh, God. <giggles> I'm just so happy. Oh, and it's all your fault. <giggles>
Found emotion tag: <SIGHS>, text: Deal with it. I will. I'll just scowl and watch TV by myself <sighs>.
Found emotion tag: <SNIFFS>, text: Wait a minute. No, that-that man over there, he's dressed different. <sniffs> Oh, he smells different.
Found emotion tag: <CHUCKLES>, text: I knew you two would get close quickly. Score another one for me! <chuckles> Sheesh!
Found emotion tag: <CHUCKLE>, text: <Chuckle> Hmm. Huh? Oh, no, no, I was just...
...
Emotion tags:
<LAUGHS>: 157
<GIGGLES>: 33
<SIGHS>: 72
<SNIFFS>: 3
<CHUCKLES>: 10
<CHUCKLE>: 1
...

情绪标签的含义如下:

<LAUGHS> 大笑 哈哈哈 ha-ha-ha
<GIGGLES> 咯咯笑 咯咯咯 嘻嘻嘻 hee-hee
<CHUCKLES> 轻笑 呵呵呵 heh-heh
<SIGHS> 叹气 唉 sigh    
<SNIFFS> 嗅
<YAWNING> 打哈欠
<SINGING> 唱歌
<CHEWING> 咀嚼
<GASPS> 倒吸气
<SCOFFS> 嘲笑
<SMOOCHES> 吻
<WHISPERS> 低语
<EXHALES> 呼气
<MOANS> 呻吟
<COUGHS> 咳嗽

扩充文本Tokenizer词表以支持情绪标签

为支持自定义情绪标签,需要为 sentencepiece BPE 模型扩充词表。 遵循sentencepiece官方示例

做为示例,这里只增加 <LAUGHS>, <GIGGLES>, <SIGHS>, <CHUCKLES> 这4个数据集中样本比较多的标签:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import sentencepiece.sentencepiece_model_pb2 as sp_model

def expand_vocab(path, new_path, additional_special_tokens=[]):
m = sp_model.ModelProto()
with open(path, "rb") as f:
m.ParseFromString(f.read())
for token in additional_special_tokens:
new_token = sp_model.ModelProto().SentencePiece()
new_token.piece = token
new_token.score = 0
m.pieces.append(new_token)

with open(new_path, "wb") as f:
f.write(m.SerializeToString())
print(f"Expanded BPE model saved to: {new_path}")

expand_vocab(
"bpe.model",
"bpe_new.model",
additional_special_tokens=["<LAUGHS>", "<GIGGLES>", "<SIGHS>", "<CHUCKLES>"],
)

测试新增的情绪标签是否能正常编码:

1
2
3
4
5
6
7
8
9
10

new_bpe_model = spm.SentencePieceProcessor("bpe_new.model")
new_bpe_model.load(model_path)

for tag in additional_special_tokens:
token_id = new_bpe_model.piece_to_id(tag)
if token_id == new_bpe_model.unk_id():
print(f"Tag '{tag}' not found in the new BPE model!")
else:
print(f"Tag '{tag}' ID: {token_id}")

输出:

Tag '<LAUGHS>' ID: 12000
Tag '<GIGGLES>' ID: 12001
Tag '<SIGHS>' ID: 12002
Tag '<CHUCKLES>' ID: 12003

数据预处理,获得经过 DVAE 编码后的梅尔频谱ID

IndexTTS 使用 VQ-VAE (实际为 DiscreteVAE)对音频的梅尔频谱进行离散编码。

1
2
3
4
5
6
7
8
9
10
11
12
from indextts.vqvae.xtts_dvae import DiscreteVAE
from omegaconf import OmegaConf

config_path = os.path.join(model_dir, "config.yaml")
config = OmegaConf.load(config_path)

dvae = DiscreteVAE(**config.vqvae)
pre_trained_dvae = torch.load(
os.path.join(model_dir, config.dvae_checkpoint), map_location="cpu", weights_only=True
)
dvae.load_state_dict(pre_trained_dvae["model"] if "model" in pre_trained_dvae else pre_trained_dvae, strict=True)
dvae.eval()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch

def process_text(text: str, text_tokenizer):
return text_tokenizer.EncodeAsIds(text.upper())

@torch.no_grad()
def process_audio_data(audio: torch.Tensor | np.ndarray, sr, dvae: DiscreteVAE, mel_config):
"""Generate discrete codes from mel spectrograms using the DiscreteVAE."""
from indextts.utils.feature_extractors import MelSpectrogramFeatures

mel_feature = MelSpectrogramFeatures(**mel_config)
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
if sr != mel_config.sample_rate:
audio = torchaudio.transforms.Resample(sr, mel_config.sample_rate)(audio)
mel = mel_feature(torch.tensor(audio) if not torch.is_tensor(audio) else audio)
codes = dvae.get_codebook_indices(mel)
if audio.ndim > 1 and audio.shape[0] == 1:
audio = audio.squeeze(0)
return audio, mel, codes

def process_sample(sample: dict, text_tokenizer, dvae, mel_config):
if "text_ids" not in sample or sample["text_ids"] is None:
text_ids = process_text(sample["text"], text_tokenizer)
sample["text_ids"] = text_ids
audio_value = sample["audio"]
sr = audio_value["sampling_rate"]
with torch.no_grad():
audio, mel, codes = process_audio_data(audio_value["array"], sr, dvae, mel_config)
sample["mel"] = mel.squeeze(0) if mel.ndim == 3 else mel
sample["codes"] = codes.squeeze(0) if codes.ndim == 2 else codes
return sample

展示随机样本的音频:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
ids: IterableDataset = ds.to_iterable_dataset()
supported_tags_pattern = re.compile("|".join([re.escape(tag) for tag in additional_special_tokens]), re.IGNORECASE)
tags_count = {tag: 0 for tag in additional_special_tokens}

def filter_tags(sample, tags_count=tags_count):
text = sample["text"].upper()
if m := re.search(supported_tags_pattern, text):
tag = m.group(0)
tag = tag.upper()
tags_count[tag] += 1
# Limit the number of samples per tag
return tags_count[tag] <= 1
return False

random_samples = (
ids.shuffle(seed=33, buffer_size=10)
.filter(filter_tags)
.take(10)
.with_format("numpy")
.map(
process_sample,
fn_kwargs={"text_tokenizer": new_bpe_model, "dvae": dvae, "mel_config": config.mel},
)
)

for i, sample in enumerate(random_samples):
print(f"Text: {sample['text']}")
print(f"Text IDs shape: {sample['text_ids'].shape}")
print(f"Mel shape: {sample['mel'].shape}")
print(f"Codes: {sample['codes'].shape}")
plot_audio_sample(
sample["mel"],
sample["codes"],
num_codes=config.vqvae.num_tokens,
title=f"Sample {i + 1}",
)
display(Audio(sample["audio"]["array"], rate=sample["audio"]["sampling_rate"].item()))
print("-" * 40)
点击展开 plot_audio_sample 的实现代码...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from IPython.display import display, Audio
from datasets import IterableDataset
import re
import matplotlib.pyplot as plt
import numpy as np


def plot_audio_sample(
mel: np.ndarray,
codes_indices: np.ndarray = None,
num_codes=8192,
title="Sample Audio spectrogram",
):
# Visualize the audio waveform (mono channel) and mel spectrogram
fig, ax1 = plt.subplots(figsize=(10, 4))

# Mel spectrogram on left y-axis
im = ax1.imshow(
mel,
aspect="auto",
origin="lower",
interpolation="none",
cmap="magma",
)
ax1.set_xlabel("Time (frames)")
ax1.set_ylabel("Mel Frequency Channels")
ax1.set_title(title)
# show colorbar below the image
cbar = plt.colorbar(im, ax=ax1, pad=0.1)
cbar.set_label("Amplitude (dB)")
# Code indices on right y-axis
if codes_indices is not None:
downsample_factor = 4 # 每4帧对应一个mel code
code_time_axis = np.arange(0, len(codes_indices) * downsample_factor, downsample_factor)
ax2 = ax1.twinx()
scatter = ax2.scatter(code_time_axis, codes_indices, c="cyan", marker="*", zorder=10, alpha=0.8, s=20)
ax2.set_ylim(0, num_codes-1)
ax2.set_ylabel("Codebook Index")
plt.tight_layout()
plt.show()

输出:

Text: Cutie. <giggles> They’d impale you if I tried to take a bite, and obviously I want your blood.
Text IDs shape: (30,)
Mel shape: (100, 539)
Codes: (135,)
audio sample1

Text: Running through the grass, playing under the falling leaves. <laughs> My sweet little kit, the-
Text IDs shape: (28,)
Mel shape: (100, 769)
Codes: (193,)
audio sample2


图示的梅尔频谱的右侧是对应的离散编码索引(每4帧25Hz对应一个编码)

从图中可以观察得知数据集中的原始音频存在很长的静音段,这对于微调实验来说是有害的,因为它会导致模型学习到错误的停顿。

有趣的静音频帧id

使用 IndexTTS 1.0 版本dvae.pth 权重编码后静音频帧 id 为 52, 而 1.5 版本的则为 428audio sample2 with 1.0

音频预处理:裁剪掉超长静音

因为微调数据集样本量有限,需要保证音频质量,这里对所有音频样本通过librosa库进行静音裁剪。

关于top_db

这里的静音裁剪是基于音频的能量阈值(top_db)来判断的,可能有误判。
top_db的含义为比最大音频能量低多少分贝的部分被认为是静音,这里设为20dB。可以根据实际情况调整 top_db 参数。

核心代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import librosa
import torch
import numpy as np
def shrink_audio_long_silence(audio: torch.Tensor | np.ndarray, sr, top_db=20.0):
"""Shrink long silence in audio. mono audio shape: (T,)"""
import librosa.effects
audio_np = audio.numpy() if torch.is_tensor(audio) else audio
if audio_np.ndim > 1:
audio_np = audio_np[0]
nonsilence_idx = librosa.effects.split(audio_np, top_db=top_db)
if nonsilence_idx.size == 0:
return audio
# Concatenate non-silence segments
pad = max(1024 * 2, sr // 10)
padded_nonsilence_idx = []
for start, end in nonsilence_idx:
if len(padded_nonsilence_idx) == 0:
padded_nonsilence_idx.append((start if start < pad else start - pad, end))
else:
last_start, last_end = padded_nonsilence_idx[-1]
if start - last_end <= pad:
padded_nonsilence_idx[-1] = (last_start, end) # Extend the last segment
else:
if start - last_end > pad * 4:
start = start - pad
padded_nonsilence_idx[-1] = (last_start, last_end + pad)
padded_nonsilence_idx.append((start - pad, end))
nonsilence_audio = np.concatenate([audio_np[start:end] for start, end in padded_nonsilence_idx])
if torch.is_tensor(audio):
nonsilence_audio = torch.tensor(nonsilence_audio, dtype=audio.dtype, device=audio.device)
return nonsilence_audio.unsqueeze(0) if audio.ndim > 1 else nonsilence_audio

经过处理后,音频的静音段被裁剪掉,保留了主要的语音内容。如下图所示为经过处理后的音频样例:

shrinked audio sample2

扩展文本Embedding层与Head层

新增加了文本Token,相应地,文本的 embedding 层与输出 head 层的通道数也需要扩增。

如果是transformers里的PreTrainedModel 可以直接调用其 resize_token_embeddings 方法:

1
2
3
4
5
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained('xxx')
new_vocab_size = ... # 比如 12004
model.resize_token_embeddings(new_vocab_size)

然而 IndexTTS 是直接沿用的 Tortoise TTS 的UnifiedVoice模型结构,其通过自行拼接 inputs_embedsGPT2 前向传播推理,如下图所示:

UnifiedVoice forward
UnifiedVoice 的前向传播流程
图为作者绘制

核心代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

self.gpt = GPT2Model(...)

cond_mel_spec = ...
text_ids = ...
mel_codes = ...
speech_conditioning_latent = self.perceiver_encoder(self.conditioning_encoder(cond_mel_spec))
text_embs = self.text_embedding(text_ids) + self.text_pos_embedding(text_ids)
mel_embs = self.mel_embedding(mel_codes) + self.mel_pos_embedding(mel_codes)
# 拼接speech_conditioning_latent, text_embs, mel_embs 为一个输入
embs = torch.cat([speech_conditioning_latent, text_embs, mel_embs], dim=1)
# gpt 前向传播
gpt_out = self.gpt(inputs_embeds=embs, return_dict=True)
# 获取最后的隐藏状态(ln_f的输出)
hidden_state = gpt_out.last_hidden_state
text_len = text_embs.shape[1]
mel_codes_len = mel_codes.shape[1:]
offset = speech_conditioning_latent.shape[1]
# 获取文本和梅尔编码的 latent 表示
h = hidden_state[i, offset:].unsqueeze(0) # (1, T+S, D)
# 再经过最终的归一化层
latent = self.final_norm(h)
text_latent, mel_latent = latent[:, :text_len], latent[:, -mel_codes_len:]
# 获取文本和梅尔编码的 logits
text_logits, mel_logits = self.text_head(text_latent), self.mel_head(mel_latent)
...

这里对预训练权重文件中里的 text_embeddingtext_head 相关参数做类似的处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
added_num_tokens = 4 # 新增的情绪标签数量
config_path = os.path.join(model_dir, "config.yaml")
config = OmegaConf.load(config_path)
# 加载预训练的 UnifiedVoice 模型权重,其中保存的主要是 GPT2
pre_trained_gpt_path = os.path.join(model_dir, config.gpt_checkpoint)
pre_trained_gpt = torch.load(pre_trained_gpt_path, map_location="cpu", weights_only=True)
# 将文本 embedding 和 head 层权重调整为新的词表大小
resized_text_state_dict = resize_text_embedding_weights(
pre_trained_gpt["model"] if "model" in pre_trained_gpt else pre_trained_gpt,
config.gpt.number_text_tokens, # 原有文本token数量
added_num_tokens,
)
# 保存调整后的文本 embedding 和 head 层权重
finetune_model_dir = "finetune_models"
os.makedirs(finetune_model_dir, exist_ok=True)
resized_model_path = os.path.join(finetune_model_dir, "gpt_resized.pth")
torch.save(resized_text_state_dict, resized_model_path)
# 保存新的 config
new_config = config.copy()
new_config.gpt_checkpoint = "gpt_resized.pth"
new_gpt_config = config.gpt.copy()
new_gpt_config.number_text_tokens = config.gpt.number_text_tokens + added_num_tokens
new_config.gpt = new_gpt_config
new_config.dataset.bpe_model = "bpe_new.model"
new_config.dataset.additional_special_tokens = {
tag: new_bpe_model.piece_to_id(tag) for tag in ["<LAUGHS>", "<GIGGLES>", "<SIGHS>", "<CHUCKLES>"]
}
new_config_path = os.path.join(finetune_model_dir, "config.yaml")
OmegaConf.save(new_config, new_config_path)

其中 resize_text_embedding_weights(...)函数为参考transformers里的的resize_token_embeddings的实现。

点击查看resize_token_embeddings函数实现代码...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch.nn.functional as F
from torch.distributions import constraints, multivariate_normal

def init_resized_embeddings_weights_with_mean(old_embedding_weight: torch.Tensor, old_num_tokens, added_num_tokens):
# follow transformers/modeling_utils.py: _init_added_embeddings_weights_with_mean
otype = old_embedding_weight.dtype
reserved_tokens = old_embedding_weight.shape[0] - old_num_tokens
text_embed_dim = old_embedding_weight.shape[-1]
assert reserved_tokens >= 0, "The number of old tokens must be greater than or equal to the number of new tokens."
if reserved_tokens > 0:
reserved_embedding_weight = old_embedding_weight[old_num_tokens:, :]
old_embedding_weight = old_embedding_weight[:old_num_tokens, :]
new_embedding_weight = torch.cat(
[
old_embedding_weight,
torch.zeros((added_num_tokens, text_embed_dim), dtype=otype, device=old_embedding_weight.device),
reserved_embedding_weight,
],
dim=0,
)
else:
new_embedding_weight = torch.zeros(
(old_num_tokens + added_num_tokens, text_embed_dim), dtype=otype, device=old_embedding_weight.device
)
new_embedding_weight[:old_num_tokens, :] = old_embedding_weight
old_embed_weight_f32 = old_embedding_weight.to(torch.float32)
mean_embed = torch.mean(old_embed_weight_f32, axis=0)
old_centered_embed = old_embed_weight_f32 - mean_embed
covariance = old_centered_embed.T @ old_centered_embed / old_num_tokens

epsilon = 1e-9
is_covariance_psd = constraints.positive_definite.check(epsilon * covariance).all()
if is_covariance_psd:
distribution = multivariate_normal.MultivariateNormal(mean_embed, covariance_matrix=epsilon * covariance)
new_embedding_weight[old_num_tokens : old_num_tokens + added_num_tokens, :] = distribution.sample(
sample_shape=(added_num_tokens,)
).to(otype)
else:
new_embedding_weight[old_num_tokens : old_num_tokens + added_num_tokens, :] = (
mean_embed[None, :].repeat(added_num_tokens, 1).to(otype)
)
return new_embedding_weight


def resize_text_embedding_weights(state_dict: dict, old_number_text_tokens, added_num_tokens):
old_text_embedding_weight = state_dict["text_embedding.weight"]
new_text_embedding_weight = init_resized_embeddings_weights_with_mean(
old_text_embedding_weight, old_number_text_tokens, added_num_tokens
)

# Resize the text head to match the new number of text tokens
old_text_head_weight = state_dict["text_head.weight"]
new_text_head_weight = init_resized_embeddings_weights_with_mean(
old_text_head_weight, old_number_text_tokens, added_num_tokens
)
old_text_head_bias = state_dict.get("text_head.bias", None)
new_state_dict = state_dict.copy() # shallow copy
new_state_dict["text_embedding.weight"] = new_text_embedding_weight
new_state_dict["text_head.weight"] = new_text_head_weight
if old_text_head_bias is not None:
new_state_dict["text_head.bias"] = init_added_text_head_bias_with_mean(
old_text_head_bias, old_number_text_tokens, added_num_tokens
)
return new_state_dict

详细的代码可以查看 Jupyter Notebook preprocess_mel_dataset.ipynb

至此,UnifiedVoice 新的权重文件、BPE 模型和配置文件都准备好了,接下来就可以开始微调了。

微调 UnifiedVoice 模型

微调的核心步骤如下:

  1. 训练文本新增token的 Embedding 层与输出 Head层,使其能够处理新增的情绪标签。
  2. UnifiedVoice 里使用的 GPT2 模型进行 LoRA 微调,使其学习到情绪标签和对应的声音音频的潜在表示之间的关系。

加载新的模型配置和权重:

1
2
3
4
5
6
7
import os
from omegaconf import OmegaConf
finetune_model_dir = "finetune_models"

new_config_path = os.path.join(finetune_model_dir, "config.yaml")
config = OmegaConf.load(new_config_path)
print(config.dataset)

输出如下:

bpe_model: bpe_new.model
sample_rate: 24000
squeeze: false
mel:
    sample_rate: 24000
    n_fft: 1024
    hop_length: 256
    win_length: 1024
    n_mels: 100
    mel_fmin: 0
    normalize: false
    additional_special_tokens:
        <LAUGHS>: 12000
        <GIGGLES>: 12001
        <SIGHS>: 12002
        <CHUCKLES>: 12003
1
2
3
4
5
6
7
8
9
10
11
12
13
from indextts.gpt.model import UnifiedVoice

def load_UnifiedVoice(gpt_config, gpt_checkpoint_path, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(gpt_checkpoint_path, map_location=device, weights_only=True)
state_dict = state_dict["model"] if "model" in state_dict else state_dict
model = UnifiedVoice(**gpt_config)
model.load_state_dict(state_dict, strict=True)
model.post_init_gpt2_config()
del state_dict
return model.to(device)
model = load_UnifiedVoice(config.gpt, os.path.join(finetune_model_dir, config.gpt_checkpoint))

UnifiedVoice 的训练函数定义

由于 UnifiedVoiceGPT2 的前向推理(forward)逻辑进行了自定义,无法直接使用 Trainer 进行微调,因此需要自行实现训练循环。

训练的本质

每次训练迭代包括一次前向传播、损失计算和反向传播,通过梯度下降来更新模型参数。

需要注意的是,自回归模型是基于当前输入 n 个 token 预测下一个(第n+1个) token,如下图所示:


输入:<s>hell 输出:o
图为作者绘制

因此在计算损失时需要将模型输出的 logits 做一次偏移后再与目标labels 计算损失。

点击查看train_UnifiedVoiceforward_UnifiedVoice函数实现代码...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def train_UnifiedVoice(
model: UnifiedVoice,
mel_spec: torch.FloatTensor,
mel_codes: torch.LongTensor,
text_ids: torch.LongTensor,
output_loss=True,
output_logits=False,
add_mel_stop_token=True,
loss_reduction="mean",
):
model.train()
model.inference_model.kv_cache = False # disable kv cache for training
return forward_UnifiedVoice(
model,
mel_spec=mel_spec,
mel_codes=mel_codes,
text_ids=text_ids,
add_mel_stop_token=add_mel_stop_token,
output_loss=output_loss,
output_logits=output_logits,
output_latent=False,
loss_reduction=loss_reduction,
)

def forward_UnifiedVoice(
model: UnifiedVoice,
mel_spec: torch.FloatTensor,
mel_codes: torch.LongTensor,
text_ids: torch.LongTensor,
add_mel_stop_token=True,
output_loss=True,
output_logits=True,
output_latent=False,
loss_reduction="mean",
device=None,
):
"""Forward pass for UnifiedVoice model.
mel_spec: (1, 100, T)
mel_codes: (1, s)
text_ids: (1, t)
"""
if device is None:
device = model.final_norm.weight.device
mel_spec = mel_spec.to(device)
mel_codes = mel_codes.to(device)
text_ids = text_ids.to(device=device)
cond_mel_lengths = torch.tensor([mel_spec.shape[-1]], device=device)
conditioning_latent = model.get_conditioning(mel_spec, cond_mel_lengths)
text_inputs = F.pad(text_ids, (1, 1), value=model.start_text_token)
text_inputs[:, -1] = model.stop_text_token
# shift labels
text_targets = text_inputs[:, 1:].clone().contiguous()
mel_codes = F.pad(mel_codes, (1, 0), value=model.start_mel_token)
if add_mel_stop_token:
mel_codes = F.pad(mel_codes, (0, 1), value=model.stop_mel_token)
mel_targets = mel_codes[:, 1:].clone().contiguous()

text_emb = model.text_embedding(text_inputs) + model.text_pos_embedding(text_inputs)
mel_emb = model.mel_embedding(mel_codes) + model.mel_pos_embedding(mel_codes)
inputs_embeds = torch.cat([conditioning_latent, text_emb, mel_emb], dim=1)
gpt2_outputs = forward_gpt2(
model,
inputs_embeds,
torch.tensor([text_inputs.shape[-1]], device=device),
torch.tensor([mel_codes.shape[-1]], device=device),
attention_mask=None,
output_latent=output_latent,
output_logits=output_logits or output_loss,
)
assert isinstance(gpt2_outputs, dict), "gpt2_outputs should be a dict"

outputs = {}
if output_logits or output_loss:
text_logits, mel_logits = gpt2_outputs["logits"]
# shift logits, n tokens predict n+1
text_logits = text_logits[:, :, :-1].contiguous()
mel_logits = mel_logits[:, :, :-1].contiguous()
if output_loss:
loss_text = F.cross_entropy(text_logits, text_targets.long(), reduction=loss_reduction)
loss_mel = F.cross_entropy(mel_logits, mel_targets.long(), reduction=loss_reduction)
outputs["loss"] = (loss_text, loss_mel)
if output_logits:
outputs["logits"] = (text_logits, mel_logits)
outputs["targets"] = (text_targets, mel_targets)
del text_logits, mel_logits, text_targets, mel_targets
if output_latent:
text_latent, mel_latent = gpt2_outputs["latent"]
outputs["latent"] = (text_latent, mel_latent)
del gpt2_outputs, text_emb, mel_emb, conditioning_latent, text_inputs, mel_codes
if device.type == "cuda":
torch.cuda.empty_cache()
return outputs

def forward_gpt2(
model: UnifiedVoice,
inputs_embeds: torch.FloatTensor,
text_lengths: torch.LongTensor,
codes_lengths: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
output_latent: bool = False,
output_logits: bool = True,
):
"""Forward pass for the GPT2Model of UnifiedVoice."""
b = inputs_embeds.shape[0]
gpt_out = model.gpt(inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True)
hidden_state = gpt_out.last_hidden_state

outputs = []
for i in range(b):
text_len = text_lengths[i].item()
mel_codes_len = codes_lengths[i].item()
offset = inputs_embeds.shape[1] - text_len - mel_codes_len
h = hidden_state[i, offset:].unsqueeze(0) # (1, T+S, D)
latent = model.final_norm(h)
text_latent, mel_latent = latent[:, :text_len], latent[:, -mel_codes_len:]
text_logits = model.text_head(text_latent)
mel_logits = model.mel_head(mel_latent)
text_logits = text_logits.permute(0, 2, 1)
mel_logits = mel_logits.permute(0, 2, 1)
output = {}
if output_logits:
output["logits"] = (text_logits, mel_logits)
if output_latent:
output["latent"] = (text_latent, mel_latent)
outputs.append(output)
if b == 1:
return outputs[0]
return outputs

特训文本的 Embedding 层与输出 Head 层

Embedding层的本质

Embedding 层实质是一个查找表(lookup table),其将每个 token 的 ID 映射到一个高维向量空间中。 通过这种方式,模型可以将离散的 token ID 映射到连续的向量空间中,从而使模型能学习到 token 之间的语义关系。 Head 则是将模型的输出映射到目标 token 的 logits 上。

使用 t-SNE 降维查看新增的情绪标签 Token Embbeding 的分布情况:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

def visualize_embeddings(
text_embedding, new_token_ids, token_labels=None, sample_ids=None, title="Text embbdings", show=True
):
"""Visualize the text embeddings using t-SNE."""
embeddings = text_embedding.weight.detach().cpu().numpy()
# Select the embeddings for the new tokens and some random samples
sample_indices = new_token_ids.cpu().numpy()
if sample_ids is not None:
sample_indices = np.concatenate((sample_indices, sample_ids))

tsne = TSNE(n_components=2, init="random", random_state=42, perplexity=30, max_iter=1000)
reduced = tsne.fit_transform(embeddings[sample_indices])

plt.figure(figsize=(6, 4))

if sample_ids is not None:
offset = len(new_token_ids)
new_token_idx = reduced[:offset:, :]
scatter = plt.scatter(
reduced[offset:, 0],
reduced[offset:, 1],
alpha=0.5,
label="Original tokens",
)
else:
new_token_idx = reduced
plt.scatter(new_token_idx[:, 0], new_token_idx[:, 1], alpha=0.8, label="New tokens", c="red")
if token_labels is not None:
for i, (x, y) in enumerate(new_token_idx):
plt.annotate(
token_labels[i],
(x, y),
fontsize=12,
)
plt.legend()
plt.title(title)
if not show:
plt.savefig(f"{'_'.join(title.split())}.png")
if show:
plt.show()
点击查看visualize_text_embeddings代码...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import functools
import numpy as np
new_token_ids = torch.tensor(
list(config.dataset.additional_special_tokens.values()),
dtype=torch.long,
)

np.random.seed(100)
sample_ids = list(np.random.randint(10, 12000, size=300))

visualize_text_embeddings = functools.partial(
visualize_embeddings,
new_token_ids=new_token_ids,
token_labels=list(config.dataset.additional_special_tokens.keys()),
sample_ids=sample_ids,
show=True,
)
visualize_text_embeddings(
model.text_embedding,
title="Text embeddings before training",
)
visualize_text_embeddings(
model.text_head,
title="Text head weight before training",
)

Text embeddings before training

Text head weight before training

NewTokensTrainableAdapter

这里引入一个 NewTokensTrainableAdapter 类来包装原有的文本 Embedding 和 Head 层,冻结原有token 相关的权重W0W_0 ,只训练新token 的ΔW\Delta WW=WR(k+δ)×d=[W0Rk×dΔWRδ×d]W' = W \in \mathbb{R}^{(k+\delta) \times d} = \begin{bmatrix} W_0 \in \mathbb{R}^{k \times d} \\ \Delta W \in \mathbb{R}^{\delta \times d} \end{bmatrix}

点击查看 `NewTokensTrainableAdapter`的实现代码...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn.functional as F
class NewTokensTrainableAdapter(torch.nn.Module):
"""Wrapper for the new tokens training"""
merged: bool = False
base_layer: torch.nn.Embedding | torch.nn.Linear
new_token_ids: torch.LongTensor
def __init__(self, base_layer: torch.nn.Embedding | torch.nn.Linear, new_token_ids: torch.Tensor):
super().__init__()
self.base_layer = base_layer
self.new_token_ids = new_token_ids.long()
self.register_parameter(
"new_tokens_weight",
torch.nn.Parameter(base_layer.weight.data[new_token_ids].clone().detach(), requires_grad=True),
)
# linear layer bias
if isinstance(base_layer, torch.nn.Linear):
if base_layer.bias is not None:
self.register_parameter(
"new_tokens_bias",
torch.nn.Parameter(base_layer.bias.data[new_token_ids].clone().detach(), requires_grad=True),
)
else:
self.register_parameter("new_tokens_bias", None)
self.base_layer.requires_grad_(False) # freeze the base layer

def requires_grad_(self, requires_grad: bool):
"""Set the requires_grad attribute for the new tokens weights and bias."""
self.new_tokens_weight.requires_grad_(requires_grad)
if isinstance(self.base_layer, torch.nn.Linear) and self.new_tokens_bias is not None:
self.new_tokens_bias.requires_grad_(requires_grad)
return self

def __getattr__(self, name):
if name == "weight":
if self.merged:
# if merged, return the merged weight
return self.base_layer.weight
return self.get_merged_weight()
return super().__getattr__(name)

def get_merged_weight(self):
W = self.base_layer.weight
index = self.new_token_ids.to(W.device)
deltas = self.new_tokens_weight.to(W)
return W.index_copy(dim=0, index=index, source=deltas)

def forward(self, x):
"""Forward method to replace the base layer with new tokens."""
if self.merged:
return self.base_layer(x)
W = self.get_merged_weight()
if isinstance(self.base_layer, torch.nn.Linear):
b = self.base_layer.bias
if b is not None:
bias_deltas = self.new_tokens_bias.to(b)
b = b.index_copy(dim=0, index=self.new_token_ids, source=bias_deltas)
else:
b = None
return F.linear(x, weight=W, bias=b)

return F.embedding(
x,
weight=W,
padding_idx=self.base_layer.padding_idx,
max_norm=self.base_layer.max_norm,
norm_type=self.base_layer.norm_type,
scale_grad_by_freq=self.base_layer.scale_grad_by_freq,
sparse=self.base_layer.sparse,
)

def merge(self):
"""Merge the new tokens weights into the base layer."""
if not self.merged:
self.merged = True
self.base_layer.weight.copy_(self.get_merged_weight())
if isinstance(self.base_layer, torch.nn.Linear) and self.new_tokens_bias is not None:
b = self.base_layer.bias
if b is not None:
b.index_copy_(dim=0, index=self.new_token_ids.to(b.device), source=self.new_tokens_bias.to(b))

def unmerge(self):
if self.merged:
self.merged = False

使用 NewTokensTrainableAdapter 替换掉UnifiedVoice 的文本 Embedding 和 Head 层,并冻结其它所有参数权重:

1
2
3
4
model.requires_grad_(False)  # freeze the model parameters

model.text_embedding = NewTokensTrainableAdapter(model.text_embedding, new_token_ids)
model.text_head = NewTokensTrainableAdapter(model.text_head, new_token_ids)

可训练参数如下表:

nameshapedtypeparams
text_embedding.new_tokens_weight(4, 1280)torch.float325120
text_head.new_tokens_weight(4, 1280)torch.float325120
text_head.new_tokens_bias(4,)torch.float324

对新增Token 的 Embedding 进行特调

使用 AdamW 优化器对NewTokensTrainableAdapter的参数进行训练。

BATCH_SIZE 设置为1,而GRAD_ACCUMULATION_STEPS 这里设置为 2,表示每 2 个 样本累积一次梯度。

LEARNING_RATE 设为1e-4,对weight参数进行4倍率调高,并加上0.1weight_decay,以防止过拟合。

同时,针对新增的文本 Token 的损失值进行加权处理,主要是为了让NewTokensTrainableAdapternew_tokens_weightnew_tokens_bias参数能够更快地收敛。

点击查看NewTokensTrainableAdapter的训练代码...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from torch.optim import AdamW
EPOCHS = 2
BATCH_SIZE = 1 # always 1 here
GRAD_ACCUMULATION_STEPS = 2 # accumulate gradients for 2 steps
WEIGHT_DECAY = 0.1 # weight decay for the new tokens training
MAX_GRAD_NORM = 3.0 # max gradient norm for clipping
LEARNING_RATE = 1e-4 # learning rate for the new tokens training

optimizer = AdamW(
[
{
"params": [
("text_embed.weight", model.text_embedding.new_tokens_weight),
("text_head.weight", model.text_head.new_tokens_weight),
],
"weight_decay": weight_decay,
"lr": LEARNING_RATE * 4,
},
{
"params": [
("text_head.bias", model.text_head.new_tokens_bias),
],
"weight_decay": 0.0,
"lr": LEARNING_RATE,
},
],
)
proc_seq_count = 0
batched_loss = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
new_token_weight = 5 # weight for the new tokens loss
for epoch in range(EPOCHS):
acc_loss = torch.tensor(0, device=device, dtype=torch.float32)

for idx in range(len(processed_ds)):
mels, mel_codes, text_ids = as_torch_tensor(processed_ds[idx], device=device)
outputs = train_UnifiedVoice(
model,
mels,
mel_codes[:, :5], # 只取前5个mel codes,这一阶段只训练文本相关的层
text_ids,
add_mel_stop_token=False, # 不加 mel stop 标记
output_loss=False,
output_logits=True,
)
text_logits = outputs["logits"][0]
text_targets = outputs["targets"][0]
# 这里只希望针对新增token位置训练,所以在对应位置的loss值上加权:
loss_text = F.cross_entropy(text_logits, text_targets, reduction="none")
new_token_mask = torch.isin(text_targets, new_token_ids)
weight_mask = torch.ones_like(text_targets, dtype=torch.float32, device=device)
weight_mask[new_token_mask] = new_token_weight
weighted_loss = (loss_text * weight_mask.view(-1)).mean()
if torch.isnan(weighted_loss):
continue # skip NaN losses
proc_seq_count += 1
weighted_loss /= BATCH_SIZE * GRAD_ACCUMULATION_STEPS
weighted_loss.backward()
acc_loss += weighted_loss.detach()
if proc_seq_count % (BATCH_SIZE * GRAD_ACCUMULATION_STEPS) == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=MAX_GRAD_NORM,
)
optimizer.step()
optimizer.zero_grad()
acc_loss.zero_()

loss曲线如下:

Training loss for new tokens

训练后的文本 Embedding 分布情况:

Text embeddings after training

数据增强(交叉匹配):提升带情绪标签样本的数量

由于数据集中带有特定情绪标签的样本量较少(如 <giggles>仅有区区33条),远远不足以让模型学习到情绪标签和其相应mel code的对应关系。为此,这里采用数据增强策略:

  • 为每个目标情绪标签,找到随机 n 条带相同标签的文本,与参考音频配对,形成一条新的样本,提升学习速度;
  • 对不带标签的普通音频,也配对 m * n 条带目标情绪标签文本,丰富训练样本,提升泛化性。 (m 为目标情绪标签的数量)

为什么用交叉匹配来增强数据?

由于 IndexTTS 主要用于零样本声音克隆合成参考音频相似音色的语音,所以原始数据集的 audio 列不仅可以用于获取训练 mel codes的 Ground Truth,也可以作为参考语音音频。
那么,只要某条样本的文本中包含我们要的情绪标签(如<giggles>),就可以将其与另一条语音音频当作参考音频进行配对,形成这样一个新的样本: audio 和 mel spectrogram 保留做为参考音频数据,但文本和mel codes被换成另一个样本的。
这样可以使模型更好的学习到情绪标签和mel codes之间的对应关系。
但也会带来副作用,如果训练轮数过多,模型很可能会忽略参考音频的情绪、语速、音调等音色特征,而只关注到文本和mel codes之间的对应关系。

增强样本示例:

ref audiotarget textmel codes groud truth
audio sample1…AREN’T YOU? <giggles> I LOVE YOU TOO …(audio sample1) mel codes
audio sample2…AREN’T YOU? <giggles> I LOVE YOU TOO …(audio sample1) mel codes
audio sample2Cutie. <giggles> They’d impale you …(audio sample2) mel codes
audio sample3Cutie. <giggles> They’d impale you …(audio sample2) mel codes

核心代码片段:

class AugmentMelDataset 的定义
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class AugmentMelDataset(torch.utils.data.IterableDataset):
def __init__(self, dataset, augmented_pairs):
self.dataset = dataset
self.augmented_pairs = augmented_pairs
def __len__(self):
return len(self.augmented_pairs)
def __getitem__(self, idx):
idx1, idx2 = self.augmented_pairs[idx]
if idx2 is None:
return self.dataset[idx1]
ref, target = self.dataset[idx1], self.dataset[idx2]
return {
"audio": ref["audio"],
"mel": ref["mel"],
"codes": target["codes"],
"text": target["text"],
"text_ids": target["text_ids"],
}
def __iter__(self):
return (self.__getitem__(i) for i in range(len(self.augmented_pairs)))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def create_augment_mel_datasets(
dataset, max_new_pairs=2, emotion_tag_ids=[], augment_emotion_tag_ids=None, seed=333,
):
from collections import defaultdict
import random
random.seed(seed)
emotion_samples = defaultdict(list)
# 1. 按情绪标签分组
def find_emotion_tag(text_ids):
for tag_id in emotion_tag_ids:
if tag_id in text_ids:
return tag_id
return -1
for idx in range(len(dataset)):
sample = dataset[idx]
text_ids = sample["text_ids"]
emotion_tag_id = find_emotion_tag(text_ids)
emotion_samples[emotion_tag_id].append(idx)
non_emotion_samples = emotion_samples.pop(-1, [])
if not augment_emotion_tag_ids:
augment_emotion_samples = list(emotion_samples.values())
else:
augment_emotion_samples = [
emotion_samples[tag_id]
for tag_id in augment_emotion_tag_ids
if tag_id in emotion_samples and len(emotion_samples[tag_id]) > 0
]
augment_emotion_pairs = []
# 2. 针对目标情绪标签,交叉配对生成新样本
for sample_indices in augment_emotion_samples:
if len(sample_indices) == 1:
augment_emotion_pairs.append((sample_indices[0], None))
continue
for idx in sample_indices:
other_indices = [i for i in sample_indices if i != idx]
augment_ids = random.sample(other_indices, min(max_new_pairs, len(other_indices)))
for other_idx in augment_ids:
augment_emotion_pairs.append((idx, other_idx))
# 3. 处理无情绪标签的样本
for idx in non_emotion_samples:
for sample_indices in augment_emotion_samples:
if len(sample_indices) == 0:
continue
augment_ids = random.sample(sample_indices, min(max_new_pairs, len(sample_indices)))
for other_idx in augment_ids:
augment_emotion_pairs.append((idx, other_idx))
# 4. 构建增强数据集对象
return AugmentMelDataset(dataset, augment_emotion_pairs)

取一部分增强后的数据作为验证集和测试集:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
test_pairs = []
validation_pairs = []
...
for idx in non_emotion_samples:
sample = dataset[idx]
if sample["mel"].shape[1] < 250:
continue
for sample_indices in augment_emotion_samples:
if len(sample_indices) == 1:
augment_emotion_pairs.append((idx, sample_indices[0]))
continue
augment_ids = random.sample(other_indices, min(max_new_pairs, len(sample_indices)))
if len(augment_ids) == len(other_indices):
validation_ids = []
else:
validation_ids = [i for i in other_indices if i not in augment_ids]
validation_ids = random.sample(validation_ids, min(max_new_pairs, len(validation_ids)))
for other_idx in augment_ids:
augment_emotion_pairs.append((idx, other_idx))
for other_idx in validation_ids:
validation_pairs.append((idx, other_idx))
if len(other_indices) - len(augment_ids) - len(validation_ids) > 0:
remaining_indices = [i for i in other_indices if i not in augment_ids and i not in validation_ids]
remaining_indices = random.sample(remaining_indices, min(max_new_pairs, len(remaining_indices)))
for other_idx in remaining_indices:
test_pairs.append((idx, other_idx))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
emotion_tag_ids = list(config.dataset.additional_special_tokens.values())
# select one or more emotion tags to augment mel dataset
augment_emotion_tag_ids = [
config.dataset.additional_special_tokens["<GIGGLES>"],
]
train_augment_ds, valid_ds, test_ds = create_augment_mel_datasets(
processed_ds,
max_new_pairs=2,
emotion_tag_ids=emotion_tag_ids,
augment_emotion_tag_ids=augment_emotion_tag_ids,
seed=233,
)
print(f"Augmented mel dataset size: {len(train_augment_ds)}, for target emotion tags: {augment_emotion_tag_ids}")
print(f"Validation dataset size: {len(valid_ds)}")
print(f"Test dataset size: {len(test_ds)}")

输出为:

Augmented mel dataset size: 1762, for target emotion tags: [12001]
Validation dataset size: 293
Test dataset size: 10

对 GPT2 进行 LoRA 微调

在完成 UnifiedVoice 的新文本 Token的针对性训练后,接下来使用 🤗 PEFTGPT2Model 进行 LoRA (Low-Rank Adaptation[1]) 微调。

lora
LoRA图解,来自LoRA 论文

如上图所示,LoRA 核心思路是冻结目标层的原始预训练的权重矩阵W0W_0 , 并用两个低秩矩阵BAB A 相乘来近似需要学习的新权重矩阵ΔW\Delta WWnew=W0+ΔW=W0+BAW_\text{new} = W_0 + \Delta W = W_0 + B A

前向传播时:h=W0x+ΔWx=W0x+BAxh = W_0 x + \Delta W x = W_0 x + B A x

训练时,再给∆W 乘以一个缩放系数 α/r ,其中 r 是低秩矩阵的秩(rank),rα 都是 LoRA 的超参数(亦即为常数):

ΔW=αrBA\Delta W = \frac{\alpha}{r} B A

其中W0Rd×kW_0 \in \mathbb{R} ^ {d \times k}BRd×rB \in \mathbb{R} ^ {d \times r}ARr×kA \in \mathbb{R} ^ {r \times k}rmin(d,k)r \ll \text{min}(d, k)

从而将目标层需要训练的参数量从d×kd \times k下降到d×r+r×kd \times r + r \times k

对于 GPT2 来说, 只微调自注意力层和前馈网络层就够了:

1
2
3
4
5
6
7
8
9
10
11
12
from peft import get_peft_model, LoraConfig, TaskType
adapter_name = "new_tokens"
gpt_lora_config = LoraConfig(
r=16,
target_modules=["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"],
task_type=TaskType.CAUSAL_LM,
lora_alpha=32,
lora_dropout=0.1,
bias="none",
)
model.requires_grad_(False)
model.inference_model = get_peft_model(model.inference_model, gpt_lora_config, adapter_name=adapter_name)

GPT2 LoRA
GPT2 LoRA 适配器的结构图
作者标注,原图来自FLUID-GPT (Yang et al., 2023)[2]

可训练参数如下表:

nameshapedtypeparams
0text_embedding.new_tokens_weight(4, 1280)torch.float325120
1gpt.h.0.attn.c_attn.lora_A.new_tokens.weight(16, 1280)torch.float3220480
2gpt.h.0.attn.c_attn.lora_B.new_tokens.weight(3840, 16)torch.float3261440
3gpt.h.0.attn.c_proj.lora_A.new_tokens.weight(16, 1280)torch.float3220480
4gpt.h.0.attn.c_proj.lora_B.new_tokens.weight(1280, 16)torch.float3220480
...............
190gpt.h.23.mlp.c_fc.lora_B.new_tokens.weight(5120, 16)torch.float3281920
191gpt.h.23.mlp.c_proj.lora_A.new_tokens.weight(16, 5120)torch.float3281920
192gpt.h.23.mlp.c_proj.lora_B.new_tokens.weight(1280, 16)torch.float3220480
193text_head.new_tokens_weight(4, 1280)torch.float325120
194text_head.new_tokens_bias(4,)torch.float324

使用 LoRA+[3] 方法来调节适配器里的lora_Alora_B两个矩阵的学习率,进一步提升训练效率:

1
2
3
4
5
6
7
8
9
10
11
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
# Efficient Low Rank Adaptation of Large Models: https://arxiv.org/abs/2402.12354
optimizer = create_loraplus_optimizer(
model=model,
optimizer_cls=AdamW,
lr=LEARNING_RATE,
loraplus_lr_ratio=8, # lr ratio for lora_B weights
loraplus_weight_decay=WEIGHT_DECAY,
# loraplus_lr_embedding=2e-5, # not used, since we are training text embedding with NewTokensTrainableAdapter
)

训练代码与之前类似,就不重复了,具体可以查看 fine_tune_indextts.ipynb notebook

使用交叉匹配后的 1200 个样本对 GPT2 的微调训练损失曲线如下:

GPT2 training loss

微调完成后,将新的权重和配置保存即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

@torch.no_grad()
def merge_lora_weights(model: UnifiedVoice, unload=False):
model.text_embedding.merge()
model.text_head.merge()
if unload:
model.inference_model = model.inference_model.merge_and_unload()
model.text_embedding = model.text_embedding.base_layer
model.text_head = model.text_head.base_layer
else:
model.inference_model.merge_adapter()
print("Merged LoRA weights into the model.")
return model


@torch.no_grad()
def unmerge_lora_weights(model: UnifiedVoice):
model.inference_model.unmerge_adapter()
model.text_embedding.unmerge()
model.text_head.unmerge()
print("Unmerged LoRA adapters")
return model


def save_checkpoint(model: UnifiedVoice, checkpoint_path, merge_lora=True, unload_after_merge=False):
"""Save the model checkpoint."""
from collections import OrderedDict

checkpoint_state_dict = OrderedDict()
model.eval()
if merge_lora:
model = merge_lora_weights(model, unload=unload_after_merge)
state_dict = model.state_dict()
for key, value in state_dict.items():
if not key.startswith(("gpt.wte", "inference_model.")) and "new_tokens" not in key:
checkpoint_state_dict[key] = value
torch.save(checkpoint_state_dict, checkpoint_path)
model.train()
if merge_lora and not unload_after_merge:
unmerge_lora_weights(model)
del checkpoint_state_dict, state_dict
print(f"UnifiedVoice checkpoint saved to: {checkpoint_path}")
return checkpoint_path

final_checkpoint_path = os.path.join(finetune_model_dir, "gpt_finetuned.pth")
save_checkpoint(model, final_checkpoint_path, merge_lora=True, unload_after_merge=True)
new_config_path = os.path.join(finetune_model_dir, "config_finetuned.yaml")
new_config = config.copy()
new_config.gpt_checkpoint = "gpt_finetuned.pth"
OmegaConf.save(new_config, new_config_path)

测试微调模型的效果

用微调中没有出现过的音频和文本进行测试,验证模型是否能正确生成带有情绪标签的语音。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import urllib.request
import os
audio_prompts = [
{
"id": "Female-1",
"name": "Cyndi",
"url": "https://bytedancespeech.github.io/seedtts_tech_report/audios/SpeechFactorization_samples/prompt/prompt1/4813840990459345930.wav",
"lang": "en",
},
{
"id": "Male-1",
"name": "Michael",
"url": "https://bytedancespeech.github.io/seedtts_tech_report/audios/SpeechFactorization_samples/source/2188769758301752050.wav",
"lang": "en",
},
]

prompts_dir = os.path.join("tests", "prompts")
os.makedirs(prompts_dir, exist_ok=True)

for audio_prompt in audio_prompts:
audio_path = os.path.join(prompts_dir, f"{audio_prompt['id']}.wav")
if not os.path.exists(audio_path):
urllib.request.urlretrieve(
audio_prompt["url"],
audio_path,
)

output_dir = os.path.join("tests", "output")
if os.path.exists(output_dir):
import shutil

shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)

welcome_texts = [
r"Hey there my name is {name}, <giggles> and I'm a speech generation model that can sound like a person.",
r"大家好,我是 {name},一个语音生成模型,<giggles> 我的声音听起来跟真人一样。",
]
en_texts = [
"Seriously? <giggles> That's the cutest thing I've ever heard!",
"Oh, my gosh! <giggles> She called my name many times! <giggles> I— I’m so excited!",
"Wha-? Cute? <giggles> You think I'm cute? Well, uh, thanks, I guess?",
]
zh_texts = [
"真的吗? <giggles> 这也太可爱了吧!",
"我的天啊!<giggles> 她叫了好几次我的名字。<giggles> 我…我好兴奋啊!",
"你说什么?可爱?<giggles> 你觉得我可爱?<giggles> 嗯…呃… 谢谢。",
]

使用 IndexTTS 的推理 API 加载微调后的权重和配置:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

import torch
import transformers
from indextts.infer import IndexTTS
seed = 42
transformers.set_seed(seed)

# 微调后的模型配置路径
finetune_model_dir = "finetune_models"
finetuned_config_path = os.path.join(finetune_model_dir, "config_finetuned.yaml")

tts = IndexTTS(
cfg_path=finetuned_config_path,
model_dir=finetune_model_dir,
use_cuda_kernel=False,
)
def to_safe_file_name(text: str):
import re
return re.sub(r"[,'\"?.<>:,。?!“’”‘&*#@~/\\\s]", "", text).strip()
all_texts = en_texts + zh_texts
for audio_prompt in audio_prompts:
audio_path = os.path.join(prompts_dir, f"{audio_prompt['id']}.wav")
print(f"**Audio Prompt**: {audio_prompt['id']} ({audio_prompt['lang']})")
print(audio_path)
for text in welcome_texts + all_texts:
if r'{name}' in text:
text = text.format(name=audio_prompt['name'])
print("**Text**: `{text}`".format(text=text))
file_name = to_safe_file_name(text[:40])
output_path = os.path.join(output_dir, f"{audio_prompt['id']}_{file_name}.wav")
tts.infer(
audio_prompt=audio_path,
text=text,
verbose=False,
output_path=output_path,
)
print(output_path)

你可以 clone notebook 代码 到本地打开fine_tune_indextts.ipynb文件或者直接用Google Colab打开 ,滚动到最后面的测试部分,试听一下notebook中的测试结果。

总结

通过上述微调流程,实现了让 IndexTTS 的核心模块 UnifiedVoice 能够理解和生成带有特殊标签 <GIGGLES> 的文本所对应的咯咯笑声 mel 编码潜在表示。

但由于用来微调的数据集样本量过小,且只有英文样本,导致微调后模型产生过拟合现象。

当使用微调后的模型对英文文本进行合成语音时,生成的语音音色与预期输入存在明显偏差(这里需要与未微调前做客观指标对比,纯听感只是主观对比),更倾向于模仿微调数据集Elise的音色特征(如说话的语速、音高、气口等)。模型似乎遗忘了对英文文本克隆参考语音音色特征的能力,几乎可以说变成了 Elise 专用版了。

真实微调场景中,除非就需要生成更贴近某个说话人的音色特征,否则需要大量数据集来微调以保持模型的泛化性。

当然你也可以尝试改动 Notebook 中的训练超参数(如学习率、GRADIENT_ACCUMULATION_STEPS、MAX_GRAD_NORM等),或者改变增强数据集的方式,来试试是否能够得到更好的效果。

下一步的微调实验方向可能有(挖坑:

  • 搜集更大的数据集来微调模型,提升模型的泛化性;
  • 尝试其它模型后训练方法来微调GPT模型,如强化学习(RL),以支持更复杂的可控性(如语速、音调、情感等);
  • 尝试用其它语言的语音数据集,如广东话、日语等,提升模型的多语言能力;
  • 用知识蒸馏的方式把 GPT2 换成更「现代」的模型,如 Qwen3、minicpm4 等,看看是否能得到更好的效果。

Refs


  1. 1.Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan AllenZhu, Yuanzhi Li, Shean Wang, Lu Wang, and Weizhu Chen. Lora: Low-rank adaptation of large language models. arXiv:2106.09685, 2021
  2. 2.Yang S, Ali Z, Wong B. FLUID-GPT (Fast Learning to Understand and Investigate Dynamics with a Generative Pre-Trained Transformer): Efficient Predictions of Particle Trajectories and Erosion. ChemRxiv. 2023; doi:10.26434/chemrxiv-2023-ppk9s
  3. 3.Soufiane Hayou, Nikhil Ghosh, Bin Yu. LoRA+: Efficient Low Rank Adaptation of Large Models. arXiv:2402.12354, 2024