"""
 Recurrent Temporal Boltzmann Machine
 
 Application and analysis on graphotactic learning task.

 Author:                Alberto Testolin
 Original Source Code:  Ilya Sutskever
 Last modified:         16/01/2013

"""


### Main file of the program.
### First it imports useful modules which contain all the dependencies and functions definitions.
### Then it starts the training session (or loads a previous model) and eventually performs further analyses.

import pickle                                       ## I/O from files
import time                                         ## used to calculate running time
import rbm                                          ## class which implements an RBM
import npmat                                        ## module to simulate a GPU if you don't have one
import gnumpy                as gpu                 ## module to parallelize computation exploiting GPU's power
import numpy                 as np                  ## NumPy library for scientific computations
import data.words            as words               ## function which generates the training data
import data.trie             as trie                ## prefix tree data structure
import empirical_evaluation  as e                   ## used to calculate empirical error
import std.basic             as bas                 ## general-utility implemented functions (e.g. plotting procedures)
import rnn_trbm.hidden_analysis as hidden_analyse   ## module to perform analyses on hidden units
from   mats.bias_mat         import bias_mat        ## class which implements a weights matrix with bias
from   mats.std_mat          import std_mat         ## class which implements a basic weights matrix
from   rnn_trbm.rnn_trbm     import rnn_trbm        ## class which implements an RTRBM
from   trainers.std_trainer  import std_trainer     ## general trainer
import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    def LR(x):                                          ## adaptive learning rate
        return initial_lr * (1- float(x)/num_iter)
    ##    if x > (0.1 * num_iter):
    ##        return initial_lr * (1- float(x - (0.1 * num_iter))/num_iter)
    ##    else:
    ##        return initial_lr

    def WD(x):                                          ## weight decay
        return .0002
    ##    if x > (0.6 * num_iter):
    ##        return .0001
    ##    else:
    ##        return .0002
    def grad(W, x, l): return W.grad(x, l)              ## RTRBM gradient function
    def loss(W, x): return W.grad(x)[1]                 ## loss function to minimize
    def dat(w):     return words.codify_word(w, code)   ## function to transform external data (words) into neural representations (vectors)
    def dat_2(w):   return words.codify_word_2(w, 7, code)

    shell               = False                         ## run from ipython shell?
    realtime            = False                         ## implementation with prefixes
    gpu.board_id_to_use = 1                             ## select the GPU board
    CD_n                = 5                             ## for pre-training and for the first training phase use CD-5
    CD_n_sched          = [1000, 100000, 1600000]       ## schedule the Contrastive Divergence parameter
                                                        ##(CD-5 until iteration 100, then CD-10 until iteration 40000, then CD-25 until iteration 1000000 and finally CD-40)
    code                = 'symbolic'                    ## code type to use for word representations
    v                   = 27                            ## visible units (n_digits * 10)
    h                   = 200                           ## hidden units
    num_iter            = 2000000                       ## number of iterations (|TR| * num_epochs)
    initial_lr          = 0.3
    minibatch_size      = 50                            ## number of training sequences to process in parallel
    sequence_length     = 10                            ## length of the sequence to sample
    max_length          = 7                             ## max length of a sequence in the training set
    gibbs_steps         = 50                            ## Gibbs steps when sampling
    samples_num         = 100000                         ## -> multiple of 10000! number of samples to collect when testing generation capability
    patterns_num        = 5000000                       ## samples to collect when storing hidden units patterns
    max_patterns        = 50000                         ## maximum number of positive and negative patterns when collecting activations
    max_freq            = 10                            ## maximum number of words repetitions when collecting activation patterns
    number_of_tests     = 3                             ## how many times we have to test the model and average results
    lamb                = 1

    ## Prova a bilanciare LR e WD
    parameters          = '_h200_2000M_TR5300_MS50_LR0.3_WD0.0002_one'
    simulation_name     = 'r_words' + parameters
    if shell:
        dataset_path        = 'data/Spokefre2'                                      ## to extract the training set
        dataset_save_path   = 'rnn_trbm/r_data/TR_5300.txt'                         ## to save the training set
        testset_save_path   = 'rnn_trbm/r_data/TS_1370.txt'                         ## to save the test set
        trainer_path        = 'rnn_trbm/r_data/'                                    ## to save the model
        samples_path        = 'rnn_trbm/r_data/samples_data/samples'  + parameters  ## to save generated samples
        evaluation_path     = 'rnn_trbm/r_data/evaluation/evaluation' + parameters  ## to save error measures
        patterns_save_path  = 'rnn_trbm/r_data/patterns/patterns'     + parameters  ## to save hidden units activations
        stats_save_path     = 'rnn_trbm/r_data/hidden_analysis/stats' + parameters  ## to save distances statistics
    else:
        dataset_path        = '../../data/Spokefre2'
        dataset_save_path   = '../r_data/TR_5300.txt'
        testset_save_path   = '../r_data/TS_1370.txt'
        trainer_path        = '../r_data/'
        samples_path        = '../r_data/samples_data/samples'  + parameters + '_mediumSampling'
        evaluation_path     = '../r_data/evaluation/evaluation' + parameters
        patterns_save_path  = '../r_data/patterns/patterns'     + parameters + '_mediumSampling'
        stats_save_path     = '../r_data/hidden_analysis/stats' + parameters


    ### create the Visible-Hidden and Hidden-Hidden connections (as objects)
    ### and randomly initialize weights and add biases on units
    VH = bias_mat(std_mat(v, h))
    HH = bias_mat(std_mat(h, h))

    ### create a new RTRBM (with small weights)

    print '\nCreating a new RTRBM...'
    W = .005 * rnn_trbm(VH, HH, CD_n)
    print '\nSize of Visible-Hidden connections matrix: ',  W.w[0].w[0][0].shape
    print 'Size of Hidden-Hidden connections matrix: ',     W.w[1].w[0][0].shape

    ### load training and test set and create context tree
    create = 'n'#raw_input ('\nDo you want to extract a new dataset? (y/n)')
    if create == 'y':
        training_set, test_set, lengths = words.extract_words(dataset_path)
        f = open(dataset_save_path, 'w')
        f.write('\t'.join(lengths))
        f.write('\n')
        f.write('\t'.join(training_set))
        f.close()
        f = open(testset_save_path, 'w')
        f.write('\t'.join(test_set))
        f.close()
    else:
        f = open(dataset_save_path, 'r')
        lengths = f.readline() ## discard first line that contains information on lengths
        line = f.readline()
        training_set = line.split('\t')
        training_set = np.array(training_set)
        f.close()
        f = open(testset_save_path, 'r')
        line = f.readline()
        test_set = line.split('\t')
        test_set = np.array(test_set)
        f.close()

    tr_size = training_set.shape[0]
    print 'Training set size: ', tr_size
    context_tree = trie.Trie()
    for word in training_set:
        context_tree.insert(word)
    context_tree.head.prefix_count = tr_size
    context_tree.calculate_probabilities(context_tree.head)

    ### set up a trainer for the main learning task
    t = std_trainer(name = simulation_name,
                    path = trainer_path,
                    dataset_path = dataset_save_path,
                    W = W,
                    unnorm_grad_fn = grad,
                    unnorm_valid_loss = loss,
                    data_fn = dat_2,
                    valid_data_fn = dat,
                    num_iter = num_iter,
                    minibatch_size = minibatch_size,
                    CD_n_sched = CD_n_sched,
                    LR = LR,
                    WD = WD,
                    max_length = max_length,
                    realtime = realtime,
                    lamb = lamb)

    ### given a codified sample, print the corrensponding word
    def show_sample(x):
        word = words.decodify_word(x)
        print '\n', ''.join(word[0:-1]) ## do not print '$' symbol (terminal)



    ##-----------------------------  MAIN PROGRAM  -----------------------------##


    train = 'n'#raw_input ('\nDo you want to start training? (y/n)')
    if train == 'y':
        x1 = time.strftime('%s')
        t.train()
        x2 = time.strftime('%s')
        timediff = int(x2) - int(x1)
        t.save()
        print '...training complete. It took ', timediff, ' seconds.'
    else:
        print 'No training selected.'


    samples = 'n'#raw_input ('\nDo you want to collect samples? (y/n)')
    if samples == 'y':
        print 'Generating samples...'
        not_matches = []
        matches = []
        matches_with_freq = []
        not_matches_with_freq = []
        max_size = 20000
        generated = 0
        empty = 0
        fake = 0
        if samples_num < max_size:
            batch = samples_num
        else:
            batch = max_size
        x1 = time.strftime('%s')
        while generated < samples_num:
            #x, temp = t.W.get_samples_and_hidden(sequence_length, gibbs_steps, batch)
            x, H, V_b, H_b = t.W.get_samples_and_hidden(sequence_length, gibbs_steps, batch, real_and_bin = True)
            for i in range(0, batch):
                w = ''.join(words.decodify_word(x[i]))
                if context_tree.search(w) == True:
                    #print 'found:', w
                    matches.append(w)
                else:
                    ## do not count empty words
                    if w == '$':
                        empty += 1
                    elif w == '?':
                        fake += 1
                    else:
                        #print 'not found:', w
                        not_matches.append(w)
            generated += batch
            if generated % (10*batch) == 0:
                print generated
        x2 = time.strftime('%s')
        timediff = int(x2) - int(x1)
        print 'Generation complete. It took ', timediff, ' seconds.'
        unique_matches = sorted(set(matches))
        unique_not_matches = sorted(set(not_matches))
        f = open(samples_path, 'w')
        f.write('\nTotal samples generated:\t' + str(samples_num) + '\t(empty: ' + str(empty) + '   fake: ' + str(fake) + ')')
        f.write('\nTotal matches found:\t\t' + str(len(matches)) + ' (ratio: ' + str(round(float(len(matches))/(samples_num-empty-fake) * 100, 2)) + '%)')
        f.write('\nUnique matches:\t\t\t' + str(len(unique_matches)) + ' (ratio: ' + str(round(float(len(unique_matches))/(tr_size) * 100, 2)) + '%)')
        f.write('\n\n\n*** MATCHED ***\n\n')
        for i in unique_matches:
            matches_with_freq.append(tuple((matches.count(i), i)))
        matches_with_freq = sorted(matches_with_freq, reverse = True)
        for k in matches_with_freq:
            f.write(''.join(str(k[1][0:-1])) + '\t' + str(k[0]) + '\n')
        f.write('\n\n\n*** NOT MATCHED ***\n\n')
        for j in unique_not_matches:
            not_matches_with_freq.append(tuple((not_matches.count(j), j)))
        not_matches_with_freq = sorted(not_matches_with_freq, reverse = True)
        for k in not_matches_with_freq:
            f.write(''.join(str(k[1][0:-1])) + '\t' + str(k[0]) + '\n')
        f.write('\n\n\n*** NOT GENERATED ***\n\n')
        for i in range(len(unique_matches)):
            unique_matches[i] = unique_matches[i][0:-1]
        not_generated = list(set(training_set) - set(unique_matches))
        for w in not_generated:
            f.write(''.join(str(w)) + '\n')
        f.close()
        print '...samples saved.'

##    #sample_data = ['grave', 'swept', 'likes', 'seats', 'swing']
##    #sample_data = ['wince', 'yearn', 'gaols', 'could', 'vetch']
##    #sample_data = ['zqbcw', 'xqxqx', 'lnmxz', 'yyhqp', 'pvdjk']
##    #sample_data = test_set
##    context_tree_sample = trie.Trie()
##    for word in sample_data:
##        context_tree_sample.insert(word)
##    context_tree_sample.head.prefix_count = len(sample_data)
##    context_tree_sample.calculate_probabilities(context_tree_sample.head)
##    err_L2_avg = err_cos_avg = err_KL_avg = 0
##    for i in range(0, 3):
##        #x1 = time.strftime('%s')
##        err_cos, err_KL, err_L2 = e.match_successor_distribution(t, context_tree_sample, gibbs_steps)
##        #x2 = time.strftime('%s')
##        #timediff = int(x2) - int(x1)
##        #print 'Evaluation took ', timediff, ' seconds.'
##        err_cos_avg += err_cos
##        err_L2_avg += err_L2
##        err_KL_avg += err_KL
##        l = e.calculate_perplexity(t, sample_data, gibbs_steps)
##    print float(err_L2_avg)/3, float(err_cos_avg)/3, float(err_KL_avg)/3
##    print 'Perplexity: ', l, '\n'


    evaluate = 'y'#raw_input ('\nDo you want to test the model? (y/n)')
    if evaluate == 'y':
        
        f = open(evaluation_path, 'w')
        f.write('Tot. evaluations: ' + str(number_of_tests))
        
        ts_size = test_set.shape[0]
        context_tree_test = trie.Trie()
        for word in test_set:
            context_tree_test.insert(word)
        context_tree_test.head.prefix_count = ts_size
        context_tree_test.calculate_probabilities(context_tree_test.head)
        err_cos_t = err_KL_t = err_L2_t = 0
        for i in range(number_of_tests):
            err_cos, err_KL, err_L2 = e.match_successor_distribution(t, context_tree_test, gibbs_steps)
            err_cos_t += err_cos
            err_KL_t += err_KL
            err_L2_t += err_L2
            #err_M_t += err_M
        print 'Test set\nEmpirical error:\t', float(err_cos_t)/number_of_tests, '\t', float(err_KL_t)/number_of_tests, '\t', float(err_L2_t)/number_of_tests
        f.write('\n\nTest set\nEmpirical error:\t' + str(float(err_cos_t)/number_of_tests) + '\t' + str(float(err_KL_t)/number_of_tests) + '\t' + str(float(err_L2_t)/number_of_tests))

        err_cos_t = err_KL_t = err_L2_t = 0
        for i in range(number_of_tests):
            err_cos, err_KL, err_L2 = e.match_successor_distribution(t, context_tree, gibbs_steps)
            err_cos_t += err_cos
            err_KL_t += err_KL
            err_L2_t += err_L2
        print 'Training set\nEmpirical error:\t', float(err_L2_t)/number_of_tests, '\t', float(err_cos_t)/number_of_tests, '\t', float(err_KL_t)/number_of_tests 
        #l = e.calculate_perplexity(t, training_set, gibbs_steps)
        #print 'Perplexity: ', l, '\n'
        f.write('\n\nTraining set\nEmpirical error:\t' + str(float(err_L2_t)/number_of_tests) + '\t' + str(float(err_cos_t)/number_of_tests) + '\t' + str(float(err_KL_t)/number_of_tests))
        f.close()




    inp = 'n'#raw_input ('\nProceed with hidden units activations analysis? (y/n)')
    if inp == 'y':

        print '\n*** MODEL ANALYSIS ***\n'

        ## Plot a square matrix that shows representations similarity
        ## and calculate average similarity and higher similarities:
        #hidden_analyse.compare_distributed_representations(t, h, samples_num, sequence_length, gibbs_steps, context_tree)

        ## Compare final hidden representation of a word with the representations of its prefixes,
        ## in order to analyse the attractors dynamic during the generation of a sequence
        #hidden_analyse.analyse_representations_dynamic(t, h, v, samples_num, sequence_length, training_set, gibbs_steps)

        ## Analysis on the hidden layer
        #hidden_analyse.save_hidden_activations(t, h, v, patterns_num, sequence_length, training_set, gibbs_steps, patterns_save_path)

        #hidden_analyse.distances_stats(patterns_save_path, stats_save_path)
        #hidden_analyse.compare_representations_and_Levenshtein(training_set, patterns_save_path)

        ## Plot the weights matrix
        #hidden_analyse.plot_weights(W)

        ## Collect activations during generation
        #hidden_analyse.collect_activations(t, h, v, patterns_num, sequence_length, training_set, gibbs_steps, patterns_save_path, context_tree, max_patterns, max_freq)
        #hidden_analyse.prepare_classification_data(patterns_save_path, context_tree)
        hidden_analyse.train_linear_classifier(patterns_save_path)
















        ## Plot hidden units activations when the context is a specific letter
        #pattern = 'a'
        #hidden_analyse.activations_histogram(t, h, samples_num, sequence_length, gibbs_steps, pattern)

        ## Plot average neuron's dynamic activations:
        #hidden_analyse.activations_matrix_avg(t, h, samples_num, sequence_length, gibbs_steps)

        ## Perform linear regression on activations to discover possible specific patterns
        #hidden_analyse.plot_neurons_specific_activations(t, h, samples_num, sequence_length, gibbs_steps, context_tree)
        #hidden_analyse.analyse_position_sensitivity(t, h, samples_num, sequence_length, gibbs_steps, context_tree)


        ## The following analysis was only partially performed:
        #hidden_analyse.check_stability_of_representations(t, h, v, gibbs_steps, context_tree)



    ##sample = raw_input ('\nDo you want to get a sample from the network? (y/n)')
    ##while sample == 'y':
    ##    x = t.W.get_samples(sequence_length, gibbs_steps, 1)
    ##    show_sample(x[0])
    ##    sample = raw_input ('\nDo you want to get a sample from the network? (y/n)')

    ##### ANALYSIS ON PARTICULAR PATTERNS (e.g. 'th' vs 'wh')
        
    ##    inp = 'y'#raw_input ('\nCompute activations matrix for PREDICTIONS? (y/n)')
    ##    if inp == 'y':
    ##        pattern = 'h'
    ##        matrix_p, label_p, neuron_list_pred = hidden_analyse.activations_matrix_for_predictions(t, h, samples_num, sequence_length, gibbs_steps, pattern)
    ##
    ##    inp = 'y'#raw_input ('\nCompute activations matrix for CONTEXTS? (y/n)')
    ##    if inp == 'y':
    ##        pattern = 'h'
    ##        matrix_c, label_c, neuron_list_ctx = hidden_analyse.activations_matrix_for_contexts(t, h, samples_num, sequence_length, gibbs_steps, pattern)
    ##
    ##    inp = 'y'#raw_input ('\nCompute activations matrix for RULES? (y/n)')
    ##    if inp == 'y':
    ##        pattern = 'th'
    ##        matrix_r, label_r = hidden_analyse.activations_matrix_for_rules(t, h, samples_num, sequence_length, gibbs_steps, pattern)
    ##
    ##    inp = 'y'#raw_input ('\nCompute activations matrix for BIGRAMS? (y/n)')
    ##    if inp == 'y':
    ##        pattern = 'th'
    ##        matrix_b, label_b = hidden_analyse.activations_matrix_for_bigrams(t, h, samples_num, sequence_length, gibbs_steps, pattern)
    ##
    ##    hidden_analyse.plot_activations_matrices(h, matrix_p, label_p, matrix_c, label_c, matrix_r, label_r, matrix_b, label_b)
    ##
    ##    print '\nNeurons specific response:\n'
    ##    for n in range(h):
    ##        if (neuron_list_pred[n] != []) or (neuron_list_ctx[n] != []):
    ##            print n, '\tpred:  ', neuron_list_pred[n], '\n\tcont:  ', neuron_list_ctx[n], '\n'
    ##
    ##    hidden_analyse.plot_neurons_specific_activations(t, h, gibbs_steps, context_tree)

    ##print 'Weights analysis:'
    ##VH = gpu.as_numpy_array(W.w[0].w[0][0])
    ##VH_abs = np.abs(VH)
    ##print '\nVH Avg:\t%.2f'     % np.mean(VH)
    ##print 'VH StdDev:\t%.2f'    % np.std(VH)
    ##print 'VH Max:\t%.2f'       % np.max(VH)
    ##print 'VH Min:\t%.2f'       % np.min(VH)
    ##print 'VH_abs Avg:\t%.2f'   % np.mean(VH_abs)
    ##m = np.mean(VH_abs, axis=0)
    ##print 'VH NeuronsAvg: (input)'
    ##for i in m:
    ##    print '%.4f' % i
    ##m2 = np.mean(VH_abs, axis=1)
    ##print 'VH NeuronsAvg: (output)'
    ##for i in m2:
    ##    print '%.4f' % i
    ##
    ##HH = gpu.as_numpy_array(W.w[1].w[0][0])
    ##HH_abs = np.abs(HH)
    ##print '\nHH Avg: %.2f'      % np.mean(HH)
    ##print 'HH StdDev:\t%.2f'    % np.std(HH)
    ##print 'HH Max: %.2f'        % np.max(HH)
    ##print 'HH Min: %.2f'        % np.min(HH)
    ##print 'HH_abs Avg:\t%.2f'   % np.mean(HH_abs)
    ##m = np.mean(HH_abs, axis=0)
    ##print 'HH NeuronsAvg: (input)'
    ##for i in m:
    ##    print '%.4f' % i
    ##m2 = np.mean(HH_abs, axis=1)
    ##print 'HH NeuronsAvg: (output)'
    ##for i in m2:
    ##    print '%.4f' % i
