#!/usr/bin/env python2.3
################################################################################
#
#       This file is part of the GQL (Graphical Query Language) Toolkit
#
#       file:   GQLQMixture.py
#       author: Alexander Schliep (alexander@schliep.org) and
#
#       Copyright (C) 2003-2004 Alexander Schliep
#                                   
#       Contact: alexander@schliep.org           
#
#       Information: http://ghmm.org/gql
#
#   GQL is free software; you can redistribute it and/or modify
#   it under the terms of the GNU General Public License as published by
#   the Free Software Foundation; either version 2 of the License, or
#   (at your option) any later version.
#
#   GQL is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#   GNU General Public License for more details.
#
#   You should have received a copy of the GNU General Public License
#   along with GQL; if not, write to the Free Software
#   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#
#
#       This file is version $Revision: 1211 $ 
#                       from $Date: 2006-10-16 16:53:45 -0300 (Mon, 16 Oct 2006) $
#             last change by $Author: filho $.
#
################################################################################
from ghmm import *
import numpy as Numeric
import math
import getopt, sys, string
import copy

# Ben: import functions in non-parallel GQLMixture
from GQLMixture import *

# Ben: import pypar module
try:
    import pypar
except:
    raise 'Module pypar must be present to run parallel'

# Ben: for timing only
import time

def estimate_mixture_parallel(models, seqs, max_iter, eps, fixed_models, alpha=None):
    """ Given a Python-list of models and a SequenceSet seqs
        perform an nested EM to estimate maximum-likelihood
        parameters for the models and the mixture coefficients.
        The iteration stops after max_iter steps or if the
        improvement in log-likelihood is less than eps.

        alpha is a Numeric of dimension len(models) containing
        the mixture coefficients. If alpha is not given, uniform
        values will be chosen.

        Result: The models are changed in place. Return value
        is (l, alpha, P) where l is the final log likelihood of
        seqs under the mixture, alpha is a Numeric of
        dimension len(models) containing the mixture coefficients
        and P is a (#sequences x #models)-matrix containing
        P[model j| sequence i]

    """
    done = 0
    iter = 1
    norm = 0.0
    minus_infinity = -float('Inf')
    last_mixture_likelihood = None
    reestimation_logalpha_cutoff = math.log(10.0 / len(seqs))
    # The (nr of seqs x nr of models)-matrix holding the likelihoods
    l = Numeric.zeros((len(seqs), len(models)), Numeric.float)
    if alpha == None: # Uniform alpha
        logalpha = Numeric.ones(len(models), Numeric.float) * \
                   math.log(1.0/len(models))
    else:
        # normalizing the alphas which are not fixed
#        for i in xrange(len(models)):
#           if i in fixed_models: # do not change the alpha of fixed values
#               norm += alpha[i]
#       print norm
#        for i in xrange(len(models)):
#           if not i in fixed_models: # only for not fixed alphas
#               alpha[i] = alpha[i]/(1-norm)

        logalpha = Numeric.log(alpha)

    print logalpha, Numeric.exp(logalpha)
    log_nrseqs = math.log(len(seqs))

    while 1:
        # Score all sequences with all models
        for i in mpi_myrange:
            m = models[i]
            loglikelihood = m.loglikelihoods(seqs)
            #print "#model %d min(loglklhsd)=%f max()=%f" % (i, min(loglikelihood),max(loglikelihood))

            # NOTE: loglikelihood might contain -Inf for sequences which cannot be built
            # Numeric slices: l[:,i] is the i-th column of l
            l[:,i] = Numeric.array(loglikelihood)

        # Collect results at master
        if mpi_myid == mpi_masterid:
            for proc in mpi_slaves:
                l[:, mpi_partitioning[proc][0]:mpi_partitioning[proc][1]] = pypar.receive(proc)
        else:
            pypar.send(l[:, mpi_partitioning[mpi_myid][0]:mpi_partitioning[mpi_myid][1]], mpi_masterid);


        # Now, do this only on the master
        if mpi_myid == mpi_masterid:
            #print l
            for i in xrange(len(seqs)):
                # Leaves -Inf values unchanged
                l[i] += logalpha # l[i,k] = log( a_k * P[seq i| model k])
            #print l

            # Compute P[model j| seq i]
            mixture_likelihood = 0.0
            for i in xrange(len(seqs)):
                # We want to compute \sum_{k} a_k P[seq i| model k]
                # from the log( a_k * P[seq i| model k]) we have
                # NOTE: The sumlogs functions returns the sum for values != -Inf
                seq_logprob = sumlogs(l[i])
                # By subtracting the log of the sum we divide it and obtain
                # a prob dist.
                l[i] -= seq_logprob # l[i] = ( log P[model j | seq i] )
                mixture_likelihood += seq_logprob
            #print l

            # NOTE: Numeric.exp gives underflow warnings when computing exp
            # for values of  -7.1e2 and smaller. We set them to -Inf manually
            # (exp(-7.1e2) ~ 4.4762862256751298e-309 anyways
            l = Numeric.where(l > -4.8e2, l, minus_infinity)

            l_exp = Numeric.exp(l) # XXX Use approx with table lookup
            # NOTE: exp(-Inf) = 0.0 in l_exp
            #print "exp(l)", l_exp

            row_sums = Numeric.sum(Numeric.transpose(l_exp))
            if abs(1.0 - min(row_sums)) < 1e-10 and abs(max(row_sums) - 1.0) < 1e-10:
                print "l_exp row sums are all one"
            else:
                print row_sums

            print "# iter %s joint likelihood = %f" % (iter, mixture_likelihood)

            norm = 0.0

            # Compute priors alpha
            for i in xrange(len(models)):
                # NOTE: The sumlogs functions returns the sum for values != -Inf
                if not i in fixed_models: # do not change the alpha of fixed values
                    logalpha[i] = sumlogs(l[:,i]) - log_nrseqs
                    norm += pow(2,logalpha[i])

            # Normalizing the prior (given the fixed models)
            for i in xrange(len(models)):
                if not i in fixed_models: # only for not fixed alphas
                    logalpha[i] = log(pow(2,logalpha[i])/norm,2)

            logalpha_exp = Numeric.exp(logalpha)
            print "logalpha", logalpha, min(logalpha_exp), max(logalpha_exp),logalpha_exp

            # Decide whether we want to go on or not
            if max_iter == 0:
                break
            if last_mixture_likelihood == None: # First time through while-loop
                last_mixture_likelihood = mixture_likelihood
            else:
                improvement = mixture_likelihood - last_mixture_likelihood
                if iter > max_iter or (0.0 < improvement and improvement < eps):
                    break

        # Ben: Resume parallel execution
        # Ben: Master signals whether to continue or not.
        if mpi_myid == mpi_masterid:
            for proc in mpi_slaves:
                pypar.send("continue", proc)
        else:
            msg = pypar.receive(mpi_masterid)
            if msg == "break": break

        # Distribute the result l_exp to all processors
        if mpi_myid == mpi_masterid:
            for proc in mpi_slaves:
                pypar.send(l_exp, proc)
        else:
            l_exp = pypar.receive(mpi_masterid);

        for j in mpi_myrange:
            m = models[j]

            # Set the sequence weight for sequence i under model m to P[m| i]
            # NOTE: If model m is really unpopular this can lead to numerical
            # instabilities. Rescale the weight vector, so that it sums to unity
            # This doesnt solve the problem
            # More generally: if s below is really tiny we should neither
            # reestimate nor use that model in the calculations
            #
            w = copy.deepcopy(l_exp[:,j])
            s = Numeric.sum(w)
            print "weight sum=%e min=%e max=%e" % (s, min(w), max(w))

            if s < 1e-200: # This case cannot be handled due to limited range.
                # In a log-based implementation we would have enough precision
                # to still train the corresponding model. Here we dont
                #
                # XXX Possible fix. Train model with all sequences equally weighted?
                # Still produces problems in the BW-implementation ... (has too)
                #
                print "# unnecessary model %d in mixture" % j
                # Increase variances ???
                #w = Numeric.ones(len(w),Numeric.float)
                continue
            else:
                # NOTE: scaling might throw underflow (bug in Numeric?)
                # might have been caused too small values in l, which got
                # propagated to real tiny ones in l_exp. Seems okay now
                w /= s
                print "scaled sum=%e min=%e max=%e" % (Numeric.sum(w), min(w), max(w))

            if not j in fixed_models:
                for i in xrange(len(seqs)):
                    seqs.setWeight(i,w[i])
                print " Reestimating model", j
                m.baumWelch(seqs, 20, 0.1)

        iter += 1

        if mpi_myid == mpi_masterid:
            #print "exp_l min=%f max=%f" % (min(min(l_exp)), max(max(l_exp)))
            last_mixture_likelihood = mixture_likelihood

    if mpi_myid == mpi_masterid:
        # Ben: Master signals other processors to break before returning.
        for proc in mpi_slaves:
            pypar.send("break", proc)
        return (mixture_likelihood, Numeric.exp(logalpha).tolist(), l_exp)
    else:
        # What should the other processors return?
        return (None, None, None)


usage_info = """
GQLMixturePar.py [options] hmms.smo seqs.sqd [hmms-reestimated.smo]

Estimate a mixture of hmms from a file hmms.smo containing continous
emission HMMs and a set of sequences of reals given in the file
seqs.sqd.

This is a version of GQLMixture that runs on parallel processors using
MPI. Requires module PyPar for parallel execution.

Running:

-m iter Maximal number of iterations (default is 100)

-e eps  If the improvement in likelihood is below eps, the training
        is terminated (default is 0.001)

Post-analyis (the options are mutually exclusive):

-p      Output the matrix p_{ij} = P[model j| sequence i] to the console

-c      Cluster sequences. Assign each sequence i to the model maximizing
        P[model j| sequence i]. Outputs seq_id\tcluster_nr to the console
        
-d ent  Decode mixture. If the entropy of { P[model j| sequence i] } is
        less than 'ent', sequence i is assigned to the model maximizing
        P[model j| sequence i]. Outputs seq_id\tcluster_nr to the console,
        cluster_nr is None if no assignment was possible


Example:

GQLMixturePar.py -m 10 -e 0.1 -d 0.15 test2.smo test100.sqd reestimated.smo

"""

def usage():
    if mpi_myid == mpi_masterid:
        print usage_info

if __name__ == '__main__':
    # Ben: How many processors and which one am I?
    mpi_howmany = pypar.size()
    mpi_myid = pypar.rank()

    # Default values
    max_iter = 100
    eps = 0.001
    output = None

    try:
        opts, args = getopt.getopt(sys.argv[1:], "m:e:pcd:", [])
    except getopt.GetoptError:
        usage()
        sys.exit(2)
        
    for o, a in opts:
        if o in ['-m']:
            max_iter = int(a)
        if o in ['-e']:
            eps = float(a)
        if o in ['-p']:
            output = 'p_matrix'
        if o in ['-c']:
            output = 'cluster'
        if o in ['-d']:
            output = 'decode'
            entropy_cutoff = float(a)
            
    if len(args) != 3:
        usage()
        sys.exit(2)

    hmmsFileName = args[0]
    seqsFileName = args[1]
    outFileName = args[2]

    models = HMMOpen.all(hmmsFileName)
    print "# Read %d models from '%s'" % (len(models), hmmsFileName)
    seqs = SequenceSet(Float(), seqsFileName)
    print "# Read %d sequences from '%s'" % (len(seqs), seqsFileName)

    # Ben: Check number of processors.  If 1, fall back to non-parallel version
    if mpi_howmany == 1:
        starttime = time.time()
        (ml, alpha, P) = estimate_mixture(models, seqs, max_iter, eps, [])
        endtime = time.time()
    else:
        # Ben: Divide the models between the processors.
        mpi_partitioning = [pypar.balance(len(models), mpi_howmany, id) for id in range(mpi_howmany)]
        mpi_myrange = range(*mpi_partitioning[mpi_myid])

        # Ben: One processor (0) is the 'master', the others are 'slaves'.
        mpi_masterid = 0
        mpi_slaves = range(1, mpi_howmany)

        starttime = time.time()
        (ml, alpha, P) = estimate_mixture_parallel(models, seqs, max_iter, eps, [])
        endtime = time.time()

    # Ben: only the master delivers the results.
    if mpi_myid == mpi_masterid:
        print ">>>>Mixture estimation took %s seconds on %d processors." % (endtime - starttime, mpi_howmany)
        if output != None:
            if output == 'p_matrix':
                for i in xrange(len(seqs)):
                    print string.join(map(lambda x:"%1.3f" % x, P[i]), '\t')
            else:
                if output == 'cluster':
                    assignment = decode_mixture(P, len(models)) # max ent: log(len(models))
                else:
                    assignment = decode_mixture(P, entropy_cutoff)
                for i, c in enumerate(assignment):
                    print "%s\t%s" % (str(i), str(c))

    #Ben: clean up MPI
    pypar.finalize()
        
