Source code for updates
# -*- coding: utf-8 -*-
"""
updates.py
~~~~~~~~~~
.. topic:: Contents
The update module regroupse the update functions used for the beta NMF"""
import theano.tensor as T
import theano
from theano.ifelse import ifelse
[docs]def beta_H(X, W, H, beta):
"""Update activation with beta divergence
Parameters
----------
X : Theano tensor
data
W : Theano tensor
Bases
H : Theano tensor
activation matrix
beta : Theano scalar
Returns
-------
H : Theano tensor
Updated version of the activations
"""
up = ifelse(
T.eq(beta, 2),
(T.dot(X, W)) / (T.dot(T.dot(H, W.T), W)),
(T.dot(T.mul(T.power(T.dot(H, W.T), (beta - 2)), X), W)) /
(T.dot(T.power(T.dot(H, W.T), (beta-1)), W)))
return T.mul(H, up)
[docs]def beta_W(X, W, H, beta):
"""Update bases with beta divergence
Parameters
----------
X : Theano tensor
data
W : Theano tensor
Bases
H : Theano tensor
activation matrix
beta : Theano scalar
Returns
-------
W : Theano tensor
Updated version of the bases
"""
up = ifelse(
T.eq(beta, 2),
(T.dot(X.T, H)) / (T.dot(T.dot(H, W.T).T, H)),
(T.dot(T.mul(T.power(T.dot(H, W.T), (beta - 2)), X).T, H)) /
(T.dot(T.power(T.dot(H, W.T), (beta-1)).T, H)))
return T.mul(W, up)