Building Your Own Music Separation Pipeline: A Developer's Guide
Transform theoretical knowledge into production-ready code. Build, evaluate, and deploy a complete music source separation system using modern tools and best practices.
From Theory to Practice
You've learned the mathematics, studied the architectures, and understood the trade-offs. Now it's time to build something real. This comprehensive guide will take you through creating a production-ready music source separation pipeline from scratch.
By the end of this tutorial, you'll have a complete system that can separate vocals from any song, with proper evaluation metrics, error handling, and deployment considerations. We'll use Python, PyTorch, and modern MLOps practices to create something you can actually ship.
🎯 What We'll Build
Core Components
- • Data preprocessing pipeline
- • Neural network training framework
- • Evaluation and metrics system
- • Real-time inference API
Production Features
- • Docker containerization
- • Model versioning and rollback
- • Monitoring and logging
- • Performance optimization
Project Setup and Architecture
music-separation-pipeline/ ├── src/ │ ├── data/ │ │ ├── __init__.py │ │ ├── dataset.py # Dataset classes and loaders │ │ ├── preprocessing.py # Audio preprocessing utilities │ │ └── augmentation.py # Data augmentation techniques │ ├── models/ │ │ ├── __init__.py │ │ ├── unet.py # U-Net architecture │ │ ├── demucs.py # Demucs implementation │ │ └── base.py # Base model class │ ├── training/ │ │ ├── __init__.py │ │ ├── trainer.py # Training loop and utilities │ │ └── losses.py # Loss functions │ ├── inference/ │ │ ├── __init__.py │ │ ├── separator.py # Inference engine │ │ └── api.py # REST API server │ └── evaluation/ │ ├── __init__.py │ ├── metrics.py # SDR, SIR, SAR calculations │ └── benchmark.py # Evaluation scripts ├── configs/ │ ├── model_config.yaml # Model hyperparameters │ ├── training_config.yaml # Training settings │ └── api_config.yaml # API configuration ├── docker/ │ ├── Dockerfile │ └── docker-compose.yml ├── tests/ ├── notebooks/ # Jupyter notebooks for exploration ├── requirements.txt ├── setup.py └── README.md
Design Principles
- • Modular architecture for easy extension
- • Configuration-driven development
- • Clear separation of concerns
- • Comprehensive testing coverage
Key Dependencies
- • PyTorch: Deep learning framework
- • torchaudio: Audio processing
- • librosa: Audio analysis
- • FastAPI: High-performance web API
# Create virtual environment
python -m venv music-separation-env
source music-separation-env/bin/activate # On Windows: music-separation-env\Scripts\activate
# Install dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install librosa soundfile numpy scipy
pip install fastapi uvicorn pydantic
pip install wandb tensorboard # For experiment tracking
pip install pytest black flake8 # Development tools
# Install project in development mode
pip install -e .
# Verify installation
python -c "import torch; print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}')"
python -c "import librosa; print(f'librosa {librosa.__version__}')"
Data Pipeline: Processing Audio for ML
# src/data/dataset.py
import torch
from torch.utils.data import Dataset
import torchaudio
import librosa
import numpy as np
from pathlib import Path
from typing import Tuple, Optional
class MusicSeparationDataset(Dataset):
"""
Dataset for music source separation training.
Expects directory structure:
dataset/
├── train/
│ ├── mixture/
│ │ ├── song1.wav
│ │ └── song2.wav
│ ├── vocals/
│ │ ├── song1.wav
│ │ └── song2.wav
│ └── accompaniment/
│ ├── song1.wav
│ └── song2.wav
"""
def __init__(
self,
data_dir: str,
split: str = "train",
sample_rate: int = 44100,
segment_length: float = 6.0, # seconds
normalize: bool = True,
transforms=None
):
self.data_dir = Path(data_dir) / split
self.sample_rate = sample_rate
self.segment_samples = int(segment_length * sample_rate)
self.normalize = normalize
self.transforms = transforms
# Find all mixture files
self.mixture_files = list((self.data_dir / "mixture").glob("*.wav"))
self.mixture_files.sort()
# Verify corresponding files exist
self._verify_files()
def _verify_files(self):
"""Ensure all stems exist for each mixture"""
verified_files = []
for mix_file in self.mixture_files:
stem_name = mix_file.stem
vocals_path = self.data_dir / "vocals" / f"{stem_name}.wav"
accomp_path = self.data_dir / "accompaniment" / f"{stem_name}.wav"
if vocals_path.exists() and accomp_path.exists():
verified_files.append(mix_file)
else:
print(f"Warning: Missing stems for {stem_name}")
self.mixture_files = verified_files
print(f"Dataset loaded: {len(self.mixture_files)} songs")
def __len__(self) -> int:
return len(self.mixture_files)
def __getitem__(self, idx: int) -> dict:
mix_file = self.mixture_files[idx]
stem_name = mix_file.stem
# Load audio files
mixture, sr = torchaudio.load(mix_file)
vocals, _ = torchaudio.load(
self.data_dir / "vocals" / f"{stem_name}.wav"
)
accompaniment, _ = torchaudio.load(
self.data_dir / "accompaniment" / f"{stem_name}.wav"
)
# Resample if necessary
if sr != self.sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
mixture = resampler(mixture)
vocals = resampler(vocals)
accompaniment = resampler(accompaniment)
# Convert to mono if stereo
if mixture.shape[0] > 1:
mixture = torch.mean(mixture, dim=0, keepdim=True)
vocals = torch.mean(vocals, dim=0, keepdim=True)
accompaniment = torch.mean(accompaniment, dim=0, keepdim=True)
# Extract random segment
audio_length = mixture.shape[1]
if audio_length > self.segment_samples:
start = torch.randint(0, audio_length - self.segment_samples, (1,))
mixture = mixture[:, start:start + self.segment_samples]
vocals = vocals[:, start:start + self.segment_samples]
accompaniment = accompaniment[:, start:start + self.segment_samples]
else:
# Pad if too short
pad_amount = self.segment_samples - audio_length
mixture = torch.nn.functional.pad(mixture, (0, pad_amount))
vocals = torch.nn.functional.pad(vocals, (0, pad_amount))
accompaniment = torch.nn.functional.pad(accompaniment, (0, pad_amount))
# Normalize
if self.normalize:
max_val = torch.max(torch.abs(mixture))
if max_val > 0:
mixture = mixture / max_val
vocals = vocals / max_val
accompaniment = accompaniment / max_val
# Apply transforms
if self.transforms:
mixture = self.transforms(mixture)
vocals = self.transforms(vocals)
accompaniment = self.transforms(accompaniment)
return {
"mixture": mixture.squeeze(0), # [samples]
"vocals": vocals.squeeze(0),
"accompaniment": accompaniment.squeeze(0),
"filename": stem_name
}# src/data/preprocessing.py
import torch
import torchaudio
import librosa
import numpy as np
from typing import Tuple
class AudioPreprocessor:
"""Handles STFT and spectrogram operations"""
def __init__(
self,
n_fft: int = 2048,
hop_length: int = 512,
win_length: Optional[int] = None,
window: str = "hann",
sample_rate: int = 44100
):
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length or n_fft
self.window = window
self.sample_rate = sample_rate
# Create STFT transform
self.stft = torchaudio.transforms.Spectrogram(
n_fft=n_fft,
hop_length=hop_length,
win_length=self.win_length,
window_fn=torch.hann_window,
power=None # Return complex spectrogram
)
# Create inverse STFT
self.istft = torchaudio.transforms.InverseSpectrogram(
n_fft=n_fft,
hop_length=hop_length,
win_length=self.win_length,
window_fn=torch.hann_window
)
def waveform_to_spectrogram(self, waveform: torch.Tensor) -> torch.Tensor:
"""Convert waveform to complex spectrogram"""
# waveform: [channels, samples] or [samples]
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
# Apply STFT
spec = self.stft(waveform) # [channels, freq, time, 2] (real, imag)
return spec
def spectrogram_to_waveform(self, spectrogram: torch.Tensor) -> torch.Tensor:
"""Convert complex spectrogram back to waveform"""
# spectrogram: [channels, freq, time, 2]
waveform = self.istft(spectrogram)
return waveform
def magnitude_phase_split(self, complex_spec: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Split complex spectrogram into magnitude and phase"""
# complex_spec: [channels, freq, time, 2]
real = complex_spec[..., 0]
imag = complex_spec[..., 1]
magnitude = torch.sqrt(real**2 + imag**2)
phase = torch.atan2(imag, real)
return magnitude, phase
def magnitude_phase_combine(self, magnitude: torch.Tensor, phase: torch.Tensor) -> torch.Tensor:
"""Combine magnitude and phase into complex spectrogram"""
real = magnitude * torch.cos(phase)
imag = magnitude * torch.sin(phase)
complex_spec = torch.stack([real, imag], dim=-1)
return complex_spec
def create_masks(self, mixture_mag: torch.Tensor, source_mags: list) -> list:
"""Create soft masks for source separation"""
# Add small epsilon to avoid division by zero
eps = 1e-8
total_magnitude = sum(source_mags) + eps
masks = []
for source_mag in source_mags:
mask = source_mag / total_magnitude
masks.append(mask)
return masks
# Data augmentation transforms
class AudioAugmentation:
"""Audio data augmentation techniques"""
@staticmethod
def add_noise(waveform: torch.Tensor, noise_level: float = 0.01) -> torch.Tensor:
"""Add Gaussian noise to waveform"""
noise = torch.randn_like(waveform) * noise_level
return waveform + noise
@staticmethod
def time_stretch(waveform: torch.Tensor, rate: float, sample_rate: int) -> torch.Tensor:
"""Time stretching using librosa"""
waveform_np = waveform.numpy()
stretched = librosa.effects.time_stretch(waveform_np, rate=rate)
return torch.from_numpy(stretched)
@staticmethod
def pitch_shift(waveform: torch.Tensor, n_steps: int, sample_rate: int) -> torch.Tensor:
"""Pitch shifting using librosa"""
waveform_np = waveform.numpy()
shifted = librosa.effects.pitch_shift(waveform_np, sr=sample_rate, n_steps=n_steps)
return torch.from_numpy(shifted)Model Implementation: U-Net for Source Separation
# src/models/unet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
"""Basic convolutional block with normalization and activation"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class DownBlock(nn.Module):
"""Encoder block with max pooling"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.conv = ConvBlock(in_channels, out_channels)
self.pool = nn.MaxPool2d(2)
def forward(self, x):
skip = self.conv(x)
x = self.pool(skip)
return x, skip
class UpBlock(nn.Module):
"""Decoder block with skip connections"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, stride=2)
self.conv = ConvBlock(in_channels, out_channels)
def forward(self, x, skip):
x = self.up(x)
# Handle size mismatch
diffY = skip.size()[2] - x.size()[2]
diffX = skip.size()[3] - x.size()[3]
x = F.pad(x, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([skip, x], dim=1)
return self.conv(x)
class UNet(nn.Module):
"""U-Net for music source separation"""
def __init__(
self,
in_channels: int = 1,
n_sources: int = 2,
base_channels: int = 64,
n_layers: int = 4
):
super().__init__()
self.n_sources = n_sources
# Encoder path
self.encoder = nn.ModuleList()
channels = [in_channels] + [base_channels * (2**i) for i in range(n_layers)]
for i in range(n_layers):
self.encoder.append(DownBlock(channels[i], channels[i+1]))
# Bottleneck
self.bottleneck = ConvBlock(channels[-1], channels[-1] * 2)
# Decoder path
self.decoder = nn.ModuleList()
channels.reverse()
channels[0] = channels[0] * 2 # Account for bottleneck
for i in range(n_layers):
self.decoder.append(UpBlock(channels[i], channels[i+1]))
# Final output layers - separate head for each source
self.output_layers = nn.ModuleList([
nn.Conv2d(channels[-1], 1, 1) for _ in range(n_sources)
])
# Activation for masks
self.activation = nn.Sigmoid()
def forward(self, x):
# x: [batch, 1, freq, time] - magnitude spectrogram
# Encoder path
skips = []
for encoder in self.encoder:
x, skip = encoder(x)
skips.append(skip)
# Bottleneck
x = self.bottleneck(x)
# Decoder path
skips.reverse()
for decoder, skip in zip(self.decoder, skips):
x = decoder(x, skip)
# Generate masks for each source
masks = []
for output_layer in self.output_layers:
mask = self.activation(output_layer(x))
masks.append(mask)
# Stack masks: [batch, n_sources, freq, time]
masks = torch.cat(masks, dim=1)
# Normalize masks to sum to 1
masks = masks / (torch.sum(masks, dim=1, keepdim=True) + 1e-8)
return masks
class MagnitudeUNet(nn.Module):
"""Complete model for magnitude spectrogram separation"""
def __init__(
self,
n_fft: int = 2048,
hop_length: int = 512,
n_sources: int = 2,
**unet_kwargs
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.n_sources = n_sources
# STFT parameters
self.stft = torchaudio.transforms.Spectrogram(
n_fft=n_fft,
hop_length=hop_length,
power=None # Complex spectrogram
)
self.istft = torchaudio.transforms.InverseSpectrogram(
n_fft=n_fft,
hop_length=hop_length
)
# U-Net for mask prediction
self.unet = UNet(in_channels=1, n_sources=n_sources, **unet_kwargs)
def forward(self, waveform):
# waveform: [batch, samples]
batch_size = waveform.size(0)
# Convert to spectrogram
complex_spec = self.stft(waveform) # [batch, freq, time, 2]
magnitude = torch.sqrt(complex_spec[..., 0]**2 + complex_spec[..., 1]**2)
phase = torch.atan2(complex_spec[..., 1], complex_spec[..., 0])
# Predict masks: [batch, n_sources, freq, time]
masks = self.unet(magnitude.unsqueeze(1))
# Apply masks to magnitude
separated_mags = masks * magnitude.unsqueeze(1) # [batch, n_sources, freq, time]
# Reconstruct with original phase
separated_specs = []
for i in range(self.n_sources):
mag = separated_mags[:, i] # [batch, freq, time]
real = mag * torch.cos(phase)
imag = mag * torch.sin(phase)
complex_spec = torch.stack([real, imag], dim=-1)
separated_specs.append(complex_spec)
# Convert back to waveform
separated_waveforms = []
for spec in separated_specs:
waveform = self.istft(spec)
separated_waveforms.append(waveform)
return torch.stack(separated_waveforms, dim=1) # [batch, n_sources, samples]Training Pipeline: From Data to Model
# src/training/trainer.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from pathlib import Path
import wandb
from tqdm import tqdm
import json
class SourceSeparationTrainer:
"""Training pipeline for source separation models"""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
config: dict,
device: str = "cuda"
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
self.device = device
# Training setup
self.optimizer = optim.Adam(
model.parameters(),
lr=config["learning_rate"],
weight_decay=config.get("weight_decay", 1e-5)
)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode="min",
patience=config.get("lr_patience", 5),
factor=0.5
)
self.criterion = nn.L1Loss() # Mean Absolute Error
# Tracking
self.best_val_loss = float("inf")
self.train_losses = []
self.val_losses = []
# Experiment tracking
if config.get("use_wandb", False):
wandb.init(project="music-separation", config=config)
def train_epoch(self, epoch: int) -> float:
"""Train for one epoch"""
self.model.train()
total_loss = 0.0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
for batch_idx, batch in enumerate(pbar):
# Move data to device
mixture = batch["mixture"].to(self.device)
vocals = batch["vocals"].to(self.device)
accompaniment = batch["accompaniment"].to(self.device)
# Target: [batch, n_sources, samples]
target = torch.stack([vocals, accompaniment], dim=1)
# Forward pass
self.optimizer.zero_grad()
output = self.model(mixture) # [batch, n_sources, samples]
# Compute loss
loss = self.criterion(output, target)
# Backward pass
loss.backward()
# Gradient clipping
if self.config.get("grad_clip", False):
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config["grad_clip"]
)
self.optimizer.step()
# Update metrics
total_loss += loss.item()
# Update progress bar
pbar.set_postfix({
"loss": f"{loss.item():.4f}",
"avg_loss": f"{total_loss/(batch_idx+1):.4f}"
})
# Log to wandb
if self.config.get("use_wandb", False):
wandb.log({
"train/batch_loss": loss.item(),
"train/learning_rate": self.optimizer.param_groups[0]["lr"]
})
avg_loss = total_loss / len(self.train_loader)
return avg_loss
def validate(self, epoch: int) -> float:
"""Validate the model"""
self.model.eval()
total_loss = 0.0
with torch.no_grad():
for batch in tqdm(self.val_loader, desc="Validation"):
# Move data to device
mixture = batch["mixture"].to(self.device)
vocals = batch["vocals"].to(self.device)
accompaniment = batch["accompaniment"].to(self.device)
target = torch.stack([vocals, accompaniment], dim=1)
# Forward pass
output = self.model(mixture)
loss = self.criterion(output, target)
total_loss += loss.item()
avg_loss = total_loss / len(self.val_loader)
# Log validation metrics
if self.config.get("use_wandb", False):
wandb.log({
"val/loss": avg_loss,
"epoch": epoch
})
return avg_loss
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
"""Save model checkpoint"""
checkpoint = {
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"val_loss": val_loss,
"config": self.config,
"train_losses": self.train_losses,
"val_losses": self.val_losses
}
# Save latest checkpoint
checkpoint_dir = Path(self.config["checkpoint_dir"])
checkpoint_dir.mkdir(parents=True, exist_ok=True)
torch.save(checkpoint, checkpoint_dir / "latest_checkpoint.pt")
# Save best model
if is_best:
torch.save(checkpoint, checkpoint_dir / "best_model.pt")
print(f"Saved new best model with val_loss: {val_loss:.4f}")
def train(self, num_epochs: int):
"""Complete training loop"""
print(f"Starting training for {num_epochs} epochs...")
for epoch in range(1, num_epochs + 1):
# Training phase
train_loss = self.train_epoch(epoch)
self.train_losses.append(train_loss)
# Validation phase
val_loss = self.validate(epoch)
self.val_losses.append(val_loss)
# Learning rate scheduling
self.scheduler.step(val_loss)
# Check for best model
is_best = val_loss < self.best_val_loss
if is_best:
self.best_val_loss = val_loss
# Save checkpoint
if epoch % self.config.get("save_every", 5) == 0:
self.save_checkpoint(epoch, val_loss, is_best)
print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, "
f"Val Loss: {val_loss:.4f}")
print("Training completed!")
return self.train_losses, self.val_losses
# Loss functions
class MultiScaleLoss(nn.Module):
"""Multi-scale L1 loss for better detail preservation"""
def __init__(self, scales: list = [2048, 1024, 512]):
super().__init__()
self.scales = scales
self.criterion = nn.L1Loss()
def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
total_loss = 0.0
# Original scale loss
total_loss += self.criterion(prediction, target)
# Multi-scale losses
for scale in self.scales:
# Average pooling to create different scales
pred_scaled = F.avg_pool1d(prediction, scale, stride=scale//2)
target_scaled = F.avg_pool1d(target, scale, stride=scale//2)
scale_loss = self.criterion(pred_scaled, target_scaled)
total_loss += 0.3 * scale_loss # Weight for multi-scale terms
return total_loss# train.py
import yaml
import torch
from torch.utils.data import DataLoader
from src.data.dataset import MusicSeparationDataset
from src.models.unet import MagnitudeUNet
from src.training.trainer import SourceSeparationTrainer
def main():
# Load configuration
with open("configs/training_config.yaml", "r") as f:
config = yaml.safe_load(f)
# Setup device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Create datasets
train_dataset = MusicSeparationDataset(
data_dir=config["data_dir"],
split="train",
**config["dataset"]
)
val_dataset = MusicSeparationDataset(
data_dir=config["data_dir"],
split="val",
**config["dataset"]
)
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=config["batch_size"],
shuffle=True,
num_workers=config["num_workers"],
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=config["batch_size"],
shuffle=False,
num_workers=config["num_workers"],
pin_memory=True
)
# Create model
model = MagnitudeUNet(**config["model"])
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Create trainer
trainer = SourceSeparationTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
config=config,
device=device
)
# Start training
trainer.train(config["num_epochs"])
if __name__ == "__main__":
main()Evaluation and Metrics
# src/evaluation/metrics.py
import torch
import numpy as np
from typing import Dict, List
import musdb
import museval
class SeparationMetrics:
"""Comprehensive evaluation metrics for source separation"""
@staticmethod
def sdr_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Signal-to-Distortion Ratio loss (differentiable)
Higher values = better separation
"""
# Ensure same length
min_len = min(prediction.size(-1), target.size(-1))
prediction = prediction[..., :min_len]
target = target[..., :min_len]
# Compute SDR
target_energy = torch.sum(target**2, dim=-1) + 1e-8
noise_energy = torch.sum((prediction - target)**2, dim=-1) + 1e-8
sdr = 10 * torch.log10(target_energy / noise_energy)
return -sdr.mean() # Negative for loss (we want to minimize)
@staticmethod
def compute_sdr_sir_sar(
prediction: np.ndarray,
target: np.ndarray,
mixture: np.ndarray
) -> Dict[str, float]:
"""
Compute SDR, SIR, SAR using museval library
Args:
prediction: Estimated source [samples]
target: True source [samples]
mixture: Original mixture [samples]
Returns:
Dictionary with SDR, SIR, SAR values
"""
# Ensure same length
min_len = min(len(prediction), len(target), len(mixture))
prediction = prediction[:min_len]
target = target[:min_len]
mixture = mixture[:min_len]
# Stack for museval (expects [sources, samples])
estimates = np.array([prediction])
references = np.array([target])
# Compute metrics
scores = museval.evaluate(references, estimates)
return {
"SDR": float(scores[0][0]["SDR"]),
"SIR": float(scores[0][0]["SIR"]),
"SAR": float(scores[0][0]["SAR"])
}
@staticmethod
def batch_evaluate(
model: torch.nn.Module,
dataloader,
device: str = "cuda"
) -> Dict[str, List[float]]:
"""Evaluate model on entire dataset"""
model.eval()
all_metrics = {"SDR": [], "SIR": [], "SAR": []}
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
mixture = batch["mixture"].to(device)
vocals = batch["vocals"].cpu().numpy()
accompaniment = batch["accompaniment"].cpu().numpy()
# Get predictions
output = model(mixture) # [batch, 2, samples]
pred_vocals = output[:, 0].cpu().numpy()
pred_accompaniment = output[:, 1].cpu().numpy()
# Compute metrics for each sample in batch
for i in range(mixture.size(0)):
# Vocals metrics
vocals_metrics = SeparationMetrics.compute_sdr_sir_sar(
pred_vocals[i], vocals[i], mixture[i].cpu().numpy()
)
# Accompaniment metrics
accomp_metrics = SeparationMetrics.compute_sdr_sir_sar(
pred_accompaniment[i], accompaniment[i], mixture[i].cpu().numpy()
)
# Average metrics across sources
for metric in ["SDR", "SIR", "SAR"]:
avg_metric = (vocals_metrics[metric] + accomp_metrics[metric]) / 2
all_metrics[metric].append(avg_metric)
# Compute summary statistics
summary = {}
for metric, values in all_metrics.items():
summary[f"{metric}_mean"] = np.mean(values)
summary[f"{metric}_std"] = np.std(values)
summary[f"{metric}_median"] = np.median(values)
return summary, all_metrics
class PerceptualEvaluator:
"""Tools for perceptual quality assessment"""
@staticmethod
def spectral_convergence(prediction: torch.Tensor, target: torch.Tensor) -> float:
"""Measure spectral similarity"""
# Compute spectrograms
pred_spec = torch.stft(prediction, n_fft=2048, return_complex=True)
target_spec = torch.stft(target, n_fft=2048, return_complex=True)
# Magnitude spectrograms
pred_mag = torch.abs(pred_spec)
target_mag = torch.abs(target_spec)
# Spectral convergence
numerator = torch.norm(target_mag - pred_mag, p="fro")
denominator = torch.norm(target_mag, p="fro")
return (numerator / (denominator + 1e-8)).item()
@staticmethod
def magnitude_loss(prediction: torch.Tensor, target: torch.Tensor) -> float:
"""Log magnitude loss"""
pred_spec = torch.stft(prediction, n_fft=2048, return_complex=True)
target_spec = torch.stft(target, n_fft=2048, return_complex=True)
pred_mag = torch.abs(pred_spec)
target_mag = torch.abs(target_spec)
# Log magnitude
pred_log = torch.log(pred_mag + 1e-8)
target_log = torch.log(target_mag + 1e-8)
return torch.nn.functional.l1_loss(pred_log, target_log).item()# benchmark.py - Evaluate trained model
import torch
import yaml
import json
from pathlib import Path
from torch.utils.data import DataLoader
from src.data.dataset import MusicSeparationDataset
from src.models.unet import MagnitudeUNet
from src.evaluation.metrics import SeparationMetrics
def benchmark_model(model_path: str, config_path: str, data_dir: str):
"""Run comprehensive benchmark on trained model"""
# Load config and model
with open(config_path, "r") as f:
config = yaml.safe_load(f)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load trained model
model = MagnitudeUNet(**config["model"])
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(device)
# Create test dataset
test_dataset = MusicSeparationDataset(
data_dir=data_dir,
split="test",
**config["dataset"]
)
test_loader = DataLoader(
test_dataset,
batch_size=1, # Process one song at a time for detailed analysis
shuffle=False,
num_workers=4
)
# Run evaluation
print("Running evaluation...")
summary, detailed_metrics = SeparationMetrics.batch_evaluate(
model, test_loader, device
)
# Print results
print("\n" + "="*50)
print("BENCHMARK RESULTS")
print("="*50)
for metric, value in summary.items():
print(f"{metric}: {value:.3f}")
# Save detailed results
results = {
"summary": summary,
"detailed_metrics": {k: [float(x) for x in v] for k, v in detailed_metrics.items()},
"model_path": model_path,
"config": config
}
results_path = Path(model_path).parent / "benchmark_results.json"
with open(results_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\nDetailed results saved to: {results_path}")
return summary, detailed_metrics
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, help="Path to trained model")
parser.add_argument("--config", required=True, help="Path to config file")
parser.add_argument("--data", required=True, help="Path to test data")
args = parser.parse_args()
benchmark_model(args.model, args.config, args.data)Deployment: Production-Ready API
# src/inference/api.py
import torch
import torchaudio
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
import io
import tempfile
import asyncio
from pathlib import Path
import logging
from typing import Optional
from src.models.unet import MagnitudeUNet
from src.inference.separator import AudioSeparator
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model instance
separator: Optional[AudioSeparator] = None
app = FastAPI(
title="Music Source Separation API",
description="AI-powered vocal separation from music",
version="1.0.0"
)
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
global separator
try:
model_path = "models/best_model.pt"
separator = AudioSeparator(model_path)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"model_loaded": separator is not None,
"gpu_available": torch.cuda.is_available()
}
@app.post("/separate")
async def separate_audio(
audio_file: UploadFile = File(...),
output_format: str = "wav"
):
"""
Separate vocals from music
Args:
audio_file: Input audio file (wav, mp3, etc.)
output_format: Output format (wav, mp3)
Returns:
ZIP file containing separated stems
"""
if separator is None:
raise HTTPException(status_code=500, detail="Model not loaded")
# Validate file format
allowed_formats = {".wav", ".mp3", ".flac", ".m4a"}
file_suffix = Path(audio_file.filename).suffix.lower()
if file_suffix not in allowed_formats:
raise HTTPException(
status_code=400,
detail=f"Unsupported format: {file_suffix}"
)
try:
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(suffix=file_suffix, delete=False) as tmp_file:
content = await audio_file.read()
tmp_file.write(content)
tmp_path = tmp_file.name
# Perform separation
vocals, accompaniment = await asyncio.get_event_loop().run_in_executor(
None, separator.separate_file, tmp_path
)
# Create output streams
vocals_buffer = io.BytesIO()
accomp_buffer = io.BytesIO()
# Save separated audio
torchaudio.save(vocals_buffer, vocals, separator.sample_rate, format=output_format)
torchaudio.save(accomp_buffer, accompaniment, separator.sample_rate, format=output_format)
# Create ZIP response
import zipfile
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
vocals_buffer.seek(0)
accomp_buffer.seek(0)
zip_file.writestr(f"vocals.{output_format}", vocals_buffer.read())
zip_file.writestr(f"accompaniment.{output_format}", accomp_buffer.read())
zip_buffer.seek(0)
# Cleanup
Path(tmp_path).unlink()
return StreamingResponse(
io.BytesIO(zip_buffer.read()),
media_type="application/zip",
headers={"Content-Disposition": "attachment; filename=separated_audio.zip"}
)
except Exception as e:
logger.error(f"Separation failed: {e}")
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
@app.post("/separate/preview")
async def separate_preview(
audio_file: UploadFile = File(...),
duration: float = 30.0
):
"""
Quick preview separation (first N seconds)
"""
if separator is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
# Process only first N seconds for preview
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
content = await audio_file.read()
tmp_file.write(content)
tmp_path = tmp_file.name
# Load and trim audio
waveform, sr = torchaudio.load(tmp_path)
max_samples = int(duration * sr)
if waveform.size(1) > max_samples:
waveform = waveform[:, :max_samples]
# Perform separation
vocals, accompaniment = await asyncio.get_event_loop().run_in_executor(
None, separator.separate_tensor, waveform
)
# Return JSON with basic info
return {
"duration_processed": waveform.size(1) / sr,
"sample_rate": sr,
"vocals_energy": float(torch.mean(torch.abs(vocals))),
"accompaniment_energy": float(torch.mean(torch.abs(accompaniment))),
"separation_quality": "estimated_good" # Could implement quality metric
}
except Exception as e:
logger.error(f"Preview failed: {e}")
raise HTTPException(status_code=500, detail=f"Preview failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)# Dockerfile
FROM python:3.9-slim
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
g++ \
libsndfile1 \
ffmpeg \
&& rm -rf /var/lib/apt/lists/*
# Set working directory
WORKDIR /app
# Copy requirements and install Python dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy source code
COPY src/ src/
COPY configs/ configs/
COPY models/ models/
# Set environment variables
ENV PYTHONPATH=/app
ENV MODEL_PATH=/app/models/best_model.pt
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run API server
CMD ["uvicorn", "src.inference.api:app", "--host", "0.0.0.0", "--port", "8000"]# docker-compose.yml
version: '3.8'
services:
music-separation-api:
build: .
ports:
- "8000:8000"
volumes:
- ./models:/app/models
- ./logs:/app/logs
environment:
- MODEL_PATH=/app/models/best_model.pt
- LOG_LEVEL=INFO
restart: unless-stopped
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
nginx:
image: nginx:alpine
ports:
- "80:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf
depends_on:
- music-separation-api
restart: unless-stopped
# Build and run:
# docker-compose up --build -dProduction Best Practices
Model Versioning and Rollback
Implement proper model versioning with Blue-Green deployments. Always keep the previous working model available for instant rollback. Use semantic versioning for model releases.
Monitoring and Observability
Monitor inference latency, GPU utilization, and model accuracy drift. Use tools like Prometheus + Grafana for metrics, and implement proper logging with structured data for debugging.
Resource Optimization
Use model quantization and ONNX Runtime for production inference. Implement request batching and GPU memory management. Consider using TensorRT for NVIDIA GPUs to achieve 2-5x speedup.
Security and Rate Limiting
Implement proper authentication, input validation, and rate limiting. Use HTTPS in production and validate all audio inputs for malicious content. Consider implementing API keys for access control.
Testing Strategy
Implement unit tests for data processing, integration tests for the API, and regression tests for model accuracy. Use golden datasets for consistent evaluation across model versions.
Scalability Considerations
Design for horizontal scaling with load balancers and multiple GPU instances. Implement proper queuing for long-running separation tasks. Consider using cloud services like AWS Batch or Kubernetes for orchestration.
Next Steps and Extensions
Model Improvements
- • Implement Demucs waveform model
- • Add 4-stem separation (drums, bass, vocals, other)
- • Experiment with Transformer architectures
- • Try diffusion models for generation
Production Features
- • Real-time streaming separation
- • Web interface for file uploads
- • Batch processing capabilities
- • Integration with cloud storage
Papers to Read
- • "Music Source Separation in the Waveform Domain"
- • "Conv-TasNet: Surpassing Ideal Time-Frequency Magnitude Masking"
- • "Hybrid Transformers for Music Source Separation"
Code Repositories
- • facebook/demucs - Official Demucs implementation
- • sigsep/open-unmix-pytorch - Open-Unmix baseline
- • asteroid-team/asteroid - Toolkit for audio separation