#!/usr/bin/env python2.3
################################################################################
#
#       This file is part of the GQL (Graphical Query Language) Toolkit
#
#       file:   GQLCluster.py
#       author: Alexander Schoenhuth (schoenhuth@zpr.uni-koeln.de)
#
#       Copyright (C) 2003-2004 Alexander Schliep
#
#       Contact: alexander@schliep.org
#
#       Information: http://ghmm.org/gql
#
#	GQL is free software; you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation; either version 2 of the License, or
#	(at your option) any later version.
#
#	GQL is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with GQL; if not, write to the Free Software
#	Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#
#
#       This file is version $Revision: 469 $
#                       from $Date: 2004-10-18 11:26:36 -0300 (Mon, 18 Oct 2004) $
#             last change by $Author: filho $.
#
################################################################################
#

import ghmm
import sys
import copy
import pygsl.rng
#import random
import math
#import numarray
import time

class LinearModel:

    """Class representing linear (left-to-right)
    models where (continuous) emission probabilities
    of states are governed by Gaussian laws.
    Models are lists of lists catching up
    a number of time points and some additional
    features. 
    """
    
    def __init__(self, data):

        """self.params[i][0] is mean of state i,
        self.params[i][1] is sigma of state i,
        self.params[i][2] is expectation
        of staying in state i,
        self.params[i][3:] are all values
        catched up by this state from sequences
        gone here, self.weight is the number
        of sequences this model comes from
        """
        
        if isinstance(data, ghmm.EmissionSequence):

            self.weight = 1
            
            self.params = []
            for i in range(len(data)):
                
                self.params.append([])
                self.params[-1].append(data[i])
                self.params[-1].append(0.0)
                self.params[-1].append(1.0)
                self.params[-1].append(data[i])

        elif isinstance(data, ghmm.SequenceSet):

            self.weight = len(data)
            
            self.params = []
            for i in range(data.sequenceLength(0)):

                self.params.append([])
                self.params[-1].append(0.0)
                self.params[-1].append(0.0)
                self.params[-1].append(1.0)
                for seq in data:
                    
                    self.params[-1].append(seq[i])

            self.updateMeans()

        elif isinstance(data, list):

            seqSet = ghmm.SequenceSet(ghmm.Float(), data)
            LinearModel.__init__(self, seqSet)

        elif isinstance(data, LinearModel):

            self.weight = data.weight
            self.params = copy.deepcopy(data.params)

        else:

            print """Error when trying to construct instance
            of LinearModel, argument passed not of suitable
            type"""
            sys.exit()


    def getMu(self, i):
        return self.params[i][0]


    def getSigma(self, i):
        if self.params[i][1] <= 1.0e-15:
            self.computeSigma(i)
        return self.params[i][1]


    def getSelfTransition(self, i):
        return (self.params[i][2] - 1.0) / self.params[i][2]

    def getNoOfTimepoints(self):
        noOfTimepoints = 0
        for i in range(len(self)):
            noOfTimepoints += len(self.params[i][3:])
        return noOfTimepoints
            

    def getDensity(self, i, x):
        return pygsl.rng.gaussian_pdf(self.getMu(i) - x, self.getSigma(i))
        #return random.gauss(self.getMu(i) - x, self.getSigma(i)) 

    
    def __len__(self):
        return len(self.params)

    def __str__(self):
        self.updateSigmas()
        helpstr = "\nWeight: %d\n" % (self.weight)
        for i in range(len(self.params)):
            helpstr += "State No. %d:\n" % (i)
            helpstr += "Mu: %2.2f, Sigma: %2.2f, Self Transition: %2.2f\n" % (self.params[i][0],
                                                                              self.params[i][1],
                                                                              self.params[i][2])
            valuestring = ""
            for x in self.params[i][3:]:
                valuestring += "%2.2f " % (x)
            helpstr += valuestring + "\n"

        return helpstr

        
    def __repr__(self):
        self.updateSigmas()
        helpstr = "\nWeight: %d\n" % (self.weight)
        for i in range(len(self.params)):
            helpstr += "State No. %d:\n" % (i)
            helpstr += "Mu: %2.2f, Sigma: %2.2f, Self Transition: %2.2f\n" % (self.params[i][0],
                                                                              self.params[i][1],
                                                                              self.params[i][2])
            valuestring = ""
            for x in self.params[i][3:]:
                valuestring += "%2.2f " % (x)
            helpstr += valuestring + "\n"

        return helpstr

    def missingDataRepair(self):

        """if missing data is found corresponding
        time point is replaced by interpolation of
        surrounding time points (in case of weight one
        and original length)
        """
        helpfile = open("Initialhelp.txt", 'w')
        if self.getNoOfTimepoints() != len(self):
            helpfile.write( "No. of timepoints:")
            helpfile.write(str(self.getNoOfTimepoints()))
            helpfile.write("Length: %d" % (len(self)))
            print "Missing data repair must be done before any shrinking or merging."
            sys.exit()
        else:
            actualMissings = []
            beforeMissings = None
            afterMissings = None
            for i in range(len(self.params)-1):#last point is 9999.99
                
                if self.params[i][3] <= -9999.0:

                    if not actualMissings and i >= 1: #first missing data point encountered
                        beforeMissings = self.params[i-1][3]
                        if i != len(self.params)-2:
                            actualMissings.append(i)
                        else:
                            self.params[i][3] = self.params[i-1][3]
                        
                    elif actualMissings and i == len(self.params)-2: #last data point is missing
                        actualMissings.append(i)
                        if beforeMissings is not None:
                            for ind in actualMissings:
                                self.params[ind][3] = beforeMissings
                        else: #no data point at all valid
                            for ind in actualMissings:
                                self.params[ind][3] = 0.0
                                
                    else:
                        actualMissings.append(i)
                        
                    
                else: #valid data point encountered
                    
                    if actualMissings: #valid data point encountered and
                                       #missing data before:
                                       #work to do
                        
                        afterMissings = self.params[i][3]
                    
                        if beforeMissings is None:
                            #first valid data point met
                            #give the previous points constant
                            #values
                            for ind in actualMissings:
                                self.params[ind][3] = afterMissings
                            actualMissings = []
                        
                        elif beforeMissings is not None:
                            #do linear interpolation in this case
                            steps = len(actualMissings) + 1
                            for k, ind in enumerate(actualMissings):
                                self.params[ind][3] = beforeMissings + (float(k + 1)/float(steps)) * (afterMissings - beforeMissings)

                        #reset actualMissings
                        actualMissings = []

            self.updateMeans()
                    
        

    def writeMatrices(self):

        """produces the complete parameter set according
        to the usual standards
        """

        self.updateMeans()
        self.updateSigmas()
        self.updateSelfTransitions()
        
        A = []
        B = []
        pi = []

        noOfStates = len(self.params)

        for i in range(noOfStates):

            A.append([0.0] * noOfStates)
            selftrans = self.getSelfTransition(i)
            A[i][i] = selftrans
            if i < noOfStates - 1:
                A[i][i+1] = 1.0 - selftrans
            
            B.append([self.params[i][0], self.params[i][1]])
                     
            if i == 0:
                pi.append(1.0)
            else:
                pi.append(0.0)

        return A, B, pi

    def toMixGHMM(self):
        """this method produces an instance of ghmm.GaussianMixtureHMM
        using an instance of HMMFromMatricesFactory
        XXX needs writeMixMatrices() not yet implemented...
        """
        A, B, pi = self.writeMatrices()

        modelFactory = ghmm.HMMFromMatrices
        model = modelFactory(ghmm.Float(), ghmm.GaussianMixtureDistribution(ghmm.Float()),
                             A, B, pi)

        return model
    
    def toGHMM(self):

        """this method produces an instance of ghmm.GaussianEmissionHMM
        using an instance of HMMFromMatricesFactory
        """
        A, B, pi = self.writeMatrices()
        
        modelFactory = ghmm.HMMFromMatrices
        model = modelFactory(ghmm.Float(), ghmm.GaussianDistribution(ghmm.Float()),
                             A, B, pi)
        return model
                             
        
    def updateMeans(self, *states):

        if not states:
            states = range(len(self))

        for i in states:
            
            muhelp = 0.0
            for value in self.params[i][3:]:
                muhelp += value

            self.params[i][0] = muhelp / len(self.params[i][3:])


    #make sure that before using updateSigmas, means have
    #been updated
    def updateSigmas(self, *states):

        if not states:
            states = range(len(self))

        for i in states:

            sigmahelp = 0.0
            for value in self.params[i][3:]:

                sigmahelp += (self.getMu(i) - value)**2

            if len(self.params[i][3:]) > 1:
                self.params[i][1] = sigmahelp / (len(self.params[i][3:]) - 1)
                if self.params[i][1] <= 1e-15:
                    self.params[i][1] = 0.01

            else:
                self.params[i][1] = 0.01


    def updateSelfTransitions(self, *states):

        if not states:
            states = range(len(self))

        for i in states:

            self.params[i][2] = float(len(self.params[i][3:])) / self.weight


    def mergeStates(self, j, sigmaOption=0):

        """merges states j and j+1
        """

        self.params[j] += self.params[j+1][3:]
        del self.params[j+1]
        self.updateMeans(j)
        self.updateSelfTransitions(j)

        if sigmaOption:
            self.updateSigmas(j)

        return 1

    
    def fastShrink(self, noOfStates):

        """This method merges states whose means come
        closest.
        """

        if len(self) < noOfStates:
            print "Model cannot be shrinked, no of states already",
            print "lower than requested length."
            return None
        
        noOfMerges = len(self) - noOfStates
        for i in range(noOfMerges):

            mindist = abs(self.params[0][0] - self.params[1][0])
            candidate = 0
            for j in range(1, len(self) - 1):

                temp = abs(self.params[j][0] - self.params[j+1][0])
                if temp < mindist:
                    mindist = temp
                    candidate = j

            self.mergeStates(candidate)

        return 1

            
    def shrink(self, noOfStates):

        """This method implements the horizontal merge
        technique described in Olof Perssons work.
        """

        pass


    def approxDist(self, other, carrierOpt=0):

        """This method computes some heuristically
        determined but fast computable distance between
        two models having the same length.
        """

        if len(self) != len(other):
            print "Distance not computable as models don't have the same length."
            return None
        
        shift = 0.0
        diff = 0.0
        carrier = 0.0
        weight = float(self.weight + other.weight) / 2.0
        for i in range(len(self)):

            if carrierOpt:
                carrier += self.params[i][3] - other.params[i][3]
                shift += abs(carrier)
            else:
                shift += abs(self.params[i][3] - other.params[i][3])
                
            diff += (self.params[i][0] - other.params[i][0])**2

        dist = weight*(diff + 0.1*shift)
        #dist = diff + 0.1*shift
        return dist


    def expIntDist(self, other):

        """This method returns the differences between integrals
        of expoential distributions of self transitions.
        """

        if len(self) != len(other):
            print "Distance not computable as models don't have the same length."
            return None

        expdist = 0.0
        for i in range(len(self)):

            lambda1 = 1.0 / self.params[i][2]
            lambda2 = 1.0 / other.params[i][2]
            if abs(lambda1 - lambda2) <= 1.0e-15:
                continue
            else:
                x = math.log(lambda1/lambda2) / (lambda1 - lambda2)
                
                if lambda1 > lambda2:
                    expdist += 2.0*(math.exp(-lambda2*x)-math.exp(-lambda1*x))
                else:
                    expdist += 2.0*(math.exp(-lambda1*x)-math.exp(-lambda2*x))

        return expdist
    

    def expMutInfo(self, other):

        """This method computes the sum of the mutual informations
        of the transition probability distributions of corresponding
        states of self and other where the transition probability
        distributions are assumed to be exponential.
        """
        
        if len(self) != len(other):
            print "Distance not computable as models don't have the same length."
            return None

        mutinfo = 0.0
        for i in range(len(self)):

            mutinfo += (math.log(other.params[i][2] / self.params[i][2]) +
                        self.params[i][2] * (1.0/other.params[i][2] - 1.0/self.params[i][2]))

        return mutinfo

    def symmExpMutInfo(self, other):

        """This returns expMutInfo(self, other) + expMutInfo(other, self).
        """
        if len(self) != len(other):
            print "Distance not computable as models don't have the same length."
            return None

        mutinfo = 0.0
        for i in range(len(self)):

            mutinfo += (self.params[i][2]*other.params[i][2]) * (1.0/self.params[i][2] - 1.0/other.params[i][2])**2

        return mutinfo

    def transLoss(self, other):

        """Computes the approximate loss of likelihood
        due to differences in transition probabilities.
        """

        if len(self) != len(other):
            print "Distance not computable as models don't have the same length."
            return None

        transLoss = 0.0
        for i in range(len(self)):

            transLoss += math.log(self.getSelfTransition(i)**(self.params[i][2]-1.0) *
                             (1.0 - self.getSelfTransition(i)))
            transLoss -= math.log(other.getSelfTransition(i)**(self.params[i][2]-1.0) *
                             (1.0 - other.getSelfTransition(i)))

        return transLoss

    def transLoss(self, other):

        """Computes the approximate loss of likelihood
        due to differences in transition probabilities.
        """

        if len(self) != len(other):
            print "Distance not computable as models don't have the same length."
            return None

        transLoss = 0.0
        for i in range(len(self)):

            if self.getSelfTransition(i) >= 1.0e-323:
                transLoss += ((self.params[i][2]-1.0) * math.log(self.getSelfTransition(i)) +
                              math.log(1.0 - self.getSelfTransition(i)))
            else:
                transLoss += math.log(1.0 - self.getSelfTransition(i))
                
            if other.getSelfTransition(i) >= 1.0e-323:
                transLoss -= ((self.params[i][2]-1.0) * math.log(other.getSelfTransition(i)) +
                              math.log(1.0 - other.getSelfTransition(i)))
            else:
                transLoss -= math.log(1.0 - other.getSelfTransition(i))

        return transLoss

    def meanDist(self, other):

        """Computes a non-symmetric distance with respect
        to the means of emission probabilities.
        """

        if len(self) != len(other):
            print "Distance not computable as models don't have the same length."
            return None

        meandist = 0.0
        for i in range(len(self)):

            meandist += self.params[i][2] * (self.getMu(i) - other.getMu(i))**2

        return meandist
    
    def approxLogLike(self, other, transWeight=1.0):

        """This method computes an approximation of
        the likelihood that self has been produced by
        other. Transition probability distributions are
        modeled by exponential distributions.
        Regimen in state i is thus: lambda * exp(-lambda*x) where
        lambda == 1 / other.params[i][2]
        Note: this is obviously not symmetric.
        """
        
        if len(self) != len(other):
            print "Distance not computable as models don't have the same length."
            return None

        loglike = transWeight * self.expIntDist(other)
        loglike += self.meanDist(other)

        return loglike


    def symmetricLogLike(self, other, funcoption, transWeight=1.0):

        """This method is a weighted symmetrization of approxLogLike above.
        """

        symmloglike = self.weight * (self.meanDist(other))
        symmloglike += other.weight * (other.meanDist(self))
        if funcoption == 0:
            symmloglike += transWeight * (other.weight + self.weight) * self.expIntDist(other)
        elif funcoption == 1:
            symmloglike += transWeight * (other.weight + self.weight) * self.symmExpMutInfo(other)
        elif funcoption == 2:
            symmloglike += transWeight * self.weight * self.transLoss(other)
            symmloglike += transWeight * other.weight * other.transLoss(self)

        #symmloglike /= (self.weight + other.weight)
        
        return symmloglike


    def symmLogLikeMutInfo(self, other, transWeight=1.0):

        return self.symmetricLogLike(other, 1, transWeight)


    def symmLogLikeIntDist(self, other, transWeight=1.0):

        return self.symmetricLogLike(other, 0, transWeight)


    def symmLogLikeLoss(self, other, transWeight=1.0):

        return self.symmetricLogLike(other, 2, transWeight)

        
    def bayesDist(self, other):

        """This method computes a distance between two
        models meaning how well the original models would
        be represented by a common merged model.
        """

        return 1.0


    def merge(self, other, sigmaOption=0):

        """This method merges a model with another one.
        """

        if len(self) != len(other):
            print "Models cannot be merged as they don't have the same length."

        self.weight += other.weight

        for i in range(len(self)):

            self.params[i] += other.params[i][3:]
            
        self.updateMeans()
        self.updateSelfTransitions()
        if sigmaOption:
            self.updateSigmas()

        return 1


    def __add__(self, other):

        """Merges as well but returns new instance.
        """

        C = copy.deepcopy(self)
        C.merge(other)
        return C


class InitialCollection:

    """Controller class for clustering linear models
    hierarchically using different distance measures.
    """

    def __init__(self, modelSet, distOption=0, transWeight=1.0, noOfClusters=20):
        
        """expects a list of linear models.
        """
        if len(modelSet) < 2:
            print "No clustering possible, have too little sequences."
            sys.exit()
            
        for x in modelSet[1:]:
            if len(x) != len(modelSet[0]):
                print "Clustering not possible, have models of different lengths."
                sys.exit()

        if distOption == 0:
            self.dist = LinearModel.approxDist

        elif distOption == 1:
            self.dist = LinearModel.symmLogLikeIntDist

        elif distOption == 2:
            self.dist = LinearModel.symmLogLikeMutInfo

        elif distOption == 3:
            self.dist = LinearModel.symmLogLikeLoss

        self.bayesDist = LinearModel.bayesDist
        
        self.models = modelSet

        self.noOfClusters = noOfClusters

        self.transWeight = transWeight

        self.distanceList = []
        for i in range(len(self.models) - 1):
            
            self.distanceList.append([])
            if i % 20 == 0:
                print "Distances for %d models computed." % (i)
            for j in range(i+1, len(self.models)):

                self.distanceList[i].append(self.dist(self.models[i], self.models[j], self.transWeight))
        print "Distance list initialized."
                             

    def determineCandidates(self):

        """determines the tuple of indices (i, j), the
        models of which are closest.
        """

        candidates = (0, 1)
        shortestDist = self.distanceList[0][1]
        
        for i in range(len(self.distanceList)):
            for j in range(len(self.distanceList[i])):

                if i == 0 and j == 0:
                    continue

                if self.distanceList[i][j] < shortestDist:
                    shortestDist = self.distanceList[i][j]
                    candidates = (i, i+j+1)


        return candidates, shortestDist

    
    def mergeModels(self, k, l):

        """Merges two models having least distance.
        """

        #candidates, shortestDist = self.determineCandidates()
        
        #Make sure that l > k
        if k == l:
            print "Nothing to be done. Model cannot be merged with itself."
            return None
        elif k > l:
            #k = l + k
            #l = k - l
            #k = k - l
            temp = k
            k = l
            l = temp
            
        #update the model set itself
        self.models[k].merge(self.models[l])
        del self.models[l]

        #update the distance matrix (list resp.)
        if l < len(self.distanceList):
            del self.distanceList[l]
            
        for i in range(l):

            if i < k:
                self.distanceList[i][k-i-1] = self.dist(self.models[i], self.models[k], self.transWeight)
                del self.distanceList[i][l-i-1]

            elif i == k:
                del self.distanceList[i][l-i-1]
                for j in range(len(self.distanceList[i])):

                    self.distanceList[i][j] = self.dist(self.models[i], self.models[j+i+1], self.transWeight)

            else: #k < i < l:
                del self.distanceList[i][l-i-1]

        return 1

    def hierarchicalCluster(self, stopCrit, bayesOpt=0):

        """Merges closest models until stop criterium is met.
        """

        while len(self.models) > self.noOfClusters:
            if len(self.models) > 1500 and len(self.models) % 5 == 0:
                print "%d models remaining..." % (len(self.models))
            elif len(self.models) > 700 and len(self.models) % 20 == 0:
                print "%d models remaining..." % (len(self.models)) 
            elif len(self.models) % 50 == 0:
                print "%d models remaining..." % (len(self.models))
            
            candidates, shortestDist = self.determineCandidates()

            if bayesOpt and ( self.bayesDist(self.models[candidates[0]], self.models[candidates[1]])
                             > stopCrit):
                break

            self.mergeModels(candidates[0], candidates[1])

        print "Ready with learning initial collection."

        return 1


    def toGHMMs(self):

        """produces a ghmm compatible model set from the clustering
        result.
        """
        
        modellist = []

        for model in self.models:

            ghmmodel = model.toGHMM()
            modellist.append(ghmmodel)

        return modellist

    
    def __str__(self):
        helpstring = ""
        for model in self.models:
            helpstring += str(model)
        return helpstring


    def __repr__(self):
        helpstring = ""
        for model in self.models:
            helpstring += str(model)
        return helpstring

    
if __name__ == '__main__':

    import string
    
    sqd = sys.argv[1]
    precType = ghmm.Float()
    seqSet = ghmm.SequenceSetOpen(precType, sqd)

    transWeight = string.atof(sys.argv[2])
    
    linearModels = []
    for i, seq in enumerate(seqSet[0]):

        linearModels.append(LinearModel(seq))
        #print linearModels[-1]
        linearModels[-1].fastShrink(6)
        #print linearModels[-1]

    clustering = InitialCollection(linearModels, 2, transWeight)
    clustering.hierarchicalCluster(2.0)
    print clustering
    #ghmmlist = clustering.toGHMMs()
    #for model in ghmmlist:
    #    print model
    

    
    

    #bigModel = LinearModel(seqSet[0])
    #print bigModel
    



