#!/usr/bin/env python2.3
################################################################################
#
#       This file is part of the GQL (Graphical Query Language) Toolkit
#
#       file:   ViterbiDecomposition.py
#       author: Alexander Schoenhuth (schoenhuth@zpr.uni-koeln.de)
#
#       Copyright (C) 2003-2004 Alexander Schliep
#
#       Contact: alexander@schliep.org
#
#       Information: http://ghmm.org/gql
#
#	GQL is free software; you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation; either version 2 of the License, or
#	(at your option) any later version.
#
#	GQL is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with GQL; if not, write to the Free Software
#	Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#
#
#       This file is version $Revision: 1871 $
#                       from $Date: 2008-11-03 15:06:11 -0300 (Mon, 03 Nov 2008) $
#             last change by $Author: filho $.
#
################################################################################

"""This module provides all functionalities to compute decompositions
of HMM based time course modeling with respect to differences in
viterbi paths
"""

import copy
import ghmm
import numpy as Numeric

class ViterbiDecomposition:

    """Controller class to compute viterbi decompositions
    """

    def __init__(self, sequences, model):

        self.sequences = sequences
        self.model = model
        self.paths = self.model.viterbi(self.sequences)[0]

    def ViterbiSimple(self, noOfGroups):
        import Pycluster
        [cluster, ml, nfound] = Pycluster.kcluster(self.paths, noOfGroups,npass=30,dist='b')
        [centroids,cmaks]=Pycluster.clustercentroids(self.paths, clusterid=cluster)
        
        centroids = Numeric.array(centroids)

        # sorting the viterbi paths and groups
        sortList = []
        for c in centroids:
            sortList.append(c)
        #sortList.sort()
        #sortList.reverse()

        # hack ... after sort() the list was turning to a Numeric array
        # and I could not do a whole element == anymore
        # the code below should be improved !!!!
        #sortList = Numeric.array(sortList).tolist()
        #centroids = Numeric.array(sortList).tolist()

        #print 'lista ordenada', sortList
        

        #map = {}        
        #for i,c in enumerate(sortList):
        #   for j,c2 in enumerate(centroids):
        #       print c, c2, c == c2
        #       if c == c2:
        #           map[j] = i
            
        #print map, cluster    

        #newcluster = [map(c) for c in cluster]
        newcluster = cluster

        #print 'sorted res', newcluster, cluster

        return newcluster, sortList

    def computeFixedNumberDecomposition(self, noOfGroups):

        """similar to computeBestDecomposition, but
        collapses until a certain number of components
        has been reached
        """

        #compute the candidates first

        candidates = self.preScreening()

        #compute all groupings, assess their quality
        #and return the indices of the best two choices

        groups, states = self.computeGroupings(candidates)
        ranking = self.fullAssess(groups, states)
        print "Ranking:\n", ranking

        #now cluster the best two groupings hierarchically
        #groupInfo can be altered here

        first_after, first_before = self.collapse(groups[ranking[0]][0],
                                                  groups[ranking[0]][1],
                                                  2, noOfGroups)
        return groups[ranking[0]][0]

    def computeBestDecomposition(self):

        """computes a list of lists of indices indicating
        the best grouping found screening a selection of states
        looking like good candidates
        """

        #compute the candidates first
        candidates = self.preScreening()

        #compute all groupings, assess their quality
        #and return the indices of the best two choices
        groups, states = self.computeGroupings(candidates)
        ranking = self.fullAssess(groups, states)
        print "Ranking:\n", ranking
        #now cluster the best two groupings hierarchically
        #groupInfo can be altered here
        first_after, first_before = self.collapse(groups[ranking[0]][0],
                                                  groups[ranking[0]][1],
                                                  3)
        second_after, second_before = self.collapse(groups[ranking[1]][0],
                                                    groups[ranking[1]][1],
                                                    3)

        #return the best grouping after collapsing
        if second_after > first_after:
            return groups[ranking[1]][0]
        else:
            return groups[ranking[0]][0]


    def preScreening(self):

        """Proposes a selection of states for further inspection.
        """
        #no of states without end state
        noOfStates = self.model.N - 1

        durations = []
        totalDuration = len(self.sequences[0])
        print "Total duration: %2.2f" % (float(totalDuration))
        durationSoFar = 0.0
        for i in range(noOfStates):
            
            dur = 1.0 / (1.0 - self.model.getTransition(i,i))
            if ((i == 0 and durationSoFar + dur < float(i+2))
                or
                (i == noOfStates - 1 and ((totalDuration - durationSoFar) < (noOfStates - i + 1)))):
                
                durations.append(1000.0)
            else:
                durations.append(dur / totalDuration)
            durationSoFar += dur
            
        candidateStates = {}
        for i in range(noOfStates):
            if i == 0:
                if isinstance(self.model,ghmm.GaussianMixtureHMM):
                    ranking = (self.model.getEmission(0,0)[0] - self.model.getEmission(1,0)[0])**2
                    varPenalty = (self.model.getEmission(0,0)[1] + self.model.getEmission(1,0)[1]) / 2.0
                else:
                    ranking = (self.model.getEmission(0)[0] - self.model.getEmission(1)[0])**2
                    varPenalty = (self.model.getEmission(0)[1] + self.model.getEmission(1)[1]) / 2.0
                ranking -= 0.1*varPenalty
                ranking -= durations[0] / 2.0
                if candidateStates.has_key(ranking):
                    candidateStates[ranking].append(i)
                else:
                    candidateStates[ranking] = [i]
            elif i == noOfStates - 1:
                if isinstance(self.model,ghmm.GaussianMixtureHMM):
                    ranking = (self.model.getEmission(i-1,0)[0] - self.model.getEmission(i,0)[0])**2
                    varPenalty = (self.model.getEmission(i-1,0)[1] + self.model.getEmission(i,0)[1]) / 2.0
                else:
                    ranking = (self.model.getEmission(i-1)[0] - self.model.getEmission(i)[0])**2
                    varPenalty = (self.model.getEmission(i-1)[1] + self.model.getEmission(i)[1]) / 2.0
                ranking -= 0.1*varPenalty
                ranking -= durations[i] / 2.0
                if candidateStates.has_key(ranking):
                    candidateStates[ranking].append(i)
                else:
                    candidateStates[ranking] = [i]
            else:
                if isinstance(self.model,ghmm.GaussianMixtureHMM):
                    ranking = ((self.model.getEmission(i-1,0)[0] - self.model.getEmission(i,0)[0])**2 +
                               (self.model.getEmission(i,0)[0] - self.model.getEmission(i+1,0)[0])**2) / 2.0
                    varPenalty = (self.model.getEmission(i-1,0)[1] + self.model.getEmission(i,0)[1] +
                                  self.model.getEmission(i+1,0)[1]) / 3.0
                else:
                    ranking = ((self.model.getEmission(i-1)[0] - self.model.getEmission(i)[0])**2 +
                               (self.model.getEmission(i)[0] - self.model.getEmission(i+1)[0])**2) / 2.0
                    varPenalty = (self.model.getEmission(i-1)[1] + self.model.getEmission(i)[1] +
                                  self.model.getEmission(i+1)[1]) / 3.0
                ranking -= 0.1*varPenalty
                ranking -= durations[i] / 2.0
                if candidateStates.has_key(ranking):
                    candidateStates[ranking].append(i)
                else:
                    candidateStates[ranking] = [i]

        print candidateStates
        Rankings = candidateStates.keys()
        Rankings.sort()
        candidateList = []
        #noOfCandidates = min(int(0.5*noOfStates)+1,5)
        noOfCandidates = 3
        for i in range(1, noOfCandidates+1):
            candidateList += candidateStates[Rankings[-i]]

        print "PreScreening proposes the following states for inspection:", candidateList
        return candidateList


    def computeGroupings(self, candidateList=[]):

        """computes a list of groupings where a grouping
        is a list of lists containing member indices of
        a group and a list of lists containing the
        respective timepoints of the groups
        (see fullPathGrouping) besides from that a list
        storing the respective states is computed"""

        if not candidateList:
            candidateList = range(model.N - 1)
    
        groupInfoList = []
        stateInfoList = []
        for i in candidateList:
            print "Computing groupInfo for state %d" % (i)
            groupInfoList.append(self.fullPathGrouping(i))
            stateInfoList.append(i,)

        return groupInfoList, stateInfoList

    
    def fullAssess(self, groupInfoList, stateInfoList):
        
        """computes a list of lists containing the indices
        corresponding to a grouping being the best of
        all groupings with respect to one state"""

        assess = []
        
        for i, groupInfo in enumerate(groupInfoList):

            print "Computing assessment for state(s)", stateInfoList[i]

            inner, between = self.assessGrouping(groupInfo[0], groupInfo[1])
            
            #here different measures of quality have been tested
            #e.g. between / inner
            assess.append(between - inner)
            

        assessCopy = assess[:]
        sortedIndices = []
        assess.sort()
        assess.reverse()

        print "Assessment:\n", assessCopy

        for measure in assess:
            sortedIndices.append(assessCopy.index(measure))

        sortedIndices = tuple(sortedIndices)
        print "Ranking:\n", sortedIndices

        return sortedIndices
    

    def fullPathGrouping(self, *states):

        """ computes a list of lists containing the indices
        corresponding to the grouping resulting from a sorting
        of the paths with respect to the passed states and
        a list of lists containing the timepoints, which
        refer to the states passed"""
        
        groupedIndices = []
        allinfo = []
        timepointList = []

        for i,path in enumerate(self.paths):

            seqinfo = []

            for x in states:
                seqstateinfo = []
                for k in range(len(path)):
                    if path[k] == x:
                        seqstateinfo.append(k)
                seqinfo.append(seqstateinfo)

            seqinfo.append(i)
            allinfo.append(seqinfo)        

        allinfo.sort()
        oldinfo = []
        
        for x in allinfo:

            if x[:-1] != oldinfo:
                helper = []
                for timelist in x[:-1]:
                    helper += timelist
                timepointList.append(helper)
                groupedIndices.append([])
                
            oldinfo = x[:-1]
            groupedIndices[-1].append(x[-1])

        return groupedIndices, timepointList


    def assessGrouping(self, memberInfo, timepointInfo):

        """gets a grouping and assesses its quality,
        needs a list of lists, each of which contains
        indices of group members, and a list of lists,
        which contain the timepoints the groups refer to,
        returns two values, the first being a measure
        of separation between groups, the second
        being a measure of tightness within groups
        """
        
        a = 0.0
        b = 0.0
        noOfPairingsBetween = 0.0
        noOfPairingsInner = 0.0
        
        for i, groupMembers1 in enumerate(memberInfo):

            count, a_tmp = self.averageInnerDistance(groupMembers1, timepointInfo[i])
            noOfPairingsInner += count
            a += a_tmp
            
            for j, groupMembers2 in enumerate(memberInfo[i+1:]):
                
                difference = self.setDifference(timepointInfo[i], timepointInfo[i+1+j])
                
                card1 = len(groupMembers1)
                card2 = len(groupMembers2)
                noOfPairingsBetween += card1*card2

                b += self.averageDistanceBetween(groupMembers1, groupMembers2, difference)
                        
        if noOfPairingsBetween >= 1.0:
            b /= noOfPairingsBetween
        else:
            b = 0.0


        if noOfPairingsInner >= 1.0:
            a /= noOfPairingsInner
        else:
            a = 0.0

        return a, b



    def collapse(self, groupedIndices, timepointList, safeMerge=0, noOfGroups=0):

        """Method for collapsing a given viterbi decomposition
        choosing the best (not the first best)
        candidate in each iteration."""

        innerPairings = 0.0
        betweenPairings = 0.0
        innerDists = []
        betweenDists = []
        b_tmp = 0.0
        a_tmp = 0.0
        for j in range(len(groupedIndices)):

            count, dist = self.averageInnerDistance(groupedIndices[j], timepointList[j])
            innerDists.append((dist, count))
            innerPairings += count
            a_tmp += dist

            if j < len(groupedIndices) - 1:
                betweenDists.append([])

                for k in range(j+1, len(groupedIndices)):

                    noOfPairings = len(groupedIndices[j])*len(groupedIndices[k])
                    betweenPairings += noOfPairings
                    setDiff = self.setDifference(timepointList[j], timepointList[k])
                    avgDistBet = self.averageDistanceBetween(groupedIndices[j], groupedIndices[k], setDiff)
                    betweenDists[j].append((avgDistBet, noOfPairings))
                    b_tmp += avgDistBet

        if innerPairings >= 1.0:
            a_tmp /= innerPairings

        if betweenPairings >= 1.0:
            b_tmp /= betweenPairings

        max_tmp = b_tmp - a_tmp
        before_collapse = max_tmp

        if a_tmp > 0.0:
            max_tmp2 = b_tmp / a_tmp
        else:
            print "There is only one cluster. No collapsing possible."
            return before_collapse, before_collapse


        timepointTrack = []
        for x in timepointList:
            timepointTrack.append([x])

        print timepointTrack


        while (True):

            print "it will make the test",groupedIndices

            if len(groupedIndices) <= noOfGroups:
                print "Desired number of subgroups reached. Breaking."
                break
	    print "missed ..."

            candidate = -1

            candidateInner = None
            candidateBetween = []

            testable = 0

            for i in range(len(groupedIndices) - 1):

                for l in range(i+1, len(groupedIndices)):

                    #Avoid merging groups, which are not mergable at all

                    unifiable = True
                    
                    if safeMerge:

                        for x in timepointTrack[i]:
                            for y in timepointTrack[l]:
                            
                                #if not self.setIntersection(x, y) or len(self.setDifference(x,y)) >= safeMerge:
                                if len(self.setDifference(x,y)) >= safeMerge:

                                    unifiable = False
                                    break
                            if not unifiable:
                                break
                    if not unifiable:
                        #print "Safe merging not possible."
                        continue
                    else:
                        print "Checking", timepointTrack[i], "and", timepointTrack[l]
                        testable += 1
                    #If groups are mergable, compute new scenario

                    IndicesTmp = groupedIndices[:]
                    timepointTmp = timepointList[:]
                    tmp1 = groupedIndices[i] + groupedIndices[l]
                    tmp2 = self.setUnion(timepointList[i], timepointList[l])
                    IndicesTmp[i] = tmp1
                    timepointTmp[i] = tmp2
                    del IndicesTmp[l]
                    del timepointTmp[l]

                    innerPairings = 0.0
                    betweenPairings = 0.0
                    b_tmp = 0.0
                    a_tmp = 0.0

                    newInner = ()
                    newBetweens = []

                    #compute new innerDists and betweenDists
                    #use old dists as often as possible
                    #store new values in intermediate data structures
                    #(newInner, newBetweens)

                    for j in range(len(innerDists)):

                        if j < i:
                            a_tmp += innerDists[j][0]
                            innerPairings += innerDists[j][1]

                            for k in range(len(betweenDists[j])):

                                if k + j + 1 == i or k + j + 1 == l:
                                    continue

                                betweenPairings += betweenDists[j][k][1]
                                b_tmp += betweenDists[j][k][0]


                        elif j == i:

                            #it's time to have some new computations

                            newInner = self.averageInnerDistance(IndicesTmp[j], timepointTmp[j])
                            a_tmp += newInner[1]
                            innerPairings += newInner[0]

                            for k in range(len(IndicesTmp)):

                                if k == i:
                                    continue

                                else:
                                    setDiff = self.setDifference(timepointTmp[j], timepointTmp[k])

                                    newBetweens.append((self.averageDistanceBetween(IndicesTmp[j],
                                                                                    IndicesTmp[k],
                                                                                    setDiff),
                                                        len(IndicesTmp[j]) * len(IndicesTmp[k])))

                                    betweenPairings += newBetweens[-1][1]
                                    b_tmp += newBetweens[-1][0]

                        elif i < j < l:

                            a_tmp += innerDists[j][0]
                            innerPairings += innerDists[j][1]

                            for k in range(len(betweenDists[j])):

                                if k + j + 1 == l:
                                    continue

                                betweenPairings += betweenDists[j][k][1]
                                b_tmp += betweenDists[j][k][0]

                        elif j == l:
                            continue

                        else: # j > l

                            a_tmp += innerDists[j][0]
                            innerPairings += innerDists[j][1]

                            if j < len(betweenDists):
                                for k in range(len(betweenDists[j])):

                                    betweenPairings += betweenDists[j][k][1]
                                    b_tmp += betweenDists[j][k][0]


                    if innerPairings >= 1.0:

                        a_tmp /= innerPairings

                    if betweenPairings >= 1.0:

                        b_tmp /= betweenPairings

                    #print "B:", b_tmp, "A:", a_tmp, "B-A", b_tmp - a_tmp, "B/A", b_tmp/a_tmp, "Max Tmp:", max_tmp

                    if not noOfGroups and (b_tmp - a_tmp) > max_tmp and (b_tmp / a_tmp) >= 0.8 * max_tmp2:

                        max_tmp = b_tmp - a_tmp
                        candidate = i, l
                        candidateInner = (newInner[1], newInner[0])
                        candidateBetween = copy.deepcopy(newBetweens)

                        if (b_tmp / a_tmp) > max_tmp2: #rethink here maybe

                            max_tmp2 = b_tmp/a_tmp

                        print "Candidates for merging found: ", candidate

                    elif noOfGroups:
                        if testable == 1:
                            max_tmp = b_tmp - a_tmp
                            candidate = i, l
                            candidateInner = (newInner[1], newInner[0])
                            candidateBetween = copy.deepcopy(newBetweens)
                            print "Candidates: ", candidate
                        else:
                            if b_tmp - a_tmp > max_tmp:
                                max_tmp = b_tmp - a_tmp
                                candidate = i, l
                                print "Candidates: ", candidate
                                candidateInner = (newInner[1], newInner[0])
                                candidateBetween = copy.deepcopy(newBetweens)


            if noOfGroups and candidate == -1:
                if len(groupedIndices) > noOfGroups:
                    safeMerge += 1
                    continue

            if candidate == -1: #no candidate for merging found

                print "No candidate found."
                break

            else:
                print "Collapsing timepoints", timepointTrack[candidate[0]], "and", timepointTrack[candidate[1]]
                print "B:", b_tmp, "A:", a_tmp, "Max Tmp:", max_tmp

                tmp1 = groupedIndices[candidate[0]] + groupedIndices[candidate[1]]
                tmp2 = self.setUnion(timepointList[candidate[0]], timepointList[candidate[1]])
                groupedIndices[candidate[0]] = tmp1
                timepointList[candidate[0]] = tmp2
                timepointTrack[candidate[0]] = timepointTrack[candidate[0]] + timepointTrack[candidate[1]]
                del groupedIndices[candidate[1]]
                del timepointList[candidate[1]]
                del timepointTrack[candidate[1]]

                #compute innerDists and betweenDists

                innerDists[candidate[0]] = candidateInner
                del innerDists[candidate[1]]

                for h in range(len(betweenDists)):

                    if h < candidate[0]:
                        del betweenDists[h][candidate[1]-h-1]
                        betweenDists[h][candidate[0]-h-1] = candidateBetween[h]


                    elif h == candidate[0]:
                        continue

                    elif candidate[0] < h < candidate[1]:
                        del betweenDists[h][candidate[1]-h-1]

                    else: # h >= candidate[1]:
                        continue


                betweenDists[candidate[0]] = candidateBetween[candidate[0]:]
                if candidate[1] <= len(betweenDists) - 1:
                    del betweenDists[candidate[1]]

	    if (len(groupedIndices) <= 2):

                print "Only two groups left."
                break


        print "Breaking. Max Tmp:", max_tmp

        return max_tmp, before_collapse


    ################################################################

    ################################################################
    #                                                              #
    # some useful operations put here for not polluting namespaces #
    #                                                              #
    ################################################################
        
    def setDifference(self, set1, set2):
        """computes the set of elements exclusively
        contained in one of the container objects
        passed"""
        res = []
        for x in set1:
            if x not in set2 and x not in res:
                res.append(x)

        for y in set2:
            if y not in set1 and y not in res:
                res.append(y)

        return res


    def setIntersection(self, set1, set2):
        """computes the intersection of two container
        objects"""
        res = []
        for x in set1:
            if x in set2 and x not in res:
                res.append(x)

        return res


    def setUnion(self, set1, set2):
        """computes the union of two container
        objects"""
        res = []
        for x in set1:
            if x not in res:
                res.append(x)

        for y in set2:
            if y not in res:
                res.append(y)
            
        return res


    def euclidean_distance(self, seq1, seq2, dimensions=None):
        """computes the euclidean distance of two
        points projected to the dimensions given by
        the container dimensions"""
        if dimensions is None:
            dimensions = range(len(seq1))
    
        elif len(dimensions) == 0:
            return 0.0
    
        sum = 0.0
    
        for x in dimensions:
            sum += (seq1[x] - seq2[x])**2

        return sum / len(dimensions)


    def max_distance(self, seq1, seq2, dimensions=None):
        """computes the maximum distance of two
        points projected to the dimensions given by
        the container dimensions"""
        if dimensions is None:
            print "Dimensions None in max_distance."
            dimensions = range(len(seq1))
        
        elif len(dimensions) == 0:
            return 0.0

        max = 0.0
        
        for x in dimensions:
            temp = (seq1[x] - seq2[x])**2 
            if temp > max:
                max = temp

        return max


    def min_distance(self, seq1, seq2, dimensions=None):
        """computes the minimum distance of two
        points projected to the dimensions given by
        the container dimensions"""
        if dimensions is None:
            dimensions = range(len(seq1))
    
        elif len(dimensions) == 0:
            return 0.0

        min = (seq1[dimensions[0]] - seq2[dimensions[0]])**2

        for x in dimensions[1:]:
            temp = (seq1[x] - seq2[x])**2
            if temp < min:
                min = temp

        return min


    def averageInnerDistance(self, memberList, timepoints=None):
                           
        counter = 0.0
        innerdist = 0.0
        for k in range(len(memberList)):
            for l in range(len(memberList[k+1:])):

                counter += 1.0
                innerdist += self.euclidean_distance(self.sequences[memberList[k]], self.sequences[memberList[l]], timepoints)

        return counter, innerdist


    def maxInnerDistance(self, memberList, timepoints=None):

        innerdist = 0.0
        for k in range(len(memberList)):
            for l in range(len(memberList[k+1:])):
            
                dist = self.euclidean_distance(self.sequences[memberList[k]], self.sequences[memberList[l]], timepoints)

                if dist > innerdist:
                    innerdist = dist

        return innerdist
            
    
    def averageDistanceBetween(self, memberList1, memberList2, timepoints=None):

        betweendist = 0.0
        for k in range(len(memberList1)):
            for l in range(len(memberList2)):

                betweendist += self.euclidean_distance(self.sequences[memberList1[k]], self.sequences[memberList2[l]], timepoints)

        return betweendist


    def maxDistanceBetween(self, memberList1, memberList2, timepoints=None):

        betweendist = 0.0
        for k in range(len(memberList1)):
            for l in range(len(memberList2)):

                dist = self.euclidean_distance(self.sequences[memberList1[k]], self.sequences[memberList2[l]], timepoints)
                
                if dist > betweendist:
                    betweendist = dist

        return betweendist




if __name__ == '__main__':
#asdasd
    import sys

    sqd = sys.argv[1]
    smo = sys.argv[2]
    #modno = int(sys.argv[3])

    #states = []
    #for x in sys.argv[4:]:
    #    states.append(int(x))
    
    precType = ghmm.Float()
    seqSet = ghmm.SequenceSetOpen(precType, sqd)[0]
    
    AnHMMFactory = ghmm.HMMOpenFactory(ghmm.GHMM_FILETYPE_SMO)
    AnHMM = AnHMMFactory(smo, 0)

    VitDecomp = ViterbiDecomposition(seqSet, AnHMM)
    print VitDecomp.computeBestDecomposition()

    
