#!/usr/bin/env python2.3
################################################################################
#
#       This file is part of the GQL (Graphical Query Language) Toolkit
#
#       file:        GQLViterbi.py
#       author:      Ruben Schilling (schillin@molgen.mpg.de)
#               
#       Copyright    (C) 2003-2004 Alexander Schliep
#
#       Contact:     alexander@schliep.org
#
#       Information: http://ghmm.org/gql
#
#	GQL is free software; you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation; either version 2 of the License, or
#	(at your option) any later version.
#
#	GQL is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with GQL; if not, write to the Free Software
#	Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#
#
#       This file is version $Revision: 688 $
#                       from $Date: 2005-10-19 12:08:01 -0300 (Wed, 19 Oct 2005) $
#             last change by $Author: schillin $.
#
################################################################################

"""This module provides all functionalities to compute decompositions
of HMM based time course modeling with respect to differences in
viterbi paths.
Note: Due to the use of python built-in sets in this module, it is expected to run faster with python 2.4
and versions above, since sets were sped up there. Though this module will already operate with python 2.3 or higher.
"""
import sys
import GQLCluster
import sets
import copy
import ghmm

class ViterbiDecomposition:

    """Controller class to compute viterbi decompositions
    """



    """class variables needed by subgrouping algorithm"""
    checked_states=sets.Set()
    current_position=0           
    current_sequence=0

    """the durations cache stores for all states for all sequences the start and end of their duration
    so the structure of the cache is [ state[ sequence[ [start, end]]]]
    initialized without entries for states etc., it is filled during subgrouping, dictionary (hash table)"""
    durations_cache = {}
    """other caches:"""
    a_coefficients_cache = []
    tmp_a_coefficients_cache = []
    b_coefficients_cache = [0][0]
    tmp_b_coefficients_cache = [0][0]

    """see to main method for a comment regarding a global sequenceSet variable"""
    sequenceSet = []

    
    def __init__(self, sequences, model):
        self.sequences = sequences
        self.model = model
        [self.paths, lh] = self.model.viterbi(self.sequences)





    #-------subgrouping algortihm-------------------------------
    """call the subgrouping algorithm with the subgroup function;
    you must have computed a correct viterbi path beforehand
    subgroup assumes the viterbi path in the variable VitDecomp.paths
    there is no error checking for this so far - be prepared..."""
    
    #give this function the current position and sequence and it will return the next unchecked state
    def find_next_state(self):
        for i in range(self.current_sequence, len(self.paths)) :
            for j in range (self.current_position, len(self.paths[i])) :
                if self.paths[i][j] not in self.checked_states :
                    self.checked_states.add(self.paths[i][j])
                    self.current_sequence=i
                    self.current_position=j
                    return self.paths[i][j]
            self.current_position=0  #you want to start for any new sequence at index 0
        return                       #if no new state was discovered just return
    
    
    
    def find_begin(self, sequence, state) :
        for i in range(len(self.paths[sequence])) :
            if self.paths[sequence][i] == state :
                return i  
        return -1
    
    
    
    def find_end(self, sequence, state, begin_index) :
        end_index=begin_index
        for i in range(begin_index, len(self.paths[sequence])) :
            if self.paths[sequence][i] != state :
                return end_index
            else :
                end_index=i
        return end_index
    


    def check_match(self, current_state, begin_index, end_index, sequence1, sequence2) :

        #case: the 2. state sequence has at least room to look one index to left and right
        if begin_index -1 >= 0 and end_index + 1 < len(self.paths[sequence2]) :
            if self.paths[sequence1][begin_index] == self.paths[sequence2][begin_index] \
	       and self.paths[sequence1][end_index] == self.paths[sequence2][end_index] \
	       and self.paths[sequence1][begin_index] != self.paths[sequence2][begin_index-1] \
	       and self.paths[sequence1][end_index] != self.paths[sequence2][end_index +1]:
                return True
			
        #case: the 2. state sequence reaches to the very right end
        elif    begin_index -1 >= 0 and end_index + 1 >= len(self.paths[sequence2]) :
            if self.paths[sequence1][begin_index] == self.paths[sequence2][begin_index] \
               and self.paths[sequence1][end_index] == self.paths[sequence2][end_index] \
               and self.paths[sequence1][begin_index] != self.paths[sequence2][begin_index-1] :
                return True
				
        #case: the 2. state sequence reaches to the very left end
        elif    begin_index - 1 < 0 and end_index + 1 < len(self.paths[sequence2]) :
            if self.paths[sequence1][begin_index] == self.paths[sequence2][begin_index] \
               and self.paths[sequence1][end_index] == self.paths[sequence2][end_index] \
               and self.paths[sequence1][end_index] != self.paths[sequence2][end_index +1] :
                return True
				
        #case: the 2. sequence is the complete sequence from left to right
        elif    begin_index - 1 < 0 and end_index + 1 >= len(self.paths[sequence2]) :
            if self.paths[sequence1][begin_index] == self.paths[sequence2][begin_index] \
               and self.paths[sequence1][end_index] == self.paths[sequence2][end_index] :
                return True

        return False


        
    def subgroup(self):
        groupings = []
        current_state=self.find_next_state() 
        
        while (current_state != None) :
            self.durations_cache[current_state]={}  #add a dictionary to the cache for each state
            current_subgrouping=[]                                       
            unchecked_sequences=range(len(self.paths))

            index=0                                           #iterate in unchecked_sequences
            modul_delimiter = len(self.paths)-1               #iterate in unchecked_sequences

            while(unchecked_sequences != []) :
                i = unchecked_sequences[index]
                begin_index=self.find_begin(i, current_state) #begin of duration of this state     
                end_index=self.find_end(i,current_state, begin_index)#end   of duration of this state
                self.durations_cache[current_state][i]=(begin_index, end_index)#dict.:start-end tuple for each sequence

                if begin_index != -1  and end_index != -1 :   #else:no such state in current sequence, possible at all?
                    current_subgrouping.append([i])           #create list for current sequence in the subgroup
                    unchecked_sequences.remove(i)

                    for j in range(i+1, len(self.paths)) :    #check all other sequences for shared subgroup
                        match=self.check_match(current_state, begin_index, end_index, i, j)
                        if match != False :
                            #append index of sequence j to subgroup of sequence i
                            current_subgrouping[len(current_subgrouping)-1].append(j)
                            #prevent this sequence from being added twice given this state&subgroup
                            unchecked_sequences.remove(j)
                            #dict. /w start-end tuple f. e. sequence
                            self.durations_cache[current_state][j]=(begin_index, end_index)
                            
                i = (i+1)%modul_delimiter
                
            groupings.append(current_subgrouping)
            current_subgrouping=[]   #reuse
            current_state=self.find_next_state()
            
        return groupings     
    #-------subgrouping algorithm-------------------------------





    #-------quality assessment----------------------------------
    """call the quality assessment with calculate_c(subgrouping)
    of course the program assumes, that subgrouping has been done beforehand 
    expects a list in the format: list(list(subgroup1),......,list(subgroupN))
    subgroups are defined as the index No's of the sequences from the viterbi path
    it also assumes, that you ran subgrouping already, because it uses the global
    #checked_states to get the sequence of appeared states"""
    def faculty(self, n) :
        faculty = 1

        #watch out: n+1 is used, because range's upper bound is exclusive, not inclusive and you want the 'n' to be multiplied too
        for i in range(2, n+1) :
            faculty = faculty * i
        return faculty



    """This is a combinatoric helper function. It expects of course a positive integer, everything else is a logical error."""
    def n_over_two(self, n) :
        #Exceptional case saves function calls and eases design
        if n == 1 or n == 2 :
            return 1

        #regular case:
        else :
            return  (self.faculty(n) / (2 * self.faculty(n - 2)))



    """The number of distinct tuples for coefficient A is the sum of the combinations within each subgroup, which is calculated as n over 2"""
    def number_of_typeA_tuples(self, subgrouping) :
        tuples = 0
        for i in range(len(subgrouping)) :
            tuples = tuples + self.n_over_two(len(subgrouping[i]))
        
        return tuples


    """This function calculates coefficient A in the quality assessment (step3). The passed argument state is the hidden state of the HMM, for which the coefficient shall be calculated. WARNING for maintainer of this algorithm: If you change this method, also change the corresponding *recalculate* method, it is based very much on this method! I have kept this ugly redundancy, since any other solution I thought of, would have resulted in many, many (expensive!) function calls, which waste rare computation time (and thus crippling many hours of my effort to enhance the execution time after all)."""
    def calculate_a(self, subgrouping, state) :
        
        state_result = 0 #result for coefficient 'a' regarding the current state
        completely_missing_pairs = 0 #counts the number of pairs, which consist of nothing but missing data
        distance_result = 0.0
        
        #traverse the subgroups within the state subgrouping
        for subgroup in range(len(subgrouping[state])) :

            #cache the number of viterbi paths for iteration
            total_paths = len(subgrouping[state][subgroup])
            
            #traverse the sequences in current subgroups
            for k in range(total_paths) :
                
                #look-ahead from k to end of subgroup: traversing tuples of sequences
                for l in range(k+1, total_paths) :
                    #for every pair collect the new distance functions result
                    distance_result = 0
                    missing_data = 0 #just counting missing data within one pairing

                    sequence1 = subgrouping[state][subgroup][k]
                    sequence2 = subgrouping[state][subgroup][l]
                    
                    only_missing_values=True #check, if there have been only missing values within this pair
                    
                    for t in range(self.durations_cache[state][sequence1][0],(self.durations_cache[state][sequence1][1])+1) :
                        if self.sequenceSet[sequence1][t] > -9999 and self.sequenceSet[sequence2][t] > -9999 :
                            distance_result = distance_result + pow(self.sequenceSet[sequence1][t] - self.sequenceSet[sequence2][t], 2)
                            only_missing_values=False
                            
                        else :
                            missing_data = missing_data + 1

                    if only_missing_values :
                        print "There has been a pairing consisting of missing values only."
                        completely_missing_pairs = completely_missing_pairs + 1
                    
                    #duration is non negative, since there is a partial order within end/start pairs 
                    duration = self.durations_cache[state][sequence1][1] - self.durations_cache[state][sequence1][0]+1
                    duration = duration - missing_data
                    distance_result = distance_result * float(1/duration)

                    state_result = state_result + distance_result

            self.a_coefficients_cache.append(distance_result)
                    
        total_tuples = self.number_of_typeA_tuples(subgrouping[state]) - completely_missing_pairs        
        if total_tuples > 0 :
            return (1/float(total_tuples) * state_result)
        else :
            return 0
                             


    """This method is used in step 5 and makes excessive use of the cache for the A coefficient. Otherwise it is a copy of calculate_a."""
    def recalculate_a(self, new_group, obsolete_group, subgrouping, state) :

        self.tmp_a_coefficients_cache = copy.deepcopy(self.a_coefficients_cache)
        
        state_result = 0 #result for coefficient 'a' regarding the current state
        completely_missing_pairs = 0 #counts the number of pairs, which consist of nothing but missing data
        distance_result = 0.0
        
        #traverse the subgroups within the state subgrouping
        for subgroup in range(len(subgrouping[state])) :

            #cache the number of viterbi paths for iteration
            total_paths = len(subgrouping[state][subgroup])
            
            if subgroup != new_group :
                state_result = state_result + self.a_coefficients_cache[subgroup]

            if subgroup == obsolete_group :
                pass

            else : #see calculate_a
                #traverse the sequences in current subgroups
                for k in range(total_paths) :
                    
                    #look-ahead from k to end of subgroup: traversing tuples of sequences
                    for l in range(k+1, total_paths) :
                        #for every pair collect the new distance functions result
                        distance_result = 0
                        missing_data = 0 #just counting missing data within one pairing

                        sequence1 = subgrouping[state][subgroup][k]
                        sequence2 = subgrouping[state][subgroup][l]
                    
                        only_missing_values=True #check, if there have been only missing values within this pair
                    
                        for t in range(self.durations_cache[state][sequence1][0],(self.durations_cache[state][sequence1][1])+1) :
                            if self.sequenceSet[sequence1][t] > -9999 and self.sequenceSet[sequence2][t] > -9999 :
                                distance_result = distance_result + pow(self.sequenceSet[sequence1][t] - self.sequenceSet[sequence2][t], 2)
                                only_missing_values=False
                            
                            else :
                                missing_data = missing_data + 1

                        if only_missing_values :
                            print "There has been a pairing consisting of missing values only."
                            completely_missing_pairs = completely_missing_pairs + 1
                    
                        #duration is non negative, since there is a partial order within end/start pairs 
                        duration = self.durations_cache[state][sequence1][1] - self.durations_cache[state][sequence1][0]+1
                        duration = duration - missing_data
                        distance_result = distance_result * float(1/duration)

                        state_result = state_result + distance_result

                self.tmp_a_coefficients_cache.pop(subgroup)
                self.tmp_a_coefficients_cache.insert(subgroup, distance_result) #store the new result
                    
            total_tuples = self.number_of_typeA_tuples(subgrouping[state]) - completely_missing_pairs        
            del self.tmp_a_coefficients_cache[obsolete_group] #clean-up in the tmp cache, don't do it before, you would mess up the indices!
            
            if total_tuples > 0 :
                return (1/float(total_tuples) * state_result)
            else :
                return 0
        


    """This is a combinatorial helper function. In case B the number of tuples is detemined by the cardinality of
    the cartesian product or bluntly spoken the product of the cardinalities of the subgroups."""
    def number_of_typeB_tuples(self, subgrouping) :
        tuples = 1
        for i in range(len(subgrouping)) :
            tuples = tuples * len(subgrouping[i])
        return tuples


        
    def calculate_b(self, subgrouping, state) :

        state_result = 0.0
        completely_missing_pairs = 0 #counts the number of pairs, which consist of nothing but missing data

        #cache this for iteration
        number_subgroups = len(subgrouping[state])
        
        #traverse the subgroups within the state subgrouping
        for subgroup in range(number_subgroups) :

            #cache this for iteration
            number_paths = len(subgrouping[state][subgroup])
            
            #traverse now pairs of different subgroups
            for subgroup_2 in range(subgroup+1, number_subgroups) :
                #now get the correct time set, which is the symmetric difference between time sets of the 2 current subgroups
                time_set_1 = sets.Set(range(self.durations_cache[state][subgrouping[state][subgroup][0]][0], \
                                            self.durations_cache[state][subgrouping[state][subgroup][0]][1]+1))
                time_set_2 = sets.Set(range(self.durations_cache[state][subgrouping[state][subgroup_2][0]][0], \
                                            self.durations_cache[state][subgrouping[state][subgroup_2][0]][1]+1))
                difference = time_set_1.symmetric_difference(time_set_2)

                #cache this for iteration
                number_paths_2 = len(subgrouping[state][subgroup_2])

                #traverse sequence pairs  and their time points
                for k in range(number_paths):

                    for l in range(number_paths_2):
                        distance_result = 0
                        missing_data = 0 #just counting missing data within one pairing

                        sequence1 = subgrouping[state][subgroup][k]
                        sequence2 = subgrouping[state][subgroup_2][l]

                        only_missing_values=True #check, if there have been only missing values within this pair
                        
                        for t in difference :
                            if self.sequenceSet[sequence1][t] > -9999 and self.sequenceSet[sequence2][t] > -9999 :
                                distance_result = distance_result + pow((self.sequenceSet[sequence1][t] - self.sequenceSet[sequence2][t]),2)
                                only_missing_values=False
                                
                            else :
                                missing_data = missing_data + 1
                            
                        duration = len(difference)
                        duration = duration - missing_data

                        if only_missing_values :
                            completely_missing_pairs = completely_missing_pairs + 1

                        if duration==0 :
                            distance_result = distance_result
                        else :
                            distance_result = distance_result * float(1/duration)

                        state_result = state_result + distance_result
    
        total_tuples = self.number_of_typeB_tuples(subgrouping[state]) - completely_missing_pairs
        
        if total_tuples > 0 :
        	return (1/float(total_tuples) * state_result)
        else:
           	return 0
    
    
    
    def calculate_c(self, subgrouping) :
        coefficients_a = []
        coefficients_b = []
        coefficients_c = []

        print "\n"

        #we don't assess the last state, since it is only a stop state, where the assessment is always zero
        for state in range(len(subgrouping)-1) :
            coefficients_a.insert(state, self.calculate_a(subgrouping, state))
            coefficients_b.insert(state, self.calculate_b(subgrouping, state))
            coefficients_c.insert(state, coefficients_b[state] - coefficients_a[state])
            print "state:", state, "c-value:", coefficients_c[state]

        print "\n", "\n", "CALCULATION OF COEFFICIENTS DONE", "\n", "\n"
        return coefficients_c
    #-------quality assessment----------------------------------





    #------choose maximum argument from quality assessment-------------

    def quality_assessment(self, quality_measures) :
        max = quality_measures[0]
        argmax=0
        for i in range(1, len(quality_measures)) :
            if max < quality_measures[i]:
                argmax = i
                max = quality_measures[i]
        print "corresponding to argmax=", argmax, "the max c value is:", quality_measures[argmax], "\n"
        return argmax

    #------choose maximum argument's index from quality assessment-------------





    #------join subgroups---------------------------------------
    """implementation of step 5 (iteratively join subgroups)"""
    def join_subgroups(self, argmax, coefficients, subgrouping) :
	
        print "Join subgroups:", "\n"
	
	max_c_value = coefficients[argmax]
	max_value_decomposition = copy.deepcopy(subgrouping)
        self.max_a_cache = copy.deepcopy(self.a_coefficients_cache)
        c_increase = True
        
        #a valid and enhanced subgrouping (enhanced --> later on :)
        current_valid_group = copy.deepcopy(subgrouping)
        
        while c_increase :
            c_increase = False

            for subgroup1 in range(len(current_valid_group[argmax])) :
                
		#try all possible joins for one iteration
                for subgroup2 in range(subgroup1+1, len(current_valid_group[argmax])) :
	
                    try_group = copy.deepcopy(current_valid_group)
                    #join two groups at index of first group, delete then second group
                    try_group[argmax][subgroup1:subgroup1+1] = \
                                                             [try_group[argmax][subgroup1]+\
                                                              try_group[argmax][subgroup2] ]
                    try_group[argmax][subgroup2:subgroup2+1] = []
                    
                    #calculate quality assessment for this join
                    #a_value = self.calculate_a(try_group, argmax)
                    a_value = self.recalculate_a(subgroup1, subgroup2, try_group, argmax)
                    b_value = self.calculate_b(try_group, argmax)
                    c_value = b_value - a_value
                    
                    #save value and joint groups if c-value increased, set new c value
                    #--> we have a new maximum, continuation of iteration
                    if c_value > max_c_value :
                        max_value_decomposition = copy.deepcopy(try_group)
                        max_c_value = c_value
                        c_increase=True
                        self.max_a_cache = copy.deepcopy(self.tmp_a_coefficients_cache)
                        print "while joining subgroups I found a new max c value:", max_c_value

            self.a_coefficients_cache = copy.deepcopy(self.max_a_cache)
            current_valid_group = copy.deepcopy(max_value_decomposition)
                
	print "max_c_value, max_value_grouping", max_c_value, max_value_decomposition[argmax], "\n"
        return max_value_decomposition
    
    #------join subgroups---------------------------------------



        
    #useless function so far, might be important for integration in the GQL package later on
    def ViterbiGrouping(self):
        return [[1,2],[2,3]] #random return value, see above

        

"""if this module (GQLViterbi) is used in a stand-alone context i.e. directly executed and not called by another module do this:"""
if __name__ == '__main__':

    import sys
    import GQLCluster

    dataFile = sys.argv[1]
    modelFile = sys.argv[2]
    #modno = int(sys.argv[3])
    
    profileSet = GQLCluster.ProfileSet()
    sequenceSet = profileSet.ReadDataFromCaged(dataFile)
    profile = GQLCluster.ProfileClustering()
    profile.setProfileSet(profileSet)
    profile.readModels(modelFile)
    
    """Step 1 of the viterbi decomposition"""
    VitDecomp = ViterbiDecomposition(sequenceSet, profile.modelList()[0])
    
    #quickly get a local representation of our sequence set, otherwise causes memory leak and uses computation time heavily
    for i in range(len(sequenceSet)) :
            VitDecomp.sequenceSet.insert(i, list(sequenceSet[i]))
    
    #print "Viterbi paths:", VitDecomp.paths

    """Step 2 of the viterbi decomposition"""
    subgrouping =  VitDecomp.subgroup()
    #print "\n", "subgrouping:", subgrouping

    """Step 3 of the viterbi decomposition"""
    coefficients =  VitDecomp.calculate_c(subgrouping)
    #print "\n", "quality coefficients (stop state not included in calculation, but including handling of missing values)", "\n", coefficients, "\n"

    """Step 4 of the viterbi decomposition"""
    argmax = VitDecomp.quality_assessment(coefficients)
    #print "\n", "argmax=", argmax, ", max(c-values)=", coefficients[argmax], "\n"
    #print VitDecomp.a_coefficients_cache

    """Step 5 of the viterbi decomposition"""
    viterbi_decomposition = VitDecomp.join_subgroups(argmax, coefficients, subgrouping)

    #print "\n", "viterbi_decomposition:" 
    #for state in range(len(viterbi_decomposition)) :
    #    print "\n", viterbi_decomposition[state]
    #print "\n"
