Deep Learning Revolution: U-Nets and Transformers for Music
How encoder-decoder architectures and attention mechanisms have transformed music source separation and transcription, achieving near-human performance.

The Paradigm Shift to Deep Learning
The limitations of classical methods like ICA and NMF—their rigid assumptions and inability to model complex, non-linear relationships—paved the way for deep learning. Neural networks learn these relationships directly from data, without requiring explicit mathematical priors.
• Non-linear modeling: Captures complex interactions between frequencies
• End-to-end learning: Optimizes directly for the task objective
• Data-driven priors: Learns patterns from massive datasets
• Hierarchical features: Builds understanding from low to high level
The U-Net Architecture: Computer Vision Meets Audio
Originally designed for biomedical image segmentation, U-Net's encoder-decoder structure with skip connections proved perfect for audio source separation when applied to spectrograms.
Encoder Path
Progressively downsamples, capturing context at multiple scales
Decoder Path
Upsamples to original resolution, localizing features
Skip Connections
Preserves fine-grained details from encoder
Bottleneck
Compressed representation forcing disentanglement
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class UNet(nn.Module):
6 def __init__(self, in_channels=1, out_channels=1, features=[64, 128, 256, 512]):
7 super(UNet, self).__init__()
8 self.encoder = nn.ModuleList()
9 self.decoder = nn.ModuleList()
10 self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
11
12 # Encoder path
13 for feature in features:
14 self.encoder.append(self._block(in_channels, feature))
15 in_channels = feature
16
17 # Bottleneck
18 self.bottleneck = self._block(features[-1], features[-1] * 2)
19
20 # Decoder path
21 for feature in reversed(features):
22 self.decoder.append(
23 nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
24 )
25 self.decoder.append(self._block(feature * 2, feature))
26
27 # Final convolution
28 self.final = nn.Conv2d(features[0], out_channels, kernel_size=1)
29
30 def _block(self, in_channels, out_channels):
31 return nn.Sequential(
32 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
33 nn.BatchNorm2d(out_channels),
34 nn.ReLU(inplace=True),
35 nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
36 nn.BatchNorm2d(out_channels),
37 nn.ReLU(inplace=True)
38 )
39
40 def forward(self, x):
41 skip_connections = []
42
43 # Encoder
44 for encoder in self.encoder:
45 x = encoder(x)
46 skip_connections.append(x)
47 x = self.pool(x)
48
49 # Bottleneck
50 x = self.bottleneck(x)
51
52 # Decoder with skip connections
53 skip_connections = skip_connections[::-1]
54
55 for idx in range(0, len(self.decoder), 2):
56 x = self.decoder[idx](x) # Transpose convolution
57 skip = skip_connections[idx // 2]
58
59 # Handle size mismatch
60 if x.shape != skip.shape:
61 x = F.interpolate(x, size=skip.shape[2:])
62
63 # Concatenate skip connection
64 x = torch.cat([skip, x], dim=1)
65 x = self.decoder[idx + 1](x) # Conv block
66
67 return torch.sigmoid(self.final(x)) # Output mask [0, 1]
68
69# Usage for source separation
70class MusicSeparationUNet(nn.Module):
71 def __init__(self, n_fft=2048):
72 super().__init__()
73 self.n_freq = n_fft // 2 + 1
74 self.unet = UNet(in_channels=1, out_channels=4) # 4 sources
75
76 def forward(self, spectrogram):
77 # Input: [batch, time, freq]
78 x = spectrogram.unsqueeze(1) # Add channel dimension
79
80 # Generate masks for each source
81 masks = self.unet(x) # [batch, 4, time, freq]
82
83 # Apply masks to input spectrogram
84 separated = masks * spectrogram.unsqueeze(1)
85
86 return separated # [batch, 4, time, freq]
Advanced Architectures: Wave-U-Net and Demucs
While standard U-Net operates on spectrograms, newer architectures work directly on raw waveforms, avoiding phase reconstruction issues.
Key innovations:
- • Hybrid approach: Processes both waveform and spectrogram
- • Bi-LSTM layers: Captures long-term temporal dependencies
- • Transformer blocks: Models global context with attention
- • Multi-resolution processing: Handles different time scales
1class Demucs(nn.Module):
2 def __init__(self, sources=4, channels=64, depth=6):
3 super().__init__()
4 self.sources = sources
5 self.encoder = nn.ModuleList()
6 self.decoder = nn.ModuleList()
7
8 # Encoder layers with increasing dilation
9 for i in range(depth):
10 stride = 4 if i < 2 else 2
11 self.encoder.append(
12 EncoderBlock(
13 in_channels=1 if i == 0 else channels * (2 ** (i-1)),
14 out_channels=channels * (2 ** i),
15 kernel_size=8,
16 stride=stride
17 )
18 )
19
20 # Bidirectional LSTM for temporal modeling
21 hidden_size = channels * (2 ** (depth - 1))
22 self.lstm = nn.LSTM(
23 hidden_size,
24 hidden_size // 2,
25 num_layers=2,
26 bidirectional=True,
27 batch_first=True
28 )
29
30 # Transformer for global context
31 self.transformer = nn.TransformerEncoder(
32 nn.TransformerEncoderLayer(
33 d_model=hidden_size,
34 nhead=8,
35 dim_feedforward=hidden_size * 4,
36 batch_first=True
37 ),
38 num_layers=2
39 )
40
41 # Decoder with skip connections
42 for i in range(depth):
43 self.decoder.append(
44 DecoderBlock(
45 in_channels=channels * (2 ** (depth - i)),
46 out_channels=channels * (2 ** (depth - i - 1)) if i < depth - 1 else sources,
47 kernel_size=8,
48 stride=4 if i >= depth - 2 else 2
49 )
50 )
51
52 def forward(self, x):
53 # x shape: [batch, samples]
54 x = x.unsqueeze(1) # [batch, 1, samples]
55
56 # Encoding
57 skips = []
58 for encoder in self.encoder:
59 x = encoder(x)
60 skips.append(x)
61
62 # Reshape for LSTM/Transformer
63 b, c, t = x.shape
64 x = x.transpose(1, 2) # [batch, time, channels]
65
66 # Temporal modeling
67 x, _ = self.lstm(x)
68 x = self.transformer(x)
69
70 # Reshape back
71 x = x.transpose(1, 2) # [batch, channels, time]
72
73 # Decoding with skip connections
74 for i, decoder in enumerate(self.decoder):
75 skip = skips[-(i + 1)]
76 if i > 0:
77 x = x + skip # Residual connection
78 x = decoder(x)
79
80 # Split into sources
81 return x.view(b, self.sources, -1) # [batch, sources, samples]
Transformers: Attention is All You Need for Music
The self-attention mechanism of Transformers excels at modeling long-range dependencies in music, crucial for understanding musical structure and harmony.
Self-Attention Mechanism
Q: Queries, K: Keys, V: Values, d_k: Key dimension
1class MusicTransformer(nn.Module):
2 def __init__(self, d_model=512, nhead=8, num_layers=6, vocab_size=128):
3 super().__init__()
4
5 # Spectrogram encoder
6 self.spec_encoder = nn.Sequential(
7 nn.Conv2d(1, d_model // 4, kernel_size=3, padding=1),
8 nn.ReLU(),
9 nn.Conv2d(d_model // 4, d_model // 2, kernel_size=3, padding=1),
10 nn.ReLU(),
11 nn.Conv2d(d_model // 2, d_model, kernel_size=3, padding=1)
12 )
13
14 # Positional encoding
15 self.pos_encoding = PositionalEncoding(d_model)
16
17 # Transformer
18 encoder_layer = nn.TransformerEncoderLayer(
19 d_model=d_model,
20 nhead=nhead,
21 dim_feedforward=d_model * 4,
22 dropout=0.1,
23 batch_first=True
24 )
25 self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
26
27 # Output heads for multi-task learning
28 self.pitch_head = nn.Linear(d_model, 88) # 88 piano keys
29 self.onset_head = nn.Linear(d_model, 88)
30 self.velocity_head = nn.Linear(d_model, 128) # MIDI velocity
31
32 def forward(self, spectrogram):
33 # Encode spectrogram
34 x = self.spec_encoder(spectrogram.unsqueeze(1))
35
36 # Reshape for transformer [batch, time, features]
37 b, c, h, w = x.shape
38 x = x.view(b, c, h * w).transpose(1, 2)
39
40 # Add positional encoding
41 x = self.pos_encoding(x)
42
43 # Apply transformer
44 x = self.transformer(x)
45
46 # Reshape to time frames
47 x = x.view(b, -1, c)
48
49 # Multi-task outputs
50 pitches = torch.sigmoid(self.pitch_head(x))
51 onsets = torch.sigmoid(self.onset_head(x))
52 velocities = self.velocity_head(x)
53
54 return {
55 'pitches': pitches, # [batch, time, 88]
56 'onsets': onsets, # [batch, time, 88]
57 'velocities': velocities # [batch, time, 128]
58 }
59
60class PositionalEncoding(nn.Module):
61 def __init__(self, d_model, max_len=5000):
62 super().__init__()
63 pe = torch.zeros(max_len, d_model)
64 position = torch.arange(0, max_len).unsqueeze(1).float()
65
66 div_term = torch.exp(torch.arange(0, d_model, 2).float() *
67 -(torch.log(torch.tensor(10000.0)) / d_model))
68
69 pe[:, 0::2] = torch.sin(position * div_term)
70 pe[:, 1::2] = torch.cos(position * div_term)
71
72 self.register_buffer('pe', pe.unsqueeze(0))
73
74 def forward(self, x):
75 return x + self.pe[:, :x.size(1)]
Training Strategies and Loss Functions
Time-Domain Loss
Combines waveform and spectral losses
Multi-Resolution STFT Loss
Evaluates at multiple time-frequency resolutions
Perceptual Loss
Uses pre-trained networks to match perceptual features
1class MultiResolutionSTFTLoss(nn.Module):
2 def __init__(self, fft_sizes=[512, 1024, 2048],
3 hop_sizes=[50, 120, 240],
4 win_lengths=[240, 600, 1200]):
5 super().__init__()
6 self.fft_sizes = fft_sizes
7 self.hop_sizes = hop_sizes
8 self.win_lengths = win_lengths
9
10 def forward(self, y_true, y_pred):
11 loss = 0
12
13 for fft_size, hop_size, win_length in zip(
14 self.fft_sizes, self.hop_sizes, self.win_lengths
15 ):
16 # Compute STFT
17 Y_true = torch.stft(
18 y_true,
19 n_fft=fft_size,
20 hop_length=hop_size,
21 win_length=win_length,
22 return_complex=True
23 )
24
25 Y_pred = torch.stft(
26 y_pred,
27 n_fft=fft_size,
28 hop_length=hop_size,
29 win_length=win_length,
30 return_complex=True
31 )
32
33 # Magnitude loss
34 mag_true = torch.abs(Y_true)
35 mag_pred = torch.abs(Y_pred)
36 loss += F.l1_loss(mag_pred, mag_true)
37
38 # Log magnitude loss
39 log_mag_true = torch.log(mag_true + 1e-7)
40 log_mag_pred = torch.log(mag_pred + 1e-7)
41 loss += F.l1_loss(log_mag_pred, log_mag_true)
42
43 # Phase loss (optional)
44 phase_true = torch.angle(Y_true)
45 phase_pred = torch.angle(Y_pred)
46 phase_loss = F.l1_loss(
47 torch.sin(phase_pred), torch.sin(phase_true)
48 ) + F.l1_loss(
49 torch.cos(phase_pred), torch.cos(phase_true)
50 )
51 loss += 0.1 * phase_loss
52
53 return loss / len(self.fft_sizes)
Comparison with Classical Methods
Aspect | ICA/NMF | Deep Learning |
---|---|---|
Assumptions | Strong (independence, linearity) | Minimal (learned from data) |
Training Data | Unsupervised | Requires large labeled datasets |
Performance | Good for simple cases | State-of-the-art |
Interpretability | High | Low (black box) |
Computational Cost | Low | High (GPU required) |
Generalization | Limited | Excellent |
Real-World Performance Metrics
Signal-to-Distortion Ratio (SDR) in dB:
Vocals
7.86
Drums
8.23
Bass
7.01
Other
5.42
Demucs v4 (Hybrid Transformer) results
Conclusion
Deep learning has revolutionized music information retrieval, with U-Net and Transformer architectures achieving unprecedented performance. The ability to learn complex, non-linear relationships directly from data has overcome the fundamental limitations of classical statistical methods.
As we move forward, the combination of these powerful architectures with WebAssembly enables us to bring state-of-the-art music AI directly to the browser, democratizing access to professional-grade audio tools.