Back to Blog
Deep Learning
State of the Art

Deep Learning Revolution: U-Nets and Transformers for Music

How encoder-decoder architectures and attention mechanisms have transformed music source separation and transcription, achieving near-human performance.

Dr. Alex Rivera, ML Research
January 27, 2025
22 min read
Neural network visualization

The Paradigm Shift to Deep Learning

The limitations of classical methods like ICA and NMF—their rigid assumptions and inability to model complex, non-linear relationships—paved the way for deep learning. Neural networks learn these relationships directly from data, without requiring explicit mathematical priors.

Why Deep Learning Dominates

Non-linear modeling: Captures complex interactions between frequencies

End-to-end learning: Optimizes directly for the task objective

Data-driven priors: Learns patterns from massive datasets

Hierarchical features: Builds understanding from low to high level

The U-Net Architecture: Computer Vision Meets Audio

Originally designed for biomedical image segmentation, U-Net's encoder-decoder structure with skip connections proved perfect for audio source separation when applied to spectrograms.

U-Net Architecture Components

Encoder Path

Progressively downsamples, capturing context at multiple scales

Decoder Path

Upsamples to original resolution, localizing features

Skip Connections

Preserves fine-grained details from encoder

Bottleneck

Compressed representation forcing disentanglement

U-Net for Music Source Separation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class UNet(nn.Module):
6    def __init__(self, in_channels=1, out_channels=1, features=[64, 128, 256, 512]):
7        super(UNet, self).__init__()
8        self.encoder = nn.ModuleList()
9        self.decoder = nn.ModuleList()
10        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
11        
12        # Encoder path
13        for feature in features:
14            self.encoder.append(self._block(in_channels, feature))
15            in_channels = feature
16        
17        # Bottleneck
18        self.bottleneck = self._block(features[-1], features[-1] * 2)
19        
20        # Decoder path
21        for feature in reversed(features):
22            self.decoder.append(
23                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
24            )
25            self.decoder.append(self._block(feature * 2, feature))
26        
27        # Final convolution
28        self.final = nn.Conv2d(features[0], out_channels, kernel_size=1)
29        
30    def _block(self, in_channels, out_channels):
31        return nn.Sequential(
32            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
33            nn.BatchNorm2d(out_channels),
34            nn.ReLU(inplace=True),
35            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
36            nn.BatchNorm2d(out_channels),
37            nn.ReLU(inplace=True)
38        )
39    
40    def forward(self, x):
41        skip_connections = []
42        
43        # Encoder
44        for encoder in self.encoder:
45            x = encoder(x)
46            skip_connections.append(x)
47            x = self.pool(x)
48        
49        # Bottleneck
50        x = self.bottleneck(x)
51        
52        # Decoder with skip connections
53        skip_connections = skip_connections[::-1]
54        
55        for idx in range(0, len(self.decoder), 2):
56            x = self.decoder[idx](x)  # Transpose convolution
57            skip = skip_connections[idx // 2]
58            
59            # Handle size mismatch
60            if x.shape != skip.shape:
61                x = F.interpolate(x, size=skip.shape[2:])
62            
63            # Concatenate skip connection
64            x = torch.cat([skip, x], dim=1)
65            x = self.decoder[idx + 1](x)  # Conv block
66        
67        return torch.sigmoid(self.final(x))  # Output mask [0, 1]
68
69# Usage for source separation
70class MusicSeparationUNet(nn.Module):
71    def __init__(self, n_fft=2048):
72        super().__init__()
73        self.n_freq = n_fft // 2 + 1
74        self.unet = UNet(in_channels=1, out_channels=4)  # 4 sources
75        
76    def forward(self, spectrogram):
77        # Input: [batch, time, freq]
78        x = spectrogram.unsqueeze(1)  # Add channel dimension
79        
80        # Generate masks for each source
81        masks = self.unet(x)  # [batch, 4, time, freq]
82        
83        # Apply masks to input spectrogram
84        separated = masks * spectrogram.unsqueeze(1)
85        
86        return separated  # [batch, 4, time, freq]

Advanced Architectures: Wave-U-Net and Demucs

While standard U-Net operates on spectrograms, newer architectures work directly on raw waveforms, avoiding phase reconstruction issues.

Demucs v4: State-of-the-Art Architecture

Key innovations:

  • Hybrid approach: Processes both waveform and spectrogram
  • Bi-LSTM layers: Captures long-term temporal dependencies
  • Transformer blocks: Models global context with attention
  • Multi-resolution processing: Handles different time scales
Simplified Demucs Architecture
1class Demucs(nn.Module):
2    def __init__(self, sources=4, channels=64, depth=6):
3        super().__init__()
4        self.sources = sources
5        self.encoder = nn.ModuleList()
6        self.decoder = nn.ModuleList()
7        
8        # Encoder layers with increasing dilation
9        for i in range(depth):
10            stride = 4 if i < 2 else 2
11            self.encoder.append(
12                EncoderBlock(
13                    in_channels=1 if i == 0 else channels * (2 ** (i-1)),
14                    out_channels=channels * (2 ** i),
15                    kernel_size=8,
16                    stride=stride
17                )
18            )
19        
20        # Bidirectional LSTM for temporal modeling
21        hidden_size = channels * (2 ** (depth - 1))
22        self.lstm = nn.LSTM(
23            hidden_size, 
24            hidden_size // 2,
25            num_layers=2,
26            bidirectional=True,
27            batch_first=True
28        )
29        
30        # Transformer for global context
31        self.transformer = nn.TransformerEncoder(
32            nn.TransformerEncoderLayer(
33                d_model=hidden_size,
34                nhead=8,
35                dim_feedforward=hidden_size * 4,
36                batch_first=True
37            ),
38            num_layers=2
39        )
40        
41        # Decoder with skip connections
42        for i in range(depth):
43            self.decoder.append(
44                DecoderBlock(
45                    in_channels=channels * (2 ** (depth - i)),
46                    out_channels=channels * (2 ** (depth - i - 1)) if i < depth - 1 else sources,
47                    kernel_size=8,
48                    stride=4 if i >= depth - 2 else 2
49                )
50            )
51    
52    def forward(self, x):
53        # x shape: [batch, samples]
54        x = x.unsqueeze(1)  # [batch, 1, samples]
55        
56        # Encoding
57        skips = []
58        for encoder in self.encoder:
59            x = encoder(x)
60            skips.append(x)
61        
62        # Reshape for LSTM/Transformer
63        b, c, t = x.shape
64        x = x.transpose(1, 2)  # [batch, time, channels]
65        
66        # Temporal modeling
67        x, _ = self.lstm(x)
68        x = self.transformer(x)
69        
70        # Reshape back
71        x = x.transpose(1, 2)  # [batch, channels, time]
72        
73        # Decoding with skip connections
74        for i, decoder in enumerate(self.decoder):
75            skip = skips[-(i + 1)]
76            if i > 0:
77                x = x + skip  # Residual connection
78            x = decoder(x)
79        
80        # Split into sources
81        return x.view(b, self.sources, -1)  # [batch, sources, samples]

Transformers: Attention is All You Need for Music

The self-attention mechanism of Transformers excels at modeling long-range dependencies in music, crucial for understanding musical structure and harmony.

Self-Attention Mechanism

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Q: Queries, K: Keys, V: Values, d_k: Key dimension

Music Transformer for Transcription
1class MusicTransformer(nn.Module):
2    def __init__(self, d_model=512, nhead=8, num_layers=6, vocab_size=128):
3        super().__init__()
4        
5        # Spectrogram encoder
6        self.spec_encoder = nn.Sequential(
7            nn.Conv2d(1, d_model // 4, kernel_size=3, padding=1),
8            nn.ReLU(),
9            nn.Conv2d(d_model // 4, d_model // 2, kernel_size=3, padding=1),
10            nn.ReLU(),
11            nn.Conv2d(d_model // 2, d_model, kernel_size=3, padding=1)
12        )
13        
14        # Positional encoding
15        self.pos_encoding = PositionalEncoding(d_model)
16        
17        # Transformer
18        encoder_layer = nn.TransformerEncoderLayer(
19            d_model=d_model,
20            nhead=nhead,
21            dim_feedforward=d_model * 4,
22            dropout=0.1,
23            batch_first=True
24        )
25        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
26        
27        # Output heads for multi-task learning
28        self.pitch_head = nn.Linear(d_model, 88)  # 88 piano keys
29        self.onset_head = nn.Linear(d_model, 88)
30        self.velocity_head = nn.Linear(d_model, 128)  # MIDI velocity
31        
32    def forward(self, spectrogram):
33        # Encode spectrogram
34        x = self.spec_encoder(spectrogram.unsqueeze(1))
35        
36        # Reshape for transformer [batch, time, features]
37        b, c, h, w = x.shape
38        x = x.view(b, c, h * w).transpose(1, 2)
39        
40        # Add positional encoding
41        x = self.pos_encoding(x)
42        
43        # Apply transformer
44        x = self.transformer(x)
45        
46        # Reshape to time frames
47        x = x.view(b, -1, c)
48        
49        # Multi-task outputs
50        pitches = torch.sigmoid(self.pitch_head(x))
51        onsets = torch.sigmoid(self.onset_head(x))
52        velocities = self.velocity_head(x)
53        
54        return {
55            'pitches': pitches,      # [batch, time, 88]
56            'onsets': onsets,        # [batch, time, 88]
57            'velocities': velocities  # [batch, time, 128]
58        }
59
60class PositionalEncoding(nn.Module):
61    def __init__(self, d_model, max_len=5000):
62        super().__init__()
63        pe = torch.zeros(max_len, d_model)
64        position = torch.arange(0, max_len).unsqueeze(1).float()
65        
66        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
67                           -(torch.log(torch.tensor(10000.0)) / d_model))
68        
69        pe[:, 0::2] = torch.sin(position * div_term)
70        pe[:, 1::2] = torch.cos(position * div_term)
71        
72        self.register_buffer('pe', pe.unsqueeze(0))
73        
74    def forward(self, x):
75        return x + self.pe[:, :x.size(1)]

Training Strategies and Loss Functions

Advanced Training Techniques

Time-Domain Loss

Ltime=yy^1+λSTFT(y)STFT(y^)1\mathcal{L}_{time} = \|y - \hat{y}\|_1 + \lambda \|STFT(y) - STFT(\hat{y})\|_1

Combines waveform and spectral losses

Multi-Resolution STFT Loss

Lmrstft=sSLstft(s)\mathcal{L}_{mr-stft} = \sum_{s \in S} \mathcal{L}_{stft}^{(s)}

Evaluates at multiple time-frequency resolutions

Perceptual Loss

Uses pre-trained networks to match perceptual features

Multi-Resolution STFT Loss
1class MultiResolutionSTFTLoss(nn.Module):
2    def __init__(self, fft_sizes=[512, 1024, 2048], 
3                 hop_sizes=[50, 120, 240], 
4                 win_lengths=[240, 600, 1200]):
5        super().__init__()
6        self.fft_sizes = fft_sizes
7        self.hop_sizes = hop_sizes
8        self.win_lengths = win_lengths
9        
10    def forward(self, y_true, y_pred):
11        loss = 0
12        
13        for fft_size, hop_size, win_length in zip(
14            self.fft_sizes, self.hop_sizes, self.win_lengths
15        ):
16            # Compute STFT
17            Y_true = torch.stft(
18                y_true, 
19                n_fft=fft_size,
20                hop_length=hop_size,
21                win_length=win_length,
22                return_complex=True
23            )
24            
25            Y_pred = torch.stft(
26                y_pred,
27                n_fft=fft_size,
28                hop_length=hop_size,
29                win_length=win_length,
30                return_complex=True
31            )
32            
33            # Magnitude loss
34            mag_true = torch.abs(Y_true)
35            mag_pred = torch.abs(Y_pred)
36            loss += F.l1_loss(mag_pred, mag_true)
37            
38            # Log magnitude loss
39            log_mag_true = torch.log(mag_true + 1e-7)
40            log_mag_pred = torch.log(mag_pred + 1e-7)
41            loss += F.l1_loss(log_mag_pred, log_mag_true)
42            
43            # Phase loss (optional)
44            phase_true = torch.angle(Y_true)
45            phase_pred = torch.angle(Y_pred)
46            phase_loss = F.l1_loss(
47                torch.sin(phase_pred), torch.sin(phase_true)
48            ) + F.l1_loss(
49                torch.cos(phase_pred), torch.cos(phase_true)
50            )
51            loss += 0.1 * phase_loss
52        
53        return loss / len(self.fft_sizes)

Comparison with Classical Methods

AspectICA/NMFDeep Learning
AssumptionsStrong (independence, linearity)Minimal (learned from data)
Training DataUnsupervisedRequires large labeled datasets
PerformanceGood for simple casesState-of-the-art
InterpretabilityHighLow (black box)
Computational CostLowHigh (GPU required)
GeneralizationLimitedExcellent

Real-World Performance Metrics

MUSDB18 Benchmark Results

Signal-to-Distortion Ratio (SDR) in dB:

Vocals

7.86

Drums

8.23

Bass

7.01

Other

5.42

Demucs v4 (Hybrid Transformer) results

Conclusion

Deep learning has revolutionized music information retrieval, with U-Net and Transformer architectures achieving unprecedented performance. The ability to learn complex, non-linear relationships directly from data has overcome the fundamental limitations of classical statistical methods.

As we move forward, the combination of these powerful architectures with WebAssembly enables us to bring state-of-the-art music AI directly to the browser, democratizing access to professional-grade audio tools.

Experience Deep Learning in Action

Try our browser-based implementations powered by these architectures: