
### Useful functions regarding an RBM neural network.
### In particular, gradient procedures have to be used during training.

from pylab     import newaxis, amap, binary_repr, zeros, log, exp, log_sum_exp, sigmoid, stochastic, Rsigmoid, rand, randn
from pylab     import dot
from std.basic import Rsigmoid_no_GPU
import numpy    as np
import gnumpy   as gpu
import time

def int_to_bin(i, v):
    assert(0 <= i < 2**v)
    return (amap(int,list(binary_repr(i).zfill(v)))[newaxis,:]).astype('f')

def log1exp(b):
    return log(1+exp(-abs(b))) + b * (b>0).astype('f');

def free_energy(w, x):
    return log1exp(w*x).sum(1) + (w[1][newaxis,:] * x).sum(1)

def brute_force_Z(W):
    """
    For binary visibles and hiddens.
    
    """
    v, h = W.v, W.h
    if v > h:
        W1 = 0 * W
        W1.w = [W1.w[0].T(), W1.w[2], W1.w[1]]
        return brute_force_Z(W1)
    if v > 20:
        print "v = ", v, " may take a long time. "
    Z = zeros(2**v)
    for i in xrange(2**v):
        V    = int_to_bin(i, v)
        Z[i] = free_energy(W, V)
    z = log_sum_exp(Z)
    return z
    
def brute_force_Z_vis_gauss(W):
    v,h = W.v, W.h
    Z = zeros(2**h)
    for i in xrange(2**h):
        H = int_to_bin(i, h)
        b = W.T() * H
        Z[i] = .5 * dot(b, b.T) + dot(H, W[2])
    return log_sum_exp(Z)

def rbm_grad_exact(W, x, vis_gauss = False):
    batch_size = float(len(x))
    v, h = W.v, W.h
    G = 0 * W
    H = sigmoid(W * x)
    G += 1./batch_size * W.outp(x, H) 
    if not vis_gauss:
        Z = brute_force_Z(W)
        def prob(V):
            return exp( free_energy(W, V)[0] - Z )
        for i in xrange(2**v):
            V = int_to_bin(i, v)
            H = sigmoid(W * V)
            G -= prob(V) * W.outp(V, H) 
            loss = - ( free_energy(W, x).mean(0) - Z )
    else:
        Z = brute_force_Z_vis_gauss(W)
        def prob(H):
            b = W.T() * H
            z = .5 * dot(b, b.T) + dot(H, W[2])
            return float(exp(z - Z))
        for i in xrange(2**h):
            H = int_to_bin(i, h)
            b = W.T() * H
            G -= prob(H) * W.outp(b, H)
        loss = - (-.5* amap(dot, x, x).mean() + free_energy(W, x).mean(0) - Z)
    return batch_size * G, dict(loss = batch_size * loss)            


def rbm_grad_cd(W, x, cd, vis_gauss = False):
    batch_size = float(len(x))   ### sequence length
    v, h, = W.v, W.h
    V = x

    #---------------------------------#GPU
    ### Positive phase:
    H = (W * x).logistic()
    G = W.outp(V, H)
    separate = 1 * H

    ### Negative phase (generate from the model):
    for g in range(cd):
        H = stochastic(H)
        if vis_gauss:
            V = W.T() * H + gpu.randn(batch_size, v)
        else:
            V = Rsigmoid(W.T() * H)
        H = (W * V).logistic()
    G -= W.outp(V, H)
    separate -= H
    #---------------------------------#

    ### loss = descrepancy between observed data and generated data
    loss = abs((V - x).as_numpy_array()).sum()
    return G, separate, dict(loss = loss)


def rbm_grad_cd_pre(W, x, cd, vis_gauss = False):
    batch_size = float(len(x))   ### sequence length
    v, h, = W.v, W.h
    V = x
    
    #---------------------------------#GPU
    H = (W * x).logistic()
    G = W.outp(V, H)
    for g in range(cd):
        H = stochastic(H)
        if vis_gauss:
            V = W.T() * H + gpu.randn(batch_size, v)
        else:
            V = Rsigmoid(W.T() * H)     ### W' * H
        H = (W * V).logistic()
    G -= W.outp(V, H)
    #---------------------------------#
    
    loss = abs((V - x).as_numpy_array()).sum()
    return G, dict(loss = loss)



def sample(W, g, batch_size, vis_gauss = False):
    v, h = W.v, W.h
    
    #---------------------------------## GPU
    #from time import clock
    #from time import sleep
    #s = int(time.time() * 100000) - int(time.time()) * 100000
    np.random.seed()
    #time.sleep(.0001)
    V_randomized = np.random.rand(batch_size, v)
    
    V = gpu.as_garray(V_randomized)
    for gg in range(g):
        H = Rsigmoid(W * V)
        V = Rsigmoid(W.T() * H)
        #H = (W * V).logistic()
        #V = (W.T() * H).logistic()
    
    ## return also the last Hidden units real-valued activations:
    V_real = (W.T() * H).logistic()
    H_real = (W * V_real).logistic()
    #H_real = (W * V).logistic()
    
    #---------------------------------#
    return V, H, V_real, H_real



def sample_last_mf_no_GPU_no_stochastic(W, g, batch_size, vis_gauss = False):
    v, h = W.v, W.h
    V = rand(batch_size, v)
    #V = gpu.as_garray(V)
    for gg in range(g):
        H = sigmoid(gpu.as_numpy_array(W * V))
        V = sigmoid(gpu.as_numpy_array(W.T() * H))
        #H = Rsigmoid(W * V)
        #V = Rsigmoid(W.T() * H)
    #V = sigmoid(W.T() * H)
    return V, H


## used to match the successor distribution:
def sample_last_mf(W, g, batch_size, vis_gauss = False):
    v, h = W.v, W.h
    #---------------------------------#GPU
    V = gpu.rand(batch_size, v)
    for gg in range(g):
        H = Rsigmoid(W * V)
        V = Rsigmoid(W.T() * H)
    V = (W.T() * H).logistic()
    #---------------------------------#
    return V, H


##def sample_last_mf_no_GPU(W, g, batch_size, vis_gauss = False):
##    v, h = W.v, W.h
##    V = rand(batch_size, v)
##    for gg in range(g):
##        H = Rsigmoid_no_GPU(gpu.as_numpy_array(W * V))
##        if vis_gauss:
##            V = gpu.as_numpy_array(W.T() * H) + randn(batch_size, v)
##        else:
##            V = Rsigmoid_no_GPU(gpu.as_numpy_array(W.T() * H))
##    if vis_gauss:	
##        V = gpu.as_numpy_array(W.T() * H)
##    else:
##        V = sigmoid(gpu.as_numpy_array(W.T() * H))
##    return V, H







## To check!!
def sample_and_store(W, g, batch_size, vis_gauss = False):
    v, h = W.v, W.h
    np.random.seed()
    V_randomized = np.random.rand(batch_size, v)
    V_stoch = gpu.as_garray(V_randomized)
    for gg in range(g):
        H_stoch = Rsigmoid(W * V_stoch)
        ## store the initial value from which Gibbs sampling started
        if gg == 0:
            H_stoch_init = H_stoch
        V_stoch = Rsigmoid(W.T() * H_stoch)
    H = (W * V_stoch).logistic()
    V = (W.T() * H).logistic()        
    return V, V_stoch, H_stoch, H_stoch_init

def sample_from_previous(W, H_stoch_init, g, batch_size):
    v, h = W.v, W.h
    for gg in range(g):
        ## load the initial value from which Gibbs sampling started
        if gg == 0:
            H = H_stoch_init
        else:
            H = Rsigmoid(W * V_stoch)
        V_stoch = Rsigmoid(W.T() * H)
    H = (W * V_stoch).logistic()
    V = (W.T() * H).logistic()
    return V, V_stoch

def sample_from_previous_no_stoch(W, H_stoch_init, g, batch_size):
    v, h = W.v, W.h
    for gg in range(g):
        ## load the initial value from which Gibbs sampling started
        if gg == 0:
            H = H_stoch_init
        else:
            H = (W * V).logistic()
        V = (W.T() * H).logistic()
    return V
