Inspired by the human hearing mechanism using contextual information to trace and distinguish the speech stream in the speech mixture, a novel contextual understanding with contextual embedding is proposed for the overlapped speech processing. This section begins with an introduction to the baseline models for speech separation and multi-talker speech recognition. Following this, the model for predicting contextual embeddings from the speech mixture is introduced. Strategies for incorporating these embeddings into speech separation and multi-talker automatic speech recognition (ASR) are then described. Additionally, advanced optimization techniques are developed to further enhance system performance. Finally, the detailed configuration of the entire system is outlined.
Traditional model for multi-talker speech processing
There are two widely used approaches in speech separation, the T-F masking approach8,9,25,26,27,28,29,30 and the time-domain approach31,32. In this paper, we selected three different model architectures from these two categories as baselines to validate the effectiveness of our proposed method in speech separation.
1) The T-F masking method usually consists of three steps:
$${\bf{Y}}=\mathrm{STFT}\,({\bf{y}})$$
(1)
$$[{\hat{{\bf{M}}}}_{1},\cdots \,,{\hat{{\bf{M}}}}_{J}]=\mathrm{Separator}\,(| {\bf{Y}}| )$$
(2)
$${\hat{{\bf{x}}}}_{j}={\mathrm{iSTFT}}\,({\hat{{\bf{M}}}}_{j}\odot | {\bf{Y}}| ,\angle {\bf{Y}})$$
(3)
where \({\bf{y}}\in {{\mathbb{R}}}^{N}\) is the input speech mixture of length N. Firstly, y is transformed to the T-F domain \({\bf{Y}}\in {{\mathbb{C}}}^{T\times F}\) via the short-time Fourier transform (STFT). The magnitude spectrum \(| {\bf{Y}}| \in {{\mathbb{R}}}^{T\times F}\) is then processed by a neural-based Separator to estimate T-F masks \({\hat{{\bf{M}}}}_{j}\in {{\mathbb{R}}}^{T\times F}\) for each speaker j. Then the magnitude spectrum for each speaker can be predicted by element-wise multiplication between the estimated mask \({\hat{{\bf{M}}}}_{j}\) and mixed magnitude spectrum Y. Finally, the estimated magnitude \({\hat{{\bf{M}}}}_{j}\odot | {\bf{Y}}|\) is combined with the noisy phase spectrum of the input ∠Y to reconstruct the STFT for each speaker j. And the inverse STFT (iSTFT) is performed to convert the estimated STFT spectrum back to the time domain.
There are various deep learning models for estimating the T-F mask, and we adopt the popular and powerful ones, i.e., the bidirectional long-short-term memory (BLSTM) and residual network (ResNet), as two baselines in this paper. The structure and specific configurations of these two models can be found in “Model configuration”. In this paper, we predict the T-F mask on the amplitude spectrum and ignore the phase of the spectrum. It is worth mentioning that some existing works estimate the complex mask of the spectrum or consider the phase reconstruction28,33,34,35.
2) The separation procedure of the time-domain method can be formulated as follows:
$${\bf{W}}=\,\text{Encoder}\,({\bf{y}})$$
(4)
$$[{\tilde{{\bf{M}}}}_{1},\cdots \,,{\tilde{{\bf{M}}}}_{J}]=\,\text{Separator}\,({\bf{W}})$$
(5)
$${\hat{{\bf{x}}}}_{j}=\,\text{Decoder}\,({\tilde{{\bf{M}}}}_{j}\odot {\bf{W}})$$
(6)
The mixed speech y is encoded into high-level audio representation \({\bf{W}}\in {{\mathbb{R}}}^{L\times C}\) by a 1D convolutional Encoder net, where L is the frame length and C is the number of convolution channels. Then, the Separator net predicts masks \({\tilde{{\bf{M}}}}_{1},\cdots \,,{\tilde{{\bf{M}}}}_{j}\in {{\mathbb{R}}}^{L\times C}\) for each speaker in the representation space. The mask \({\tilde{{\bf{M}}}}_{j}\) is element-wise multiplied with the high-level audio representation W. Finally, an 1D transposed convolutional Decoder reconstructs the estimated audio \({\hat{{\bf{x}}}}_{j}\) for each speaker in the time domain.
In this paper, we also utilize the TasNet framework as a baseline to evaluate the proposed method, the structure of the model can be found in ref. 13 and the configuration is detailed in “Model configuration”.
In this paper, we use the end-to-end multi-speaker speech recognition architecture proposed in refs. 11, 22, which was extended from the joint CTC/attention-based single-speaker end-to-end ASR system36,37,38 as the baseline model.
The architecture of this multi-speaker speech recognition model is depicted in Fig. 5. As shown in Fig. 5, the end-to-end multi-speaker speech recognition architecture is composed of three main modules: the multi-speaker encoder, the CTC module with permutation invariant training, and the attention-based decoder.
This figure illustrates an end-to-end multi-speaker speech recognition system. The model consists of a multi-talker encoder that separates the input mixture of multiple speakers' voices into distinct features in the feature space. These separated features are then processed by a shared decoder and CTC modules, which generate text transcriptions for each speaker.
The multi-speaker encoder module performs separation in the feature space implicitly and extracts high-level representations of all speakers for later recognition. It can be decomposed into three stages: EncoderMix, EncoderSD, and EncoderRec. First, the input mixture O of J speakers is processed by the mixture encoder EncoderMix to generate an intermediate representation Hasr. Then, J parallel speaker-differentiating (SD) encoders EncoderSD extract speaker-related representations \({{\bf{H}}}_{\,\text{asr}\,}^{j}(j=1,\cdots \,,J)\) from Hasr simultaneously. Finally, each representation is encoded by the shared-weight recognition encoder EncoderRec to generate high-level representations \({\hat{{\bf{G}}}}^{j}\) for the final decoding. The encoding process can be formulated as below:
$${{\bf{H}}}_{{\rm{asr}}}={\mathrm{Encoder}\,}_{{\rm{Mix}}}({\bf{O}})\,,$$
(7)
$${{\bf{H}}}_{\,\text{asr}}^{j}={\mathrm{Encoder}\,}_{\text{SD}\,}^{j}({{\bf{H}}}_{{\rm{asr}}}),j=1,\cdots \,,J\,,$$
(8)
$${{\bf{G}}}^{j}={\mathrm{Encoder}\,}_{{\rm{Rec}}}({{\bf{H}}}_{\,\text{asr}\,}^{j}),j=1,\cdots \,,J\,.$$
(9)
The CTC module performs decoding on each high-level representation \({\hat{{\bf{G}}}}^{j}\) which is used as an auxiliary task36,37,38 to train the encoder part jointly with the attention-based decoder. Meanwhile, the CTC loss is also used to solve the permutation problem between multiple outputs and labels by applying the permutation invariant training technique:
$$\hat{\pi }=\arg \mathop{\min }\limits_{\pi \in {\mathcal{P}}}\mathop{\sum }\limits_{j=1}^{J}{\mathrm{Loss}\,}_{{\rm{ctc}}}({\hat{{\bf{R}}}}_{\,\text{ctc}\,}^{j},{{\bf{R}}}^{\pi (j)})\,,$$
(10)
where \(\pi \in {\mathcal{P}}\) enumerates all possible permutations on J (output, label) pairs, \({\hat{{\bf{R}}}}_{\,\text{ctc}\,}^{j}\) is the output sequence decoded from the representation Gj, π(j) is the j-th label index in a permutation π, and R is the set of reference labels corresponding to J speakers. Later, the best permutation \(\hat{\pi }\) that minimizes the overall CTC loss is chosen for reordering the reference labels in the attention-based decoder.
The attention-based decoder module decodes each stream Gj separately and generates the corresponding output sequence \({\hat{{\bf{R}}}}_{\,\text{dec}\,}^{j}\). For each decoding output \({\hat{{\bf{R}}}}_{\,\text{dec}\,}^{j}\) and the corresponding reference label \({{\bf{R}}}^{\hat{\pi }(j)}\), the decoding process can be described as follows:
$${\hat{r}}_{\,\text{dec},n}^{j} \sim \text{Attention-Decoder}\,({{\bf{G}}}^{j},{z}_{n-1}^{j})$$
(11)
where \({\hat{r}}_{\,\text{dec}\,,n}^{j}\) denotes the decoding output at the n-th time step, and \({z}_{n-1}^{j}\) is the (n − 1)-th element in either the reference label sequence Rπ(j) or the predicted label sequence \({\hat{{\bf{R}}}}_{\,\text{dec}\,}^{j}\). The technique of choosing the source of \({z}_{n-1}^{j}\) during training is also known as scheduled sampling11,39.
The final loss function \({{\mathcal{L}}}_{{\rm{mtl}}}\) for multi-talker ASR is defined as the combination of losses in the CTC module and the attention-based decoder:
$${{\mathcal{L}}}_{{\rm{mtl}}}=\lambda {{\mathcal{L}}}_{{\rm{ctc}}}+(1-\lambda ){{\mathcal{L}}}_{{\rm{att}}}\,,$$
(12)
where λ is the interpolation factor, and 0 ≤ λ ≤ 1. \({{\mathcal{L}}}_{{\rm{ctc}}}\) and \({{\mathcal{L}}}_{{\rm{att}}}\) are based on the cross entropy (CE) loss functions. We follow the joint CTC/Attention decoding proposed by ref. 38 in the inference stage.
Contextual embedding learning
Motivated by the usage of contextual cues14,15 in human’s selective hearing mechanism, a contextual embedding learning framework is designed to exploit the contextual information directly from the input speech mixture, which can be then integrated to speech related tasks to improve the performance of speech processing systems under the cocktail party scenario.
A knowledge distillation framework is proposed to construct the contextual embedding learning module, which is illustrated in Fig. 6. The knowledge distillation is previously used in speech recognition or other machine learning tasks40,41,42, but this is the first time to be applied in the context embedding learning by utilizing the single-talker speech system to teach the multi-talker speech system.
This figure illustrates the proposed knowledge distillation architecture designed to transfer knowledge from a single-speaker ASR system to a multi-speaker ASR system, specifically for the purpose of contextual embedding learning.
The left part of Fig. 6 is the architecture of the contextual embedding predictor, which can extract the contextual embeddings from the mixture. It adopts a similar structure to the encoder in the end-to-end multi-talker ASR system introduced in section “Traditional model for multi-talker speech processing”, and it is comprised of three main modules, i.e., EncoderMix, EncoderSD, and EncoderCtx. The input speech mixture O of J speakers is first fed into the mixture encoder EncoderMix to obtain the intermediate representation Hctx. Then these intermediate representations Hctx are separated into J representations \({{\bf{H}}}_{\,\text{ctx}\,}^{j}\,(j=1,\cdots \,,J)\) by J independent speaker-differentiating (SD) encoders EncoderSD. Each representation corresponds to one speaker and is encoded by the shared-weight context encoder EncoderCtx to predict the final contextual embedding \({\hat{{\bf{E}}}}^{j}\). The above entire encoding process can be formulated as below:
$${{\bf{H}}}_{{\rm{ctx}}}={\mathrm{Encoder}\,}_{{\rm{Mix}}}({\bf{O}})\,,$$
(13)
$${{\bf{H}}}_{\,\text{ctx}}^{j}={\mathrm{Encoder}\,}_{\text{SD}\,}^{j}({{\bf{H}}}_{{\rm{ctx}}}),\quad j=1,\cdots \,,J\,,$$
(14)
$${\hat{{\bf{E}}}}^{j}={\mathrm{Encoder}\,}_{{\rm{Ctx}}}({{\bf{H}}}_{\,\text{ctx}\,}^{j}),\,\,j=1,\cdots \,,J\,.$$
(15)
In order to train the contextual embedding predictor, the encoder of a pre-trained single-talker E2E ASR model is utilized as the teacher model. More specifically, the teacher model is first constructed with the joint CTC/attention-based encoder-decoder architecture36,37,38 for single-speaker end-to-end ASR, as illustrated in Fig. 7. Once the single-speaker E2E ASR system is built, the encoder part of ASR is used as the contextual embedding teacher. It takes the parallel clean speech of each speaker in the speech mixture as input and outputs the corresponding representations Ej(j = 1, ⋯, J), which is named as oracle contextual embeddings, as the supervisions for the student model training.
This figure illustrates the architecture of the single-speaker oracle contextual embedding teacher with a joint CTC/attention-based encoder-decoder ASR model.
Then the contextual embedding predictor, as the student, learns to mimic the outputs of the teacher encoder by minimizing the loss between the predicted contextual embeddings and oracle contextual embeddings:
$${{\mathcal{L}}}_{{\rm{ctx}}}=\sum _{j}{\text{Loss}}_{{\rm{norm}}}\left({\hat{{\bf{E}}}}^{j},{{\bf{E}}}^{{\hat{\pi }}_{{\rm{ctx}}}(j)}\right)\,,$$
(16)
where Lossnorm can be either the ℓ1- or ℓ2-norm loss. \((j,{\hat{\pi }}_{{\rm{ctx}}}(j))\) denotes the best permutation of indexes of the predicted embedding and label pair, which minimizes the overall loss \({{\mathcal{L}}}_{{\rm{ctx}}}\) via permutation invariant training.
We hope that this proposed teacher-student learning framework can distill accurate contextual knowledge from the clean single-speaker system to the multi-speaker system, and the designed contextual embedding predictor can well capture the contextual information from the speech mixture directly for all speakers.
Speech separation with contextual understanding
In this section, we incorporate the proposed contextual embedding into the speech separation architecture. Both the oracle contextual embedding Ej and predicted contextual embedding \({\hat{{\bf{E}}}}_{j}\) can be used to train speech separation models, but only the predicted embedding \({\hat{{\bf{E}}}}_{j}\) can be obtained in the testing stage because only the mixed speech is available in real applications. It is worth noting that the oracle contextual embedding from clean speech is only available with the simulation data.
Figure 8 illustrates how the contextual language embedding is incorporated into the speech separation models. We first concatenate the contextual embeddings of different speakers (Ej or \({\hat{{\bf{E}}}}_{j}\)) frame-wisely into a single contextual embedding Econtext. We denote as U the processed feature of the input speech mixture, which is the magnitude spectrum ∣Y∣ in T-F masking models and the convolutional encoded audio representation W in the TasNet, respectively. The contextual embedding Econtext is then resampled to ensure it has the same length as U along the frame dimension. Then, the contextual embedding can be incorporated with U by the frame-wise concatenation. The formulas of the contextual incorporation are described as follows:
$${{\bf{E}}}_{{\rm{context}}}=\mathrm{Concat}\,({{\bf{E}}}_{1},\cdots {{\bf{E}}}_{J})$$
(17)
$${{\bf{E}}}_{\,\text{context}\,}^{{\prime} }=\mathrm{Resample}\,({{\bf{E}}}_{{\rm{context}}})$$
(18)
$${{\bf{U}}}^{{\prime} }=\mathrm{Concat}\,({\bf{U}},{{\bf{E}}}_{\,\text{context}\,}^{{\prime} })$$
(19)
U is the processed feature of the speech mixture, which is the magnitude spectrum ∣Y∣ in T-F masking models and the convolutional encoded audio representation W in TasNet, respectively.
The Separator takes the context-aware feature \({{\bf{U}}}^{{\prime} }\) as input to predict the masks Mj for different speakers.
It is noted that we concatenate the contextual embeddings of different speakers in an arbitrary order in Equation (17). The label assignment between the reference speech and the predicted speech can be determined by corresponding to the concatenating order of contextual embeddings. Thus, theoretically, the permutation invariant training (PIT) can be removed in the newly proposed speech separation with contextual understanding. In other words, the PIT is only needed in the above described contextual embedding predictor training. We actually have compared the proposed context-aware speech separation system trained with and without PIT, and they can also achieve almost the same performance.
Multi-talker ASR with contextual understanding
The end-to-end multi-talker ASR model introduced in the section “Traditional model for multi-talker speech processing“ is used as the basic architecture for multi-talker ASR, and it is further revised to integrate with the proposed contextual embedding. Different from the speech separation model introduced in the last subsection, which processes the input mixture with a shared encoder before separation, our adopted multi-speaker ASR system performs feature-level separation inside the encoder module, while the subsequent modules are shared among the separated streams. Therefore, a different integration strategy for contextual embedding is used in this section, as illustrated in Fig. 9. We first concatenate the contextual embeddings corresponding to J speakers into a single embedding Econtext, thus avoiding introducing a new permutation problem between the ASR encoder outputs and the contextual embeddings. Then, the combined contextual embedding Econtext can be integrated either (1) after the mixture encoder or (2) after the recognition encoder. The integration process after the recognition encoder can be formulated as follows:
$${{\bf{G}}}^{j{\prime} }=\mathrm{Concat}\,({{\bf{G}}}^{j},{{\bf{E}}}_{{\rm{context}}}),\,\,\,j=1,\cdots \,,J\,,$$
(20)
where Gj is the high-level representation output by EncoderRec in the ASR model, and Concat( ⋅ ) denotes the frame-wise concatenation.
The contextual embeddings can be integrated either (1) after the mixture encoder or (2) after the recognition encoder.
The whole architecture in Fig. 9 is then optimized with the joint CTC and attention loss as in Equation (12), which is the same as the basic end-to-end multi-talker ASR optimization. The same as that for speech separation, either the oracle embeddings or the predicted embeddings can be utilized as Econtext in training; however, only the predicted contextual embeddings are applied during evaluation.
Embedding sampling for model optimization
To better utilize the proposed contextual embeddings for speech separation or multi-talker ASR, two advanced training strategies are designed to optimize the system further. The first training strategy is inspired by the commonly used scheduled sampling technique11,39, which can mitigate the discrepancy between training and testing.
As described in the above subsection, the new framework uses either predicted contextual embeddings or oracle contextual embeddings during training for multi-talker speech separation and recognition, but only the predicted contextual embeddings can be utilized in real applications. Here, we design a new strategy that makes it possible to utilize both embedding sources during training. Similar to the scheduled sampling described in section “Traditional model for multi-talker speech processing”, we randomly sample from a Bernoulli distribution to choose the source of the contextual embeddings. More specifically, we choose to use the oracle contextual embeddings with a probability of p and the predicted contextual embeddings with a probability of (1 − p), with p = 0.7, in our experiments.
$${b}_{{\rm{ctx}}} \sim \mathrm{Bernoulli}\,(p)\,,$$
(21)
$${{\bf{E}}}_{{\rm{context}}}=\left\{\begin{array}{ll}{\mathrm{Concat}}\,({{\bf{E}}}^{1},\cdots \,,{{\bf{E}}}^{J}),&\,{\text{if}}\,{b}_{{\rm{ctx}}}=0\,,\\ {\mathrm{Concat}}\,({\hat{{\bf{E}}}}^{1},\cdots \,,{\hat{{\bf{E}}}}^{J}),&\,{\text{if}}\,{b}_{{\rm{ctx}}}=1\,.\end{array}\right.$$
(22)
We refer to this strategy as embedding sampling, as it shares some similarities with the scheduled sampling technique39. This strategy can also mitigate the embedding mismatch between training and testing, and enhance the generalization capability of our proposed new architecture for speech separation and multi-talker ASR.
Algorithm 1
Two-stage training strategy for end-to-end multi-talker ASR with contextual embeddings
Two-stage training for model optimization
For the multi-speaker ASR, we find the model is relatively more difficult to be well optimized, so a two-stage training strategy is further proposed to improve the multi-speaker ASR with contextual embeddings, which can reduce the risk of abusing or overemphasizing the context information during optimization. Since contextual embeddings already provide enough information for speech recognition, integrating these features in an early training stage may lead to an underfitted multi-speaker encoder, which can be sub-optimal for training. Therefore, we try to mitigate this problem by dividing the training process into two stages. As illustrated in Algorithm 1, at the first stage, the multi-speaker ASR model is trained for K epochs without contextual embeddings, so that the multi-speaker ASR model can be moderately trained before integrating the contextual embeddings. Then, at the second stage, we integrate the contextual embeddings as described in Fig. 9, and continue training the model until convergence.
Model configuration
In our experiments, we adopted the recurrent structure for the contextual embedding predictor described in section “contextual embedding learning”.
The mixture encoder is composed of two VGG-motivated CNN blocks, where the input mixed speech sequence is downsampled by four times. The SD encoders and context encoder use one-layer and two-layer bidirectional long-short-term memory layers with projection (BLSTMP) respectively, with 1024 hidden units for each direction in each layer. In Eq. (16), we adopted the ℓ2 loss to train the contextual embedding predictor for both speech separation and multi-talker ASR.
For the T-F masking approach in Fig. 8, both ResNet and BLSTM-based Separator are used. The ResNet-based Separator contains ten convolution layers, and each layer has 1024 channels. The BLSTM-based Separator is a three-layer BLSTM, and each layer of which contains 896 hidden units. For the time-domain approach, we implemented the fully convolutional TasNet (conv-TasNet) following the instructions in ref. 13 as our baseline model. In our implemented conv-TasNet, the number of channels in each convolution block is 512, and the residual connection bottleneck channel number is 256. The repeats number of convolution stack is 3 and each repeat contains five convolution blocks (the detailed configuration is {N256, L20, B256, H512, P3, X5, R3, and gLN}, and readers can refer to Table 1 in ref. 13 for further details on the hyperparameters).
The encoder of the multi-talker ASR model in Fig. 9 has the same architecture as the contextual embedding predictor with a recurrent structure, as described above. The encoder of the single-talker ASR model in Fig. 6 has a similar structure, with two VGG-motivated CNN blocks followed by three BLSTMP layers. The decoders of both multi-talker and single-talker ASR models are comprised of a single unidirectional long-short-term memory (LSTM) layer with 300 cells. For all encoder layers, the hidden units are 1024 dimensional. In the training phase, the interpolation factor λ in Eq. (12) is set to 0.2. In the decoding phase, we combine both the joint CTC/attention score and the score of the word-level RNN language model (RNNLM), which has a 1-layer LSTM with 1000 cells and was trained on the transcriptions of WSJ0 SI-84, in a shallow fusion manner. The beam width for the beam search process is 30.
All models above were built based on the ESPnet43.