
### General trainer class.

from   trainers.common_functions import *
from   pylab import *
from   std   import *
from   collections import defaultdict
import std
import pickle
import data.words           as words
import empirical_evaluation as err
import gnumpy               as gpu
import numpy                as np

class std_trainer:
    """
    This is a sequential, non-parallel (! modified: now it supports mini-batch learning !),
    most general version of a trainer of stochastic gradient descent problems with momentum.
    (URGENT FUTURE WORK: make use of some second order information, by using stochastic newton's method
    approximating the inverse hessian).

    It is given functions that gives training and validation cases,
    and functions that return unnormalized gradients and unnormalized loss/gain vaules.
    The gradient is combined with the momentum, which is then added to the parameters.
    It is recommended that when the dataset is finite, then the train_function should
    be implemented with a global variable that uses a random permutation estimator.
    This way, there will probably be less variance in the nodes.

    As the class implements grad_check, when used, it assums that the entry 'loss' is
    given in the output of the unnorm_valid_loss function; The gradient function returns
    direction in the steepest descent in loss, thus being its negative derivative.

    The learning rate, weight decay and momentum are given in their most generality as
    functions; however, a dictionary of values (that represent the 'vertices' of a sum of step
    functions or a plain constant can be added as well.

    Given the save path of the simulation and its name, the object automatically loads
    itself provided it can find the file. In addition, the saveing interval parameter
    determines how often the simulation is to be saved.

    This can be used for both finetuning and pretraining, where in pretraining, the gradient-computing
    function accesses the inference weights and the validation could be the reconstruction error
    on the test set, as well as possibly the free energy. 

    """

    def __init__(self,
                 name,          ### of simulation
                 path,          ### path to save/load simulation
                 dataset_path,  ### path to load the training set

                 W,
                 ### ff_net of weight parameters; can be anything.
                 ### if W has few "essential parameters", then by
                 ### providing a member function "W.compress()" that
                 ### returns a -REFERENCE- to the essential parameters,
                 ### this could allow for faster saving and loading.

                 unnorm_grad_fn,
                 ### gradient function:
                 ### input:  whatever data getter gives
                 ### output: (unnormalized) gradient for each W 
                 ### this could allow for more general backpropagations.
                 ### in addition, grad_fn returns an unnormalized dictionary of statistics:
                 ### losses. In particular, there should be a 'loss' which is to be minimized.
                 ### (containing various reconstruction errors)

                 unnorm_valid_loss,
                 ### unnorm_valid_loss is a function that outputs some validation statistics
                 ### which are averaged over all the batches. 
                 ### The results of unnorm_valid_loss are averaged over all the
                 ### test cases, stored and printed out (averaging takes different batch sizes into account)

                 data_fn, 
                 ### data_fn returns some object with data that is fed to unnorm_grad_fn
                 ### as well as the length of this object. This is relevant because oftentimes
                 ### this object is a tuple of inputs and labels. 
                 ### typically, data_fn might be implemented by creating a random permutation
                 ### of the datapoints and using it internally, perhaps by updating a global variable.

                 valid_data_fn,
                 ### valid_data_fn is a function class with 3 member functions:
                 ### init(): sets a new random permutation to 0
                 ### next(): returns the next batch of examples
                 ###         None if nothing is left.

                 num_iter,            ### how many weight updates (update follows after batch)
                 minibatch_size = 1,  ### number of training sequences to use in parallel
                 save_iter = 200000,### how many iterations between saves
                 test_iter = None,    ### 'None' means no validation
                 min_D_len = 1,       ### if the input sequence is any shorter, we ignore it

                 weight_update_freq = 1,
                 ### we could update the weights every datapoint (when the data is iid)
                 ### But to reduce variance (in case the data is a sequence, for instance),
                 ### we can collect a gradient over several iterations;
                 ### This can sometimes result in a significant reduction
                 ### in update variance.
                 
                 weights_constraints = lambda x : x,
                 ### Weights_constraints: W = weights_constraints(W) is executed after each weight update.
                 ### Useful for printing out stuff as well as enforcing some hard constraints, eg
                 ### to upper-bound the norm of W.
                 
                 weights_constraints_iter = lambda x, i : x,
                 ### This is another weights constraints function that also gets an iteration number.
                 ### it is convenient if we want to do things like restricting the amount of hidden
                 ### units that get updated.

                 CD_n_sched = [100, 10000, 1000000],
                 ## schedule the Contrastive Divergence parameter
                 ##(CD-5 until the first bound, then CD-10 until the second, then CD-25 until the third and finally CD-40)

                 grad_constraints_iter = lambda x, i : x,
                 ### While this one is to bound the gradient, in case we wish to freeze some weights
                 ### instead of set them to a prespecified vaule, as done by the weight constraints.

                 LR = .01,   ### learning rate
                 WD = 0, 
                 momentum = .9,
                 ### Thus, the update equations are the following:
                 ### D  = NORMALIZED_GRAD
                 ### V *= momenutm
                 ### V += LR * D
                 ### W += V

                 const_mean_batch_size = None,
                 ### If we wish not to normalize by batch size,
                 ### (eg if the batches are of different sizes) then
                 ### it is better to use the mean batch size.

                 len_fn = None,
                 max_length = None,
                 realtime = False,
                 lamb = 1,
                 ### How to compute the length of a batch
                 ### in case we wish to normalize by batch length.

                 plot_iter = 4,   ### plot every n complete swaps
                 plot = False,       ### Whether to plot at all
                 plotter = None,    ### should a specific plotting function be used?

                 backup_name = None,
                 ### If we save something and wish to resume from there,
                 ### we can use backup_name for initalization.
                 
                 train_stats_sparsity = 100000,
                 train_print_sparsity = 100000,
                 pre_training = False
                 ):


        ## initialize the trainer
        self.min_D_len = min_D_len
        self.max_length = max_length
        self.realtime = realtime
        self.lamb = lamb
        if path[-1] != '/': path = path + '/'   ### normalize path
        self.name, self.path, self.dataset_path = name, path, dataset_path
        if backup_name == None: backup_name = name
        self.W = W
        self.valid_data_fn, self.data_fn = valid_data_fn, data_fn
        self.unnorm_valid_loss, self.unnorm_grad_fn = unnorm_valid_loss, unnorm_grad_fn
        self.LR, self.WD, self.momentum = map(make_parameter, (LR, WD, momentum))
        self.num_iter, self.save_iter =  num_iter + 1, save_iter
        if test_iter == None: test_iter = inf
        self.test_iter = test_iter
        if plot_iter == None: self.plot_iter = inf
        else: self.plot_iter = plot_iter
        self.iter, self.train_stats, self.valid_stats = 0, {}, {}
        if train_stats_sparsity == None: self.train_stats_sparsity = 1
        else: self.train_stats_sparsity = train_stats_sparsity
        self.train_stats_counter = 0
        if train_print_sparsity == None: self.train_print_sparsity = 1
        else: self.train_print_sparsity = train_print_sparsity
        self.weight_update_freq = weight_update_freq
        self.dW_acc = 0 * W
        self.batch_size_acc = 0
        self.const_mean_batch_size = const_mean_batch_size
        self.minibatch_size = minibatch_size
        self.pre_training = pre_training
        def get_essential_W(W):
            try:
                ### compress returns a reference, so by setting it to something
                ### the value of the original array changes.
                return W.compress()
            except:
                return W

        self.get_essential_W = get_essential_W            
        self.V = 0 * W

        ## try to load previously saved data
        try:
            try:
                print '\nAttempting to load %s' % (path + name)
                file = open(path + name, 'r')
                data = pickle.load(file)
            except IOError:
                print 'file does not exist. Loading backup file...'
                try:
                    data = load(path + backup_name)
                    print 'loaded backup name %s' % backup_name
                except IOError:
                    print  'backup file does not exist. '
            WE = get_essential_W(self.W)
            WE.set(data['W'])
            VE = get_essential_W(self.V)
            VE.set(data['V'])
            self.train_stats = data['train_stats']
            self.valid_stats = data['valid_stats']
            self.iter = data['iter']
            self.train_stats_counter = data['train_stats_counter']
            print 'Run %s: load successful.' % name
        except:
            print 'Run %s: starting a fresh run.' % name

        self.weights_constraints, self.weights_constraints_iter = weights_constraints, weights_constraints_iter
        self.CD_n_sched = CD_n_sched
        self.grad_constraints_iter = grad_constraints_iter
        self.plot, self.plotter = plot, plotter
        
        if len_fn == None:
            def my_len_fn(x):
                if type(x) == tuple:
                    return len(x[0])
                else:
                    return len(x)
            self.len_fn = my_len_fn
        else:
            self.len_fn = len_fn


    ###### MAIN TRAINING LOOP ######
    
    def train(self):

        f = open(self.dataset_path, 'r')
        info = f.readline() ## first line that contains information on lengths
        lengths = info.split('\t')
        l_3 = int(lengths[0])
        l_4 = int(lengths[1])
        l_5 = int(lengths[2])
        l_6 = int(lengths[3])
        l_7 = int(lengths[4])
        line = f.readline()
        training_set = line.split('\t')
        training_set = np.array(training_set)
        f.close()
        
        tr_size = training_set.shape[0]
        tot_minibatch = tr_size / self.minibatch_size
        ## convert to vector representation
        _Data = gpu.zeros((tr_size, self.max_length + 1, 27))
        for i in range(tr_size):
            _Data[i, :, :] = self.data_fn(training_set[i])
        
        minibatch_schedule = np.random.permutation(tot_minibatch)
        mini = 0
        complete_swap = 1
        minibatch_loss = defaultdict(int)
        mini_loss = []
        update = False
        ion() ## matplotlib interactive mode

        print '\n*** Training parameters ***\n'
        if self.realtime:
            print 'Real-time implementation\n'
        else:
            print 'Standard implementation\n'
        print 'Hidden units:     ', self.W.h
        print 'LR:               ', round(self.LR(self.iter), 4)
        print 'WD:               ', self.WD(self.iter)
        print 'Minibatch size:   ', self.minibatch_size
        print 'Momentum:         ', self.momentum(self.iter)
        print 'Total iterations: ', self.num_iter
        print 'Saved iterations: ', self.iter, '\n'

        while self.iter < self.num_iter:

            ## update CD-k parameter according to iteration number
            if self.iter >= self.CD_n_sched[2]:
                self.W.CD_n = 40
            elif self.iter >= self.CD_n_sched[1]:
                self.W.CD_n = 25
            elif self.iter >= self.CD_n_sched[0]:
                self.W.CD_n = 10
            else:
                self.W.CD_n = 5

##            if self.iter % 20000 == 0:
##                update = True
##            else:
##                update = False

            W, V, LR, momentum, WD = self.W, self.V, self.LR, self.momentum, self.WD

            if self.iter == 0:
                ## print weigths statistics
                VH = gpu.as_numpy_array(W.w[0].w[0][0])
                VH_abs = np.abs(VH)
                print 'VH_abs Avg:\t%.3f'   % np.mean(VH_abs)
                print 'VH_abs Max:\t%.3f'       % np.max(VH_abs)
                print 'VH_abs Min:\t%.3f'       % np.min(VH_abs)
                print ''
                HH = gpu.as_numpy_array(W.w[1].w[0][0])
                HH_abs = np.abs(HH)
                print 'HH_abs Avg:\t%.3f'   % np.mean(HH_abs)
                print 'HH_abs Max:\t%.3f'       % np.max(HH_abs)
                print 'HH_abs Min:\t%.3f'       % np.min(HH_abs)
                print '\n'

            start_index = minibatch_schedule[mini] * self.minibatch_size
##            minibatch = training_set[start_index : start_index + self.minibatch_size]
##            sequence_length = len(minibatch[0])
##            _D = gpu.zeros((self.minibatch_size, sequence_length + 1, 27))
##            for i in range(self.minibatch_size):
##                _D[i, :, :] = self.data_fn(minibatch[i])
            
            sequence_length = len(training_set[start_index])
            _D = _Data[start_index : start_index + self.minibatch_size, 0:sequence_length+1, :]
            mini += 1
            self.iter += self.minibatch_size

            mb_size, seq_length, v_size = _D.shape
            if self.realtime: pref = 1              ## loop also through prefixes for each minibatch! (without adding '$')
            else:             pref = seq_length
            
            while pref < (seq_length + 1):
                (dW_loc, stats_dict, lamb) = self.unnorm_grad_fn(W, _D[:, 0:pref, :], self.lamb)
                self.lamb = lamb
                self.batch_size_acc += pref
                self.dW_acc         += dW_loc
                V.__imul__(momentum(self.iter))
                decay = W.__rmul__(WD(self.iter))
                V += ((1/float(self.batch_size_acc)) * self.dW_acc - decay).__rmul__(LR(self.iter))
                self.dW_acc *= 0
                self.batch_size_acc = 0
                W += V
                
                if pref == seq_length:
                    minibatch_loss[complete_swap] += stats_dict['loss']  ## calculate loss only on real (i.e. complete) sequences
                pref += 1


            ## after a complete sweep over all mini batches, shuffle them and start again
            if mini == tot_minibatch:
                mini = 0
                minibatch_schedule = permutation(tot_minibatch)
                np.random.shuffle(training_set[0:l_3])
                np.random.shuffle(training_set[l_3:l_3+l_4])
                np.random.shuffle(training_set[l_3+l_4:l_3+l_4+l_5])
                np.random.shuffle(training_set[l_3+l_4+l_5:l_3+l_4+l_5+l_6])
                np.random.shuffle(training_set[l_3+l_4+l_5+l_6:l_3+l_4+l_5+l_6+l_7])
                
                ## calculate loss value of all minibatches:
                minibatch_loss[complete_swap] *= (self.minibatch_size / float(100))
                print '  ', minibatch_loss[complete_swap]
                if complete_swap % self.plot_iter == 0:
                    mini_loss.append(minibatch_loss[complete_swap])
                    x = arange(0, (complete_swap / self.plot_iter))
                    y = np.array(mini_loss)
                    plot(x, y, 'b-')
                    suptitle(self.name, fontsize=12)
                    xlabel('epochs')
                    ylabel('loss')
                    ylim((200, 600))
                    draw()
                complete_swap += 1
                
##                ## print weigths statistics
##                print 'LR: ', LR(self.iter)
##                VH = gpu.as_numpy_array(W.w[0].w[0][0])
##                VH_abs = np.abs(VH)
##                print 'VH_abs Avg:\t%.3f'   % np.mean(VH_abs)
##                dVH = gpu.as_numpy_array(V.w[0].w[0][0])
##                dVH_abs = np.abs(dVH)
##                print 'dVH_abs Avg:\t%.5f'   % np.mean(dVH_abs)
##                print ''
##                HH = gpu.as_numpy_array(W.w[1].w[0][0])
##                HH_abs = np.abs(HH)
##                print 'HH_abs Avg:\t%.3f'   % np.mean(HH_abs)
##                dHH = gpu.as_numpy_array(V.w[1].w[0][0])
##                dHH_abs = np.abs(dHH)
##                print 'dHH_abs Avg:\t%.5f'   % np.mean(dHH_abs)
##                print '\n'

            ## normalize the training errors over sequences lengths
            for k in stats_dict.keys():
                stats_dict[k] /= float(sequence_length)
            stats_dict['batch_size'] = sequence_length

            if self.train_stats_counter % self.train_stats_sparsity == 0:
                self.train_stats[self.iter] = stats_dict
            self.train_stats_counter += 1

            if (self.iter % self.train_print_sparsity == 0) or (self.iter == self.minibatch_size):
                show_stats('stats    :', stats_dict)

            #if self.iter % self.plot_iter == 0 and self.plot:
                #self.show_W()

            ## validation
##            if self.iter % self.test_iter == 0:
##                #self.valid_stats[self.iter] = self.compute_validation_cost()  ### Validation
##                empirical_error = err.match_successor_distribution(self, self.dataset_path, 100)
##                print '\n\n*** Empirical error: ', empirical_error, ' ***\n\n'

            if self.iter % self.save_iter == 0:
                print 'saving...'
                fig_name = self.name + '.png'
                savefig(fig_name)
                self.save()
                

        ### end of main for loop.




    def plot_loss(self, x, y, plot_params='b-'):
        from matplotlib import pyplot as PLT
##        i = 0
##        x = np.zeros(len(values))
##        y = np.zeros(len(values))
##        for k, v in values.items():
##            x[i] = k
##            y[i] = v
##            i += 1
        plot(x, y, plot_params)
        #PLT.ylim( (0, 3) )
        draw()


    def check_grad(self, num_tries = 10, eps = 1e-6, inds = []):
        """
        This function verifies that the gradient and the loss function agree.
        There is a cavet to this function:
        the gradient function returns the direction that minimizes the loss function;
        therefore, the gradient function is actually the negative gradient function.
        """
        W_copy = 1 * self.W
        def get(W):
            for x in inds:
                W = W[x]
            return W
        ## select a random batch from the validation score.
        d = self.valid_data_fn()
        dW, grad_stats = self.unnorm_grad_fn(W_copy, d)
        Actual_W = get(W_copy).flatten()
        Ref_W    = get(W_copy)   ### the reference.
        len_W = len(Actual_W.flatten())

        def set_to_W_copy(Actual_W):
            try:
                Ref_W.set(Ref_W.unpack(Actual_W))
            except AttributeError:
                try:
                    Ref_W[:] = Actual_W.reshape(shape(Ref_W))
                except AttributeError:
                    Ref_W[:] = Actual_W

        for i in range(num_tries):
            k = int(multinomial(len_W))
            Actual_W[k] += eps
            set_to_W_copy(Actual_W)
            l1 = self.unnorm_valid_loss(W_copy , d)['loss']
            print 'l1=',l1
            Actual_W[k] -= 2*eps
            set_to_W_copy(Actual_W)
            l2 = self.unnorm_valid_loss(W_copy, d)['loss']
            Actual_W[k] += eps
            set_to_W_copy(Actual_W)
            estimated_grad = (l1-l2)/(2*eps)
            compute_grad   = - get(dW).flatten()[k]
            print i,'c',compute_grad,'e',estimated_grad,': diff=', compute_grad-estimated_grad


    ## this allows for using other data sources, too!
    def compute_validation_cost(self, valid_data_fn = None):
        W = self.W
        if valid_data_fn == None: valid_data_fn = self.valid_data_fn
        print 'testing...'
        valid_tot_num_cases = 0
        unnorm_stat_sum = dict()
        stat_mean       = dict()
        self.valid_data_fn.init()
        i = 0
        while 1:
            printf('valid: i=%d      \r' % i)
            i += 1
            d = self.valid_data_fn.next()
            if d == None:
                break  ### then we are truly done.
            len_d = self.len_fn(d)
            if len_d < self.min_D_len:
                continue
            valid_tot_num_cases += len_d
            v_stat = self.unnorm_valid_loss(W, d)
            #### We update the sum of all the losess.
            for key in v_stat.iterkeys():
                try:
                    unnorm_stat_sum[key] += v_stat[key]
                except KeyError:
                    unnorm_stat_sum[key]  = v_stat[key]

        ## Now we normalize by the total number of examples processed.
        for key in v_stat.iterkeys():
            stat_mean[key] = unnorm_stat_sum[key] / float(valid_tot_num_cases)
            
        stat_mean['batch_size'] = int(valid_tot_num_cases)
        return stat_mean


    def save(self):
        """
        To save the current state manually.
        """
        to_save = dict(train_stats = self.train_stats,
                       valid_stats = self.valid_stats,
                       iter = self.iter - self.minibatch_size,
                       W    = self.get_essential_W(self.W),
                       V    = self.get_essential_W(self.V),
                       train_stats_counter = self.train_stats_counter)
        file = open(self.path + self.name, 'w') #+ '_' + str(self.iter)
        pickle.dump(to_save, file)
        file.close()
        del file
        print 'training saved.\n'


    def show_W(self):
        W = self.W
        try:
            self.plotter.__call__(W)
        except:
            try:
                show(self.W[0].show_W())
            except:
                show(self.W.show_W(),-1,1)


    def plot_valid_stats(self, key='loss',plot_params='-',log_scale=False, down_sampling=1):
        Xs = sort(array(self.valid_stats.keys()))
        Ys = 0 * Xs.astype('d')
        for i in range(len(Xs)):
            Ys[i] = self.valid_stats[int(Xs[i])][key]
        if log_scale:
            Xs = log10(Xs)
        plot(Xs[::down_sampling], Ys[::down_sampling], plot_params)


def show_normalized_stats(name, stats_dict, total_batch_size):
    print name
    for key in stats_dict.iterkeys():
        if key!='batch_size':
            print key,'=', stats_dict[key]/float(total_batch_size)


def show_stats(name, stats_dict):
    print name, ' :::',
    i = 0
    for key in stats_dict.iterkeys():
        if key != 'batch_size':
            print key, '=', stats_dict[key]
            if i == 3:
                i = 0
                print ''
            i += 1
    print ''
