import ghmm
import random
import GQLMixture
import GQLCluster
import GQL
import math
import sys
import numpy

def relEntropy(pa,pb):
   """
   compute relative entropy between two discrete prob densities
   """
   relEnt = 0.0
   for i in range(len(pa)):
       relEnt = relEnt + ( pa[i] * math.log( pa[i]/pb[i] ) )
   return -relEnt

def generateGeneSequenceSets(multSeqs,noGenes):
    """
    for each gene, compute a sequence set that refers to that gene only.
    """
    uniSeqsList = []

    # here has to come code such that
    # geneSeqSets contains sequence sets such that geneSeqSets[gene] is
    # a sequence set containing time courses of all patients for that gene

    return uniSeqsList

def computeUnivariateModels(multModels):

    """
    for each gene, compute the respective univariate models, as described 
    in the paper
    """
    # First, I store the matrices of the multivariate models
    HMMMatrices = []
    for model in multModels:
        HMMMatrices.append(model.asMatrices())

    # I now try to get the number of dimensions of the multivariate mean mu
    # in the first state of the first model
    noOfGenes = len(multModels[0].getEmission(0,0)[0]) 


    uniModels = []#uniModels[dim][k] will be the univariate copy of component k going with dimension (gene) dim



    
    for dim in range(noOfGenes):
        uniModels.append([])
        mus = []
        sigmas = []
        for parameters in HMMMatrices: # looping over the matrices of the multvariate models

            A = parameters[0]
            pi = parameters[2]
            B = [] # B is supposed to store the univariate means and variances of gene dim

            mu = []
            sigma = []

            for i in range(len(A)):
                
                if len(parameters[1][i]) > 5:
                  noise = 1 # has noise component
                  
                else:
                  noise = 0

                #print parameters

                # I am not sure how to handle the multivariate means and covariance matrices
                # so this is just an attempt to produce some pseudocode
                # parameters[1][i][0] should be the multivariate mean of state i
                # and parameters[1][i][1] should be the covariance matrix of state i
                if noise:
                  piaux = [parameters[1][i][6][0],parameters[1][i][6][1]]
                  piaux[0] = piaux[0]/sum(piaux)
                  piaux[1] = piaux[1]/sum(piaux)
                  
                  B.append([[parameters[1][i][0][dim],parameters[1][i][2][dim]],
                            [parameters[1][i][1][noOfGenes*dim+dim], parameters[1][i][3][noOfGenes*dim+dim]],
                            piaux
                           ])           
                else:
                  B.append([[parameters[1][i][0][dim],parameters[1][i][2][dim]],
                            [parameters[1][i][1][noOfGenes*dim+dim], parameters[1][i][3][noOfGenes*dim+dim]],
                             parameters[1][i][4]
                           ])

                # getting mu and sigmas diagonal
                mu.append(parameters[1][i][0][dim])
                sigma.append(parameters[1][i][1][noOfGenes*dim+dim])
            
            # don't know what emissiondomain and distribution have to be
            #uniModels.append(emissiondomain,distribution,A,B,pi)
            #print A,B,pi
            mus.append(mu)
            sigmas.append(sigma)
            uniModels[-1].append(ghmm.HMMFromMatrices(ghmm.Float(),ghmm.GaussianMixtureDistribution(ghmm.Float), A, B, pi))
        print mus
        print sigmas

    return uniModels

def computeUniPriors(uniModelComponents,uniSeqs,multiPriors,clas):
    """
    given a set (of mixture components) of univariate models,
    compute positive (if class == '+') or negative (if class == '-')
    priors as described in the paper
    
    """
    # the following tries to implement equation (8) in my preprint:

    uniPriors = []

    patientposteriors = [] # patientposteriors[l][k] will be p(z_{gl}=k|\Theta_g,O_{gl}) in (8)
    patientlikelihoods = [] # patientlikelihoods[l][k] will be p(O_{gl}|\lambda_{gl}) in (8)
    for seq in uniSeqs:
        patientlikelihoods.append([])
        patientposteriors.append([])
        likelihoodsum = 0.0

        for k,model in enumerate(uniModelComponents):
            patientlikelihoods[-1].append(math.exp(model.loglikelihood(seq)))
            likelihoodsum += multiPriors[k]*math.exp(model.loglikelihood(seq))
            
        for k,prior in enumerate(multiPriors):
            patientposteriors[-1].append(prior*patientlikelihoods[-1][k] / likelihoodsum)


    for k in range(len(multiPriors)):
        alpha = 0.0
        classSeqs = 0
        for l,seq in enumerate(uniSeqs):
            if clas[l]: # not sure how to retrieve the label of the patient. maybe seq.getSeqLabel() ??
              classSeqs += 1
              alpha += patientposteriors[l][k]
        uniPriors.append(alpha/classSeqs)

    return uniPriors


def feature_selection(multModels,multiPriors,uniSeqsList,noOfGenes,seq_classes):

    noModels = len(multModels)

    multiPriors = numpy.array(multiPriors)
                     # don't know where to get the mixture priors of the multivariate model components from
                     # multipriors[k] is supposed to be \alpha_k in the paper, that is, the prior of
                     # HMM multModels[k]

                     

    positiveComponents = [] # here I would like to store the indices 
    negativeComponents = [] # that belong to the model components of the good resp. bad responders


        
    uniModelList = computeUnivariateModels(multModels)  # uniModels[gene] is a set of univariate model components
                                                     # parameterized as described in the paper

    #for ms in uniModelList:
    #    for m in ms:
    #      print m

    selectionCriteria = []

    for gene in range(noOfGenes):
        clas =  numpy.array(seq_classes)==1
        
        positivePriors = computeUniPriors(uniModelList[gene],uniSeqsList[gene],multiPriors,clas)

        positiveWeight = 0.0
        for k in positiveComponents:
            positiveWeight += positivePriors[k]
        #if positiveWeight < 0.5: # this gene is not worth it: the time courses of the good responders
                                 # mostly go to the negative components
            #selectionCriteria.append(0.0)
            #continue
        clas =  numpy.array(seq_classes)==0
        negativePriors = computeUniPriors(uniModelList[gene],uniSeqsList[gene],multiPriors,clas)
        negativeWeight = 0.0
        for k in negativeComponents:
            negativeWeight += negativePriors[k]
        #if negativeWeight < 0.5: # this gene is not worth it: the time courses of the bad responders
                                 # mostly go to the positive components
            #selectionCriteria.append(0.0)
            #continue


        print positivePriors,negativePriors,relEntropy(positivePriors,negativePriors)

        selectionCriteria.append(relEntropy(positivePriors,negativePriors))


    return selectionCriteria


if __name__ == '__main__':

    geneFiles = ['data/baranzini-fold-all-Caspase 2.txt','data/baranzini-fold-all-Caspase 3.txt','data/baranzini-fold-all-Caspase 10.txt','data/baranzini-fold-all-Jak2.txt','data/baranzini-fold-all-IL-4Ra.txt','data/baranzini-fold-all-MAP3K1.txt','data/baranzini-fold-all-RAIDD.txt']
    genes = ['Caspase 2','Caspase 3','Caspase 10','Jak2','IL-4Ra','MAP3K1','RAIDD']
    
    noOfGenes=7
    models = 'res/res-lab-sel-t-2-1-2-[3]-0-0.xml'
    models = 'res-lab-1-1-2-[4]-0-0.xml'
    multModels = ghmm.HMMOpen(models)

    uniSeqsList = []
    noModels = len(multModels)
    multiPriors = [1.0/noModels]*noModels
    profileSet = GQL.ProfileSet()
    for gf in geneFiles:
        uniSeqsList.append(profileSet.ReadDataFromCaged(gf, end=0))

        
    values = feature_selection(multModels,multiPriors,uniSeqsList,noOfGenes,profileSet.seq_classes)

    print genes

    values = [(s,genes[i]) for i,s in enumerate(selectionCriteria)]
    values.sort()

    print values
        

    
    
