Skip to content

API Reference

oobss

oobss public API.

__all__ module-attribute

__all__ = ['AuxIVA', 'ILRMA', 'OnlineAuxIVA', 'OnlineILRMA', 'OnlineISNMF', 'BatchRequest', 'StreamRequest', 'OnlineFrameRequest', 'SeparationOutput', 'SeparatorState', 'StreamingSeparatorState', 'load_yaml', 'save_yaml', 'JsonlLogger', 'log_steps_jsonl']

AuxIVA

Bases: BaseIterativeSeparator

Base class for auxiliary-function-based independent vector analysis.

Procedure
   input: {x_{f,t}}_{f=1..F, t=1..T}, iterations I
   initialize W_f = I_M
   for i = 1..I:
       y_{f,t} <- W_f x_{f,t}
       compute r_{k,t} and phi(r_{k,t}) (Gauss or Laplace)
       V_{k,f} <- (1/T) * sum_t phi(r_{k,t}) x_{f,t} x_{f,t}^H
       update w_{k,f} by IP1/IP2 for each k
   return y_{f,t}
Update Equations

Indices are \(k=1,\dots,K\) (source), \(m=1,\dots,M\) (channel), \(f=1,\dots,F\) (frequency), and \(t=1,\dots,T\) (time frame). In the determined case, \(K=M\).

Notation:

\[ \bm{x}_{f,t} \in \mathbb{C}^{M}, \quad \bm{y}_{f,t} = W_f \bm{x}_{f,t} \in \mathbb{C}^{M}, \quad y_{k,f,t} = \bm{w}_{k,f}^{\mathsf{H}}\bm{x}_{f,t} \]
\[ r_{k,t} = \left(\sum_{f=1}^{F}|y_{k,f,t}|^2\right)^{1/2}, \quad \varphi_{\mathrm{Laplace}}(r) = \frac{1}{\max(2r,\varepsilon)}, \quad \varphi_{\mathrm{Gauss}}(r) = \frac{1}{\max(r^2/F,\varepsilon)} \]

The weighted covariance for source \(k\) is:

\[ V_{k,f} = \frac{1}{T}\sum_{t=1}^{T} \varphi(r_{k,t})\,\bm{x}_{f,t}\bm{x}_{f,t}^{\mathsf{H}} \]

IP1 update for the \(k\)-th demixing row:

\[ \tilde{\bm{w}}_{k,f} = (W_f V_{k,f})^{-1}\bm{e}_k, \quad \bm{w}_{k,f} = \frac{\tilde{\bm{w}}_{k,f}} {\sqrt{ \tilde{\bm{w}}_{k,f}^{\mathsf{H}} V_{k,f} \tilde{\bm{w}}_{k,f} }} \]
Common projection-back post-processing

AuxIVA/ILRMA and their online variants use the same projection-back rule. For a reference microphone \(m_{\mathrm{ref}}\):

\[ A_f = W_f^{-1}, \quad \hat{y}_{k,f,t} = a_{k,f}[m_{\mathrm{ref}}] y_{k,f,t} \]

The shared implementation lives in :func:oobss.separators.utils.projection_back.

Attributes:

Name Type Description
observations ndarray of shape (n_frame, n_freq, n_src)
spatial SpatialUpdateStrategy

Demixing-matrix strategy (e.g., IP1/IP2).

source SourceModelStrategy

Source-model strategy (e.g., Gauss/Laplace).

covariance CovarianceUpdateStrategy

Weighted covariance strategy.

estimated ndarray of shape (n_frame, n_freq, n_src)
source_model ndarray of shape (n_frame, n_freq, n_src)
demix_filter ndarray of shape (n_freq, n_src, n_src)
loss list[float]

Examples:

Basic TF-domain usage:

   import numpy as np
   from scipy.signal import ShortTimeFFT, get_window
   from oobss import AuxIVA

   fs = 16000
   fft_size = 2048
   hop_size = 512
   win = get_window("hann", fft_size, fftbins=True)
   stft = ShortTimeFFT(win=win, hop=hop_size, fs=fs)

   # mixture_time: (n_samples, n_mic)
   mixture_time = np.random.randn(fs * 2, 2)
   # channel-first STFT: (n_mic, n_freq, n_frame)
   X_cft = stft.stft(mixture_time.T)
   # AuxIVA input must be frame-first: (n_frame, n_freq, n_mic)
   X_tfm = X_cft.transpose(2, 1, 0)

   model = AuxIVA(X_tfm)
   out = model.fit_transform_tf(X_tfm, n_iter=30)
   Y_tfm = out.estimate_tf
   if Y_tfm is None:
       raise ValueError("AuxIVA did not return TF estimates.")

   # Reconstruct separated waveforms: (n_src, n_samples)
   y_time = np.real(stft.istft(Y_tfm.transpose(2, 1, 0)))

Strategy plug-and-play (fix source model, swap spatial update):

   from oobss.separators.strategies import (
       BatchCovarianceStrategy,
       GaussSourceStrategy,
       IP2SpatialStrategy,
   )

   model = AuxIVA(
       X_tfm,
       source=GaussSourceStrategy(),          # fixed source model
       covariance=BatchCovarianceStrategy(),  # fixed covariance update
       spatial=IP2SpatialStrategy(),          # swapped demixing update
   )
   model.run(20)
   Y_tfm = model.get_estimate()
References

[1] N. Ono, "Stable and fast update rules for independent vector analysis based on auxiliary function technique," in Proc. IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA), pp. 189-192, Oct. 2011, doi: 10.1109/ASPAA.2011.6082320.

Source code in src/oobss/separators/auxiva.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
class AuxIVA(BaseIterativeSeparator):
    """
    Base class for auxiliary-function-based independent vector analysis.

    Procedure
    ---------
    ```text

       input: {x_{f,t}}_{f=1..F, t=1..T}, iterations I
       initialize W_f = I_M
       for i = 1..I:
           y_{f,t} <- W_f x_{f,t}
           compute r_{k,t} and phi(r_{k,t}) (Gauss or Laplace)
           V_{k,f} <- (1/T) * sum_t phi(r_{k,t}) x_{f,t} x_{f,t}^H
           update w_{k,f} by IP1/IP2 for each k
       return y_{f,t}
    ```

    Update Equations
    ----------------
    Indices are $k=1,\\dots,K$ (source), $m=1,\\dots,M$
    (channel), $f=1,\\dots,F$ (frequency), and
    $t=1,\\dots,T$ (time frame). In the determined case, $K=M$.

    Notation:

    $$
       \\bm{x}_{f,t} \\in \\mathbb{C}^{M}, \\quad
       \\bm{y}_{f,t} = W_f \\bm{x}_{f,t} \\in \\mathbb{C}^{M},
       \\quad
       y_{k,f,t} = \\bm{w}_{k,f}^{\\mathsf{H}}\\bm{x}_{f,t}
    $$

    $$
       r_{k,t} = \\left(\\sum_{f=1}^{F}|y_{k,f,t}|^2\\right)^{1/2}, \\quad
       \\varphi_{\\mathrm{Laplace}}(r) = \\frac{1}{\\max(2r,\\varepsilon)}, \\quad
       \\varphi_{\\mathrm{Gauss}}(r) = \\frac{1}{\\max(r^2/F,\\varepsilon)}
    $$

    The weighted covariance for source $k$ is:

    $$
       V_{k,f}
       = \\frac{1}{T}\\sum_{t=1}^{T}
       \\varphi(r_{k,t})\\,\\bm{x}_{f,t}\\bm{x}_{f,t}^{\\mathsf{H}}
    $$

    IP1 update for the $k$-th demixing row:

    $$
       \\tilde{\\bm{w}}_{k,f}
       = (W_f V_{k,f})^{-1}\\bm{e}_k, \\quad
       \\bm{w}_{k,f}
       = \\frac{\\tilde{\\bm{w}}_{k,f}}
       {\\sqrt{
       \\tilde{\\bm{w}}_{k,f}^{\\mathsf{H}}
       V_{k,f}
       \\tilde{\\bm{w}}_{k,f}
       }}
    $$

    Common projection-back post-processing
    --------------------------------------
    AuxIVA/ILRMA and their online variants use the same projection-back rule.
    For a reference microphone $m_{\\mathrm{ref}}$:

    $$
       A_f = W_f^{-1}, \\quad
       \\hat{y}_{k,f,t} = a_{k,f}[m_{\\mathrm{ref}}] y_{k,f,t}
    $$

    The shared implementation lives in
    :func:`oobss.separators.utils.projection_back`.

    Attributes
    ----------
    observations : ndarray of shape (n_frame, n_freq, n_src)
    spatial : SpatialUpdateStrategy
        Demixing-matrix strategy (e.g., IP1/IP2).
    source : SourceModelStrategy
        Source-model strategy (e.g., Gauss/Laplace).
    covariance : CovarianceUpdateStrategy
        Weighted covariance strategy.

    estimated : ndarray of shape (n_frame, n_freq, n_src)
    source_model : ndarray of shape (n_frame, n_freq, n_src)
    demix_filter : ndarray of shape (n_freq, n_src, n_src)
    loss : list[float]

    Examples
    --------
    Basic TF-domain usage:

    ```python

       import numpy as np
       from scipy.signal import ShortTimeFFT, get_window
       from oobss import AuxIVA

       fs = 16000
       fft_size = 2048
       hop_size = 512
       win = get_window("hann", fft_size, fftbins=True)
       stft = ShortTimeFFT(win=win, hop=hop_size, fs=fs)

       # mixture_time: (n_samples, n_mic)
       mixture_time = np.random.randn(fs * 2, 2)
       # channel-first STFT: (n_mic, n_freq, n_frame)
       X_cft = stft.stft(mixture_time.T)
       # AuxIVA input must be frame-first: (n_frame, n_freq, n_mic)
       X_tfm = X_cft.transpose(2, 1, 0)

       model = AuxIVA(X_tfm)
       out = model.fit_transform_tf(X_tfm, n_iter=30)
       Y_tfm = out.estimate_tf
       if Y_tfm is None:
           raise ValueError("AuxIVA did not return TF estimates.")

       # Reconstruct separated waveforms: (n_src, n_samples)
       y_time = np.real(stft.istft(Y_tfm.transpose(2, 1, 0)))
    ```

    Strategy plug-and-play (fix source model, swap spatial update):

    ```python

       from oobss.separators.strategies import (
           BatchCovarianceStrategy,
           GaussSourceStrategy,
           IP2SpatialStrategy,
       )

       model = AuxIVA(
           X_tfm,
           source=GaussSourceStrategy(),          # fixed source model
           covariance=BatchCovarianceStrategy(),  # fixed covariance update
           spatial=IP2SpatialStrategy(),          # swapped demixing update
       )
       model.run(20)
       Y_tfm = model.get_estimate()
    ```

    References
    ----------
    [1] N. Ono, "Stable and fast update rules for independent vector analysis
    based on auxiliary function technique," in *Proc. IEEE Workshop on
    Applications of Signal Processing to Audio and Acoustics (WASPAA)*, pp.
    189-192, Oct. 2011, doi: 10.1109/ASPAA.2011.6082320.
    """

    def __init__(
        self,
        observations,
        *,
        spatial: SpatialUpdateStrategy | None = None,
        source: SourceModelStrategy | None = None,
        covariance: CovarianceUpdateStrategy | None = None,
        reconstruction_strategy: ReconstructionStrategy | None = None,
    ):
        """Initialize parameters in AuxIVA."""
        # Setup
        self.observations = observations
        self.spatial_strategy = spatial if spatial is not None else IP1SpatialStrategy()
        self.source_strategy = (
            source if source is not None else GaussSourceStrategy(eps=float(eps))
        )
        self.covariance_strategy = (
            covariance if covariance is not None else BatchCovarianceStrategy()
        )
        self.reconstruction_strategy = (
            reconstruction_strategy
            if reconstruction_strategy is not None
            else DemixReconstructionStrategy()
        )
        # Results
        self.bind_mixture_tf(np.asarray(observations))

    def step(self):
        """Update paramters one step."""
        n_src = self.observations.shape[-1]
        # 1. Update source model
        self.source_model = self.calc_source_model()

        # 2. Update covariance
        self.covariance = self.covariance_strategy.update(
            CovarianceRequest(
                observed=self.observations,
                source_model=self.source_model,
            )
        )

        # 3. Update demixing filter
        for row_idx in self.spatial_strategy.row_groups(n_src):
            self.demix_filter[:, :, :] = self.spatial_strategy.update(
                self.covariance,
                self.demix_filter,
                row_idx=row_idx,
            )

        # 4. Update estimated sources
        recon = self.reconstruction_strategy.reconstruct(
            ReconstructionRequest(
                mixture=self.observations,
                demix_filter=self.demix_filter,
            )
        )
        self.estimated = recon.estimate

        # 5. Update loss function value
        self.loss = self.calc_loss()

    def init_demix(self):
        """Initialize demixing matrix."""
        _, n_freq, n_src = self.observations.shape
        W0 = np.zeros((n_freq, n_src, n_src), dtype=complex)
        W0[:, :, :n_src] = np.tile(np.eye(n_src, dtype=complex), (n_freq, 1, 1))
        return W0

    def bind_mixture_tf(self, mixture_tf: np.ndarray) -> None:
        """Bind a TF-domain mixture and reset internal iterative state."""
        observations = np.asarray(mixture_tf)
        if observations.ndim != 3:
            raise ValueError(
                "AuxIVA expects mixture_tf with shape (n_frame, n_freq, n_mic)."
            )

        self.observations = observations
        self.demix_filter = self.init_demix()
        recon = self.reconstruction_strategy.reconstruct(
            ReconstructionRequest(
                mixture=self.observations,
                demix_filter=self.demix_filter,
            )
        )
        self.estimated = recon.estimate
        self.source_model = None
        self.covariance = None
        self.loss = self.calc_loss()

    def calc_source_model(self):
        """
        Calculate source model.

        Returns
        -------
        ndarray of shape (n_frame, n_freq, n_src)
        """
        updated = self.source_strategy.update(
            SourceModelRequest(
                estimated=self.estimated,
                n_freq=int(self.observations.shape[1]),
            )
        )
        if updated.source_model is None:
            raise ValueError("source strategy did not return source_model.")
        return updated.source_model

    def _source_model_name_for_loss(self) -> str:
        name = getattr(self.source_strategy, "model", "Gauss")
        if not isinstance(name, str):
            return "Gauss"
        normalized = name.capitalize()
        if normalized not in {"Gauss", "Laplace"}:
            raise ValueError(
                "calc_loss supports Gauss/Laplace source models. "
                f"Got source_strategy.model={name!r}."
            )
        return normalized

    def calc_loss(self):
        """Calculate loss function value."""
        n_frames, _, _ = self.estimated.shape

        def f_norm(y):
            return np.linalg.norm(y, axis=1)

        contrast_func = {
            "Laplace": lambda y: np.sum(f_norm(y)),
            "Gauss": lambda y: np.sum(np.log(1.0 / np.maximum(eps, f_norm(y)))),
        }[self._source_model_name_for_loss()]
        target_loss = contrast_func(self.estimated)

        tfn_fnt = [1, 2, 0]
        XX = self.observations.transpose(tfn_fnt)
        YY = self.estimated.transpose(tfn_fnt)
        W_H = np.linalg.solve(XX @ tensor_H(XX), XX @ tensor_H(YY))
        _, logdet = np.linalg.slogdet(W_H)
        demix_loss = -2 * n_frames * np.sum(logdet)

        return target_loss + demix_loss

    @property
    def n_sources(self) -> int:
        """Return number of separated sources."""
        return int(self.observations.shape[-1])

    def get_estimate(self) -> np.ndarray:
        """Return current TF-domain estimate."""
        return self.estimated

n_sources property

n_sources

Return number of separated sources.

__call__

__call__(*args, **kwargs)

Alias for :meth:forward to provide a torch-like call style.

Source code in src/oobss/separators/core/base.py
31
32
33
def __call__(self, *args: Any, **kwargs: Any) -> SeparationOutput:
    """Alias for :meth:`forward` to provide a torch-like call style."""
    return self.forward(*args, **kwargs)

__init__

__init__(observations, *, spatial=None, source=None, covariance=None, reconstruction_strategy=None)

Initialize parameters in AuxIVA.

Source code in src/oobss/separators/auxiva.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def __init__(
    self,
    observations,
    *,
    spatial: SpatialUpdateStrategy | None = None,
    source: SourceModelStrategy | None = None,
    covariance: CovarianceUpdateStrategy | None = None,
    reconstruction_strategy: ReconstructionStrategy | None = None,
):
    """Initialize parameters in AuxIVA."""
    # Setup
    self.observations = observations
    self.spatial_strategy = spatial if spatial is not None else IP1SpatialStrategy()
    self.source_strategy = (
        source if source is not None else GaussSourceStrategy(eps=float(eps))
    )
    self.covariance_strategy = (
        covariance if covariance is not None else BatchCovarianceStrategy()
    )
    self.reconstruction_strategy = (
        reconstruction_strategy
        if reconstruction_strategy is not None
        else DemixReconstructionStrategy()
    )
    # Results
    self.bind_mixture_tf(np.asarray(observations))

bind_mixture_tf

bind_mixture_tf(mixture_tf)

Bind a TF-domain mixture and reset internal iterative state.

Source code in src/oobss/separators/auxiva.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def bind_mixture_tf(self, mixture_tf: np.ndarray) -> None:
    """Bind a TF-domain mixture and reset internal iterative state."""
    observations = np.asarray(mixture_tf)
    if observations.ndim != 3:
        raise ValueError(
            "AuxIVA expects mixture_tf with shape (n_frame, n_freq, n_mic)."
        )

    self.observations = observations
    self.demix_filter = self.init_demix()
    recon = self.reconstruction_strategy.reconstruct(
        ReconstructionRequest(
            mixture=self.observations,
            demix_filter=self.demix_filter,
        )
    )
    self.estimated = recon.estimate
    self.source_model = None
    self.covariance = None
    self.loss = self.calc_loss()

bind_mixture_time

bind_mixture_time(mixture_time, sample_rate)

Bind time-domain input before iterative updates.

Source code in src/oobss/separators/core/base.py
89
90
91
92
93
94
95
96
def bind_mixture_time(
    self, mixture_time: np.ndarray, sample_rate: int | None
) -> None:
    """Bind time-domain input before iterative updates."""
    del sample_rate
    raise ValueError(
        f"{self.__class__.__name__} does not support time-domain input."
    )

calc_loss

calc_loss()

Calculate loss function value.

Source code in src/oobss/separators/auxiva.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def calc_loss(self):
    """Calculate loss function value."""
    n_frames, _, _ = self.estimated.shape

    def f_norm(y):
        return np.linalg.norm(y, axis=1)

    contrast_func = {
        "Laplace": lambda y: np.sum(f_norm(y)),
        "Gauss": lambda y: np.sum(np.log(1.0 / np.maximum(eps, f_norm(y)))),
    }[self._source_model_name_for_loss()]
    target_loss = contrast_func(self.estimated)

    tfn_fnt = [1, 2, 0]
    XX = self.observations.transpose(tfn_fnt)
    YY = self.estimated.transpose(tfn_fnt)
    W_H = np.linalg.solve(XX @ tensor_H(XX), XX @ tensor_H(YY))
    _, logdet = np.linalg.slogdet(W_H)
    demix_loss = -2 * n_frames * np.sum(logdet)

    return target_loss + demix_loss

calc_source_model

calc_source_model()

Calculate source model.

Returns:

Type Description
ndarray of shape (n_frame, n_freq, n_src)
Source code in src/oobss/separators/auxiva.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def calc_source_model(self):
    """
    Calculate source model.

    Returns
    -------
    ndarray of shape (n_frame, n_freq, n_src)
    """
    updated = self.source_strategy.update(
        SourceModelRequest(
            estimated=self.estimated,
            n_freq=int(self.observations.shape[1]),
        )
    )
    if updated.source_model is None:
        raise ValueError("source strategy did not return source_model.")
    return updated.source_model

fit_transform_tf

fit_transform_tf(mixture_tf, *, n_iter=0, request=None)

Bind TF input, run iterations, and return TF-domain estimate.

Source code in src/oobss/separators/core/base.py
53
54
55
56
57
58
59
60
61
62
63
64
65
def fit_transform_tf(
    self,
    mixture_tf: np.ndarray,
    *,
    n_iter: int = 0,
    request: BatchRequest | None = None,
) -> SeparationOutput:
    """Bind TF input, run iterations, and return TF-domain estimate."""
    del request
    self.bind_mixture_tf(np.asarray(mixture_tf))
    if n_iter > 0:
        self.run(int(n_iter))
    return SeparationOutput(estimate_tf=self.get_estimate())

fit_transform_time

fit_transform_time(mixture_time, *, n_iter=0, request=None)

Bind time-domain input if supported, then run iterations.

Source code in src/oobss/separators/core/base.py
67
68
69
70
71
72
73
74
75
76
77
78
79
def fit_transform_time(
    self,
    mixture_time: np.ndarray,
    *,
    n_iter: int = 0,
    request: BatchRequest | None = None,
) -> SeparationOutput:
    """Bind time-domain input if supported, then run iterations."""
    sample_rate = None if request is None else request.sample_rate
    self.bind_mixture_time(np.asarray(mixture_time), sample_rate)
    if n_iter > 0:
        self.run(int(n_iter))
    return SeparationOutput(estimate_tf=self.get_estimate())

forward

forward(mixture, *, n_iter=0, request=None, is_time_input=None)

Run batch separation from TF-domain or time-domain input.

Source code in src/oobss/separators/core/base.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def forward(
    self,
    mixture: np.ndarray,
    *,
    n_iter: int = 0,
    request: BatchRequest | None = None,
    is_time_input: bool | None = None,
) -> SeparationOutput:
    """Run batch separation from TF-domain or time-domain input."""
    mixture_arr = np.asarray(mixture)
    if is_time_input is None:
        is_time = not (np.iscomplexobj(mixture_arr) or mixture_arr.ndim >= 3)
    else:
        is_time = bool(is_time_input)
    if is_time:
        return self.fit_transform_time(
            mixture_arr,
            n_iter=int(n_iter),
            request=request,
        )
    return self.fit_transform_tf(
        mixture_arr,
        n_iter=int(n_iter),
        request=request,
    )

get_estimate

get_estimate()

Return current TF-domain estimate.

Source code in src/oobss/separators/auxiva.py
322
323
324
def get_estimate(self) -> np.ndarray:
    """Return current TF-domain estimate."""
    return self.estimated

init_demix

init_demix()

Initialize demixing matrix.

Source code in src/oobss/separators/auxiva.py
237
238
239
240
241
242
def init_demix(self):
    """Initialize demixing matrix."""
    _, n_freq, n_src = self.observations.shape
    W0 = np.zeros((n_freq, n_src, n_src), dtype=complex)
    W0[:, :, :n_src] = np.tile(np.eye(n_src, dtype=complex), (n_freq, 1, 1))
    return W0

reset

reset()

Reset internal state (override in subclasses when needed).

Source code in src/oobss/separators/core/base.py
28
29
def reset(self) -> None:
    """Reset internal state (override in subclasses when needed)."""

run

run(n_iter)

Execute n_iter update steps and return final estimate.

Source code in src/oobss/separators/core/base.py
 98
 99
100
101
102
103
104
def run(self, n_iter: int) -> np.ndarray:
    """Execute ``n_iter`` update steps and return final estimate."""
    if n_iter < 0:
        raise ValueError("n_iter must be non-negative")
    for _ in range(n_iter):
        self.step()
    return self.get_estimate()

step

step()

Update paramters one step.

Source code in src/oobss/separators/auxiva.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def step(self):
    """Update paramters one step."""
    n_src = self.observations.shape[-1]
    # 1. Update source model
    self.source_model = self.calc_source_model()

    # 2. Update covariance
    self.covariance = self.covariance_strategy.update(
        CovarianceRequest(
            observed=self.observations,
            source_model=self.source_model,
        )
    )

    # 3. Update demixing filter
    for row_idx in self.spatial_strategy.row_groups(n_src):
        self.demix_filter[:, :, :] = self.spatial_strategy.update(
            self.covariance,
            self.demix_filter,
            row_idx=row_idx,
        )

    # 4. Update estimated sources
    recon = self.reconstruction_strategy.reconstruct(
        ReconstructionRequest(
            mixture=self.observations,
            demix_filter=self.demix_filter,
        )
    )
    self.estimated = recon.estimate

    # 5. Update loss function value
    self.loss = self.calc_loss()

BatchRequest dataclass

Execution options for batch separators.

Parameters:

Name Type Description Default
reference_mic int

Reference microphone index for scale restoration/evaluation.

0
sample_rate int | None

Sampling rate in Hz when time-domain input is supplied.

None
metadata dict[str, Any]

Additional method-specific options.

dict()
Source code in src/oobss/separators/core/io_models.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
@dataclass(slots=True)
class BatchRequest:
    """Execution options for batch separators.

    Parameters
    ----------
    reference_mic:
        Reference microphone index for scale restoration/evaluation.
    sample_rate:
        Sampling rate in Hz when time-domain input is supplied.
    metadata:
        Additional method-specific options.
    """

    reference_mic: int = 0
    sample_rate: int | None = None
    metadata: dict[str, Any] = field(default_factory=dict)

ILRMA

Bases: BaseIterativeSeparator

Base class for independent low-rank matrix analysis.

Procedure
   input: {x_{f,t}}_{f=1..F, t=1..T}, NMF rank L, iterations I
   initialize W_f = I_M, b_{k,f,\ell}, c_{k,\ell,t}
   for i = 1..I:
       y_{f,t} <- W_f x_{f,t}
       for each source k:
           r_{k,f,t} <- sum_\ell b_{k,f,\ell} c_{k,\ell,t}
           update b_{k,f,\ell}, c_{k,\ell,t} by MU using |y_{k,f,t}|^2
       V_{k,f} <- (1/T) * sum_t x_{f,t} x_{f,t}^H / r_{k,f,t}
       update w_{k,f} by IP1/IP2
   return y_{f,t}
Update Equations

Indices are \(k=1,\dots,K\) (source), \(m=1,\dots,M\) (channel), \(f=1,\dots,F\) (frequency), \(t=1,\dots,T\) (time frame), and \(\ell=1,\dots,L\) (NMF basis index). In the determined case, \(K=M\).

Notation:

\[ \bm{x}_{f,t} \in \mathbb{C}^{M}, \quad \bm{y}_{f,t} = W_f\bm{x}_{f,t}, \quad y_{k,f,t} = \bm{w}_{k,f}^{\mathsf{H}}\bm{x}_{f,t} \]

Source variance model with NMF basis entries \(b_{k,f,\ell}\) and activation entries \(c_{k,\ell,t}\):

\[ r_{k,f,t} = \sum_{\ell=1}^{L} b_{k,f,\ell}c_{k,\ell,t} \]

For each source \(k\), the multiplicative updates are:

\[ b_{k,f,\ell} \leftarrow b_{k,f,\ell} \frac{ \sum_{t=1}^{T}|y_{k,f,t}|^2r_{k,f,t}^{-2}c_{k,\ell,t} }{ \sum_{t=1}^{T}r_{k,f,t}^{-1}c_{k,\ell,t} }, \quad c_{k,\ell,t} \leftarrow c_{k,\ell,t} \frac{ \sum_{f=1}^{F}|y_{k,f,t}|^2r_{k,f,t}^{-2}b_{k,f,\ell} }{ \sum_{f=1}^{F}r_{k,f,t}^{-1}b_{k,f,\ell} } \]

Weighted covariance:

\[ V_{k,f} = \frac{1}{T}\sum_{t=1}^{T} \frac{ \bm{x}_{f,t}\bm{x}_{f,t}^{\mathsf{H}} }{ r_{k,f,t} } \]

IP1 update for source \(k\):

\[ \tilde{\bm{w}}_{k,f} = (W_fV_{k,f})^{-1}\bm{e}_{k}, \quad \bm{w}_{k,f} = \frac{\tilde{\bm{w}}_{k,f}} {\sqrt{ \tilde{\bm{w}}_{k,f}^{\mathsf{H}} V_{k,f} \tilde{\bm{w}}_{k,f} }} \]
Common projection-back post-processing

AuxIVA/ILRMA and their online variants use the same projection-back rule. For a reference microphone \(m_{\mathrm{ref}}\):

\[ A_f = W_f^{-1}, \quad \hat{y}_{k,f,t} = a_{k,f}[m_{\mathrm{ref}}] y_{k,f,t} \]

The shared implementation lives in :func:oobss.separators.utils.projection_back.

Attributes:

Name Type Description
observations ndarray of shape (n_frame, n_freq, n_src)
spatial SpatialUpdateStrategy

Demixing-matrix strategy (e.g., IP1/IP2).

source SourceModelStrategy

Source-model strategy (typically ILRMA NMF MU).

covariance CovarianceUpdateStrategy

Weighted covariance strategy.

estimated ndarray of shape (n_frame, n_freq, n_src)
source_model ndarray of shape (n_frame, n_freq, n_src)
demix_filter ndarray of shape (n_freq, n_src, n_src)
loss list[float]

Examples:

Basic TF-domain usage:

   import numpy as np
   from scipy.signal import ShortTimeFFT, get_window
   from oobss import ILRMA

   fs = 16000
   fft_size = 2048
   hop_size = 512
   win = get_window("hann", fft_size, fftbins=True)
   stft = ShortTimeFFT(win=win, hop=hop_size, fs=fs)

   mixture_time = np.random.randn(fs * 2, 2)  # (n_samples, n_mic)
   X_tfm = stft.stft(mixture_time.T).transpose(2, 1, 0)  # (T, F, M)

   model = ILRMA(X_tfm, n_basis=8, random_state=0)
   out = model.fit_transform_tf(X_tfm, n_iter=50)
   Y_tfm = out.estimate_tf
   if Y_tfm is None:
       raise ValueError("ILRMA did not return TF estimates.")

   y_time = np.real(stft.istft(Y_tfm.transpose(2, 1, 0)))

Warm-start NMF factors and demixing:

   # Initial factors: basis0=(n_src, n_freq, n_basis),
   # activ0=(n_src, n_frame, n_basis)
   basis0 = np.abs(np.random.randn(2, X_tfm.shape[1], 8)) + 1e-6
   activ0 = np.abs(np.random.randn(2, X_tfm.shape[0], 8)) + 1e-6

   model = ILRMA(
       X_tfm,
       n_basis=8,
       basis0=basis0,
       activ0=activ0,
       random_state=0,
   )
   model.run(30)
   Y_tfm = model.get_estimate()
References

[1] D. Kitamura, N. Ono, H. Sawada, H. Kameoka, and H. Saruwatari, "Determined blind source separation unifying independent vector analysis and nonnegative matrix factorization," IEEE/ACM Trans. Audio, Speech, and Language Processing, vol. 24, no. 9, pp. 1622-1637, Sep. 2016, doi: 10.1109/TASLP.2016.2577880.

Source code in src/oobss/separators/ilrma.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
class ILRMA(BaseIterativeSeparator):
    """
    Base class for independent low-rank matrix analysis.

    Procedure
    ---------
    ```text

       input: {x_{f,t}}_{f=1..F, t=1..T}, NMF rank L, iterations I
       initialize W_f = I_M, b_{k,f,\\ell}, c_{k,\\ell,t}
       for i = 1..I:
           y_{f,t} <- W_f x_{f,t}
           for each source k:
               r_{k,f,t} <- sum_\\ell b_{k,f,\\ell} c_{k,\\ell,t}
               update b_{k,f,\\ell}, c_{k,\\ell,t} by MU using |y_{k,f,t}|^2
           V_{k,f} <- (1/T) * sum_t x_{f,t} x_{f,t}^H / r_{k,f,t}
           update w_{k,f} by IP1/IP2
       return y_{f,t}
    ```

    Update Equations
    ----------------
    Indices are $k=1,\\dots,K$ (source), $m=1,\\dots,M$
    (channel), $f=1,\\dots,F$ (frequency), $t=1,\\dots,T$
    (time frame), and $\\ell=1,\\dots,L$ (NMF basis index). In the
    determined case, $K=M$.

    Notation:

    $$
       \\bm{x}_{f,t} \\in \\mathbb{C}^{M}, \\quad
       \\bm{y}_{f,t} = W_f\\bm{x}_{f,t}, \\quad
       y_{k,f,t} = \\bm{w}_{k,f}^{\\mathsf{H}}\\bm{x}_{f,t}
    $$

    Source variance model with NMF basis entries $b_{k,f,\\ell}$ and
    activation entries $c_{k,\\ell,t}$:

    $$
       r_{k,f,t} = \\sum_{\\ell=1}^{L} b_{k,f,\\ell}c_{k,\\ell,t}
    $$

    For each source $k$, the multiplicative updates are:

    $$
       b_{k,f,\\ell} \\leftarrow b_{k,f,\\ell}
       \\frac{
       \\sum_{t=1}^{T}|y_{k,f,t}|^2r_{k,f,t}^{-2}c_{k,\\ell,t}
       }{
       \\sum_{t=1}^{T}r_{k,f,t}^{-1}c_{k,\\ell,t}
       },
       \\quad
       c_{k,\\ell,t} \\leftarrow c_{k,\\ell,t}
       \\frac{
       \\sum_{f=1}^{F}|y_{k,f,t}|^2r_{k,f,t}^{-2}b_{k,f,\\ell}
       }{
       \\sum_{f=1}^{F}r_{k,f,t}^{-1}b_{k,f,\\ell}
       }
    $$

    Weighted covariance:

    $$
       V_{k,f} = \\frac{1}{T}\\sum_{t=1}^{T}
       \\frac{
       \\bm{x}_{f,t}\\bm{x}_{f,t}^{\\mathsf{H}}
       }{
       r_{k,f,t}
       }
    $$

    IP1 update for source $k$:

    $$
       \\tilde{\\bm{w}}_{k,f}
       = (W_fV_{k,f})^{-1}\\bm{e}_{k}, \\quad
       \\bm{w}_{k,f}
       = \\frac{\\tilde{\\bm{w}}_{k,f}}
       {\\sqrt{
       \\tilde{\\bm{w}}_{k,f}^{\\mathsf{H}}
       V_{k,f}
       \\tilde{\\bm{w}}_{k,f}
       }}
    $$

    Common projection-back post-processing
    --------------------------------------
    AuxIVA/ILRMA and their online variants use the same projection-back rule.
    For a reference microphone $m_{\\mathrm{ref}}$:

    $$
       A_f = W_f^{-1}, \\quad
       \\hat{y}_{k,f,t} = a_{k,f}[m_{\\mathrm{ref}}] y_{k,f,t}
    $$

    The shared implementation lives in
    :func:`oobss.separators.utils.projection_back`.

    Attributes
    ----------
    observations : ndarray of shape (n_frame, n_freq, n_src)
    spatial : SpatialUpdateStrategy
        Demixing-matrix strategy (e.g., IP1/IP2).
    source : SourceModelStrategy
        Source-model strategy (typically ILRMA NMF MU).
    covariance : CovarianceUpdateStrategy
        Weighted covariance strategy.

    estimated : ndarray of shape (n_frame, n_freq, n_src)
    source_model : ndarray of shape (n_frame, n_freq, n_src)
    demix_filter : ndarray of shape (n_freq, n_src, n_src)
    loss : list[float]

    Examples
    --------
    Basic TF-domain usage:

    ```python

       import numpy as np
       from scipy.signal import ShortTimeFFT, get_window
       from oobss import ILRMA

       fs = 16000
       fft_size = 2048
       hop_size = 512
       win = get_window("hann", fft_size, fftbins=True)
       stft = ShortTimeFFT(win=win, hop=hop_size, fs=fs)

       mixture_time = np.random.randn(fs * 2, 2)  # (n_samples, n_mic)
       X_tfm = stft.stft(mixture_time.T).transpose(2, 1, 0)  # (T, F, M)

       model = ILRMA(X_tfm, n_basis=8, random_state=0)
       out = model.fit_transform_tf(X_tfm, n_iter=50)
       Y_tfm = out.estimate_tf
       if Y_tfm is None:
           raise ValueError("ILRMA did not return TF estimates.")

       y_time = np.real(stft.istft(Y_tfm.transpose(2, 1, 0)))
    ```

    Warm-start NMF factors and demixing:

    ```python

       # Initial factors: basis0=(n_src, n_freq, n_basis),
       # activ0=(n_src, n_frame, n_basis)
       basis0 = np.abs(np.random.randn(2, X_tfm.shape[1], 8)) + 1e-6
       activ0 = np.abs(np.random.randn(2, X_tfm.shape[0], 8)) + 1e-6

       model = ILRMA(
           X_tfm,
           n_basis=8,
           basis0=basis0,
           activ0=activ0,
           random_state=0,
       )
       model.run(30)
       Y_tfm = model.get_estimate()
    ```

    References
    ----------
    [1] D. Kitamura, N. Ono, H. Sawada, H. Kameoka, and H. Saruwatari,
    "Determined blind source separation unifying independent vector analysis and
    nonnegative matrix factorization," *IEEE/ACM Trans. Audio, Speech, and
    Language Processing*, vol. 24, no. 9, pp. 1622-1637, Sep. 2016,
    doi: 10.1109/TASLP.2016.2577880.
    """

    def __init__(
        self,
        observations,
        *,
        n_basis: int = 10,
        basis0=None,
        activ0=None,
        random_state: int | None = None,
        rng: np.random.Generator | None = None,
        spatial: SpatialUpdateStrategy | None = None,
        source: SourceModelStrategy | None = None,
        covariance: CovarianceUpdateStrategy | None = None,
        reconstruction_strategy: ReconstructionStrategy | None = None,
    ):
        """Initialize parameters in ILRMA."""
        # Setup
        self.observations = observations
        self.n_basis = int(n_basis)
        self.rng = np.random.default_rng(random_state) if rng is None else rng
        self._basis0 = None if basis0 is None else np.array(basis0, copy=True)
        self._activ0 = None if activ0 is None else np.array(activ0, copy=True)
        self.spatial_strategy = spatial if spatial is not None else IP1SpatialStrategy()
        self.source_strategy = (
            source if source is not None else ILRMANMFSourceStrategy(eps=float(eps))
        )
        self.covariance_strategy = (
            covariance if covariance is not None else BatchCovarianceStrategy()
        )
        self.reconstruction_strategy = (
            reconstruction_strategy
            if reconstruction_strategy is not None
            else DemixReconstructionStrategy()
        )

        # Results
        self.bind_mixture_tf(np.asarray(observations))

    def step(self):
        """Update paramters one step."""
        n_src = self.observations.shape[-1]
        y_power = np.square(np.abs(self.estimated))

        row_groups = self.spatial_strategy.row_groups(n_src)
        pairwise_mode = any(np.ndim(row_idx) > 0 for row_idx in row_groups)

        if pairwise_mode:
            for row_idx in row_groups:
                for src_idx in np.atleast_1d(np.asarray(row_idx, dtype=np.int64)):
                    b = self.basis[src_idx]
                    a = self.activ[src_idx]
                    self.basis[src_idx], self.activ[src_idx] = self.calc_source_model(
                        b, a, y_power[:, :, src_idx]
                    )
                    self.source_model[:, :, src_idx] = (
                        self.activ[src_idx] @ self.basis[src_idx].T
                    )
                self.covariance = self.covariance_strategy.update(
                    CovarianceRequest(
                        observed=self.observations,
                        source_model=self.source_model,
                    )
                )
                self.demix_filter[:, :, :] = self.spatial_strategy.update(
                    self.covariance,
                    self.demix_filter,
                    row_idx=row_idx,
                )
        else:
            for s in range(n_src):
                b = self.basis[s]
                a = self.activ[s]
                self.basis[s], self.activ[s] = self.calc_source_model(
                    b, a, y_power[:, :, s]
                )
                self.source_model[:, :, s] = self.activ[s] @ self.basis[s].T
            self.covariance = self.covariance_strategy.update(
                CovarianceRequest(
                    observed=self.observations,
                    source_model=self.source_model,
                )
            )
            for row_idx in row_groups:
                self.demix_filter[:, :, :] = self.spatial_strategy.update(
                    self.covariance,
                    self.demix_filter,
                    row_idx=row_idx,
                )

        # 3. Update estimated sources
        recon = self.reconstruction_strategy.reconstruct(
            ReconstructionRequest(
                mixture=self.observations,
                demix_filter=self.demix_filter,
            )
        )
        self.estimated = recon.estimate

        # 4. Update loss function value
        self.loss = self.calc_loss()

    def init_demix(self):
        """Initialize demixing matrix."""
        _, n_freq, n_src = self.observations.shape
        W0 = np.zeros((n_freq, n_src, n_src), dtype=complex)
        W0[:, :, :n_src] = np.tile(np.eye(n_src, dtype=complex), (n_freq, 1, 1))
        return W0

    def bind_mixture_tf(self, mixture_tf: np.ndarray) -> None:
        """Bind a TF-domain mixture and reset internal iterative state."""
        observations = np.asarray(mixture_tf)
        if observations.ndim != 3:
            raise ValueError(
                "ILRMA expects mixture_tf with shape (n_frame, n_freq, n_mic)."
            )

        self.observations = observations
        self.demix_filter = self.init_demix()
        recon = self.reconstruction_strategy.reconstruct(
            ReconstructionRequest(
                mixture=self.observations,
                demix_filter=self.demix_filter,
            )
        )
        self.estimated = recon.estimate

        if self._basis0 is None:
            self.basis = self.init_basis()
        else:
            self.basis = np.array(self._basis0, copy=True)
        if self._activ0 is None:
            self.activ = self.init_activ()
        else:
            self.activ = np.array(self._activ0, copy=True)

        if self.basis.shape[:2] != (
            self.observations.shape[-1],
            self.observations.shape[1],
        ):
            raise ValueError(
                "basis0 shape mismatch: expected (n_src, n_freq, n_basis) for mixture_tf."
            )
        if self.activ.shape[:2] != (
            self.observations.shape[-1],
            self.observations.shape[0],
        ):
            raise ValueError(
                "activ0 shape mismatch: expected (n_src, n_frame, n_basis) for mixture_tf."
            )

        self.source_model = self.init_source_model()
        self.covariance = None
        self.loss = self.calc_loss()

    def init_basis(self):
        """Initialize basis matrix."""
        _, n_freq, n_src = self.observations.shape
        return np.ones((n_src, n_freq, self.n_basis))

    def init_activ(self):
        """Initialize activation matrix."""
        n_frame, _, n_src = self.observations.shape
        return self.rng.uniform(
            low=0.1,
            high=1.0,
            size=(n_src, n_frame, self.n_basis),
        )

    def init_source_model(self):
        """Initialize source variance model ``R`` with shape ``(T, F, N)``."""
        return self._compose_source_model()

    def _compose_source_model(self) -> np.ndarray:
        """Compose source variance model from basis/activation factors.

        Returns
        -------
        np.ndarray
            Source variance model ``R`` with shape ``(n_frame, n_freq, n_src)``.
        """
        return np.einsum(
            "sfk,stk->tfs",
            self.basis,
            self.activ,
            optimize=True,
        )

    def calc_source_model(self, B, A, y_power):
        """
        Calculate source model.
        By overriding this method, various source models (e.g., Student t, ILRMA-T, generalized Kullback---Leibler divergence, or IDLMA) can be applied.

        Parameters
        ----------
        B : ndarray of shape (n_freq, n_basis)
            Basis matrix
        A : ndarray of shape (n_frame, n_basis)
            Activation matrix
        y_power : ndarray of shape (n_frame, n_freq)
            Power spectrograms of estimated source

        Returns
        -------
        tuple[np.ndarray, np.ndarray]
            Updated ``(basis, activation)`` tuple where:
            - basis has shape ``(n_freq, n_basis)``
            - activation has shape ``(n_frame, n_basis)``
        """
        updated = self.source_strategy.update(
            SourceModelRequest(
                basis=B,
                activ=A,
                y_power=y_power,
            )
        )
        if updated.basis is None or updated.activ is None:
            raise ValueError("source strategy did not return basis/activ.")
        return updated.basis, updated.activ

    def calc_loss(self, axis=None):
        """
        Calculate loss function value of ILRMA.

        Parameters
        ----------
        axis : int or None, default=None

        Raises
        ------
        ValueError:
            If `cost` is infinite or not a number.
        """
        # (n_frame, n_freq, n_src)
        y_power = np.square(np.abs(self.estimated))

        # basis: (n_src, n_freq, n_basis)
        # activ: (n_src, n_frame, n_basis)

        src_var = self._compose_source_model()

        # (n_freq,)
        target_loss = -2 * np.linalg.slogdet(self.demix_filter)[1]

        # (n_frame, n_freq)
        demix_loss = np.sum(y_power / src_var + np.log(src_var), axis=2)

        cost = np.sum(demix_loss + target_loss[None, :], axis=axis)
        if np.isinf(cost).any() or np.isnan(cost).any():
            raise ValueError("Cost cannot be calculated.")
        else:
            return cost

    @property
    def n_sources(self) -> int:
        """Return number of separated sources."""
        return int(self.observations.shape[-1])

    def get_estimate(self) -> np.ndarray:
        """Return current TF-domain estimate."""
        return self.estimated

n_sources property

n_sources

Return number of separated sources.

__call__

__call__(*args, **kwargs)

Alias for :meth:forward to provide a torch-like call style.

Source code in src/oobss/separators/core/base.py
31
32
33
def __call__(self, *args: Any, **kwargs: Any) -> SeparationOutput:
    """Alias for :meth:`forward` to provide a torch-like call style."""
    return self.forward(*args, **kwargs)

__init__

__init__(observations, *, n_basis=10, basis0=None, activ0=None, random_state=None, rng=None, spatial=None, source=None, covariance=None, reconstruction_strategy=None)

Initialize parameters in ILRMA.

Source code in src/oobss/separators/ilrma.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def __init__(
    self,
    observations,
    *,
    n_basis: int = 10,
    basis0=None,
    activ0=None,
    random_state: int | None = None,
    rng: np.random.Generator | None = None,
    spatial: SpatialUpdateStrategy | None = None,
    source: SourceModelStrategy | None = None,
    covariance: CovarianceUpdateStrategy | None = None,
    reconstruction_strategy: ReconstructionStrategy | None = None,
):
    """Initialize parameters in ILRMA."""
    # Setup
    self.observations = observations
    self.n_basis = int(n_basis)
    self.rng = np.random.default_rng(random_state) if rng is None else rng
    self._basis0 = None if basis0 is None else np.array(basis0, copy=True)
    self._activ0 = None if activ0 is None else np.array(activ0, copy=True)
    self.spatial_strategy = spatial if spatial is not None else IP1SpatialStrategy()
    self.source_strategy = (
        source if source is not None else ILRMANMFSourceStrategy(eps=float(eps))
    )
    self.covariance_strategy = (
        covariance if covariance is not None else BatchCovarianceStrategy()
    )
    self.reconstruction_strategy = (
        reconstruction_strategy
        if reconstruction_strategy is not None
        else DemixReconstructionStrategy()
    )

    # Results
    self.bind_mixture_tf(np.asarray(observations))

bind_mixture_tf

bind_mixture_tf(mixture_tf)

Bind a TF-domain mixture and reset internal iterative state.

Source code in src/oobss/separators/ilrma.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def bind_mixture_tf(self, mixture_tf: np.ndarray) -> None:
    """Bind a TF-domain mixture and reset internal iterative state."""
    observations = np.asarray(mixture_tf)
    if observations.ndim != 3:
        raise ValueError(
            "ILRMA expects mixture_tf with shape (n_frame, n_freq, n_mic)."
        )

    self.observations = observations
    self.demix_filter = self.init_demix()
    recon = self.reconstruction_strategy.reconstruct(
        ReconstructionRequest(
            mixture=self.observations,
            demix_filter=self.demix_filter,
        )
    )
    self.estimated = recon.estimate

    if self._basis0 is None:
        self.basis = self.init_basis()
    else:
        self.basis = np.array(self._basis0, copy=True)
    if self._activ0 is None:
        self.activ = self.init_activ()
    else:
        self.activ = np.array(self._activ0, copy=True)

    if self.basis.shape[:2] != (
        self.observations.shape[-1],
        self.observations.shape[1],
    ):
        raise ValueError(
            "basis0 shape mismatch: expected (n_src, n_freq, n_basis) for mixture_tf."
        )
    if self.activ.shape[:2] != (
        self.observations.shape[-1],
        self.observations.shape[0],
    ):
        raise ValueError(
            "activ0 shape mismatch: expected (n_src, n_frame, n_basis) for mixture_tf."
        )

    self.source_model = self.init_source_model()
    self.covariance = None
    self.loss = self.calc_loss()

bind_mixture_time

bind_mixture_time(mixture_time, sample_rate)

Bind time-domain input before iterative updates.

Source code in src/oobss/separators/core/base.py
89
90
91
92
93
94
95
96
def bind_mixture_time(
    self, mixture_time: np.ndarray, sample_rate: int | None
) -> None:
    """Bind time-domain input before iterative updates."""
    del sample_rate
    raise ValueError(
        f"{self.__class__.__name__} does not support time-domain input."
    )

calc_loss

calc_loss(axis=None)

Calculate loss function value of ILRMA.

Parameters:

Name Type Description Default
axis int or None
None

Raises:

Type Description
ValueError:

If cost is infinite or not a number.

Source code in src/oobss/separators/ilrma.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
def calc_loss(self, axis=None):
    """
    Calculate loss function value of ILRMA.

    Parameters
    ----------
    axis : int or None, default=None

    Raises
    ------
    ValueError:
        If `cost` is infinite or not a number.
    """
    # (n_frame, n_freq, n_src)
    y_power = np.square(np.abs(self.estimated))

    # basis: (n_src, n_freq, n_basis)
    # activ: (n_src, n_frame, n_basis)

    src_var = self._compose_source_model()

    # (n_freq,)
    target_loss = -2 * np.linalg.slogdet(self.demix_filter)[1]

    # (n_frame, n_freq)
    demix_loss = np.sum(y_power / src_var + np.log(src_var), axis=2)

    cost = np.sum(demix_loss + target_loss[None, :], axis=axis)
    if np.isinf(cost).any() or np.isnan(cost).any():
        raise ValueError("Cost cannot be calculated.")
    else:
        return cost

calc_source_model

calc_source_model(B, A, y_power)

Calculate source model. By overriding this method, various source models (e.g., Student t, ILRMA-T, generalized Kullback---Leibler divergence, or IDLMA) can be applied.

Parameters:

Name Type Description Default
B ndarray of shape (n_freq, n_basis)

Basis matrix

required
A ndarray of shape (n_frame, n_basis)

Activation matrix

required
y_power ndarray of shape (n_frame, n_freq)

Power spectrograms of estimated source

required

Returns:

Type Description
tuple[ndarray, ndarray]

Updated (basis, activation) tuple where: - basis has shape (n_freq, n_basis) - activation has shape (n_frame, n_basis)

Source code in src/oobss/separators/ilrma.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def calc_source_model(self, B, A, y_power):
    """
    Calculate source model.
    By overriding this method, various source models (e.g., Student t, ILRMA-T, generalized Kullback---Leibler divergence, or IDLMA) can be applied.

    Parameters
    ----------
    B : ndarray of shape (n_freq, n_basis)
        Basis matrix
    A : ndarray of shape (n_frame, n_basis)
        Activation matrix
    y_power : ndarray of shape (n_frame, n_freq)
        Power spectrograms of estimated source

    Returns
    -------
    tuple[np.ndarray, np.ndarray]
        Updated ``(basis, activation)`` tuple where:
        - basis has shape ``(n_freq, n_basis)``
        - activation has shape ``(n_frame, n_basis)``
    """
    updated = self.source_strategy.update(
        SourceModelRequest(
            basis=B,
            activ=A,
            y_power=y_power,
        )
    )
    if updated.basis is None or updated.activ is None:
        raise ValueError("source strategy did not return basis/activ.")
    return updated.basis, updated.activ

fit_transform_tf

fit_transform_tf(mixture_tf, *, n_iter=0, request=None)

Bind TF input, run iterations, and return TF-domain estimate.

Source code in src/oobss/separators/core/base.py
53
54
55
56
57
58
59
60
61
62
63
64
65
def fit_transform_tf(
    self,
    mixture_tf: np.ndarray,
    *,
    n_iter: int = 0,
    request: BatchRequest | None = None,
) -> SeparationOutput:
    """Bind TF input, run iterations, and return TF-domain estimate."""
    del request
    self.bind_mixture_tf(np.asarray(mixture_tf))
    if n_iter > 0:
        self.run(int(n_iter))
    return SeparationOutput(estimate_tf=self.get_estimate())

fit_transform_time

fit_transform_time(mixture_time, *, n_iter=0, request=None)

Bind time-domain input if supported, then run iterations.

Source code in src/oobss/separators/core/base.py
67
68
69
70
71
72
73
74
75
76
77
78
79
def fit_transform_time(
    self,
    mixture_time: np.ndarray,
    *,
    n_iter: int = 0,
    request: BatchRequest | None = None,
) -> SeparationOutput:
    """Bind time-domain input if supported, then run iterations."""
    sample_rate = None if request is None else request.sample_rate
    self.bind_mixture_time(np.asarray(mixture_time), sample_rate)
    if n_iter > 0:
        self.run(int(n_iter))
    return SeparationOutput(estimate_tf=self.get_estimate())

forward

forward(mixture, *, n_iter=0, request=None, is_time_input=None)

Run batch separation from TF-domain or time-domain input.

Source code in src/oobss/separators/core/base.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def forward(
    self,
    mixture: np.ndarray,
    *,
    n_iter: int = 0,
    request: BatchRequest | None = None,
    is_time_input: bool | None = None,
) -> SeparationOutput:
    """Run batch separation from TF-domain or time-domain input."""
    mixture_arr = np.asarray(mixture)
    if is_time_input is None:
        is_time = not (np.iscomplexobj(mixture_arr) or mixture_arr.ndim >= 3)
    else:
        is_time = bool(is_time_input)
    if is_time:
        return self.fit_transform_time(
            mixture_arr,
            n_iter=int(n_iter),
            request=request,
        )
    return self.fit_transform_tf(
        mixture_arr,
        n_iter=int(n_iter),
        request=request,
    )

get_estimate

get_estimate()

Return current TF-domain estimate.

Source code in src/oobss/separators/ilrma.py
451
452
453
def get_estimate(self) -> np.ndarray:
    """Return current TF-domain estimate."""
    return self.estimated

init_activ

init_activ()

Initialize activation matrix.

Source code in src/oobss/separators/ilrma.py
353
354
355
356
357
358
359
360
def init_activ(self):
    """Initialize activation matrix."""
    n_frame, _, n_src = self.observations.shape
    return self.rng.uniform(
        low=0.1,
        high=1.0,
        size=(n_src, n_frame, self.n_basis),
    )

init_basis

init_basis()

Initialize basis matrix.

Source code in src/oobss/separators/ilrma.py
348
349
350
351
def init_basis(self):
    """Initialize basis matrix."""
    _, n_freq, n_src = self.observations.shape
    return np.ones((n_src, n_freq, self.n_basis))

init_demix

init_demix()

Initialize demixing matrix.

Source code in src/oobss/separators/ilrma.py
295
296
297
298
299
300
def init_demix(self):
    """Initialize demixing matrix."""
    _, n_freq, n_src = self.observations.shape
    W0 = np.zeros((n_freq, n_src, n_src), dtype=complex)
    W0[:, :, :n_src] = np.tile(np.eye(n_src, dtype=complex), (n_freq, 1, 1))
    return W0

init_source_model

init_source_model()

Initialize source variance model R with shape (T, F, N).

Source code in src/oobss/separators/ilrma.py
362
363
364
def init_source_model(self):
    """Initialize source variance model ``R`` with shape ``(T, F, N)``."""
    return self._compose_source_model()

reset

reset()

Reset internal state (override in subclasses when needed).

Source code in src/oobss/separators/core/base.py
28
29
def reset(self) -> None:
    """Reset internal state (override in subclasses when needed)."""

run

run(n_iter)

Execute n_iter update steps and return final estimate.

Source code in src/oobss/separators/core/base.py
 98
 99
100
101
102
103
104
def run(self, n_iter: int) -> np.ndarray:
    """Execute ``n_iter`` update steps and return final estimate."""
    if n_iter < 0:
        raise ValueError("n_iter must be non-negative")
    for _ in range(n_iter):
        self.step()
    return self.get_estimate()

step

step()

Update paramters one step.

Source code in src/oobss/separators/ilrma.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def step(self):
    """Update paramters one step."""
    n_src = self.observations.shape[-1]
    y_power = np.square(np.abs(self.estimated))

    row_groups = self.spatial_strategy.row_groups(n_src)
    pairwise_mode = any(np.ndim(row_idx) > 0 for row_idx in row_groups)

    if pairwise_mode:
        for row_idx in row_groups:
            for src_idx in np.atleast_1d(np.asarray(row_idx, dtype=np.int64)):
                b = self.basis[src_idx]
                a = self.activ[src_idx]
                self.basis[src_idx], self.activ[src_idx] = self.calc_source_model(
                    b, a, y_power[:, :, src_idx]
                )
                self.source_model[:, :, src_idx] = (
                    self.activ[src_idx] @ self.basis[src_idx].T
                )
            self.covariance = self.covariance_strategy.update(
                CovarianceRequest(
                    observed=self.observations,
                    source_model=self.source_model,
                )
            )
            self.demix_filter[:, :, :] = self.spatial_strategy.update(
                self.covariance,
                self.demix_filter,
                row_idx=row_idx,
            )
    else:
        for s in range(n_src):
            b = self.basis[s]
            a = self.activ[s]
            self.basis[s], self.activ[s] = self.calc_source_model(
                b, a, y_power[:, :, s]
            )
            self.source_model[:, :, s] = self.activ[s] @ self.basis[s].T
        self.covariance = self.covariance_strategy.update(
            CovarianceRequest(
                observed=self.observations,
                source_model=self.source_model,
            )
        )
        for row_idx in row_groups:
            self.demix_filter[:, :, :] = self.spatial_strategy.update(
                self.covariance,
                self.demix_filter,
                row_idx=row_idx,
            )

    # 3. Update estimated sources
    recon = self.reconstruction_strategy.reconstruct(
        ReconstructionRequest(
            mixture=self.observations,
            demix_filter=self.demix_filter,
        )
    )
    self.estimated = recon.estimate

    # 4. Update loss function value
    self.loss = self.calc_loss()

JsonlLogger

Append JSON-serializable records to a JSON Lines file.

Source code in src/oobss/logging_utils.py
10
11
12
13
14
15
16
17
18
19
class JsonlLogger:
    """Append JSON-serializable records to a JSON Lines file."""

    def __init__(self, path: str | Path) -> None:
        self.path = Path(path)
        self.path.parent.mkdir(parents=True, exist_ok=True)

    def write(self, record: Mapping[str, Any]) -> None:
        with self.path.open("a", encoding="utf-8") as handle:
            handle.write(json.dumps(dict(record), ensure_ascii=False) + "\n")

OnlineAuxIVA

Bases: BaseStreamingSeparator

Online auxiliary-function-based independent vector analysis.

Procedure
   input: frame sequence {x_{f,t}}_{f=1..F, t=1..T}
   initialize W_{f,0} = I_M and V_{k,f,0}
   for each frame t:
       repeat inner_iter times:
           y_{f,t} <- W_{f,t} x_{f,t}
           compute r_{k,t} and phi(r_{k,t}) (Gauss)
           V_{k,f,t} <- (1-alpha) * phi(r_{k,t}) x_{f,t} x_{f,t}^H
                      + alpha * V_{k,f,t-1}
           update w_{k,f,t} by IP1
       projection-back by reference microphone
       emit separated frame
Update Equations

Indices are \(k=1,\dots,K\) (source), \(m=1,\dots,M\) (channel), \(f=1,\dots,F\) (frequency), and \(t=1,\dots,T\) (time frame). In the determined case, \(K=M\).

Per-frame demixing:

\[ \bm{x}_{f,t} \in \mathbb{C}^{M}, \quad \bm{y}_{f,t} = W_{f,t}\bm{x}_{f,t}, \quad y_{k,f,t} = \bm{w}_{k,f,t}^{\mathsf{H}}\bm{x}_{f,t} \]

The default source model is AuxIVA Gauss:

\[ r_{k,t} = \left(\sum_{f=1}^{F}|y_{k,f,t}|^2\right)^{1/2}, \quad \varphi(r_{k,t}) = \frac{1}{\max(r_{k,t}^2/F,\varepsilon)} \]

Online covariance recursion:

\[ \widehat{V}_{k,f,t} = \varphi(r_{k,t}) \bm{x}_{f,t}\bm{x}_{f,t}^{\mathsf{H}}, \quad V_{k,f,t} \leftarrow (1-\alpha)\widehat{V}_{k,f,t} + \alpha V_{k,f,t-1} \]

Demixing row update (IP1):

\[ \tilde{\bm{w}}_{k,f,t} = \left(W_{f,t}V_{k,f,t}\right)^{-1}\bm{e}_k, \quad \bm{w}_{k,f,t} = \frac{\tilde{\bm{w}}_{k,f,t}} {\sqrt{ \tilde{\bm{w}}_{k,f,t}^{\mathsf{H}} V_{k,f,t} \tilde{\bm{w}}_{k,f,t} }} \]

Common projection-back post-processing:

\[ A_{f,t} = W_{f,t}^{-1}, \quad \hat{y}_{k,f,t} = a_{k,f,t}[m_{\mathrm{ref}}] y_{k,f,t} \]

This is shared with batch AuxIVA/ILRMA via :func:oobss.separators.utils.projection_back.

Parameters:

Name Type Description Default
n_mic int

Number of microphones / separated sources.

required
n_freq int

Number of frequency bins.

required
ref_mic int

Reference microphone for projection-back reconstruction.

0
forget float

Forgetting factor in covariance smoothing.

0.9
inner_iter int

Number of per-frame inner updates.

30
eps float

Numerical stability constant.

1e-12
cov_scale float

Initial diagonal covariance scale.

1e-06
spatial SpatialUpdateStrategy | None

Strategy used to update demixing filters.

None
source SourceModelStrategy | None

Strategy used to compute source model per frame.

None
covariance CovarianceUpdateStrategy | None

Strategy used to update weighted covariance matrices.

None
reconstruction_strategy ReconstructionStrategy | None

Strategy used to reconstruct output spectra.

None

Examples:

Process a full stream at once:

   import numpy as np
   from scipy.signal import ShortTimeFFT, get_window
   from oobss import OnlineAuxIVA, StreamRequest

   fs = 16000
   fft_size = 2048
   hop_size = 512
   win = get_window("hann", fft_size, fftbins=True)
   stft = ShortTimeFFT(win=win, hop=hop_size, fs=fs)

   mixture_time = np.random.randn(fs * 2, 2)  # (n_samples, n_mic)
   # channel-first STFT: (n_mic, n_freq, n_frame)
   X_cft = stft.stft(mixture_time.T)
   # online input: (n_freq, n_mic, n_frame)
   X_fmt = X_cft.transpose(1, 0, 2)

   model = OnlineAuxIVA(
       n_mic=2,
       n_freq=X_fmt.shape[0],
       ref_mic=0,
       forget=0.99,
       inner_iter=5,
   )
   out = model.process_stream_tf(
       X_fmt,
       request=StreamRequest(frame_axis=2, reference_mic=0),
   )
   Y_fmt = out.estimate_tf
   if Y_fmt is None:
       raise ValueError("OnlineAuxIVA did not return TF estimates.")

   # inverse STFT expects channel-first axes
   y_time = np.real(stft.istft(Y_fmt, f_axis=0, t_axis=2)).T

Frame-by-frame update with explicit state carry:

   from oobss import StreamingSeparatorState

   state: StreamingSeparatorState | None = None
   outputs = []
   for t in range(X_fmt.shape[2]):
       frame = X_fmt[:, :, t]  # (n_freq, n_mic)
       y_frame, state = model.forward_streaming(frame, state=state)
       outputs.append(y_frame)
   Y_fmt = np.stack(outputs, axis=2)
References

[1] T. Taniguchi, N. Ono, A. Kawamura, and S. Sagayama, "An auxiliary-function approach to online independent vector analysis for real-time blind source separation," in Proc. Joint Workshop on Hands-free Speech Communication and Microphone Arrays (HSCMA), pp. 107-111, May 2014, doi: 10.1109/HSCMA.2014.6843261.

Source code in src/oobss/separators/online_auxiva.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
class OnlineAuxIVA(BaseStreamingSeparator):
    """Online auxiliary-function-based independent vector analysis.

    Procedure
    ---------
    ```text

       input: frame sequence {x_{f,t}}_{f=1..F, t=1..T}
       initialize W_{f,0} = I_M and V_{k,f,0}
       for each frame t:
           repeat inner_iter times:
               y_{f,t} <- W_{f,t} x_{f,t}
               compute r_{k,t} and phi(r_{k,t}) (Gauss)
               V_{k,f,t} <- (1-alpha) * phi(r_{k,t}) x_{f,t} x_{f,t}^H
                          + alpha * V_{k,f,t-1}
               update w_{k,f,t} by IP1
           projection-back by reference microphone
           emit separated frame
    ```

    Update Equations
    ----------------
    Indices are $k=1,\\dots,K$ (source), $m=1,\\dots,M$
    (channel), $f=1,\\dots,F$ (frequency), and
    $t=1,\\dots,T$ (time frame). In the determined case, $K=M$.

    Per-frame demixing:

    $$
       \\bm{x}_{f,t} \\in \\mathbb{C}^{M}, \\quad
       \\bm{y}_{f,t} = W_{f,t}\\bm{x}_{f,t}, \\quad
       y_{k,f,t} = \\bm{w}_{k,f,t}^{\\mathsf{H}}\\bm{x}_{f,t}
    $$

    The default source model is AuxIVA Gauss:

    $$
       r_{k,t} = \\left(\\sum_{f=1}^{F}|y_{k,f,t}|^2\\right)^{1/2}, \\quad
       \\varphi(r_{k,t}) = \\frac{1}{\\max(r_{k,t}^2/F,\\varepsilon)}
    $$

    Online covariance recursion:

    $$
       \\widehat{V}_{k,f,t}
       = \\varphi(r_{k,t})
       \\bm{x}_{f,t}\\bm{x}_{f,t}^{\\mathsf{H}},
       \\quad
       V_{k,f,t}
       \\leftarrow
       (1-\\alpha)\\widehat{V}_{k,f,t} + \\alpha V_{k,f,t-1}
    $$

    Demixing row update (IP1):

    $$
       \\tilde{\\bm{w}}_{k,f,t}
       = \\left(W_{f,t}V_{k,f,t}\\right)^{-1}\\bm{e}_k, \\quad
       \\bm{w}_{k,f,t}
       = \\frac{\\tilde{\\bm{w}}_{k,f,t}}
       {\\sqrt{
       \\tilde{\\bm{w}}_{k,f,t}^{\\mathsf{H}}
       V_{k,f,t}
       \\tilde{\\bm{w}}_{k,f,t}
       }}
    $$

    Common projection-back post-processing:

    $$
       A_{f,t} = W_{f,t}^{-1}, \\quad
       \\hat{y}_{k,f,t} = a_{k,f,t}[m_{\\mathrm{ref}}] y_{k,f,t}
    $$

    This is shared with batch AuxIVA/ILRMA via
    :func:`oobss.separators.utils.projection_back`.

    Parameters
    ----------
    n_mic:
        Number of microphones / separated sources.
    n_freq:
        Number of frequency bins.
    ref_mic:
        Reference microphone for projection-back reconstruction.
    forget:
        Forgetting factor in covariance smoothing.
    inner_iter:
        Number of per-frame inner updates.
    eps:
        Numerical stability constant.
    cov_scale:
        Initial diagonal covariance scale.
    spatial:
        Strategy used to update demixing filters.
    source:
        Strategy used to compute source model per frame.
    covariance:
        Strategy used to update weighted covariance matrices.
    reconstruction_strategy:
        Strategy used to reconstruct output spectra.

    Examples
    --------
    Process a full stream at once:

    ```python

       import numpy as np
       from scipy.signal import ShortTimeFFT, get_window
       from oobss import OnlineAuxIVA, StreamRequest

       fs = 16000
       fft_size = 2048
       hop_size = 512
       win = get_window("hann", fft_size, fftbins=True)
       stft = ShortTimeFFT(win=win, hop=hop_size, fs=fs)

       mixture_time = np.random.randn(fs * 2, 2)  # (n_samples, n_mic)
       # channel-first STFT: (n_mic, n_freq, n_frame)
       X_cft = stft.stft(mixture_time.T)
       # online input: (n_freq, n_mic, n_frame)
       X_fmt = X_cft.transpose(1, 0, 2)

       model = OnlineAuxIVA(
           n_mic=2,
           n_freq=X_fmt.shape[0],
           ref_mic=0,
           forget=0.99,
           inner_iter=5,
       )
       out = model.process_stream_tf(
           X_fmt,
           request=StreamRequest(frame_axis=2, reference_mic=0),
       )
       Y_fmt = out.estimate_tf
       if Y_fmt is None:
           raise ValueError("OnlineAuxIVA did not return TF estimates.")

       # inverse STFT expects channel-first axes
       y_time = np.real(stft.istft(Y_fmt, f_axis=0, t_axis=2)).T
    ```

    Frame-by-frame update with explicit state carry:

    ```python

       from oobss import StreamingSeparatorState

       state: StreamingSeparatorState | None = None
       outputs = []
       for t in range(X_fmt.shape[2]):
           frame = X_fmt[:, :, t]  # (n_freq, n_mic)
           y_frame, state = model.forward_streaming(frame, state=state)
           outputs.append(y_frame)
       Y_fmt = np.stack(outputs, axis=2)
    ```

    References
    ----------
    [1] T. Taniguchi, N. Ono, A. Kawamura, and S. Sagayama, "An
    auxiliary-function approach to online independent vector analysis for
    real-time blind source separation," in *Proc. Joint Workshop on Hands-free
    Speech Communication and Microphone Arrays (HSCMA)*, pp. 107-111, May 2014,
    doi: 10.1109/HSCMA.2014.6843261.
    """

    def __init__(
        self,
        n_mic: int,
        n_freq: int,
        *,
        ref_mic: int = 0,
        forget: float = 0.9,
        inner_iter: int = 30,
        eps: float = 1.0e-12,
        cov_scale: float = 1.0e-6,
        spatial: SpatialUpdateStrategy | None = None,
        source: SourceModelStrategy | None = None,
        covariance: CovarianceUpdateStrategy | None = None,
        reconstruction_strategy: ReconstructionStrategy | None = None,
    ) -> None:
        self.n_mic = int(n_mic)
        self.n_freq = int(n_freq)
        self.ref_mic = int(ref_mic)
        self.alpha = float(forget)
        self.inner_iter = int(inner_iter)
        self.eps = float(eps)
        self.cov_scale = float(cov_scale)

        self.spatial_strategy = spatial if spatial is not None else IP1SpatialStrategy()
        self.source_strategy = (
            source if source is not None else GaussSourceStrategy(eps=self.eps)
        )
        self.covariance_strategy = (
            covariance
            if covariance is not None
            else EMACovarianceStrategy(alpha=self.alpha)
        )
        self.reconstruction_strategy = (
            reconstruction_strategy
            if reconstruction_strategy is not None
            else ProjectionBackDemixReconstructionStrategy(ref_mic=self.ref_mic)
        )

        self.reset()

    def reset(self) -> None:
        """Reset online state to its initial values."""
        self.demix = np.tile(np.eye(self.n_mic, dtype=complex), (self.n_freq, 1, 1))
        self.cov = (
            np.tile(np.eye(self.n_mic, dtype=complex), (self.n_mic, self.n_freq, 1, 1))
            * self.cov_scale
        )
        self.source_model = np.ones((self.n_freq, self.n_mic), dtype=np.float64)
        self._t = 0

    def partial_fit(
        self,
        x: np.ndarray,
        *,
        reference_mic: int | None = None,
    ) -> np.ndarray:
        """Update model with one frame and return separated spectra."""
        if x.shape != (self.n_freq, self.n_mic):
            raise ValueError(
                "x must have shape (n_freq, n_mic) "
                f"= ({self.n_freq}, {self.n_mic}), got {x.shape}"
            )

        source_model = np.ones((self.n_freq, self.n_mic), dtype=np.float64)
        for _ in range(self.inner_iter):
            demixed = (self.demix @ x[:, :, None])[:, :, 0]
            source_model_result = self.source_strategy.update(
                SourceModelRequest(
                    estimated=demixed[None, :, :],
                    n_freq=self.n_freq,
                )
            )
            if source_model_result.source_model is None:
                raise ValueError("source strategy must return source_model")
            source_model = np.maximum(source_model_result.source_model[0], self.eps)

            self.cov = self.covariance_strategy.update(
                CovarianceRequest(
                    observed=x[None, :, :],
                    source_model=source_model[None, :, :],
                    prev_cov=self.cov,
                    alpha=self.alpha,
                )
            )

            for row_idx in self.spatial_strategy.row_groups(self.n_mic):
                self.demix = self.spatial_strategy.update(
                    self.cov,
                    self.demix,
                    row_idx=row_idx,
                )

        self.source_model = source_model
        ref = self.ref_mic if reference_mic is None else int(reference_mic)
        recon_strategy = self.reconstruction_strategy
        if isinstance(recon_strategy, ProjectionBackDemixReconstructionStrategy):
            recon_strategy = ProjectionBackDemixReconstructionStrategy(ref_mic=ref)

        output = recon_strategy.reconstruct(
            ReconstructionRequest(
                mixture=x,
                demix_filter=self.demix,
            )
        )
        self._t += 1
        return output.estimate

    def get_state(self) -> StreamingSeparatorState:
        """Return a typed snapshot of the current online state."""
        mix_filter: np.ndarray | None
        try:
            mix_filter = np.linalg.inv(self.demix)
        except np.linalg.LinAlgError:
            mix_filter = None
        return StreamingSeparatorState(
            source_model=np.array(self.source_model, copy=True),
            demix_filter=np.array(self.demix, copy=True),
            mix_filter=None if mix_filter is None else np.array(mix_filter, copy=True),
            frame_index=int(self._t),
            metadata={"covariance": np.array(self.cov, copy=True)},
        )

    def set_state(
        self,
        state: SeparatorState | StreamingSeparatorState,
    ) -> None:
        """Restore online state from :class:`StreamingSeparatorState`."""
        if not isinstance(state, StreamingSeparatorState):
            raise TypeError(
                "OnlineAuxIVA.set_state expects StreamingSeparatorState, "
                f"got {type(state).__name__}"
            )
        demix = state.demix_filter
        if demix is None:
            raise ValueError("state.demix_filter must be provided")
        if demix.shape != (self.n_freq, self.n_mic, self.n_mic):
            raise ValueError(
                "state.demix_filter must have shape "
                f"({self.n_freq}, {self.n_mic}, {self.n_mic}), got {demix.shape}"
            )
        self.demix = np.array(demix, copy=True)

        cov = state.metadata.get("covariance")
        if cov is None:
            cov = (
                np.tile(
                    np.eye(self.n_mic, dtype=complex), (self.n_mic, self.n_freq, 1, 1)
                )
                * self.cov_scale
            )
        cov_arr = np.asarray(cov)
        if cov_arr.shape != (self.n_mic, self.n_freq, self.n_mic, self.n_mic):
            raise ValueError(
                "state.metadata['covariance'] must have shape "
                f"({self.n_mic}, {self.n_freq}, {self.n_mic}, {self.n_mic}), "
                f"got {cov_arr.shape}"
            )
        self.cov = np.array(cov_arr, copy=True)

        if state.source_model is None:
            self.source_model = np.ones((self.n_freq, self.n_mic), dtype=np.float64)
        else:
            if state.source_model.shape != (self.n_freq, self.n_mic):
                raise ValueError(
                    "state.source_model must have shape "
                    f"({self.n_freq}, {self.n_mic}), got {state.source_model.shape}"
                )
            self.source_model = np.array(state.source_model, copy=True)
        self._t = int(state.frame_index)

    def fit(self, spectrogram: np.ndarray) -> None:
        """Process a full spectrogram with shape ``(n_freq, n_frames, n_mic)``."""
        for idx in range(spectrogram.shape[1]):
            self.partial_fit(spectrogram[:, idx])

    @property
    def n_sources(self) -> int:
        """Return number of separated sources."""
        return int(self.n_mic)

    def process_frame(
        self,
        frame: np.ndarray,
        request: OnlineFrameRequest | None = None,
    ) -> np.ndarray:
        """Process one TF frame and return separated frame."""
        ref_mic = None if request is None else request.reference_mic
        return self.partial_fit(frame, reference_mic=ref_mic)

n_sources property

n_sources

Return number of separated sources.

__call__

__call__(*args, **kwargs)

Alias for :meth:forward to provide a torch-like call style.

Source code in src/oobss/separators/core/base.py
31
32
33
def __call__(self, *args: Any, **kwargs: Any) -> SeparationOutput:
    """Alias for :meth:`forward` to provide a torch-like call style."""
    return self.forward(*args, **kwargs)

fit

fit(spectrogram)

Process a full spectrogram with shape (n_freq, n_frames, n_mic).

Source code in src/oobss/separators/online_auxiva.py
365
366
367
368
def fit(self, spectrogram: np.ndarray) -> None:
    """Process a full spectrogram with shape ``(n_freq, n_frames, n_mic)``."""
    for idx in range(spectrogram.shape[1]):
        self.partial_fit(spectrogram[:, idx])

forward

forward(stream_tf, *, request=None)

Torch-like forward alias for full streaming input.

Source code in src/oobss/separators/core/base.py
204
205
206
207
208
209
210
211
def forward(
    self,
    stream_tf: np.ndarray,
    *,
    request: StreamRequest | None = None,
) -> SeparationOutput:
    """Torch-like forward alias for full streaming input."""
    return self.process_stream_tf(stream_tf, request=request)

forward_streaming

forward_streaming(frame, *, state=None, request=None)

Process one frame and return (separated_frame, updated_state).

Source code in src/oobss/separators/core/base.py
152
153
154
155
156
157
158
159
160
161
162
163
def forward_streaming(
    self,
    frame: np.ndarray,
    *,
    state: SeparatorState | StreamingSeparatorState | None = None,
    request: OnlineFrameRequest | None = None,
) -> tuple[np.ndarray, SeparatorState | StreamingSeparatorState]:
    """Process one frame and return ``(separated_frame, updated_state)``."""
    if state is not None:
        self.set_state(state)
    output = self.process_frame(frame, request=request)
    return output, self.get_state()

get_state

get_state()

Return a typed snapshot of the current online state.

Source code in src/oobss/separators/online_auxiva.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def get_state(self) -> StreamingSeparatorState:
    """Return a typed snapshot of the current online state."""
    mix_filter: np.ndarray | None
    try:
        mix_filter = np.linalg.inv(self.demix)
    except np.linalg.LinAlgError:
        mix_filter = None
    return StreamingSeparatorState(
        source_model=np.array(self.source_model, copy=True),
        demix_filter=np.array(self.demix, copy=True),
        mix_filter=None if mix_filter is None else np.array(mix_filter, copy=True),
        frame_index=int(self._t),
        metadata={"covariance": np.array(self.cov, copy=True)},
    )

partial_fit

partial_fit(x, *, reference_mic=None)

Update model with one frame and return separated spectra.

Source code in src/oobss/separators/online_auxiva.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def partial_fit(
    self,
    x: np.ndarray,
    *,
    reference_mic: int | None = None,
) -> np.ndarray:
    """Update model with one frame and return separated spectra."""
    if x.shape != (self.n_freq, self.n_mic):
        raise ValueError(
            "x must have shape (n_freq, n_mic) "
            f"= ({self.n_freq}, {self.n_mic}), got {x.shape}"
        )

    source_model = np.ones((self.n_freq, self.n_mic), dtype=np.float64)
    for _ in range(self.inner_iter):
        demixed = (self.demix @ x[:, :, None])[:, :, 0]
        source_model_result = self.source_strategy.update(
            SourceModelRequest(
                estimated=demixed[None, :, :],
                n_freq=self.n_freq,
            )
        )
        if source_model_result.source_model is None:
            raise ValueError("source strategy must return source_model")
        source_model = np.maximum(source_model_result.source_model[0], self.eps)

        self.cov = self.covariance_strategy.update(
            CovarianceRequest(
                observed=x[None, :, :],
                source_model=source_model[None, :, :],
                prev_cov=self.cov,
                alpha=self.alpha,
            )
        )

        for row_idx in self.spatial_strategy.row_groups(self.n_mic):
            self.demix = self.spatial_strategy.update(
                self.cov,
                self.demix,
                row_idx=row_idx,
            )

    self.source_model = source_model
    ref = self.ref_mic if reference_mic is None else int(reference_mic)
    recon_strategy = self.reconstruction_strategy
    if isinstance(recon_strategy, ProjectionBackDemixReconstructionStrategy):
        recon_strategy = ProjectionBackDemixReconstructionStrategy(ref_mic=ref)

    output = recon_strategy.reconstruct(
        ReconstructionRequest(
            mixture=x,
            demix_filter=self.demix,
        )
    )
    self._t += 1
    return output.estimate

process_frame

process_frame(frame, request=None)

Process one TF frame and return separated frame.

Source code in src/oobss/separators/online_auxiva.py
375
376
377
378
379
380
381
382
def process_frame(
    self,
    frame: np.ndarray,
    request: OnlineFrameRequest | None = None,
) -> np.ndarray:
    """Process one TF frame and return separated frame."""
    ref_mic = None if request is None else request.reference_mic
    return self.partial_fit(frame, reference_mic=ref_mic)

process_stream

process_stream(stream, *, frame_axis=-1, request=None)

Process all frames in stream and stack outputs on the last axis.

Source code in src/oobss/separators/core/base.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def process_stream(
    self,
    stream: np.ndarray,
    *,
    frame_axis: int = -1,
    request: OnlineFrameRequest | None = None,
) -> np.ndarray:
    """Process all frames in ``stream`` and stack outputs on the last axis."""
    frames = np.moveaxis(stream, frame_axis, 0)
    if frames.shape[0] == 0:
        raise ValueError("stream must contain at least one frame")

    outputs = [self.process_frame(frames[0], request=request)]
    for idx in range(1, frames.shape[0]):
        outputs.append(self.process_frame(frames[idx], request=request))
    return np.stack(outputs, axis=-1)

process_stream_tf

process_stream_tf(stream_tf, *, request=None)

Process all frames in stream_tf using a typed stream request.

Source code in src/oobss/separators/core/base.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def process_stream_tf(
    self,
    stream_tf: np.ndarray,
    *,
    request: StreamRequest | None = None,
) -> SeparationOutput:
    """Process all frames in ``stream_tf`` using a typed stream request."""
    request_obj = StreamRequest() if request is None else request
    frame_request = OnlineFrameRequest(
        n_sources=request_obj.n_sources,
        component_to_source=request_obj.component_to_source,
        return_mask=request_obj.return_mask,
        reference_mic=request_obj.reference_mic,
        metadata=dict(request_obj.metadata),
    )
    output = self.process_stream(
        stream_tf,
        frame_axis=int(request_obj.frame_axis),
        request=frame_request,
    )
    return SeparationOutput(estimate_tf=output, state=self.get_state())

reset

reset()

Reset online state to its initial values.

Source code in src/oobss/separators/online_auxiva.py
235
236
237
238
239
240
241
242
243
def reset(self) -> None:
    """Reset online state to its initial values."""
    self.demix = np.tile(np.eye(self.n_mic, dtype=complex), (self.n_freq, 1, 1))
    self.cov = (
        np.tile(np.eye(self.n_mic, dtype=complex), (self.n_mic, self.n_freq, 1, 1))
        * self.cov_scale
    )
    self.source_model = np.ones((self.n_freq, self.n_mic), dtype=np.float64)
    self._t = 0

set_state

set_state(state)

Restore online state from :class:StreamingSeparatorState.

Source code in src/oobss/separators/online_auxiva.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def set_state(
    self,
    state: SeparatorState | StreamingSeparatorState,
) -> None:
    """Restore online state from :class:`StreamingSeparatorState`."""
    if not isinstance(state, StreamingSeparatorState):
        raise TypeError(
            "OnlineAuxIVA.set_state expects StreamingSeparatorState, "
            f"got {type(state).__name__}"
        )
    demix = state.demix_filter
    if demix is None:
        raise ValueError("state.demix_filter must be provided")
    if demix.shape != (self.n_freq, self.n_mic, self.n_mic):
        raise ValueError(
            "state.demix_filter must have shape "
            f"({self.n_freq}, {self.n_mic}, {self.n_mic}), got {demix.shape}"
        )
    self.demix = np.array(demix, copy=True)

    cov = state.metadata.get("covariance")
    if cov is None:
        cov = (
            np.tile(
                np.eye(self.n_mic, dtype=complex), (self.n_mic, self.n_freq, 1, 1)
            )
            * self.cov_scale
        )
    cov_arr = np.asarray(cov)
    if cov_arr.shape != (self.n_mic, self.n_freq, self.n_mic, self.n_mic):
        raise ValueError(
            "state.metadata['covariance'] must have shape "
            f"({self.n_mic}, {self.n_freq}, {self.n_mic}, {self.n_mic}), "
            f"got {cov_arr.shape}"
        )
    self.cov = np.array(cov_arr, copy=True)

    if state.source_model is None:
        self.source_model = np.ones((self.n_freq, self.n_mic), dtype=np.float64)
    else:
        if state.source_model.shape != (self.n_freq, self.n_mic):
            raise ValueError(
                "state.source_model must have shape "
                f"({self.n_freq}, {self.n_mic}), got {state.source_model.shape}"
            )
        self.source_model = np.array(state.source_model, copy=True)
    self._t = int(state.frame_index)

OnlineFrameRequest dataclass

Per-frame runtime options for streaming separators.

Source code in src/oobss/separators/core/strategy_models.py
 99
100
101
102
103
104
105
106
107
@dataclass(slots=True)
class OnlineFrameRequest:
    """Per-frame runtime options for streaming separators."""

    n_sources: int | None = None
    component_to_source: np.ndarray | None = None
    return_mask: bool = False
    reference_mic: int | None = None
    metadata: dict[str, Any] = field(default_factory=dict)

OnlineILRMA

Bases: BaseStreamingSeparator

Online independent low-rank matrix analysis.

Procedure
   input: frame sequence {x_{f,t}}_{f=1..F, t=1..T}
   initialize W_{f,0} = I_M, V_{k,f,0}
   initialize source-wise NMF parameters b_{k,f,\ell}, c_{k,\ell,t}
   for each frame t:
       repeat inner_iter times:
           y_{f,t} <- W_{f,t} x_{f,t}
           for each source k:
               update c_{k,\ell,t}, sufficient stats, and b_{k,f,\ell}
               r_{k,f,t} <- sum_\ell b_{k,f,\ell} c_{k,\ell,t}
           V_{k,f,t} <- (1-alpha) * x_{f,t} x_{f,t}^H / r_{k,f,t}
                      + alpha * V_{k,f,t-1}
           update w_{k,f,t} by IP1
       projection-back by reference microphone
       emit separated frame
Update Equations

Indices are \(k=1,\dots,K\) (source), \(m=1,\dots,M\) (channel), \(f=1,\dots,F\) (frequency), \(t=1,\dots,T\) (time frame), and \(\ell=1,\dots,L\) (NMF basis index). In the determined case, \(K=M\).

Demixing:

\[ \bm{x}_{f,t} \in \mathbb{C}^{M}, \quad \bm{y}_{f,t} = W_{f,t}\bm{x}_{f,t}, \quad y_{k,f,t} = \bm{w}_{k,f,t}^{\mathsf{H}}\bm{x}_{f,t} \]

For each source \(k\), define:

\[ v_{k,f,t} = |y_{k,f,t}|, \quad m_{k,f,t} = \sum_{\ell=1}^{L}b_{k,f,\ell}c_{k,\ell,t} + \varepsilon \]

The online MU update for \(C\) is:

\[ c_{k,\ell,t} \leftarrow c_{k,\ell,t} \frac{ \sum_{f=1}^{F} b_{k,f,\ell}v_{k,f,t}m_{k,f,t}^{-2} }{ \max\left( \sum_{f=1}^{F}b_{k,f,\ell}m_{k,f,t}^{-1}, \varepsilon \right) } \]

Sufficient statistics and basis update:

\[ P_{k,f,\ell,t} \leftarrow P_{k,f,\ell,t-1} + v_{k,f,t}m_{k,f,t}^{-2}c_{k,\ell,t}b_{k,f,\ell}^2, \quad Q_{k,f,\ell,t} \leftarrow Q_{k,f,\ell,t-1} + m_{k,f,t}^{-1}c_{k,\ell,t} \]
\[ \rho_{k,t} = \alpha^{\beta/t_k}, \quad P_{k,f,\ell,t} \leftarrow \rho_{k,t} P_{k,f,\ell,t}, \; Q_{k,f,\ell,t} \leftarrow \rho_{k,t} Q_{k,f,\ell,t}, \quad b_{k,f,\ell} \leftarrow \sqrt{P_{k,f,\ell,t} \oslash Q_{k,f,\ell,t}} \]

Source variance and covariance update:

\[ r_{k,f,t} = \max\left( \sum_{\ell=1}^{L}b_{k,f,\ell}c_{k,\ell,t}, \varepsilon \right) \]
\[ V_{k,f,t} \leftarrow (1-\alpha) \frac{ \bm{x}_{f,t}\bm{x}_{f,t}^{\mathsf{H}} }{r_{k,f,t}} + \alpha V_{k,f,t-1} \]

Demixing and common projection back:

\[ \tilde{\bm{w}}_{k,f,t} = \left(W_{f,t}V_{k,f,t}\right)^{-1}\bm{e}_k, \quad \bm{w}_{k,f,t} = \frac{\tilde{\bm{w}}_{k,f,t}} {\sqrt{ \tilde{\bm{w}}_{k,f,t}^{\mathsf{H}} V_{k,f,t} \tilde{\bm{w}}_{k,f,t} }} \]
\[ A_{f,t} = W_{f,t}^{-1}, \quad \hat{y}_{k,f,t} = a_{k,f,t}[m_{\mathrm{ref}}] y_{k,f,t} \]

This is shared with batch AuxIVA/ILRMA via :func:oobss.separators.utils.projection_back.

Examples:

Process a stream with online ILRMA:

   import numpy as np
   from scipy.signal import ShortTimeFFT, get_window
   from oobss import OnlineILRMA, StreamRequest

   fs = 16000
   fft_size = 2048
   hop_size = 512
   win = get_window("hann", fft_size, fftbins=True)
   stft = ShortTimeFFT(win=win, hop=hop_size, fs=fs)

   mixture_time = np.random.randn(fs * 2, 2)  # (n_samples, n_mic)
   X_fmt = stft.stft(mixture_time.T).transpose(1, 0, 2)  # (F, M, T)

   model = OnlineILRMA(
       n_mic=2,
       n_freq=X_fmt.shape[0],
       n_bases=8,
       ref_mic=0,
       beta=1,
       forget=0.99,
       inner_iter=5,
       random_state=0,
   )
   out = model.process_stream_tf(
       X_fmt,
       request=StreamRequest(frame_axis=2, reference_mic=0),
   )
   Y_fmt = out.estimate_tf
   if Y_fmt is None:
       raise ValueError("OnlineILRMA did not return TF estimates.")

   y_time = np.real(stft.istft(Y_fmt, f_axis=0, t_axis=2)).T

Plug-and-play NMF updater while keeping spatial update fixed:

   from oobss.separators.strategies import MultiplicativeNMFStrategy

   model = OnlineILRMA(
       n_mic=2,
       n_freq=X_fmt.shape[0],
       n_bases=8,
       nmf=MultiplicativeNMFStrategy(),
   )
References

[1] T. Nakashima and N. Ono, "Online independent low-rank matrix analysis as a lightweight and trainable model for real-time multichannel music source separation," in Proc. AAAI 2026 Workshop on Audio-Centric AI: Towards Real-World Multimodal Reasoning and Application Use Cases (Audio-AAAI), Jan. 2026.

Source code in src/oobss/separators/online_ilrma.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
class OnlineILRMA(BaseStreamingSeparator):
    """Online independent low-rank matrix analysis.

    Procedure
    ---------
    ```text

       input: frame sequence {x_{f,t}}_{f=1..F, t=1..T}
       initialize W_{f,0} = I_M, V_{k,f,0}
       initialize source-wise NMF parameters b_{k,f,\\ell}, c_{k,\\ell,t}
       for each frame t:
           repeat inner_iter times:
               y_{f,t} <- W_{f,t} x_{f,t}
               for each source k:
                   update c_{k,\\ell,t}, sufficient stats, and b_{k,f,\\ell}
                   r_{k,f,t} <- sum_\\ell b_{k,f,\\ell} c_{k,\\ell,t}
               V_{k,f,t} <- (1-alpha) * x_{f,t} x_{f,t}^H / r_{k,f,t}
                          + alpha * V_{k,f,t-1}
               update w_{k,f,t} by IP1
           projection-back by reference microphone
           emit separated frame
    ```

    Update Equations
    ----------------
    Indices are $k=1,\\dots,K$ (source), $m=1,\\dots,M$
    (channel), $f=1,\\dots,F$ (frequency), $t=1,\\dots,T$
    (time frame), and $\\ell=1,\\dots,L$ (NMF basis index). In the
    determined case, $K=M$.

    Demixing:

    $$
       \\bm{x}_{f,t} \\in \\mathbb{C}^{M}, \\quad
       \\bm{y}_{f,t} = W_{f,t}\\bm{x}_{f,t}, \\quad
       y_{k,f,t} = \\bm{w}_{k,f,t}^{\\mathsf{H}}\\bm{x}_{f,t}
    $$

    For each source $k$, define:

    $$
       v_{k,f,t} = |y_{k,f,t}|, \\quad
       m_{k,f,t} =
       \\sum_{\\ell=1}^{L}b_{k,f,\\ell}c_{k,\\ell,t} + \\varepsilon
    $$

    The online MU update for $C$ is:

    $$
       c_{k,\\ell,t}
       \\leftarrow
       c_{k,\\ell,t}
       \\frac{
       \\sum_{f=1}^{F}
       b_{k,f,\\ell}v_{k,f,t}m_{k,f,t}^{-2}
       }{
       \\max\\left(
       \\sum_{f=1}^{F}b_{k,f,\\ell}m_{k,f,t}^{-1},
       \\varepsilon
       \\right)
       }
    $$

    Sufficient statistics and basis update:

    $$
       P_{k,f,\\ell,t}
       \\leftarrow
       P_{k,f,\\ell,t-1}
       +
       v_{k,f,t}m_{k,f,t}^{-2}c_{k,\\ell,t}b_{k,f,\\ell}^2,
       \\quad
       Q_{k,f,\\ell,t}
       \\leftarrow
       Q_{k,f,\\ell,t-1} + m_{k,f,t}^{-1}c_{k,\\ell,t}
    $$

    $$
       \\rho_{k,t} = \\alpha^{\\beta/t_k}, \\quad
       P_{k,f,\\ell,t} \\leftarrow \\rho_{k,t} P_{k,f,\\ell,t}, \\;
       Q_{k,f,\\ell,t} \\leftarrow \\rho_{k,t} Q_{k,f,\\ell,t}, \\quad
       b_{k,f,\\ell} \\leftarrow
       \\sqrt{P_{k,f,\\ell,t} \\oslash Q_{k,f,\\ell,t}}
    $$

    Source variance and covariance update:

    $$
       r_{k,f,t}
       = \\max\\left(
       \\sum_{\\ell=1}^{L}b_{k,f,\\ell}c_{k,\\ell,t},
       \\varepsilon
       \\right)
    $$

    $$
       V_{k,f,t}
       \\leftarrow
       (1-\\alpha)
       \\frac{
       \\bm{x}_{f,t}\\bm{x}_{f,t}^{\\mathsf{H}}
       }{r_{k,f,t}}
       + \\alpha V_{k,f,t-1}
    $$

    Demixing and common projection back:

    $$
       \\tilde{\\bm{w}}_{k,f,t}
       = \\left(W_{f,t}V_{k,f,t}\\right)^{-1}\\bm{e}_k, \\quad
       \\bm{w}_{k,f,t}
       = \\frac{\\tilde{\\bm{w}}_{k,f,t}}
       {\\sqrt{
       \\tilde{\\bm{w}}_{k,f,t}^{\\mathsf{H}}
       V_{k,f,t}
       \\tilde{\\bm{w}}_{k,f,t}
       }}
    $$

    $$
       A_{f,t} = W_{f,t}^{-1}, \\quad
       \\hat{y}_{k,f,t} = a_{k,f,t}[m_{\\mathrm{ref}}] y_{k,f,t}
    $$

    This is shared with batch AuxIVA/ILRMA via
    :func:`oobss.separators.utils.projection_back`.

    Examples
    --------
    Process a stream with online ILRMA:

    ```python

       import numpy as np
       from scipy.signal import ShortTimeFFT, get_window
       from oobss import OnlineILRMA, StreamRequest

       fs = 16000
       fft_size = 2048
       hop_size = 512
       win = get_window("hann", fft_size, fftbins=True)
       stft = ShortTimeFFT(win=win, hop=hop_size, fs=fs)

       mixture_time = np.random.randn(fs * 2, 2)  # (n_samples, n_mic)
       X_fmt = stft.stft(mixture_time.T).transpose(1, 0, 2)  # (F, M, T)

       model = OnlineILRMA(
           n_mic=2,
           n_freq=X_fmt.shape[0],
           n_bases=8,
           ref_mic=0,
           beta=1,
           forget=0.99,
           inner_iter=5,
           random_state=0,
       )
       out = model.process_stream_tf(
           X_fmt,
           request=StreamRequest(frame_axis=2, reference_mic=0),
       )
       Y_fmt = out.estimate_tf
       if Y_fmt is None:
           raise ValueError("OnlineILRMA did not return TF estimates.")

       y_time = np.real(stft.istft(Y_fmt, f_axis=0, t_axis=2)).T
    ```

    Plug-and-play NMF updater while keeping spatial update fixed:

    ```python

       from oobss.separators.strategies import MultiplicativeNMFStrategy

       model = OnlineILRMA(
           n_mic=2,
           n_freq=X_fmt.shape[0],
           n_bases=8,
           nmf=MultiplicativeNMFStrategy(),
       )
    ```

    References
    ----------
    [1] T. Nakashima and N. Ono, "Online independent low-rank matrix analysis
    as a lightweight and trainable model for real-time multichannel music
    source separation," in *Proc. AAAI 2026 Workshop on Audio-Centric AI:
    Towards Real-World Multimodal Reasoning and Application Use Cases
    (Audio-AAAI)*, Jan. 2026.
    """

    def __init__(
        self,
        n_mic: int,
        n_freq: int,
        n_bases: int,
        *,
        ref_mic: int = 0,
        beta: int = 1,
        forget: float = 0.9,
        inner_iter: int = 30,
        keep_h: bool | str = False,
        eps: float = 1.0e-12,
        cov_scale: float = 1.0e-6,
        random_state: Optional[int] = None,
        spatial: SpatialUpdateStrategy | None = None,
        covariance: CovarianceUpdateStrategy | None = None,
        nmf: OnlineNMFUpdateStrategy | None = None,
        reconstruction_strategy: ReconstructionStrategy | None = None,
    ) -> None:
        self.n_mic = int(n_mic)
        self.n_freq = int(n_freq)
        self.n_bases = int(n_bases)
        self._random_state = random_state

        self.ref_mic = int(ref_mic)
        self.beta = int(beta)
        self.alpha = float(forget)
        self.inner_iter = int(inner_iter)
        self.eps = float(eps)
        self._cov_scale = float(cov_scale)

        self.spatial_strategy = spatial if spatial is not None else IP1SpatialStrategy()
        self.covariance_strategy = (
            covariance
            if covariance is not None
            else EMACovarianceStrategy(alpha=self.alpha)
        )
        self.nmf_strategy = nmf if nmf is not None else MultiplicativeNMFStrategy()
        self.reconstruction_strategy = (
            reconstruction_strategy
            if reconstruction_strategy is not None
            else ProjectionBackDemixReconstructionStrategy(ref_mic=self.ref_mic)
        )

        if keep_h == "auto":
            self.keep_h = self.beta < 1000
        else:
            self.keep_h = bool(keep_h)
        self.reset()

    def reset(self) -> None:
        """Reset online state and NMF statistics to initial values."""
        self.rng = np.random.default_rng(self._random_state)
        self.demix = np.tile(np.eye(self.n_mic, dtype=complex), (self.n_freq, 1, 1))
        self.cov = (
            np.tile(np.eye(self.n_mic, dtype=complex), (self.n_mic, self.n_freq, 1, 1))
            * self._cov_scale
        )

        self.basis = self.rng.random((self.n_mic, self.n_freq, self.n_bases)) + self.eps
        self.A = np.zeros_like(self.basis)
        self.B = np.zeros_like(self.basis)
        self._l1_normalise_W()

        self._t = np.zeros((self.n_mic,), dtype=np.int64)
        self._batch_counter = np.zeros((self.n_mic,), dtype=np.int64)
        self._H_store: list[list[np.ndarray]] = [[] for _ in range(self.n_mic)]

    def partial_fit(
        self,
        x: np.ndarray,
        *,
        reference_mic: int | None = None,
    ) -> np.ndarray:
        """Update model with one frame and return separated spectra."""
        if x.shape != (self.n_freq, self.n_mic):
            raise ValueError(
                "x must have shape (n_freq, n_mic) "
                f"= ({self.n_freq}, {self.n_mic}), got {x.shape}"
            )

        source_model = np.ones((self.n_freq, self.n_mic), dtype=np.float64)
        for _ in range(self.inner_iter):
            demixed = (self.demix @ x[:, :, None])[:, :, 0]
            for k in range(self.n_mic):
                prev_h = (
                    self._H_store[k][-1].copy()
                    if self.keep_h and self._H_store[k]
                    else None
                )
                result = self.nmf_strategy.update(
                    NMFUpdateRequest(
                        v=np.abs(demixed[:, k]),
                        basis=self.basis[k],
                        stat_a=self.A[k],
                        stat_b=self.B[k],
                        inner_iter=1,
                        beta=self.beta,
                        alpha=self.alpha,
                        batch_counter=int(self._batch_counter[k]),
                        t=int(self._t[k]),
                        eps=self.eps,
                        h_prev=prev_h,
                    )
                )
                self.basis[k] = result.basis
                self.A[k] = result.stat_a
                self.B[k] = result.stat_b
                self._batch_counter[k] = result.batch_counter
                self._t[k] = result.t
                if self.keep_h:
                    self._H_store[k].append(result.h)
                source_model[:, k] = np.maximum(result.basis @ result.h, self.eps)

            self.cov = self.covariance_strategy.update(
                CovarianceRequest(
                    observed=x[None, :, :],
                    source_model=source_model[None, :, :],
                    prev_cov=self.cov,
                    alpha=self.alpha,
                )
            )
            for row_idx in self.spatial_strategy.row_groups(self.n_mic):
                self.demix = self.spatial_strategy.update(
                    self.cov,
                    self.demix,
                    row_idx=row_idx,
                )

        ref = self.ref_mic if reference_mic is None else int(reference_mic)
        recon_strategy = self.reconstruction_strategy
        if isinstance(recon_strategy, ProjectionBackDemixReconstructionStrategy):
            recon_strategy = ProjectionBackDemixReconstructionStrategy(ref_mic=ref)

        output = recon_strategy.reconstruct(
            ReconstructionRequest(
                mixture=x,
                demix_filter=self.demix,
            )
        )
        return output.estimate

    def fit(self, spectrogram: np.ndarray) -> None:
        """Process a full spectrogram sequentially."""
        for n in range(spectrogram.shape[1]):
            self.partial_fit(spectrogram[:, n])

    def _l1_normalise_W(self) -> None:
        colsum = np.maximum(self.basis.sum(axis=0, keepdims=True), self.eps)
        self.basis /= colsum
        self.A /= colsum
        self.B *= colsum

    @property
    def n_sources(self) -> int:
        """Return number of separated sources."""
        return int(self.n_mic)

    def process_frame(
        self,
        frame: np.ndarray,
        request: OnlineFrameRequest | None = None,
    ) -> np.ndarray:
        """Process one TF frame and return separated frame."""
        ref_mic = None if request is None else request.reference_mic
        return self.partial_fit(frame, reference_mic=ref_mic)

n_sources property

n_sources

Return number of separated sources.

__call__

__call__(*args, **kwargs)

Alias for :meth:forward to provide a torch-like call style.

Source code in src/oobss/separators/core/base.py
31
32
33
def __call__(self, *args: Any, **kwargs: Any) -> SeparationOutput:
    """Alias for :meth:`forward` to provide a torch-like call style."""
    return self.forward(*args, **kwargs)

fit

fit(spectrogram)

Process a full spectrogram sequentially.

Source code in src/oobss/separators/online_ilrma.py
360
361
362
363
def fit(self, spectrogram: np.ndarray) -> None:
    """Process a full spectrogram sequentially."""
    for n in range(spectrogram.shape[1]):
        self.partial_fit(spectrogram[:, n])

forward

forward(stream_tf, *, request=None)

Torch-like forward alias for full streaming input.

Source code in src/oobss/separators/core/base.py
204
205
206
207
208
209
210
211
def forward(
    self,
    stream_tf: np.ndarray,
    *,
    request: StreamRequest | None = None,
) -> SeparationOutput:
    """Torch-like forward alias for full streaming input."""
    return self.process_stream_tf(stream_tf, request=request)

forward_streaming

forward_streaming(frame, *, state=None, request=None)

Process one frame and return (separated_frame, updated_state).

Source code in src/oobss/separators/core/base.py
152
153
154
155
156
157
158
159
160
161
162
163
def forward_streaming(
    self,
    frame: np.ndarray,
    *,
    state: SeparatorState | StreamingSeparatorState | None = None,
    request: OnlineFrameRequest | None = None,
) -> tuple[np.ndarray, SeparatorState | StreamingSeparatorState]:
    """Process one frame and return ``(separated_frame, updated_state)``."""
    if state is not None:
        self.set_state(state)
    output = self.process_frame(frame, request=request)
    return output, self.get_state()

get_state

get_state()

Return current separator state snapshot.

Source code in src/oobss/separators/core/base.py
144
145
146
def get_state(self) -> SeparatorState | StreamingSeparatorState:
    """Return current separator state snapshot."""
    return SeparatorState()

partial_fit

partial_fit(x, *, reference_mic=None)

Update model with one frame and return separated spectra.

Source code in src/oobss/separators/online_ilrma.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
def partial_fit(
    self,
    x: np.ndarray,
    *,
    reference_mic: int | None = None,
) -> np.ndarray:
    """Update model with one frame and return separated spectra."""
    if x.shape != (self.n_freq, self.n_mic):
        raise ValueError(
            "x must have shape (n_freq, n_mic) "
            f"= ({self.n_freq}, {self.n_mic}), got {x.shape}"
        )

    source_model = np.ones((self.n_freq, self.n_mic), dtype=np.float64)
    for _ in range(self.inner_iter):
        demixed = (self.demix @ x[:, :, None])[:, :, 0]
        for k in range(self.n_mic):
            prev_h = (
                self._H_store[k][-1].copy()
                if self.keep_h and self._H_store[k]
                else None
            )
            result = self.nmf_strategy.update(
                NMFUpdateRequest(
                    v=np.abs(demixed[:, k]),
                    basis=self.basis[k],
                    stat_a=self.A[k],
                    stat_b=self.B[k],
                    inner_iter=1,
                    beta=self.beta,
                    alpha=self.alpha,
                    batch_counter=int(self._batch_counter[k]),
                    t=int(self._t[k]),
                    eps=self.eps,
                    h_prev=prev_h,
                )
            )
            self.basis[k] = result.basis
            self.A[k] = result.stat_a
            self.B[k] = result.stat_b
            self._batch_counter[k] = result.batch_counter
            self._t[k] = result.t
            if self.keep_h:
                self._H_store[k].append(result.h)
            source_model[:, k] = np.maximum(result.basis @ result.h, self.eps)

        self.cov = self.covariance_strategy.update(
            CovarianceRequest(
                observed=x[None, :, :],
                source_model=source_model[None, :, :],
                prev_cov=self.cov,
                alpha=self.alpha,
            )
        )
        for row_idx in self.spatial_strategy.row_groups(self.n_mic):
            self.demix = self.spatial_strategy.update(
                self.cov,
                self.demix,
                row_idx=row_idx,
            )

    ref = self.ref_mic if reference_mic is None else int(reference_mic)
    recon_strategy = self.reconstruction_strategy
    if isinstance(recon_strategy, ProjectionBackDemixReconstructionStrategy):
        recon_strategy = ProjectionBackDemixReconstructionStrategy(ref_mic=ref)

    output = recon_strategy.reconstruct(
        ReconstructionRequest(
            mixture=x,
            demix_filter=self.demix,
        )
    )
    return output.estimate

process_frame

process_frame(frame, request=None)

Process one TF frame and return separated frame.

Source code in src/oobss/separators/online_ilrma.py
376
377
378
379
380
381
382
383
def process_frame(
    self,
    frame: np.ndarray,
    request: OnlineFrameRequest | None = None,
) -> np.ndarray:
    """Process one TF frame and return separated frame."""
    ref_mic = None if request is None else request.reference_mic
    return self.partial_fit(frame, reference_mic=ref_mic)

process_stream

process_stream(stream, *, frame_axis=-1, request=None)

Process all frames in stream and stack outputs on the last axis.

Source code in src/oobss/separators/core/base.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def process_stream(
    self,
    stream: np.ndarray,
    *,
    frame_axis: int = -1,
    request: OnlineFrameRequest | None = None,
) -> np.ndarray:
    """Process all frames in ``stream`` and stack outputs on the last axis."""
    frames = np.moveaxis(stream, frame_axis, 0)
    if frames.shape[0] == 0:
        raise ValueError("stream must contain at least one frame")

    outputs = [self.process_frame(frames[0], request=request)]
    for idx in range(1, frames.shape[0]):
        outputs.append(self.process_frame(frames[idx], request=request))
    return np.stack(outputs, axis=-1)

process_stream_tf

process_stream_tf(stream_tf, *, request=None)

Process all frames in stream_tf using a typed stream request.

Source code in src/oobss/separators/core/base.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def process_stream_tf(
    self,
    stream_tf: np.ndarray,
    *,
    request: StreamRequest | None = None,
) -> SeparationOutput:
    """Process all frames in ``stream_tf`` using a typed stream request."""
    request_obj = StreamRequest() if request is None else request
    frame_request = OnlineFrameRequest(
        n_sources=request_obj.n_sources,
        component_to_source=request_obj.component_to_source,
        return_mask=request_obj.return_mask,
        reference_mic=request_obj.reference_mic,
        metadata=dict(request_obj.metadata),
    )
    output = self.process_stream(
        stream_tf,
        frame_axis=int(request_obj.frame_axis),
        request=frame_request,
    )
    return SeparationOutput(estimate_tf=output, state=self.get_state())

reset

reset()

Reset online state and NMF statistics to initial values.

Source code in src/oobss/separators/online_ilrma.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def reset(self) -> None:
    """Reset online state and NMF statistics to initial values."""
    self.rng = np.random.default_rng(self._random_state)
    self.demix = np.tile(np.eye(self.n_mic, dtype=complex), (self.n_freq, 1, 1))
    self.cov = (
        np.tile(np.eye(self.n_mic, dtype=complex), (self.n_mic, self.n_freq, 1, 1))
        * self._cov_scale
    )

    self.basis = self.rng.random((self.n_mic, self.n_freq, self.n_bases)) + self.eps
    self.A = np.zeros_like(self.basis)
    self.B = np.zeros_like(self.basis)
    self._l1_normalise_W()

    self._t = np.zeros((self.n_mic,), dtype=np.int64)
    self._batch_counter = np.zeros((self.n_mic,), dtype=np.int64)
    self._H_store: list[list[np.ndarray]] = [[] for _ in range(self.n_mic)]

set_state

set_state(state)

Restore separator state from a snapshot.

Source code in src/oobss/separators/core/base.py
148
149
150
def set_state(self, state: SeparatorState | StreamingSeparatorState) -> None:
    """Restore separator state from a snapshot."""
    del state

OnlineISNMF

Bases: BaseStreamingSeparator

Online Itakura-Saito NMF with source-wise ratio-mask reconstruction.

Procedure
   input: STFT frame sequence {x_t}_{t=1..T}, x_t in C^F
   initialize NMF basis W, statistics A/B
   for each frame x_t:
       v_t <- |x_t|^2
       update activation h_t and basis statistics by online MU
       if update timing reached: refresh W from A/B with forgetting
       compute source power p_{s,f} by component-to-source assignment
       compute ratio mask m_{s,f} = p_{s,f} / sum_s p_{s,f}
       emit separated sources y_{s,f} = m_{s,f} x_f
Update Equations

With \(m = Wh + \varepsilon\), the online NMF updates are:

\[ h \leftarrow h \odot \frac{W^{\top}(v \oslash m^2)} {\max\left(W^{\top}(1 \oslash m), \varepsilon\right)} A \leftarrow A + \left((v \oslash m^2)h^{\top}\right) \odot W^2, \quad B \leftarrow B + (1 \oslash m)h^{\top} \]

Basis refresh (every \(\beta\) frames):

\[ \rho = \alpha^{\beta / t}, \quad A \leftarrow \rho A, \; B \leftarrow \rho B, \quad W \leftarrow \sqrt{A \oslash B} \]

Source-wise power aggregation and Wiener-style ratio mask:

\[ p_{s,f} = \sum_{k:\pi(k)=s} W_{f,k} h_k, \quad m_{s,f} = \frac{p_{s,f}}{\sum_{s'} p_{s',f} + \varepsilon}, \quad \hat{y}_{s,f} = m_{s,f} x_f \]

Examples:

Separate one-channel STFT stream into two sources:

   import numpy as np
   from scipy.signal import ShortTimeFFT, get_window
   from oobss import OnlineISNMF, StreamRequest

   fs = 16000
   fft_size = 1024
   hop_size = 256
   win = get_window("hann", fft_size, fftbins=True)
   stft = ShortTimeFFT(win=win, hop=hop_size, fs=fs)

   mixture_time = np.random.randn(fs * 2, 1)  # mono mixture
   X_ft = stft.stft(mixture_time[:, 0])       # (n_freq, n_frame)

   model = OnlineISNMF(
       n_components=16,
       n_features=X_ft.shape[0],
       n_sources=2,
       beta=2,
       forget=0.99,
       inner_iter=10,
       random_state=0,
   )
   out = model.process_stream_tf(
       X_ft,
       request=StreamRequest(frame_axis=1, n_sources=2),
   )
   Y_sft = out.estimate_tf  # (n_sources, n_freq, n_frame)
   if Y_sft is None:
       raise ValueError("OnlineISNMF did not return TF estimates.")

Request masks instead of separated spectra:

   out = model.process_stream_tf(
       X_ft,
       request=StreamRequest(frame_axis=1, n_sources=2, return_mask=True),
   )
   mask = out.estimate_tf  # (n_sources, n_freq, n_frame)
Source code in src/oobss/separators/online_isnmf.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
class OnlineISNMF(BaseStreamingSeparator):
    """Online Itakura-Saito NMF with source-wise ratio-mask reconstruction.

    Procedure
    ---------
    ```text

       input: STFT frame sequence {x_t}_{t=1..T}, x_t in C^F
       initialize NMF basis W, statistics A/B
       for each frame x_t:
           v_t <- |x_t|^2
           update activation h_t and basis statistics by online MU
           if update timing reached: refresh W from A/B with forgetting
           compute source power p_{s,f} by component-to-source assignment
           compute ratio mask m_{s,f} = p_{s,f} / sum_s p_{s,f}
           emit separated sources y_{s,f} = m_{s,f} x_f
    ```

    Update Equations
    ----------------
    With $m = Wh + \\varepsilon$, the online NMF updates are:

    $$
       h \\leftarrow h \\odot
       \\frac{W^{\\top}(v \\oslash m^2)}
       {\\max\\left(W^{\\top}(1 \\oslash m), \\varepsilon\\right)}
       A \\leftarrow A + \\left((v \\oslash m^2)h^{\\top}\\right) \\odot W^2,
       \\quad
       B \\leftarrow B + (1 \\oslash m)h^{\\top}
    $$

    Basis refresh (every $\\beta$ frames):

    $$
       \\rho = \\alpha^{\\beta / t}, \\quad
       A \\leftarrow \\rho A, \\; B \\leftarrow \\rho B, \\quad
       W \\leftarrow \\sqrt{A \\oslash B}
    $$

    Source-wise power aggregation and Wiener-style ratio mask:

    $$
       p_{s,f} = \\sum_{k:\\pi(k)=s} W_{f,k} h_k, \\quad
       m_{s,f} = \\frac{p_{s,f}}{\\sum_{s'} p_{s',f} + \\varepsilon},
       \\quad
       \\hat{y}_{s,f} = m_{s,f} x_f
    $$

    Examples
    --------
    Separate one-channel STFT stream into two sources:

    ```python

       import numpy as np
       from scipy.signal import ShortTimeFFT, get_window
       from oobss import OnlineISNMF, StreamRequest

       fs = 16000
       fft_size = 1024
       hop_size = 256
       win = get_window("hann", fft_size, fftbins=True)
       stft = ShortTimeFFT(win=win, hop=hop_size, fs=fs)

       mixture_time = np.random.randn(fs * 2, 1)  # mono mixture
       X_ft = stft.stft(mixture_time[:, 0])       # (n_freq, n_frame)

       model = OnlineISNMF(
           n_components=16,
           n_features=X_ft.shape[0],
           n_sources=2,
           beta=2,
           forget=0.99,
           inner_iter=10,
           random_state=0,
       )
       out = model.process_stream_tf(
           X_ft,
           request=StreamRequest(frame_axis=1, n_sources=2),
       )
       Y_sft = out.estimate_tf  # (n_sources, n_freq, n_frame)
       if Y_sft is None:
           raise ValueError("OnlineISNMF did not return TF estimates.")
    ```

    Request masks instead of separated spectra:

    ```python

       out = model.process_stream_tf(
           X_ft,
           request=StreamRequest(frame_axis=1, n_sources=2, return_mask=True),
       )
       mask = out.estimate_tf  # (n_sources, n_freq, n_frame)
    ```
    """

    def __init__(
        self,
        n_components: int,
        n_features: int,
        *,
        beta: int = 50,
        forget: float = 0.9,
        inner_iter: int = 30,
        keep_h: bool | str = "auto",
        eps: float = 1.0e-12,
        n_sources: Optional[int] = None,
        component_to_source: np.ndarray | None = None,
        random_state: Optional[int] = None,
        nmf: OnlineNMFUpdateStrategy | None = None,
        assignment: ComponentAssignmentStrategy | None = None,
        reconstruction_strategy: ReconstructionStrategy | None = None,
    ) -> None:
        self.F, self.K = int(n_features), int(n_components)
        self.beta = int(beta)
        self.r = float(forget)
        self.inner_iter = int(inner_iter)
        self.eps = float(eps)
        self._random_state = random_state

        self._default_n_sources = n_sources
        self._default_component_to_source = component_to_source

        self.reconstruction_strategy = (
            reconstruction_strategy
            if reconstruction_strategy is not None
            else RatioMaskReconstructionStrategy(eps=self.eps)
        )
        self.nmf_strategy = nmf if nmf is not None else MultiplicativeNMFStrategy()
        self.assignment_strategy = (
            assignment if assignment is not None else ModuloAssignmentStrategy()
        )

        if keep_h == "auto":
            self.keep_h = self.beta < 1000
        else:
            self.keep_h = bool(keep_h)
        self.reset()

    def reset(self) -> None:
        """Reset online NMF parameters and sufficient statistics."""
        rng = np.random.default_rng(self._random_state)
        self.W = rng.random((self.F, self.K)) + self.eps
        self.A = np.zeros_like(self.W)
        self.B = np.zeros_like(self.W)
        self._l1_normalise_W()
        self._t = 0
        self._batch_counter = 0
        self._H_store: list[np.ndarray] = []

    def partial_fit(self, v: np.ndarray) -> np.ndarray:
        """Update online NMF model from one power spectrum frame."""
        if v.ndim != 1 or v.shape[0] != self.F:
            raise ValueError("v must be 1-D array of length F")

        h_prev = self._H_store[-1].copy() if self.keep_h and self._H_store else None
        result = self.nmf_strategy.update(
            NMFUpdateRequest(
                v=v,
                basis=self.W,
                stat_a=self.A,
                stat_b=self.B,
                inner_iter=self.inner_iter,
                beta=self.beta,
                alpha=self.r,
                batch_counter=self._batch_counter,
                t=self._t,
                eps=self.eps,
                h_prev=h_prev,
            )
        )

        self.W = result.basis
        self.A = result.stat_a
        self.B = result.stat_b
        self._batch_counter = result.batch_counter
        self._t = result.t

        if self.keep_h:
            self._H_store.append(result.h)
        return result.h

    def _l1_normalise_W(self) -> None:
        colsum = np.maximum(self.W.sum(axis=0, keepdims=True), self.eps)
        self.W /= colsum
        self.A /= colsum
        self.B *= colsum

    def source_power(
        self,
        h: np.ndarray,
        *,
        n_sources: int,
        component_to_source: np.ndarray | None = None,
    ) -> np.ndarray:
        """Compute source-wise power model from component activation."""
        if h.ndim != 1 or h.shape[0] != self.K:
            raise ValueError("h must be a 1-D array of length K")

        assignment = self.assignment_strategy.resolve(
            ComponentAssignmentRequest(
                n_components=self.K,
                n_sources=n_sources,
                component_to_source=component_to_source,
            )
        )
        component_power = self.W * h[None, :]
        source_power = np.zeros((n_sources, self.F), dtype=component_power.dtype)
        np.add.at(source_power, assignment, component_power.T)
        return source_power

    def separate_frame(
        self,
        x: np.ndarray,
        *,
        n_sources: int,
        component_to_source: np.ndarray | None = None,
    ) -> tuple[np.ndarray, np.ndarray]:
        """Separate one complex STFT frame via source-wise ratio masking."""
        if x.ndim != 1 or x.shape[0] != self.F:
            raise ValueError("x must be a 1-D complex STFT frame of length F")

        h = self.partial_fit(np.abs(x) ** 2)
        source_power = self.source_power(
            h,
            n_sources=n_sources,
            component_to_source=component_to_source,
        )
        recon = self.reconstruction_strategy.reconstruct(
            ReconstructionRequest(
                mixture=x,
                source_power=source_power,
            )
        )
        if recon.mask is None:
            raise ValueError("OnlineISNMF reconstruction strategy must return a mask")
        return recon.estimate, recon.mask

    @property
    def n_sources(self) -> int:
        """Return configured source count (fallback to 1 when unspecified)."""
        return int(self._default_n_sources) if self._default_n_sources else 1

    def process_frame(
        self,
        frame: np.ndarray,
        request: OnlineFrameRequest | None = None,
    ) -> np.ndarray:
        """Process one frame using default or explicitly provided settings."""
        request_obj = request if request is not None else OnlineFrameRequest()
        n_sources = (
            request_obj.n_sources
            if request_obj.n_sources is not None
            else self._default_n_sources
        )
        if n_sources is None:
            raise ValueError("n_sources must be provided for OnlineISNMF separation")

        component_to_source = (
            request_obj.component_to_source
            if request_obj.component_to_source is not None
            else self._default_component_to_source
        )
        separated, mask = self.separate_frame(
            frame,
            n_sources=int(n_sources),
            component_to_source=component_to_source,
        )
        return mask if request_obj.return_mask else separated

n_sources property

n_sources

Return configured source count (fallback to 1 when unspecified).

__call__

__call__(*args, **kwargs)

Alias for :meth:forward to provide a torch-like call style.

Source code in src/oobss/separators/core/base.py
31
32
33
def __call__(self, *args: Any, **kwargs: Any) -> SeparationOutput:
    """Alias for :meth:`forward` to provide a torch-like call style."""
    return self.forward(*args, **kwargs)

forward

forward(stream_tf, *, request=None)

Torch-like forward alias for full streaming input.

Source code in src/oobss/separators/core/base.py
204
205
206
207
208
209
210
211
def forward(
    self,
    stream_tf: np.ndarray,
    *,
    request: StreamRequest | None = None,
) -> SeparationOutput:
    """Torch-like forward alias for full streaming input."""
    return self.process_stream_tf(stream_tf, request=request)

forward_streaming

forward_streaming(frame, *, state=None, request=None)

Process one frame and return (separated_frame, updated_state).

Source code in src/oobss/separators/core/base.py
152
153
154
155
156
157
158
159
160
161
162
163
def forward_streaming(
    self,
    frame: np.ndarray,
    *,
    state: SeparatorState | StreamingSeparatorState | None = None,
    request: OnlineFrameRequest | None = None,
) -> tuple[np.ndarray, SeparatorState | StreamingSeparatorState]:
    """Process one frame and return ``(separated_frame, updated_state)``."""
    if state is not None:
        self.set_state(state)
    output = self.process_frame(frame, request=request)
    return output, self.get_state()

get_state

get_state()

Return current separator state snapshot.

Source code in src/oobss/separators/core/base.py
144
145
146
def get_state(self) -> SeparatorState | StreamingSeparatorState:
    """Return current separator state snapshot."""
    return SeparatorState()

partial_fit

partial_fit(v)

Update online NMF model from one power spectrum frame.

Source code in src/oobss/separators/online_isnmf.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def partial_fit(self, v: np.ndarray) -> np.ndarray:
    """Update online NMF model from one power spectrum frame."""
    if v.ndim != 1 or v.shape[0] != self.F:
        raise ValueError("v must be 1-D array of length F")

    h_prev = self._H_store[-1].copy() if self.keep_h and self._H_store else None
    result = self.nmf_strategy.update(
        NMFUpdateRequest(
            v=v,
            basis=self.W,
            stat_a=self.A,
            stat_b=self.B,
            inner_iter=self.inner_iter,
            beta=self.beta,
            alpha=self.r,
            batch_counter=self._batch_counter,
            t=self._t,
            eps=self.eps,
            h_prev=h_prev,
        )
    )

    self.W = result.basis
    self.A = result.stat_a
    self.B = result.stat_b
    self._batch_counter = result.batch_counter
    self._t = result.t

    if self.keep_h:
        self._H_store.append(result.h)
    return result.h

process_frame

process_frame(frame, request=None)

Process one frame using default or explicitly provided settings.

Source code in src/oobss/separators/online_isnmf.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def process_frame(
    self,
    frame: np.ndarray,
    request: OnlineFrameRequest | None = None,
) -> np.ndarray:
    """Process one frame using default or explicitly provided settings."""
    request_obj = request if request is not None else OnlineFrameRequest()
    n_sources = (
        request_obj.n_sources
        if request_obj.n_sources is not None
        else self._default_n_sources
    )
    if n_sources is None:
        raise ValueError("n_sources must be provided for OnlineISNMF separation")

    component_to_source = (
        request_obj.component_to_source
        if request_obj.component_to_source is not None
        else self._default_component_to_source
    )
    separated, mask = self.separate_frame(
        frame,
        n_sources=int(n_sources),
        component_to_source=component_to_source,
    )
    return mask if request_obj.return_mask else separated

process_stream

process_stream(stream, *, frame_axis=-1, request=None)

Process all frames in stream and stack outputs on the last axis.

Source code in src/oobss/separators/core/base.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def process_stream(
    self,
    stream: np.ndarray,
    *,
    frame_axis: int = -1,
    request: OnlineFrameRequest | None = None,
) -> np.ndarray:
    """Process all frames in ``stream`` and stack outputs on the last axis."""
    frames = np.moveaxis(stream, frame_axis, 0)
    if frames.shape[0] == 0:
        raise ValueError("stream must contain at least one frame")

    outputs = [self.process_frame(frames[0], request=request)]
    for idx in range(1, frames.shape[0]):
        outputs.append(self.process_frame(frames[idx], request=request))
    return np.stack(outputs, axis=-1)

process_stream_tf

process_stream_tf(stream_tf, *, request=None)

Process all frames in stream_tf using a typed stream request.

Source code in src/oobss/separators/core/base.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def process_stream_tf(
    self,
    stream_tf: np.ndarray,
    *,
    request: StreamRequest | None = None,
) -> SeparationOutput:
    """Process all frames in ``stream_tf`` using a typed stream request."""
    request_obj = StreamRequest() if request is None else request
    frame_request = OnlineFrameRequest(
        n_sources=request_obj.n_sources,
        component_to_source=request_obj.component_to_source,
        return_mask=request_obj.return_mask,
        reference_mic=request_obj.reference_mic,
        metadata=dict(request_obj.metadata),
    )
    output = self.process_stream(
        stream_tf,
        frame_axis=int(request_obj.frame_axis),
        request=frame_request,
    )
    return SeparationOutput(estimate_tf=output, state=self.get_state())

reset

reset()

Reset online NMF parameters and sufficient statistics.

Source code in src/oobss/separators/online_isnmf.py
166
167
168
169
170
171
172
173
174
175
def reset(self) -> None:
    """Reset online NMF parameters and sufficient statistics."""
    rng = np.random.default_rng(self._random_state)
    self.W = rng.random((self.F, self.K)) + self.eps
    self.A = np.zeros_like(self.W)
    self.B = np.zeros_like(self.W)
    self._l1_normalise_W()
    self._t = 0
    self._batch_counter = 0
    self._H_store: list[np.ndarray] = []

separate_frame

separate_frame(x, *, n_sources, component_to_source=None)

Separate one complex STFT frame via source-wise ratio masking.

Source code in src/oobss/separators/online_isnmf.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def separate_frame(
    self,
    x: np.ndarray,
    *,
    n_sources: int,
    component_to_source: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray]:
    """Separate one complex STFT frame via source-wise ratio masking."""
    if x.ndim != 1 or x.shape[0] != self.F:
        raise ValueError("x must be a 1-D complex STFT frame of length F")

    h = self.partial_fit(np.abs(x) ** 2)
    source_power = self.source_power(
        h,
        n_sources=n_sources,
        component_to_source=component_to_source,
    )
    recon = self.reconstruction_strategy.reconstruct(
        ReconstructionRequest(
            mixture=x,
            source_power=source_power,
        )
    )
    if recon.mask is None:
        raise ValueError("OnlineISNMF reconstruction strategy must return a mask")
    return recon.estimate, recon.mask

set_state

set_state(state)

Restore separator state from a snapshot.

Source code in src/oobss/separators/core/base.py
148
149
150
def set_state(self, state: SeparatorState | StreamingSeparatorState) -> None:
    """Restore separator state from a snapshot."""
    del state

source_power

source_power(h, *, n_sources, component_to_source=None)

Compute source-wise power model from component activation.

Source code in src/oobss/separators/online_isnmf.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def source_power(
    self,
    h: np.ndarray,
    *,
    n_sources: int,
    component_to_source: np.ndarray | None = None,
) -> np.ndarray:
    """Compute source-wise power model from component activation."""
    if h.ndim != 1 or h.shape[0] != self.K:
        raise ValueError("h must be a 1-D array of length K")

    assignment = self.assignment_strategy.resolve(
        ComponentAssignmentRequest(
            n_components=self.K,
            n_sources=n_sources,
            component_to_source=component_to_source,
        )
    )
    component_power = self.W * h[None, :]
    source_power = np.zeros((n_sources, self.F), dtype=component_power.dtype)
    np.add.at(source_power, assignment, component_power.T)
    return source_power

SeparationOutput dataclass

Unified separation result container.

Parameters:

Name Type Description Default
estimate_time ndarray | None

Optional separated signals in time domain (n_src, n_samples).

None
estimate_tf ndarray | None

Optional separated signals in TF domain.

None
mask ndarray | None

Optional source mask values.

None
demix_filter ndarray | None

Optional demixing filter/matrix.

None
permutation ndarray | None

Optional source permutation indices.

None
state SeparatorState | StreamingSeparatorState | None

Optional internal runtime state snapshot.

None
metadata dict[str, Any]

Free-form method metadata for logging/benchmarking.

dict()
Source code in src/oobss/separators/core/io_models.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@dataclass(slots=True)
class SeparationOutput:
    """Unified separation result container.

    Parameters
    ----------
    estimate_time:
        Optional separated signals in time domain ``(n_src, n_samples)``.
    estimate_tf:
        Optional separated signals in TF domain.
    mask:
        Optional source mask values.
    demix_filter:
        Optional demixing filter/matrix.
    permutation:
        Optional source permutation indices.
    state:
        Optional internal runtime state snapshot.
    metadata:
        Free-form method metadata for logging/benchmarking.
    """

    estimate_time: np.ndarray | None = None
    estimate_tf: np.ndarray | None = None
    mask: np.ndarray | None = None
    demix_filter: np.ndarray | None = None
    permutation: np.ndarray | None = None
    state: SeparatorState | StreamingSeparatorState | None = None
    metadata: dict[str, Any] = field(default_factory=dict)

    def __post_init__(self) -> None:
        if (
            self.estimate_time is None
            and self.estimate_tf is None
            and self.mask is None
        ):
            raise ValueError(
                "SeparationOutput requires estimate_time, estimate_tf, or mask."
            )

SeparatorState dataclass

Mutable runtime state used across iterative or online updates.

Source code in src/oobss/separators/core/io_models.py
58
59
60
61
62
63
64
65
@dataclass(slots=True)
class SeparatorState:
    """Mutable runtime state used across iterative or online updates."""

    arrays: dict[str, np.ndarray] = field(default_factory=dict)
    counters: dict[str, int] = field(default_factory=dict)
    scalars: dict[str, float] = field(default_factory=dict)
    metadata: dict[str, Any] = field(default_factory=dict)

StreamRequest dataclass

Execution options for streaming separators.

Parameters:

Name Type Description Default
frame_axis int

Axis in input tensor that corresponds to frame/time index.

-1
reference_mic int | None

Optional reference microphone index.

None
n_sources int | None

Optional source count override for methods that require it.

None
component_to_source ndarray | None

Optional NMF component-to-source assignment.

None
return_mask bool

If True, separator may return masks instead of separated spectra.

False
metadata dict[str, Any]

Additional method-specific options.

dict()
Source code in src/oobss/separators/core/io_models.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@dataclass(slots=True)
class StreamRequest:
    """Execution options for streaming separators.

    Parameters
    ----------
    frame_axis:
        Axis in input tensor that corresponds to frame/time index.
    reference_mic:
        Optional reference microphone index.
    n_sources:
        Optional source count override for methods that require it.
    component_to_source:
        Optional NMF component-to-source assignment.
    return_mask:
        If ``True``, separator may return masks instead of separated spectra.
    metadata:
        Additional method-specific options.
    """

    frame_axis: int = -1
    reference_mic: int | None = None
    n_sources: int | None = None
    component_to_source: np.ndarray | None = None
    return_mask: bool = False
    metadata: dict[str, Any] = field(default_factory=dict)

StreamingSeparatorState dataclass

Typed runtime state for streaming BSS-style separators.

Parameters:

Name Type Description Default
source_model ndarray | None

Optional source model state snapshot at the current frame.

None
demix_filter ndarray | None

Optional demixing filter snapshot.

None
mix_filter ndarray | None

Optional mixing filter snapshot (typically inverse of demix_filter).

None
frame_index int

Number of processed frames associated with this state.

0
metadata dict[str, Any]

Additional algorithm-specific state values.

dict()
Source code in src/oobss/separators/core/io_models.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@dataclass(slots=True)
class StreamingSeparatorState:
    """Typed runtime state for streaming BSS-style separators.

    Parameters
    ----------
    source_model:
        Optional source model state snapshot at the current frame.
    demix_filter:
        Optional demixing filter snapshot.
    mix_filter:
        Optional mixing filter snapshot (typically inverse of ``demix_filter``).
    frame_index:
        Number of processed frames associated with this state.
    metadata:
        Additional algorithm-specific state values.
    """

    source_model: np.ndarray | None = None
    demix_filter: np.ndarray | None = None
    mix_filter: np.ndarray | None = None
    frame_index: int = 0
    metadata: dict[str, Any] = field(default_factory=dict)

load_yaml

load_yaml(path, *, overrides=None)

Load a YAML file into a dictionary, with optional dotlist overrides.

Source code in src/oobss/configs.py
25
26
27
28
29
30
31
32
33
34
35
36
def load_yaml(
    path: str | Path,
    *,
    overrides: Iterable[str] | None = None,
) -> dict[str, Any]:
    """Load a YAML file into a dictionary, with optional dotlist overrides."""
    override_list = [item for item in (overrides or []) if item]
    cfg = OmegaConf.load(Path(path))
    if override_list:
        cfg = OmegaConf.merge(cfg, OmegaConf.from_dotlist(override_list))
    loaded = OmegaConf.to_container(cfg, resolve=True)
    return _as_str_key_dict(loaded, context=str(path))

log_steps_jsonl

log_steps_jsonl(path, steps)

Write many step dictionaries to JSONL.

Source code in src/oobss/logging_utils.py
22
23
24
25
26
def log_steps_jsonl(path: str | Path, steps: Iterable[Mapping[str, Any]]) -> None:
    """Write many step dictionaries to JSONL."""
    logger = JsonlLogger(path)
    for step in steps:
        logger.write(step)

save_yaml

save_yaml(path, data)

Write dictionary data to a YAML file.

Source code in src/oobss/configs.py
54
55
56
57
58
59
def save_yaml(path: str | Path, data: dict[str, Any]) -> None:
    """Write dictionary data to a YAML file."""
    out = Path(path)
    out.parent.mkdir(parents=True, exist_ok=True)
    with out.open("w", encoding="utf-8") as handle:
        yaml.safe_dump(data, handle, sort_keys=False)

Benchmark

oobss.benchmark

Benchmark orchestration APIs.

__all__ module-attribute

__all__ = ['ExperimentEngine', 'MethodRunnerRegistry', 'default_method_runner_registry', 'validate_builtin_method_params']

ExperimentEngine dataclass

Orchestrate planning and execution of benchmark experiments.

Source code in src/oobss/benchmark/engine.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
@dataclass
class ExperimentEngine:
    """Orchestrate planning and execution of benchmark experiments."""

    build_tasks_fn: BuildTasksFn = build_tasks_from_recipe
    run_task_fn: RunTaskFn = _run_task_with_registry
    write_outputs_fn: WriteOutputsFn | None = None

    def __post_init__(self) -> None:
        if self.write_outputs_fn is None:
            self.write_outputs_fn = _write_outputs_adapter

    def build_tasks(
        self,
        recipe: ExperimentRecipe,
        methods: Sequence[MethodConfig],
        *,
        method_grid: Mapping[str, Mapping[str, Sequence[object]]] | None = None,
    ) -> list[ExperimentTask]:
        """Build tasks from recipe and method definitions."""
        expanded_methods = expand_method_grids(methods, method_grid)
        return self.build_tasks_fn(recipe, methods=expanded_methods)

    def run(
        self,
        *,
        recipe: ExperimentRecipe,
        methods: Sequence[MethodConfig],
        output_root: Path,
        workers: int | None,
        overwrite: bool,
        save_framewise: bool,
        summary_precision: int,
        save_audio: bool,
        method_grid: Mapping[str, Mapping[str, Sequence[object]]] | None = None,
        runner_registry: MethodRunnerRegistry | None = None,
    ) -> EngineRunArtifacts:
        """Execute planned experiments and persist outputs."""
        tasks = self.build_tasks(recipe, methods, method_grid=method_grid)
        run_root = generate_run_directory(output_root, overwrite=overwrite)
        outputs = self._execute_tasks(
            tasks,
            workers=workers,
            runner_registry=runner_registry,
        )

        assert self.write_outputs_fn is not None  # assigned in __post_init__
        results_path = self.write_outputs_fn(
            outputs,
            run_root,
            save_framewise,
            summary_precision,
            save_audio,
        )
        method_count = len({task.method.id for task in tasks})
        return EngineRunArtifacts(
            run_root=run_root,
            results_path=results_path,
            task_count=len(tasks),
            method_count=method_count,
        )

    def _execute_tasks(
        self,
        tasks: Sequence[ExperimentTask],
        *,
        workers: int | None,
        runner_registry: MethodRunnerRegistry | None,
    ) -> list[ExperimentOutput]:
        worker_count = workers or (os.cpu_count() or 1)
        worker_count = max(1, int(worker_count))
        outputs: list[ExperimentOutput] = []

        if worker_count > 1 and runner_registry is not None:
            LOGGER.warning(
                "Custom runner registry was provided; falling back to single-process execution."
            )
            worker_count = 1

        if worker_count == 1:
            for idx, task in enumerate(tasks, start=1):
                outputs.append(self.run_task_fn(task, runner_registry))
                LOGGER.info("Completed %d/%d", idx, len(tasks))
            return outputs

        with ProcessPoolExecutor(max_workers=worker_count) as pool:
            futures = {
                pool.submit(self.run_task_fn, task, runner_registry): task
                for task in tasks
            }
            for idx, future in enumerate(as_completed(futures), start=1):
                outputs.append(future.result())
                LOGGER.info("Completed %d/%d", idx, len(tasks))
        return outputs

build_tasks

build_tasks(recipe, methods, *, method_grid=None)

Build tasks from recipe and method definitions.

Source code in src/oobss/benchmark/engine.py
242
243
244
245
246
247
248
249
250
251
def build_tasks(
    self,
    recipe: ExperimentRecipe,
    methods: Sequence[MethodConfig],
    *,
    method_grid: Mapping[str, Mapping[str, Sequence[object]]] | None = None,
) -> list[ExperimentTask]:
    """Build tasks from recipe and method definitions."""
    expanded_methods = expand_method_grids(methods, method_grid)
    return self.build_tasks_fn(recipe, methods=expanded_methods)

run

run(*, recipe, methods, output_root, workers, overwrite, save_framewise, summary_precision, save_audio, method_grid=None, runner_registry=None)

Execute planned experiments and persist outputs.

Source code in src/oobss/benchmark/engine.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def run(
    self,
    *,
    recipe: ExperimentRecipe,
    methods: Sequence[MethodConfig],
    output_root: Path,
    workers: int | None,
    overwrite: bool,
    save_framewise: bool,
    summary_precision: int,
    save_audio: bool,
    method_grid: Mapping[str, Mapping[str, Sequence[object]]] | None = None,
    runner_registry: MethodRunnerRegistry | None = None,
) -> EngineRunArtifacts:
    """Execute planned experiments and persist outputs."""
    tasks = self.build_tasks(recipe, methods, method_grid=method_grid)
    run_root = generate_run_directory(output_root, overwrite=overwrite)
    outputs = self._execute_tasks(
        tasks,
        workers=workers,
        runner_registry=runner_registry,
    )

    assert self.write_outputs_fn is not None  # assigned in __post_init__
    results_path = self.write_outputs_fn(
        outputs,
        run_root,
        save_framewise,
        summary_precision,
        save_audio,
    )
    method_count = len({task.method.id for task in tasks})
    return EngineRunArtifacts(
        run_root=run_root,
        results_path=results_path,
        task_count=len(tasks),
        method_count=method_count,
    )

MethodRunnerRegistry dataclass

Registry that resolves method IDs to runner instances.

Source code in src/oobss/benchmark/methods.py
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
@dataclass
class MethodRunnerRegistry:
    """Registry that resolves method IDs to runner instances."""

    _runners: dict[str, MethodRunner] = field(default_factory=dict)

    def register(
        self,
        method_type: str,
        runner: MethodRunner | MethodRunnerFn,
        *,
        overwrite: bool = False,
    ) -> None:
        if not overwrite and method_type in self._runners:
            raise ValueError(f"Method runner already registered: {method_type}")
        if callable(runner) and not hasattr(runner, "run"):
            self._runners[method_type] = FunctionalMethodRunner(
                cast(MethodRunnerFn, runner)
            )
        else:
            self._runners[method_type] = cast(MethodRunner, runner)

    def resolve(self, method_type: str) -> MethodRunner:
        try:
            return self._runners[method_type]
        except KeyError as exc:
            available = ", ".join(sorted(self._runners))
            raise ValueError(
                f"Unknown method type: {method_type}. Available methods: {available}"
            ) from exc

    def available(self) -> list[str]:
        return sorted(self._runners)

default_method_runner_registry

default_method_runner_registry()

Create the built-in method runner registry.

Source code in src/oobss/benchmark/methods.py
942
943
944
945
946
947
948
949
950
951
952
def default_method_runner_registry() -> MethodRunnerRegistry:
    """Create the built-in method runner registry."""
    registry = MethodRunnerRegistry()
    registry.register("batch_auxiva", BatchAuxIVARunner())
    registry.register("batch_ilrma", BatchILRMARunner())
    registry.register("blockbatch_auxiva", BlockBatchAuxIVARunner())
    registry.register("blockbatch_ilrma", BlockBatchILRMARunner())
    registry.register("online_auxiva", OnlineAuxIVARunner())
    registry.register("online_ilrma", OnlineILRMARunner())
    registry.register("online_isnmf", OnlineISNMFRunner())
    return registry

validate_builtin_method_params

validate_builtin_method_params(method_type, params)

Validate params for built-in method types.

Unknown method types are ignored so external plugin types remain supported.

Source code in src/oobss/benchmark/methods.py
221
222
223
224
225
226
227
228
229
230
231
def validate_builtin_method_params(
    method_type: str, params: Mapping[str, object]
) -> None:
    """Validate params for built-in method types.

    Unknown method types are ignored so external plugin types remain supported.
    """
    schema = _METHOD_PARAM_SCHEMAS.get(method_type)
    if schema is None:
        return
    _decode_params(params, schema)

Dataloaders

oobss.dataloaders

Dataset adapter and loader APIs.

AdapterFactory module-attribute

AdapterFactory = Callable[[dict[str, Any]], BaseDatasetAdapter]

DatasetLoader module-attribute

DatasetLoader = BaseDatasetAdapter

__all__ module-attribute

__all__ = ['AdapterFactory', 'BaseDatasetAdapter', 'DatasetLoader', 'TorchrirDynamicDatasetAdapter', 'TrackAudio', 'TrackHandle', 'build_torchrir_dynamic_adapter', 'create_loader', 'loader_registry', 'TorchrirDynamicDataset', 'build_torchrir_dynamic_dataloader']

BaseDatasetAdapter

Bases: ABC

Abstract dataset adapter used by experiment pipelines.

Source code in src/oobss/dataloaders/base.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class BaseDatasetAdapter(ABC):
    """Abstract dataset adapter used by experiment pipelines."""

    @abstractmethod
    def discover(
        self,
        *,
        include: Iterable[str] | None = None,
        sample_limit: int | None = None,
    ) -> list[TrackHandle]:
        """Discover available track handles."""

    @abstractmethod
    def load(
        self,
        handle: TrackHandle,
        *,
        duration_sec: float | None = None,
    ) -> TrackAudio:
        """Load one track represented by ``handle``."""

    @abstractmethod
    def stem_names(self) -> list[str]:
        """Return ordered stem labels used in reporting outputs."""

discover abstractmethod

discover(*, include=None, sample_limit=None)

Discover available track handles.

Source code in src/oobss/dataloaders/base.py
79
80
81
82
83
84
85
86
@abstractmethod
def discover(
    self,
    *,
    include: Iterable[str] | None = None,
    sample_limit: int | None = None,
) -> list[TrackHandle]:
    """Discover available track handles."""

load abstractmethod

load(handle, *, duration_sec=None)

Load one track represented by handle.

Source code in src/oobss/dataloaders/base.py
88
89
90
91
92
93
94
95
@abstractmethod
def load(
    self,
    handle: TrackHandle,
    *,
    duration_sec: float | None = None,
) -> TrackAudio:
    """Load one track represented by ``handle``."""

stem_names abstractmethod

stem_names()

Return ordered stem labels used in reporting outputs.

Source code in src/oobss/dataloaders/base.py
97
98
99
@abstractmethod
def stem_names(self) -> list[str]:
    """Return ordered stem labels used in reporting outputs."""

TorchrirDynamicDataset

Dataset with __len__ / __getitem__ for dynamic torchrir scenes.

Source code in src/oobss/dataloaders/torchrir_dynamic.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
class TorchrirDynamicDataset:
    """Dataset with ``__len__`` / ``__getitem__`` for dynamic torchrir scenes."""

    def __init__(
        self,
        root: Path | str,
        *,
        return_type: str = "torch",
        include: Sequence[str] | None = None,
        sample_limit: int | None = None,
        duration_sec: float | None = None,
        dtype: Any | None = None,
        include_metadata: bool = False,
    ) -> None:
        self.root = Path(root).expanduser().resolve()
        self.return_type = str(return_type).lower()
        if self.return_type not in {"torch", "numpy"}:
            raise ValueError("return_type must be either 'torch' or 'numpy'.")
        if self.return_type == "torch" and torch is None:
            raise RuntimeError(
                "TorchrirDynamicDataset(return_type='torch') requires torch."
            )

        self.duration_sec = duration_sec
        self.include_metadata = bool(include_metadata)
        default_dtype = None if torch is None else torch.float32
        self.dtype = default_dtype if dtype is None else dtype
        self._scene_paths = discover_torchrir_scene_paths(
            self.root,
            include=include,
            sample_limit=sample_limit,
        )

    def __len__(self) -> int:
        return len(self._scene_paths)

    def __getitem__(self, index: int) -> TorchrirSample:
        scene = load_torchrir_dynamic_scene(
            self._scene_paths[index],
            duration_sec=self.duration_sec,
            include_metadata=self.include_metadata,
        )
        if self.return_type == "numpy":
            return scene
        if torch is None:  # pragma: no cover - guarded in __init__
            raise RuntimeError("torch is required for return_type='torch'.")

        mix_np = np.asarray(scene["mix"])
        stems_np = np.asarray(scene["stems"])
        mix = torch.as_tensor(mix_np, dtype=self.dtype)
        stems = torch.as_tensor(stems_np, dtype=self.dtype)

        output: TorchrirSample = dict(scene)
        output["mix"] = mix
        output["stems"] = stems
        return output

TorchrirDynamicDatasetAdapter dataclass

Bases: BaseDatasetAdapter

Adapter for dynamic torchrir scene directories.

Source code in src/oobss/dataloaders/base.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
@dataclass(frozen=True)
class TorchrirDynamicDatasetAdapter(BaseDatasetAdapter):
    """Adapter for dynamic torchrir scene directories."""

    root: Path
    duration_sec: float | None = None

    def discover(
        self,
        *,
        include: Iterable[str] | None = None,
        sample_limit: int | None = None,
    ) -> list[TrackHandle]:
        """Return discovered dynamic-scene handles sorted by scene ID."""
        scenes = discover_torchrir_scene_paths(
            self.root,
            include=include,
            sample_limit=sample_limit,
        )
        return [
            TrackHandle(
                track_id=path.name,
                payload={"path": str(path)},
            )
            for path in scenes
        ]

    def load(
        self,
        handle: TrackHandle,
        *,
        duration_sec: float | None = None,
    ) -> TrackAudio:
        """Load one dynamic scene and return canonical track audio."""
        raw_path = handle.payload.get("path")
        if not isinstance(raw_path, str):
            raise ValueError(f"Invalid track handle payload for {handle.track_id}")
        scene = load_torchrir_dynamic_scene(
            Path(raw_path),
            duration_sec=self.duration_sec if duration_sec is None else duration_sec,
        )
        stems = np.asarray(scene["stems"], dtype=np.float64)
        mix = np.asarray(scene["mix"], dtype=np.float64)
        sample_rate = int(scene["sample_rate"])
        return TrackAudio(
            track_id=str(scene["scene_id"]),
            path=Path(str(scene["path"])),
            stems=stems,
            mix=mix,
            sample_rate=sample_rate,
        )

    def stem_names(self) -> list[str]:
        """Return ordered source names from the first discovered scene."""
        scenes = discover_torchrir_scene_paths(self.root, sample_limit=1)
        if not scenes:
            return []
        sample = load_torchrir_dynamic_scene(
            scenes[0],
            duration_sec=self.duration_sec,
        )
        source_names_obj = sample.get("source_names", [])
        if not isinstance(source_names_obj, list):
            return []
        return [str(name) for name in source_names_obj]

discover

discover(*, include=None, sample_limit=None)

Return discovered dynamic-scene handles sorted by scene ID.

Source code in src/oobss/dataloaders/base.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def discover(
    self,
    *,
    include: Iterable[str] | None = None,
    sample_limit: int | None = None,
) -> list[TrackHandle]:
    """Return discovered dynamic-scene handles sorted by scene ID."""
    scenes = discover_torchrir_scene_paths(
        self.root,
        include=include,
        sample_limit=sample_limit,
    )
    return [
        TrackHandle(
            track_id=path.name,
            payload={"path": str(path)},
        )
        for path in scenes
    ]

load

load(handle, *, duration_sec=None)

Load one dynamic scene and return canonical track audio.

Source code in src/oobss/dataloaders/base.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def load(
    self,
    handle: TrackHandle,
    *,
    duration_sec: float | None = None,
) -> TrackAudio:
    """Load one dynamic scene and return canonical track audio."""
    raw_path = handle.payload.get("path")
    if not isinstance(raw_path, str):
        raise ValueError(f"Invalid track handle payload for {handle.track_id}")
    scene = load_torchrir_dynamic_scene(
        Path(raw_path),
        duration_sec=self.duration_sec if duration_sec is None else duration_sec,
    )
    stems = np.asarray(scene["stems"], dtype=np.float64)
    mix = np.asarray(scene["mix"], dtype=np.float64)
    sample_rate = int(scene["sample_rate"])
    return TrackAudio(
        track_id=str(scene["scene_id"]),
        path=Path(str(scene["path"])),
        stems=stems,
        mix=mix,
        sample_rate=sample_rate,
    )

stem_names

stem_names()

Return ordered source names from the first discovered scene.

Source code in src/oobss/dataloaders/base.py
158
159
160
161
162
163
164
165
166
167
168
169
170
def stem_names(self) -> list[str]:
    """Return ordered source names from the first discovered scene."""
    scenes = discover_torchrir_scene_paths(self.root, sample_limit=1)
    if not scenes:
        return []
    sample = load_torchrir_dynamic_scene(
        scenes[0],
        duration_sec=self.duration_sec,
    )
    source_names_obj = sample.get("source_names", [])
    if not isinstance(source_names_obj, list):
        return []
    return [str(name) for name in source_names_obj]

TrackAudio dataclass

Container holding stems, mixture, and metadata for one track.

Attributes:

Name Type Description
track_id str

Unique identifier of the track in the dataset.

path Path

Backing track path, if available.

stems ndarray

Time-domain stem signals with shape (n_src, n_samples, n_mic).

mix ndarray

Time-domain mixture with shape (n_samples, n_mic).

sample_rate int

Sampling rate in Hz.

Source code in src/oobss/dataloaders/base.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@dataclass(frozen=True)
class TrackAudio:
    """Container holding stems, mixture, and metadata for one track.

    Attributes
    ----------
    track_id:
        Unique identifier of the track in the dataset.
    path:
        Backing track path, if available.
    stems:
        Time-domain stem signals with shape ``(n_src, n_samples, n_mic)``.
    mix:
        Time-domain mixture with shape ``(n_samples, n_mic)``.
    sample_rate:
        Sampling rate in Hz.
    """

    track_id: str
    path: Path
    stems: np.ndarray  # shape: (n_src, n_samples, n_mic)
    mix: np.ndarray  # shape: (n_samples, n_mic)
    sample_rate: int

    @property
    def n_src(self) -> int:
        """Return number of sources."""
        return int(self.stems.shape[0])

    @property
    def n_mic(self) -> int:
        """Return number of microphones."""
        return int(self.stems.shape[2])

    @property
    def n_samples(self) -> int:
        """Return number of samples."""
        return int(self.stems.shape[1])

    @property
    def duration(self) -> float:
        """Return track duration in seconds."""
        return self.n_samples / float(self.sample_rate)

duration property

duration

Return track duration in seconds.

n_mic property

n_mic

Return number of microphones.

n_samples property

n_samples

Return number of samples.

n_src property

n_src

Return number of sources.

TrackHandle dataclass

Opaque reference to a track discovered by a dataset loader.

The payload is loader-specific metadata required to retrieve the actual audio for the track. It must remain pickle-safe because tasks can run in subprocess workers.

Source code in src/oobss/dataloaders/base.py
63
64
65
66
67
68
69
70
71
72
73
@dataclass(frozen=True)
class TrackHandle:
    """Opaque reference to a track discovered by a dataset loader.

    The payload is loader-specific metadata required to retrieve the actual
    audio for the track. It must remain pickle-safe because tasks can run in
    subprocess workers.
    """

    track_id: str
    payload: dict[str, Any]

build_torchrir_dynamic_adapter

build_torchrir_dynamic_adapter(dataset_cfg)

Create :class:TorchrirDynamicDatasetAdapter from dataset configuration.

Source code in src/oobss/dataloaders/base.py
173
174
175
176
177
178
179
180
181
182
183
184
def build_torchrir_dynamic_adapter(
    dataset_cfg: dict[str, Any],
) -> BaseDatasetAdapter:
    """Create :class:`TorchrirDynamicDatasetAdapter` from dataset configuration."""
    root = (
        Path(dataset_cfg.get("root", "outputs/cmu_arctic_torchrir_dynamic_dataset"))
        .expanduser()
        .resolve()
    )
    duration_obj = dataset_cfg.get("duration_sec")
    duration_sec = None if duration_obj is None else float(duration_obj)
    return TorchrirDynamicDatasetAdapter(root=root, duration_sec=duration_sec)

build_torchrir_dynamic_dataloader

build_torchrir_dynamic_dataloader(*, root, return_type='torch', include=None, sample_limit=None, duration_sec=None, dtype=None, include_metadata=False, batch_size=1, shuffle=False, num_workers=0, pin_memory=False, drop_last=False, collate_fn=None)

Build a torch.utils.data.DataLoader for dynamic torchrir scenes.

Source code in src/oobss/dataloaders/torchrir_dynamic.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def build_torchrir_dynamic_dataloader(
    *,
    root: Path | str,
    return_type: str = "torch",
    include: Sequence[str] | None = None,
    sample_limit: int | None = None,
    duration_sec: float | None = None,
    dtype: Any | None = None,
    include_metadata: bool = False,
    batch_size: int = 1,
    shuffle: bool = False,
    num_workers: int = 0,
    pin_memory: bool = False,
    drop_last: bool = False,
    collate_fn: CollateFn | None = None,
) -> Any:
    """Build a ``torch.utils.data.DataLoader`` for dynamic torchrir scenes."""
    if DataLoader is None:
        raise RuntimeError("build_torchrir_dynamic_dataloader requires torch.")

    dataset = TorchrirDynamicDataset(
        root=root,
        return_type=return_type,
        include=include,
        sample_limit=sample_limit,
        duration_sec=duration_sec,
        dtype=dtype,
        include_metadata=include_metadata,
    )
    resolved_collate = collate_scene_batch_as_list if collate_fn is None else collate_fn
    return DataLoader(
        dataset,
        batch_size=int(batch_size),
        shuffle=bool(shuffle),
        num_workers=int(num_workers),
        pin_memory=bool(pin_memory),
        drop_last=bool(drop_last),
        collate_fn=resolved_collate,
    )

create_loader

create_loader(dataset_cfg, *, registry_overrides=None)

Instantiate a dataset loader from dataset_cfg.

Parameters:

Name Type Description Default
dataset_cfg dict[str, Any]

Dataset configuration. Must include type when using non-default adapters. The default type is torchrir_dynamic.

required
registry_overrides Mapping[str, AdapterFactory] | None

Optional injected adapter factories for tests or external integrations.

None
Source code in src/oobss/dataloaders/base.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def create_loader(
    dataset_cfg: dict[str, Any],
    *,
    registry_overrides: Mapping[str, AdapterFactory] | None = None,
) -> BaseDatasetAdapter:
    """Instantiate a dataset loader from ``dataset_cfg``.

    Parameters
    ----------
    dataset_cfg:
        Dataset configuration. Must include ``type`` when using non-default
        adapters. The default type is ``torchrir_dynamic``.
    registry_overrides:
        Optional injected adapter factories for tests or external integrations.
    """
    cfg = dict(dataset_cfg)
    loader_type = str(cfg.get("type", "torchrir_dynamic"))
    registry = loader_registry(overrides=registry_overrides)
    try:
        factory = registry[loader_type]
    except KeyError as exc:
        available = ", ".join(sorted(registry))
        raise ValueError(
            f"Unknown dataset loader type: {loader_type}. Available: {available}"
        ) from exc
    return factory(cfg)

loader_registry

loader_registry(overrides=None)

Return registry mapping loader type names to factories.

Source code in src/oobss/dataloaders/base.py
187
188
189
190
191
192
193
194
195
196
def loader_registry(
    overrides: Mapping[str, AdapterFactory] | None = None,
) -> dict[str, AdapterFactory]:
    """Return registry mapping loader type names to factories."""
    registry: dict[str, AdapterFactory] = {
        "torchrir_dynamic": build_torchrir_dynamic_adapter,
    }
    if overrides:
        registry.update(overrides)
    return registry

Evaluation

oobss.evaluation

Evaluation utilities for separation outputs.

__all__ module-attribute

__all__ = ['Framing', 'align_lengths', 'calc_si_sdr_framewise', 'framewise_si_sdr_summary', 'summarize_framewise_si_sdr', 'si_bss_eval', 'MetricsBundle', 'compute_metrics', 'normalize_framewise_metrics']

Framing

Iterator for overlapping window slices along the last axis.

Source code in src/oobss/evaluation/framewise.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Framing:
    """Iterator for overlapping window slices along the last axis."""

    def __init__(self, window: int, hop: int, length: int) -> None:
        if window <= 0 or hop <= 0:
            raise ValueError("window and hop must be positive integers.")
        self.window = window
        self.hop = hop
        self.length = length
        self._index = 0

    def __iter__(self) -> Iterator[slice]:
        return self

    def __next__(self) -> slice:
        if self._index >= self.nwin:
            raise StopIteration
        start = self._index * self.hop
        stop = min(start + self.window, self.length)
        self._index += 1
        return slice(start, stop)

    @property
    def nwin(self) -> int:
        if self.window >= self.length:
            return 1
        return int(np.floor((self.length - self.window + self.hop) / self.hop))

MetricsBundle dataclass

Source code in src/oobss/evaluation/metrics.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@dataclass
class MetricsBundle:
    sdr_mix: np.ndarray
    sdr_est: np.ndarray
    sir_est: np.ndarray
    sar_est: np.ndarray
    permutation: np.ndarray
    framewise: dict[str, np.ndarray] | None

    def to_summary(self) -> dict[str, object]:
        sdr_imp = np.asarray(self.sdr_est) - np.asarray(self.sdr_mix)
        stats = {
            "sdr_mix_mean": float(np.nanmean(self.sdr_mix)),
            "sdr_est_mean": float(np.nanmean(self.sdr_est)),
            "sdr_imp_mean": float(np.nanmean(sdr_imp)),
            "sdr_est_median": float(np.nanmedian(self.sdr_est)),
            "sdr_imp_median": float(np.nanmedian(sdr_imp)),
            "sir_mean": float(np.nanmean(self.sir_est)),
            "sar_mean": float(np.nanmean(self.sar_est)),
        }
        stats["sdr_mix_channels"] = [float(v) for v in np.ravel(self.sdr_mix)]
        stats["sdr_est_channels"] = [float(v) for v in np.ravel(self.sdr_est)]
        stats["sdr_imp_channels"] = [float(v) for v in np.ravel(sdr_imp)]
        stats["sir_channels"] = [float(v) for v in np.ravel(self.sir_est)]
        stats["sar_channels"] = [float(v) for v in np.ravel(self.sar_est)]
        if self.framewise:
            for key in (
                "mean_si_sdr",
                "median_si_sdr",
                "mean_si_sdr_imp",
                "median_si_sdr_imp",
                "mean_si_sdr_mix",
                "median_si_sdr_mix",
            ):
                if key in self.framewise:
                    stats[key] = [float(v) for v in np.ravel(self.framewise[key])]
        return stats

align_lengths

align_lengths(arrays)

Trim all arrays to the shortest shared length along the last axis.

Source code in src/oobss/evaluation/framewise.py
45
46
47
48
49
50
def align_lengths(arrays: Iterable[np.ndarray]) -> tuple[list[np.ndarray], int]:
    """Trim all arrays to the shortest shared length along the last axis."""
    array_list = list(arrays)
    min_len = min(arr.shape[-1] for arr in array_list)
    trimmed = [arr[..., :min_len] for arr in array_list]
    return trimmed, min_len

calc_si_sdr_framewise

calc_si_sdr_framewise(ref, est, window, hop, *, scaling=True, compute_permutation=True)

Compute frame-wise SI-SDR with sliding windows.

Source code in src/oobss/evaluation/framewise.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def calc_si_sdr_framewise(
    ref: np.ndarray,
    est: np.ndarray,
    window: int,
    hop: int,
    *,
    scaling: bool = True,
    compute_permutation: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
    """Compute frame-wise SI-SDR with sliding windows."""
    if ref.ndim != 2 or est.ndim != 2:
        raise ValueError("ref and est must be 2-D arrays (n_src, n_samples)")
    if ref.shape[-1] != est.shape[-1]:
        raise ValueError(
            f"Shape mismatch between ref {ref.shape} and est {est.shape} (sample length)"
        )
    if not compute_permutation and ref.shape[0] != est.shape[0]:
        raise ValueError(
            "compute_permutation=False requires equal number of channels in ref and est"
        )

    _, n_samples = ref.shape
    n_est = est.shape[0]
    window = min(window, n_samples)
    hop = min(hop, window) if window > 0 else hop

    windows = Framing(window, hop, n_samples)
    n_win = windows.nwin
    si_sdr = np.zeros((n_est, n_win), dtype=float)
    perm = np.zeros((n_est, n_win), dtype=int)

    for t, slc in enumerate(windows):
        ref_seg = ref[:, slc].T
        est_seg = est[:, slc].T
        try:
            if compute_permutation:
                si_sdr[:, t], _, _, perm[:, t] = si_bss_eval(
                    ref_seg,
                    est_seg,
                    scaling=scaling,
                )
            else:
                for ch in range(n_est):
                    sdr_ch, _, _, _ = si_bss_eval(
                        ref_seg[:, [ch]],
                        est_seg[:, [ch]],
                        scaling=scaling,
                    )
                    si_sdr[ch, t] = sdr_ch[0]
                perm[:, t] = np.arange(n_est)
        except ValueError as exc:
            LOGGER.error("SI-SDR evaluation failed at frame %s: %s", t, exc)
            si_sdr[:, t] = np.nan
            perm[:, t] = np.arange(n_est)

    return si_sdr, perm

compute_metrics

compute_metrics(reference, estimate, mixture, sample_rate, *, filter_length, frame_cfg, compute_permutation=True, permutation_strategy=None)

Compute batch and optional frame-wise separation metrics.

Parameters:

Name Type Description Default
reference ndarray

Reference signals with shape (n_ref, n_samples).

required
estimate ndarray

Estimated signals with shape (n_est, n_samples).

required
mixture ndarray

Mixture baseline with shape (n_mix, n_samples) or (n_samples,). For reference-mic baseline evaluation, pass a single channel.

required
compute_permutation bool

Whether to allow permutation solving in fast_bss_eval. This should be enabled for under/over-determined cases where channel counts differ.

True
Source code in src/oobss/evaluation/metrics.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def compute_metrics(
    reference: np.ndarray,
    estimate: np.ndarray,
    mixture: np.ndarray,
    sample_rate: int,
    *,
    filter_length: int,
    frame_cfg: FrameEvalLike | None,
    compute_permutation: bool = True,
    permutation_strategy: PermutationStrategy | None = None,
) -> MetricsBundle:
    """Compute batch and optional frame-wise separation metrics.

    Parameters
    ----------
    reference:
        Reference signals with shape ``(n_ref, n_samples)``.
    estimate:
        Estimated signals with shape ``(n_est, n_samples)``.
    mixture:
        Mixture baseline with shape ``(n_mix, n_samples)`` or ``(n_samples,)``.
        For reference-mic baseline evaluation, pass a single channel.
    compute_permutation:
        Whether to allow permutation solving in ``fast_bss_eval``. This should
        be enabled for under/over-determined cases where channel counts differ.
    """
    if reference.ndim != 2 or estimate.ndim != 2:
        raise ValueError("reference and estimate must be 2-D arrays (n_src, n_samples)")
    if reference.shape[-1] != estimate.shape[-1]:
        raise ValueError(
            f"reference and estimate must share sample length, got {reference.shape} and {estimate.shape}"
        )
    mixture_eval = np.asarray(mixture)
    if mixture_eval.ndim == 1:
        mixture_eval = mixture_eval[None, :]
    if mixture_eval.ndim != 2:
        raise ValueError("mixture must be 1-D or 2-D array")
    if mixture_eval.shape[-1] != reference.shape[-1]:
        raise ValueError(
            "mixture must share sample length with reference/estimate, "
            f"got {mixture_eval.shape[-1]} and {reference.shape[-1]}"
        )
    if not compute_permutation and estimate.shape[0] != reference.shape[0]:
        raise ValueError(
            "compute_permutation=False requires matching channels for reference and estimate."
        )
    if not compute_permutation and mixture_eval.shape[0] != reference.shape[0]:
        if mixture_eval.shape[0] == 1:
            mixture_eval = np.repeat(mixture_eval, reference.shape[0], axis=0)
        else:
            raise ValueError(
                "compute_permutation=False requires matching channels for reference and mixture."
            )

    sdr_mix, _, _, _ = cast(
        BssevalMetrics,
        bss_eval_sources(
            reference,
            mixture_eval,
            filter_length=filter_length,
            compute_permutation=compute_permutation,
        ),
    )
    if compute_permutation:
        sdr_est, sir_est, sar_est, perm_default = cast(
            BssevalMetrics,
            bss_eval_sources(
                reference,
                estimate,
                filter_length=filter_length,
                compute_permutation=True,
            ),
        )
        strategy = (
            permutation_strategy
            if permutation_strategy is not None
            else BssEvalPermutationStrategy(filter_length=filter_length)
        )
        perm_idx = strategy.solve(
            PermutationRequest(
                score=np.asarray(sdr_est),
                reference=reference,
                estimate=estimate,
                filter_length=filter_length,
                default_perm=np.asarray(perm_default, dtype=np.int64),
            )
        )
    else:
        sdr_est, sir_est, sar_est = cast(
            tuple[np.ndarray, np.ndarray, np.ndarray],
            bss_eval_sources(
                reference,
                estimate,
                filter_length=filter_length,
                compute_permutation=False,
            ),
        )
        perm_idx = np.arange(estimate.shape[0], dtype=np.int64)

    if (
        estimate.shape[0] == reference.shape[0]
        and perm_idx.shape[0] == estimate.shape[0]
        and np.all((0 <= perm_idx) & (perm_idx < estimate.shape[0]))
    ):
        estimate_perm = estimate[np.argsort(perm_idx), :]
    else:
        estimate_perm = estimate

    framewise = None
    if frame_cfg:
        frame_mixture = mixture_eval
        frame_compute_perm = bool(frame_cfg.compute_permutation)
        if not frame_compute_perm and frame_mixture.shape[0] != reference.shape[0]:
            if frame_mixture.shape[0] == 1:
                frame_mixture = np.repeat(frame_mixture, reference.shape[0], axis=0)
            else:
                raise ValueError(
                    "frame.compute_permutation=False requires matching channels for reference and mixture."
                )
        framewise = summarize_framewise_si_sdr(
            reference,
            estimate_perm,
            sample_rate,
            mixture=frame_mixture,
            window_sec=float(frame_cfg.window_sec),
            hop_sec=frame_cfg.hop_sec,
            scaling=bool(frame_cfg.scaling),
            compute_permutation=frame_compute_perm,
        )
        framewise = normalize_framewise_metrics(framewise)

    return MetricsBundle(
        sdr_mix=np.asarray(sdr_mix),
        sdr_est=np.asarray(sdr_est),
        sir_est=np.asarray(sir_est),
        sar_est=np.asarray(sar_est),
        permutation=perm_idx,
        framewise=framewise,
    )

framewise_si_sdr_summary

framewise_si_sdr_summary(ref, est, *, mixture=None, window, hop, scaling=True, compute_permutation=True)

Compute frame-wise SI-SDR and optional mixture baseline.

Source code in src/oobss/evaluation/framewise.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def framewise_si_sdr_summary(
    ref: np.ndarray,
    est: np.ndarray,
    *,
    mixture: np.ndarray | None = None,
    window: int,
    hop: int,
    scaling: bool = True,
    compute_permutation: bool = True,
) -> dict[str, np.ndarray]:
    """Compute frame-wise SI-SDR and optional mixture baseline."""
    arrays: list[np.ndarray] = [ref, est]
    has_mixture = mixture is not None
    if has_mixture:
        arrays.append(mixture)  # type: ignore[arg-type]
    aligned, _ = align_lengths(arrays)
    ref_aligned, est_aligned = aligned[:2]
    mix_aligned = aligned[2] if has_mixture else None

    si_sdr_est, perm = calc_si_sdr_framewise(
        ref_aligned,
        est_aligned,
        window,
        hop,
        scaling=scaling,
        compute_permutation=compute_permutation,
    )
    summary: dict[str, np.ndarray] = {"si_sdr": si_sdr_est, "perm": perm}
    if has_mixture and mix_aligned is not None:
        si_sdr_mix, _ = calc_si_sdr_framewise(
            ref_aligned,
            mix_aligned,
            window,
            hop,
            scaling=scaling,
            compute_permutation=compute_permutation,
        )
        summary["si_sdr_mix"] = si_sdr_mix
        summary["si_sdr_imp"] = si_sdr_est - si_sdr_mix
    return summary

normalize_framewise_metrics

normalize_framewise_metrics(framewise)

Normalize frame-wise metric keys while keeping backward compatibility.

The framewise SI-SDR utilities return canonical keys such as mean_si_sdr / mean_si_sdr_imp. Some downstream code expects alias keys with *_channels suffix. This function guarantees both key styles are available and fills missing aggregate keys from raw frame matrices.

Parameters:

Name Type Description Default
framewise dict[str, ndarray] | None

Framewise metric dictionary, typically produced by :func:oobss.evaluation.framewise.summarize_framewise_si_sdr. If None, None is returned.

required

Returns:

Type Description
dict[str, ndarray] | None

A normalized dictionary where numeric values are converted to numpy arrays and the following aliases are guaranteed when source keys exist: - mean_si_sdr_channels, median_si_sdr_channels - mean_si_sdr_mix_channels, median_si_sdr_mix_channels - mean_si_sdr_imp_channels, median_si_sdr_imp_channels

Source code in src/oobss/evaluation/metrics.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def normalize_framewise_metrics(
    framewise: dict[str, np.ndarray] | None,
) -> dict[str, np.ndarray] | None:
    """Normalize frame-wise metric keys while keeping backward compatibility.

    The framewise SI-SDR utilities return canonical keys such as
    ``mean_si_sdr`` / ``mean_si_sdr_imp``. Some downstream code expects alias
    keys with ``*_channels`` suffix. This function guarantees both key styles
    are available and fills missing aggregate keys from raw frame matrices.

    Parameters
    ----------
    framewise:
        Framewise metric dictionary, typically produced by
        :func:`oobss.evaluation.framewise.summarize_framewise_si_sdr`.
        If ``None``, ``None`` is returned.

    Returns
    -------
    dict[str, np.ndarray] | None
        A normalized dictionary where numeric values are converted to numpy
        arrays and the following aliases are guaranteed when source keys exist:
        - ``mean_si_sdr_channels``, ``median_si_sdr_channels``
        - ``mean_si_sdr_mix_channels``, ``median_si_sdr_mix_channels``
        - ``mean_si_sdr_imp_channels``, ``median_si_sdr_imp_channels``
    """
    if framewise is None:
        return None

    normalized = {key: np.asarray(value) for key, value in framewise.items()}

    if "si_sdr" in normalized:
        si_sdr = np.asarray(normalized["si_sdr"])
        if "mean_si_sdr" not in normalized:
            normalized["mean_si_sdr"] = np.nanmean(si_sdr, axis=1)
        if "median_si_sdr" not in normalized:
            normalized["median_si_sdr"] = np.nanmedian(si_sdr, axis=1)

    if "mean_si_sdr" in normalized and "mean_si_sdr_channels" not in normalized:
        normalized["mean_si_sdr_channels"] = np.asarray(normalized["mean_si_sdr"])
    if "median_si_sdr" in normalized and "median_si_sdr_channels" not in normalized:
        normalized["median_si_sdr_channels"] = np.asarray(normalized["median_si_sdr"])

    if "si_sdr_mix" in normalized:
        si_sdr_mix = np.asarray(normalized["si_sdr_mix"])
        if "mean_si_sdr_mix" not in normalized:
            normalized["mean_si_sdr_mix"] = np.nanmean(si_sdr_mix, axis=1)
        if "median_si_sdr_mix" not in normalized:
            normalized["median_si_sdr_mix"] = np.nanmedian(si_sdr_mix, axis=1)
        if "mean_si_sdr_mix_channels" not in normalized:
            normalized["mean_si_sdr_mix_channels"] = np.asarray(
                normalized["mean_si_sdr_mix"]
            )
        if "median_si_sdr_mix_channels" not in normalized:
            normalized["median_si_sdr_mix_channels"] = np.asarray(
                normalized["median_si_sdr_mix"]
            )

    if "si_sdr_imp" in normalized:
        si_sdr_imp = np.asarray(normalized["si_sdr_imp"])
        if "mean_si_sdr_imp" not in normalized:
            normalized["mean_si_sdr_imp"] = np.nanmean(si_sdr_imp, axis=1)
        if "median_si_sdr_imp" not in normalized:
            normalized["median_si_sdr_imp"] = np.nanmedian(si_sdr_imp, axis=1)
        if "mean_si_sdr_imp_channels" not in normalized:
            normalized["mean_si_sdr_imp_channels"] = np.asarray(
                normalized["mean_si_sdr_imp"]
            )
        if "median_si_sdr_imp_channels" not in normalized:
            normalized["median_si_sdr_imp_channels"] = np.asarray(
                normalized["median_si_sdr_imp"]
            )

    return normalized

si_bss_eval

si_bss_eval(reference_signals, estimated_signals, scaling=True)

Compute SI-SDR family metrics and permutation.

Parameters:

Name Type Description Default
reference_signals ndarray

Reference matrix shaped (n_samples, n_channels).

required
estimated_signals ndarray

Estimated matrix shaped (n_samples, n_channels).

required
scaling bool

If True, compute scale-invariant scores.

True
Source code in src/oobss/evaluation/si_sdr.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def si_bss_eval(
    reference_signals: np.ndarray,
    estimated_signals: np.ndarray,
    scaling: bool = True,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Compute SI-SDR family metrics and permutation.

    Parameters
    ----------
    reference_signals:
        Reference matrix shaped ``(n_samples, n_channels)``.
    estimated_signals:
        Estimated matrix shaped ``(n_samples, n_channels)``.
    scaling:
        If ``True``, compute scale-invariant scores.
    """
    _, n_chan = estimated_signals.shape
    rss = np.dot(reference_signals.transpose(), reference_signals)

    sdr = np.zeros((n_chan, n_chan))
    sir = np.zeros((n_chan, n_chan))
    sar = np.zeros((n_chan, n_chan))

    for ref_idx in range(n_chan):
        for est_idx in range(n_chan):
            sdr[ref_idx, est_idx], sir[ref_idx, est_idx], sar[ref_idx, est_idx] = (
                _compute_measures(
                    estimated_signals[:, est_idx],
                    reference_signals,
                    rss,
                    ref_idx,
                    scaling=scaling,
                )
            )

    row_idx, perm = _linear_sum_assignment_with_inf(-sir)
    return sdr[row_idx, perm], sir[row_idx, perm], sar[row_idx, perm], perm

summarize_framewise_si_sdr

summarize_framewise_si_sdr(ref, est, fs, *, window_sec=5.0, hop_sec=None, mixture=None, scaling=True, compute_permutation=True)

Return frame-wise SI-SDR and aggregate statistics.

Source code in src/oobss/evaluation/framewise.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def summarize_framewise_si_sdr(
    ref: np.ndarray,
    est: np.ndarray,
    fs: int,
    *,
    window_sec: float = 5.0,
    hop_sec: float | None = None,
    mixture: np.ndarray | None = None,
    scaling: bool = True,
    compute_permutation: bool = True,
) -> dict[str, np.ndarray]:
    """Return frame-wise SI-SDR and aggregate statistics."""
    window = max(1, int(round(window_sec * fs)))
    hop = window if hop_sec is None else max(1, int(round(hop_sec * fs)))
    summary = framewise_si_sdr_summary(
        ref,
        est,
        mixture=mixture,
        window=window,
        hop=hop,
        scaling=scaling,
        compute_permutation=compute_permutation,
    )

    result: dict[str, np.ndarray] = dict(summary)
    result["mean_si_sdr"] = np.nanmean(summary["si_sdr"], axis=1)
    result["median_si_sdr"] = np.nanmedian(summary["si_sdr"], axis=1)
    if "si_sdr_imp" in summary:
        result["mean_si_sdr_imp"] = np.nanmean(summary["si_sdr_imp"], axis=1)
        result["median_si_sdr_imp"] = np.nanmedian(summary["si_sdr_imp"], axis=1)
    return result

Postprocess

oobss.postprocess

Post-processing utilities.

__all__ module-attribute

__all__ = ['ParameterEstimationResult', 'PerReferenceSeparationResult', 'gaussian_source_model_weight', 'mixing_matrix_from_demixing_for_reference', 'separate_with_reference']

ParameterEstimationResult dataclass

Estimated parameters and raw demixed spectra for one method.

Source code in src/oobss/postprocess/separation.py
20
21
22
23
24
25
26
27
28
@dataclass(frozen=True)
class ParameterEstimationResult:
    """Estimated parameters and raw demixed spectra for one method."""

    method_id: str
    demixed_tf_raw: np.ndarray
    demixing_matrix: np.ndarray
    source_model: np.ndarray
    stft: ISTFTLike

PerReferenceSeparationResult dataclass

Per-reference separation outputs.

Source code in src/oobss/postprocess/separation.py
31
32
33
34
35
36
37
38
@dataclass(frozen=True)
class PerReferenceSeparationResult:
    """Per-reference separation outputs."""

    ref_mic: int
    estimate: np.ndarray
    projected_tf: np.ndarray
    mixing_matrix: np.ndarray

gaussian_source_model_weight

gaussian_source_model_weight(demixed_tfm)

Create Gaussian source-model weights from demixed spectra.

This utility computes frame/source power and broadcasts it across the frequency axis so the output can be used as a simple Gaussian source model.

Parameters:

Name Type Description Default
demixed_tfm ndarray

Demixed STFT with shape (T, F, M) where T is the number of frames, F is the number of frequency bins, and M is the number of sources/channels.

required

Returns:

Type Description
ndarray

Broadcast source model weights with shape (T, F, M) and float64-compatible numeric values.

Source code in src/oobss/postprocess/separation.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def gaussian_source_model_weight(demixed_tfm: np.ndarray) -> np.ndarray:
    """Create Gaussian source-model weights from demixed spectra.

    This utility computes frame/source power and broadcasts it across the
    frequency axis so the output can be used as a simple Gaussian source model.

    Parameters
    ----------
    demixed_tfm:
        Demixed STFT with shape ``(T, F, M)`` where ``T`` is the number of
        frames, ``F`` is the number of frequency bins, and ``M`` is the number
        of sources/channels.

    Returns
    -------
    np.ndarray
        Broadcast source model weights with shape ``(T, F, M)`` and
        ``float64``-compatible numeric values.
    """
    n_freq = max(int(demixed_tfm.shape[1]), 1)
    power_tm = (np.linalg.norm(demixed_tfm, axis=1) ** 2) / float(n_freq)
    return np.broadcast_to(power_tm[:, None, :], demixed_tfm.shape).copy()

mixing_matrix_from_demixing_for_reference

mixing_matrix_from_demixing_for_reference(demixing_matrix, *, ref_mic)

Return reference-normalized mixing matrices from demixing matrices.

The function first inverts each demixing matrix, then applies projection-back scaling such that the selected reference microphone row becomes source-consistent for each source.

Parameters:

Name Type Description Default
demixing_matrix ndarray

Demixing matrix array. Supported shapes are: - (F, M, M) for batch separators. - (T, F, M, M) for online separators.

required
ref_mic int

Reference microphone index used for projection-back normalization.

required

Returns:

Type Description
ndarray

Reference-normalized mixing matrix with the same leading dimensionality as demixing_matrix: - (F, M, M) for batch. - (T, F, M, M) for online.

Raises:

Type Description
ValueError

If demixing_matrix is not 3-D or 4-D.

Source code in src/oobss/postprocess/separation.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def mixing_matrix_from_demixing_for_reference(
    demixing_matrix: np.ndarray,
    *,
    ref_mic: int,
) -> np.ndarray:
    """Return reference-normalized mixing matrices from demixing matrices.

    The function first inverts each demixing matrix, then applies
    projection-back scaling such that the selected reference microphone row
    becomes source-consistent for each source.

    Parameters
    ----------
    demixing_matrix:
        Demixing matrix array. Supported shapes are:
        - ``(F, M, M)`` for batch separators.
        - ``(T, F, M, M)`` for online separators.
    ref_mic:
        Reference microphone index used for projection-back normalization.

    Returns
    -------
    np.ndarray
        Reference-normalized mixing matrix with the same leading dimensionality
        as ``demixing_matrix``:
        - ``(F, M, M)`` for batch.
        - ``(T, F, M, M)`` for online.

    Raises
    ------
    ValueError
        If ``demixing_matrix`` is not 3-D or 4-D.
    """
    if demixing_matrix.ndim == 3:
        return _mixing_matrix_single_frame(demixing_matrix, ref_mic=ref_mic)
    if demixing_matrix.ndim == 4:
        return np.stack(
            [
                _mixing_matrix_single_frame(demixing_matrix[t], ref_mic=ref_mic)
                for t in range(demixing_matrix.shape[0])
            ],
            axis=0,
        )
    raise ValueError(f"Unsupported demixing_matrix ndim: {demixing_matrix.ndim}")

separate_with_reference

separate_with_reference(params, *, ref_mic, n_samples)

Apply reference-wise projection-back and reconstruct time signals.

Parameters:

Name Type Description Default
params ParameterEstimationResult

Parameter estimation bundle that includes demixed STFT, demixing matrix, and an STFT object implementing istft.

required
ref_mic int

Reference microphone index for projection-back normalization.

required
n_samples int

Number of time-domain samples to keep after iSTFT. The reconstructed output is cropped to this length.

required

Returns:

Type Description
PerReferenceSeparationResult

Result containing: - estimate: separated time-domain signals (M, N). - projected_tf: projection-back scaled STFT. - mixing_matrix: reference-normalized mixing matrix.

Raises:

Type Description
ValueError

If demixing and demixed STFT shapes are unsupported or inconsistent.

Source code in src/oobss/postprocess/separation.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def separate_with_reference(
    params: ParameterEstimationResult,
    *,
    ref_mic: int,
    n_samples: int,
) -> PerReferenceSeparationResult:
    """Apply reference-wise projection-back and reconstruct time signals.

    Parameters
    ----------
    params:
        Parameter estimation bundle that includes demixed STFT, demixing matrix,
        and an STFT object implementing ``istft``.
    ref_mic:
        Reference microphone index for projection-back normalization.
    n_samples:
        Number of time-domain samples to keep after iSTFT. The reconstructed
        output is cropped to this length.

    Returns
    -------
    PerReferenceSeparationResult
        Result containing:
        - ``estimate``: separated time-domain signals ``(M, N)``.
        - ``projected_tf``: projection-back scaled STFT.
        - ``mixing_matrix``: reference-normalized mixing matrix.

    Raises
    ------
    ValueError
        If demixing and demixed STFT shapes are unsupported or inconsistent.
    """
    projected_tf, mixing_matrix = _project_for_reference(
        demixed_tf_raw=params.demixed_tf_raw,
        demixing_matrix=params.demixing_matrix,
        ref_mic=ref_mic,
    )
    estimate = np.real(params.stft.istft(projected_tf.transpose(2, 1, 0)))
    estimate = np.asarray(estimate, dtype=np.float64)[:, : int(n_samples)]
    return PerReferenceSeparationResult(
        ref_mic=int(ref_mic),
        estimate=estimate,
        projected_tf=projected_tf,
        mixing_matrix=mixing_matrix,
    )

Signal

oobss.signal

Signal processing utilities.

__all__ module-attribute

__all__ = ['STFTPlan', 'build_stft']

STFTPlan dataclass

STFT configuration shared across benchmark runners.

Source code in src/oobss/signal/stft.py
10
11
12
13
14
15
16
@dataclass(frozen=True)
class STFTPlan:
    """STFT configuration shared across benchmark runners."""

    fft_size: int
    hop_size: int
    window: str

build_stft

build_stft(plan, sample_rate)

Build a :class:scipy.signal.ShortTimeFFT instance from plan.

Source code in src/oobss/signal/stft.py
19
20
21
22
def build_stft(plan: STFTPlan, sample_rate: int) -> ShortTimeFFT:
    """Build a :class:`scipy.signal.ShortTimeFFT` instance from ``plan``."""
    win = get_window(plan.window, plan.fft_size, fftbins=True)
    return ShortTimeFFT(win=win, hop=plan.hop_size, fs=sample_rate)

Visualization

oobss.visualization

Visualization utilities for inspection and reporting.

__all__ module-attribute

__all__ = ['plot_nmf_factors', 'save_channel_spectrograms']

plot_nmf_factors

plot_nmf_factors(x, basis, activations, *, vmin=None, vmax=None)

Plot NMF factors and reconstructed spectrogram in a compact grid.

Source code in src/oobss/visualization/spectrogram.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def plot_nmf_factors(
    x: np.ndarray,
    basis: np.ndarray,
    activations: np.ndarray,
    *,
    vmin: float | None = None,
    vmax: float | None = None,
) -> plt.Figure:
    """Plot NMF factors and reconstructed spectrogram in a compact grid."""
    if x.ndim != 2:
        raise ValueError("x must be a 2-D array shaped (n_freq, n_frame)")
    if basis.ndim != 2:
        raise ValueError("basis must be a 2-D array shaped (n_freq, n_components)")
    if activations.ndim != 2:
        raise ValueError(
            "activations must be a 2-D array shaped (n_components, n_frame)"
        )

    n_freq, n_frame = x.shape
    _, n_components = basis.shape
    if activations.shape != (n_components, n_frame):
        raise ValueError(
            "activations shape must match (n_components, n_frame): "
            f"got {activations.shape}, expected {(n_components, n_frame)}"
        )

    boxwidth = 0.1
    line_width = 0.2
    freq_axis = np.arange(n_freq)
    time_axis = np.arange(n_frame)

    fig_width, fig_height = plt.gcf().get_size_inches()
    margin_ratio = 1.0 / 4.0
    dx = fig_width * margin_ratio
    dy = fig_height * margin_ratio

    fig, axes = plt.subplots(
        nrows=n_components + 1,
        ncols=n_components + 1,
        figsize=(dx + fig_width, dy + fig_height),
        width_ratios=[dx / n_components] * n_components + [fig_width],
        height_ratios=[fig_height] + [dy / n_components] * n_components,
    )

    reconstructed = np.maximum(basis @ activations, 1.0e-12)
    axes[0, n_components].imshow(
        10.0 * np.log10(reconstructed),
        vmin=vmin,
        vmax=vmax,
        rasterized=True,
    )

    for k in range(n_components):
        axes[0, k].plot(-basis[:, k], freq_axis, linewidth=line_width)
        axes[k + 1, n_components].plot(
            time_axis, activations[k, :], linewidth=line_width
        )

    for row in range(n_components):
        for col in range(n_components):
            fig.delaxes(axes[row + 1, col])

    for ax in axes.flat:
        ax.tick_params(
            axis="both",
            length=0,
            width=0,
            labelbottom=False,
            labelleft=False,
        )
        for spine in ax.spines.values():
            spine.set_linewidth(boxwidth)

    return fig

save_channel_spectrograms

save_channel_spectrograms(spec, name, outdir, *, vmin=-40.0, vmax=20.0)

Save one spectrogram image per channel from a frame-first spectrogram.

Source code in src/oobss/visualization/spectrogram.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def save_channel_spectrograms(
    spec: np.ndarray,
    name: str,
    outdir: str | Path,
    *,
    vmin: float = -40.0,
    vmax: float = 20.0,
) -> list[Path]:
    """Save one spectrogram image per channel from a frame-first spectrogram."""
    if spec.ndim != 3:
        raise ValueError("spec must be 3-D shaped (n_frame, n_freq, n_channel)")

    output_dir = Path(outdir)
    output_dir.mkdir(parents=True, exist_ok=True)

    saved: list[Path] = []
    n_channel = spec.shape[-1]
    for ch in range(n_channel):
        path = output_dir / f"{name}-{ch}.pdf"
        power = np.maximum(np.abs(spec[:, :, ch].T), 1.0e-12)
        plt.imshow(10.0 * np.log10(power), vmin=vmin, vmax=vmax, rasterized=True)
        plt.axis("off")
        plt.savefig(path)
        plt.close()
        saved.append(path)

    return saved