#!/usr/bin/env python2.3
################################################################################
#
#       This file is part of the GQL (Graphical Query Language) Toolkit
#
#       file:   GQLQuery.py
#       author: Alexander Schliep (alexander@schliep.org) and
#               Wasinee Rungsarityotin
#
#       Copyright (C) 2003-2004 Alexander Schliep
#
#       Contact: alexander@schliep.org
#
#       Information: http://ghmm.org/gql
#
#   GQL is free software; you can redistribute it and/or modify
#   it under the terms of the GNU General Public License as published by
#   the Free Software Foundation; either version 2 of the License, or
#   (at your option) any later version.
#
#   GQL is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#   GNU General Public License for more details.
#
#   You should have received a copy of the GNU General Public License
#   along with GQL; if not, write to the Free Software
#   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#
#
#       This file is version $Revision: 2146 $
#                       from $Date: 2009-11-16 19:28:10 -0300 (Mon, 16 Nov 2009) $
#             last change by $Author: schliep $.
#
################################################################################


import copy
import sets
import ghmmwrapper
from ghmm import *
import numpy as Numeric

class GQLQuery:
    """ A GQLQuery is a container (models) for the query
        formulation and the results.  It keeps tracks of listeners (views) for
        result updates.

        Controllers use the set... functions.

        self.allRanked contains all the sequence ids ranked by decreasing
        likelihood (NOTE: 0.0 likelihood is understood to be -infinity)
        Stupid convention in GHMM

        self.queryResult contains the subset of sequence ids ranked
        above (i.e. with lower ranks) than the cutoff.

    """

    def __init__(self):
        self.queryLength = None

        self.allRanked = None # All the sequences ranked by the HMM
        self.queryResult = sets.ImmutableSet() # The displayed set
        self.rankCutoff = 0 # Larger cutoff number == lower rank
        # Rank from 0 to n-1
        self.maxRank = 0
        self.data = None

        self.listeners = []
        self.shmm = None

        self.cyclic_prob = 0.5

        self.likelihood = None
        self.viterbiPaths = []
        self.queryModified = 0
        self.missingRate = 0.0001 # assume a low probabilitie of missing data as default ... 
        self.data = None
        


    def newQuery(self, queryLength, mean, variance, duration, cyclic,densityType=[]):
        """ queryLength    number of emitting states
            mean           vector of mean values for queryLength states
            variance       vector of variance vales for queryLength states
            cyclic         indicator wheter the whole HMM cycles
        """
        if self.queryLength is not None:
            self.shmm = None
            del(self.mean)
            del(self.variance)
            del(self.duration)
            del(self.densityType)
            
        self.queryLength = queryLength
        self.mean = copy.copy(mean)
        self.variance = copy.copy(variance)
        self.duration = copy.copy(duration)
        self.cyclic = cyclic
        if densityType == []: # normal emission by default           
          self.densityType = [ghmmwrapper.normal]*(queryLength+1)
        else:
          self.densityType = copy.copy(densityType)

        if (self.data != None):
            self.missingRate = self.data.missingRate
        self.newSHMM()
        self.runQuery()


    def openQuery(self, fileName, fileType):        
        if (fileType == 'xml'):
          model = self.readXML(fileName)  
        else:
          model = self.readSHMM(fileName)
        if model != None:
          del self.shmm
          self.shmm = model
          self.loadQuery(model)


    def loadQuery(self,model):
        self.shmm = model
        self.mean = []
        self.variance = []
        self.duration = []
        self.densityType = []

        self.queryLength = self.shmm.N - 1 # Correct for 'end' state
        
        for i in xrange(self.queryLength):
            if isinstance(self.shmm,ContinuousMixtureHMM):
                param = self.shmm.getEmission(i,0)
                if ((param[0] == ghmmwrapper.normal) or
                    (param[0] == ghmmwrapper.uniform)):
                  mu = param[1]
                  sigma = param[2]
                  type = param[0]                  
                else:
                  print 'Emission type not suported'
            elif isinstance(self.shmm,GaussianMixtureHMM):
                (mu, sigma,weight) = self.shmm.getEmission(i,0)
                type = ghmmwrapper.normal

            else:
                (mu, sigma) = self.shmm.getEmission(i)
                type = ghmmwrapper.normal
            self.mean.append(mu)
            self.variance.append(sigma)
            self.densityType.append(type)
            a = self.shmm.getTransition(i,i) # Self-transition
            if( a!=1.0):
              self.duration.append(1.0 / (1.0 - a))
            else:
              self.duration.append(9999.99) # big self duration!!!
              
        # check this code 
        if self.shmm.getTransition(self.queryLength-1,0) > 0.0:
            self.cyclic = 1
        else:
            self.cyclic = 0

        # rebuild the model, now with missing data emission ...
        self.newSHMM()
        self.runQuery()


    def saveQuery(self, fileName):
        self.writeSHMM(fileName)


    def setData(self, data):
        """ We presume everything is shown at this point. That means
            rankCutoff equals maxRank and the queryResult is the whole set
        """
        self.data = data

        # because of the missing data rate, we have to rebuild the models
        self.missingRate = self.data.missingRate
        self.newSHMM()
        
        
        self.maxRank = len(self.data)
        self.rankCutoff = self.maxRank
        self.allRanked  = self.rankByLikelihood()

        #show = self.allRanked[0:self.rankCutoff]
        show = self.allRanked
        hide = []
        self.queryResult = sets.ImmutableSet(show)
        #print "setData", self.queryResult, self.allRanked, self.maxRank, self.rankCutoff
        self.runQuery()
        

    def setRankCutoff(self,newCutoff):
        """ Set the rank cutoff to the newCutoff """
        if self.data == None: # Cant have cutoff, if we dont have data
            return

        if newCutoff < self.rankCutoff: # Get more specific
            show = []
            hide = self.allRanked[newCutoff:self.rankCutoff]
            self.queryResult = self.queryResult.difference(sets.ImmutableSet(hide))

        elif newCutoff > self.rankCutoff: # Get less specific
            show = self.allRanked[self.rankCutoff:newCutoff]
            hide = []
            self.queryResult = self.queryResult.union(sets.ImmutableSet(show))

        else: # Nothing changed, nothing to do
            return

        self.updateResult(show,hide)
        self.rankCutoff = newCutoff


    def runQuery(self):
        self.allRanked = self.rankByLikelihood()
        newQueryResult = sets.ImmutableSet(self.allRanked[0:self.rankCutoff])
        hide = self.queryResult.difference(newQueryResult)
        show = newQueryResult.difference(self.queryResult)
        self.queryResult = newQueryResult
        self.updateResult(show, hide)
        self.queryModified = 1

    def updateResult(self, showProfiles, hideProfiles):
        for l in self.listeners:
            l.update(showProfiles, hideProfiles)

    def addListener(self, listener):
        self.listeners.append(listener)

    def removeListener(self, listener):
        self.listeners.remove(listener)

    def Result(self):
        return self.queryResult

    def LowestRankedLikelihood(self):
        if self.likelihood != None and self.allRanked != None:
            p = self.allRanked[self.rankCutoff-1]
            return self.likelihood[p]
        else:
            return 0.0 # Should not get here


    def Likelihood(self, profileNr):
        if self.likelihood != None:
            return self.likelihood[profileNr]
        else:
            return 0.0 # Should not get here


    def newSHMM(self):
        del self.shmm
        N = self.queryLength + 1 # Correct for 'end' state

        # Build the transition matrix
        A = Numeric.zeros((N,N), Numeric.float)

        for i in range(N - 2):
            p = 1.0 - 1.0/self.duration[i]
            A[i,i] = p
            A[i,i+1] = 1.0 - p

        # NOTE: To allow switching on cyclic mode, we *always* need
        # the back-transition (N-2,0) to be present. That is from the
        # last query state we can go to itself, the end state or the
        # first query state.
        q = 1.0 - p
        A[N-2,0] = self.cyclic_prob * q
        A[N-2,N-2] = p
        A[N-2,N-1] = (1.0-self.cyclic_prob)* q
            
        A[N-1,N-1] = 1.0 # End state

        B = Numeric.zeros((N,4,2), Numeric.float)
        for i in range(N - 1):
            B[i,0,0] = self.mean[i]
            B[i,0,1] = -9999.99 # missing data
            B[i,1,0] = self.variance[i]
            B[i,1,1] = 0.0001
            B[i,2,0] = 0.0
            B[i,2,1] = 0.0               
            B[i,3,0] = 1-self.missingRate
            B[i,3,1] = self.missingRate

        B[N-1,0,0] = 9999.99 # END symbol, only produced in last state
        B[N-1,0,1] = -9999.99 # missing data
        B[N-1,1,0] = 0.01
        B[N-1,1,1] = 0.01
        B[N-1,2,0] = 0.0
        B[N-1,2,1] = 0.0
        B[N-1,3,0] = 1
        B[N-1,3,1] = 0        

        
        pi = Numeric.zeros(N, Numeric.float)
        pi[0] = 1.0

        densities = []
        for i in self.densityType:
            densities.append([i,ghmmwrapper.normal])
        densities.append([ghmmwrapper.normal,ghmmwrapper.normal])

        print A, B, pi, densities
        #del self.shmm
        self.shmm = HMMFromMatrices(Float(),ContinuousMixtureDistribution(Float()),
                                    A, B, pi,densities=densities)
        # Set proper transition probabilities out of last query state 
        if not self.cyclic:
            self.shmm.setTransition(N-2, 0, 0.0) # to first query state  
            self.shmm.setTransition(N-2, N-2, p) # self-transition
            self.shmm.setTransition(N-2, N-1, 1.0 - p) # to end state

        for i in range(N):
           self.shmm.setMixtureFix(i,[0,1])
        # the last states is also fixed   
        self.shmm.setStateFix(N-1,1)


    def deleteSHMM(self):
        self.freeSHMM()

    def readSHMM(self, filename):
        """ Read a model from a file using text format (not XML).
            Return an instance of class SHMM """
        model = HMMOpen(filename,0)
        return model

    def readXML(self, filename):
        """ Read a model from a file using text format (not XML).
            Return an instance of class SHMM """
        #XXX - latter on add suport for choosing the model
        model = HMMOpenXML(filename,0)
        if( model == None):
          return None
        #model = all[0] # currently only opens first model
        return model

    def writeSHMM(self, filename):
        """ Write the model to a file in text format """
        self.shmm.write(filename)

    def rankByLikelihood(self):
        if self.data is None:
            return []
        seqs = self.data.GHMMProfileSet() # Returns a ghmm SequenceSet object
        self.likelihood = self.shmm.loglikelihoods(seqs)
        indices = range(len(self.likelihood))
        return SortOnItem(indices, self.likelihood)

    def setDuration(self,k,newDuration):
        self.duration[k] = newDuration
        p = 1.0 - 1.0/self.duration[k]

        if k < self.shmm.N - 2:
            self.shmm.setTransition(k,k  ,p)     # self-transition
            self.shmm.setTransition(k,k+1,1-p)   # next state
        elif k == self.shmm.N - 2:
            q = 1.0 - p # transition prob to end without cycling
            if self.cyclic == 0:
                self.shmm.setTransition(k, 0, 0.0) # to first query state
                self.shmm.setTransition(k, k, p) # self-transition
                self.shmm.setTransition(k,k+1,q) # to end state
            else:
                # 95 % go to first, remainder to end
                self.shmm.setTransition(k, 0, self.cyclic_prob * q) # to first query state
                self.shmm.setTransition(k, k, p) # self-transition
                self.shmm.setTransition(k,k+1,(1.0-self.cyclic_prob) * q) # to end state
        else:
            self.shmm.setTransition(k,k,1.0) # Should we ever get here?
        self.runQuery()

    def setMean(self,k,newMean):
        if ( k < self.shmm.N - 1):
            self.mean[k] = float(newMean)
            if (self.densityType[k] != ghmmwrapper.uniform):
              self.shmm.setEmission(k,0, self.densityType[k],
                                    (self.mean[k], self.variance[k], 0, 1-self.missingRate))
              self.runQuery()

    def setVariance(self,k,newVariance):
        if ( k < self.shmm.N - 1 ):
            self.variance[k] = float(newVariance)
            if (self.densityType[k] != ghmmwrapper.uniform):
              self.shmm.setEmission(k,0, self.densityType[k],
                                    (self.mean[k], self.variance[k], 0, 1-self.missingRate))
              self.runQuery()

    def setDensity(self,k,newDensity):
        if ( k < self.shmm.N - 1 ):
            self.densityType[k] = newDensity
            if (newDensity == ghmmwrapper.uniform):
              self.shmm.setEmission(k, 0, self.densityType[k],
                                    (50.0, -50.0, 0, 1-self.missingRate, 0))  
            else:
              self.shmm.setEmission(k, 0, self.densityType[k],
                                    (self.mean[k], self.variance[k], 0, 1-self.missingRate))
            self.runQuery()
            
    def setCyclic(self,cyclic):
        self.cyclic = cyclic
        # Update SHMM parameters: last-to-first, last-last, last-end transition

        k = self.queryLength - 1 # indexing from 0
        p = 1.0 - 1.0/self.duration[k]
        if cyclic == 0:
            self.shmm.setTransition(k, 0, 0.0) # to first query state
            self.shmm.setTransition(k,k,p) # self-transition
            self.shmm.setTransition(k, k+1, 1.0 - p) # to end state
        else:
            q = 1.0 - p # transition prob to end without cycling
            # 95 % go to first, remainder to end
            #print "antes", self.shmm.getTransition(k, 0)
            self.shmm.setTransition(k, 0, self.cyclic_prob * q) # to first query state
            #print "depois", self.shmm.getTransition(k, 0)
            self.shmm.setTransition(k, k, p) # self-transition
            self.shmm.setTransition(k,k+1,(1.0-self.cyclic_prob)* q) # to end state
            #print self.shmm.toMatrices()
            #print self.cyclic_prob
        self.runQuery()

    def isModified(self):
            if self.queryModified:
                if self.data != None:
                    self.viterbiPaths = self.shmm.viterbi(self.data.GHMMProfileSet())
                self.queryModified = 0

def SortOnItem(list, sortKey):
    templist = [ (sortKey[item], item) for item in list ]
    templist.sort()
    templist.reverse()
    return [ item[1] for item in templist ]

