
### Trie data structure, used to test the model performances using a prefix tree
### when evaluating the successor distributions prediction performances.

# Number of children in the trie, determined by size of alphabet
ALPHABET_SIZE = 27

class Node:
    """ Single node to be used in the trie.
    Here only a boolean value is_end to signify if the node is end of a
    word and a integer prefix_count to count number of words with given
    prefix (ending at this node from head) is used.
    """
    def __init__(self, is_end = False):
        self.is_end = is_end
        self.prefix_count = 0
        self.prob = 0
        self.content = ''
        self.children = [None for child in range(ALPHABET_SIZE)]

class Trie:
    """Class supporting build, insert, search and count of words with
    given prefix in Trie data structure. 
    
    >>> t = Trie()
    >>> t.insert('apple')
    >>> t.insert('banana')
    >>> t.insert('applet')
    >>> t.search('apple')
    True
    >>> t.search('app')
    False
    >>> t.count_words_with_prefix('app')
    2
    >>> t.insert('apple')
    >>> t.count_words_with_prefix('app')
    3
    """
    
    def __init__(self):
        """ Initialize the Trie with a dummy node head """
        self.head = Node()
    
    def insert(self, word):
        """ Insert a word in the trie """
        current = self.head
        for letter in word:
            int_value = ord(letter) - ord('a')
            try:
                assert(0 <= int_value < 26)
            except Exception:
                print letter, '\n'
                raise Exception('Invalid Word. Latin small alphabets required')
            if current.children[int_value] is None:
                current.children[int_value] = Node()
            previous = current.content
            current = current.children[int_value]
            current.content = previous + letter
            current.prefix_count += 1

        if current.children[26] is None:
            current.children[26] = Node()
        previous = current.content
        current = current.children[26]
        current.content = previous + '$'
        current.prefix_count += 1        
        current.is_end = True
    
    def search(self, word):
        """ Search for given word in trie.return true if found and false
        if word is not in trie or is invalid (does not consist of latin small
        characters only).
        """
        current = self.head
        for letter in word:
            int_value = ord(letter) - ord('a')
            if letter == '$':
                int_value = 26
            try:
                assert(0 <= int_value <= 26)
            except Exception:
                return False
            if current.children[int_value] is None:
                return False
            else:
                current = current.children[int_value]
        return current.is_end

    def search_also_prefixes(self, word):
        """ Search for given existing prefix in trie.
        Return the last legal position (if word exists, return its length)
        """
        current = self.head
        l = len(word)
        #print word
        for letter in range(0, l):
            int_value = ord(word[letter]) - ord('a')
            if word[letter] == '$':
                int_value = 26
            if current.children[int_value] is None:
                return letter
            else:
                current = current.children[int_value]
        return letter + 1
    
    def count_words_with_prefix(self, word):
        """ return number of words in the trie with have the given 
        argument as prefix or None if the word is not present in the
        trie or the word is invalid (does not consist of latin small
        characters only).
        """
        current = self.head
        for letter in word:
            int_value = ord(letter) - ord('a')
            try:
                assert(0 <= int_value < 26)
            except Exception:
                return None
            if current.children[int_value] is None:
                return None
            else:
                current = current.children[int_value]
        return current.prob

    def calculate_probabilities(self, node):
        """ calculate the probability distribution for each node
        dividing each prefix_count by the total count of its brothers
        """

        def set_value(node, z):
            node.prob = node.prefix_count / float(z)
        
        current = node
        tot = current.prefix_count
        for ch in range(ALPHABET_SIZE):
            if current.children[ch] != None:
                set_value(current.children[ch], tot)
                self.calculate_probabilities(current.children[ch])

    def breadth_first_successor_distr(self):
        queue = []
        queue.append(self.head)
        while queue:
            p = []
            current = queue.pop(0)
            count = 0
            for ch in range(ALPHABET_SIZE):
                if current.children[ch] != None:
                    queue.append(current.children[ch])
                    p.append(current.children[ch].prob)
                else:
                    count += 1
                    p.append(0)
            if count < ALPHABET_SIZE:
                print '\nContext: ', current.content
                print 'Distrib: ', p

