#!/usr/bin/env python2.3
################################################################################
#
#       This file is part of the GQL (Graphical Query Language) Toolkit
#
#       file:   GQLQuery.py
#       author: Ivan Costa (filho@molgen.mpg.de)
#
#       Copyright (C) 2003-2004 Alexander Schliep and Ivan Costa
#
#       Contact: filho@molgen.mpg.de
#
#       Information: http://ghmm.org/gql
#
#   GQL is free software; you can redistribute it and/or modify
#   it under the terms of the GNU General Public License as published by
#   the Free Software Foundation; either version 2 of the License, or
#   (at your option) any later version.
#
#   GQL is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#   GNU General Public License for more details.
#
#   You should have received a copy of the GNU General Public License
#   along with GQL; if not, write to the Free Software
#   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#
#
#       This file is version $Revision: 1215 $
#                       from $Date: 2006-10-20 12:53:00 -0300 (Fri, 20 Oct 2006) $
#             last change by $Author: filho $.
#
################################################################################

import GQLMixture
import GQLCluster
from GQLValidation import *
import numpy.oldnumeric as Numeric
#import Numeric
import ghmm
import math
import time
from GODag import *
import GQLComponentClustering

################################################################################
# Methods for evaluating specific sources of biological data with the
# results of the clustering analysis (this should be improved)
################################################################################

def clusterEvaluationMixture(mixtureCluster,genes,go,level,ontologies,entropy=0,contigency=0,cutoff=0):
        '''
	Super method to be described
	'''

        atLevel = go.nodesAtLevel(level)
        atLevel = go.filterNodesOntology(atLevel,ontologies)
        atLevel = go.filterNotMeaningfull(atLevel)
	[mapping,goIds] = go.genes2TermsMapping(atLevel)
	(mixtureGO,mask,mixtureGOSum) = mixtureFromGoMapping(mapping,goIds,genes)

	# entropy cutoff
	if cutoff != 0:
	    entropies = []
            for p in mixtureCluster:
	        entropies.append(GQLMixture.Entropy(p))
	    maxEntropy = max(entropies)
	    mask = mask&(entropies < cutoff)

	# getting ride of genes not in go or no in the cutoff
	selected = []
        for (i,m) in enumerate(mask):
	    if m:
               selected.append(i)
	print "selecao", len(selected)
	mixtureGO = take(mixtureGO,selected,axis=2)
        mixtureCluster = take(mixtureCluster,selected,axis=2)

	# getting ride of go nodes with only one gene
        nzeros = []
	for i in range(len(mixtureGO[0])):
	    if (sum(mixtureGO[:,i])!= 0.0):
	        if (sum(mixtureGOSum[:,i])!=1.0):
            	    nzeros.append(i)
	mixtureGO = take(mixtureGO,nzeros,axis=1)
	gonodes = len(mixtureGO[0])

	s = []
        goids = []

	#peforming the clustering of components
	if entropy == 1:
	   size = len(mixtureGO[0])
	   (mixtureGO, l, p) = GQLComponentClustering.greedJoining(mixtureGO,mixtureCluster)
	   #print mixtureGO
	   for g in p:
  	     aux = []
	     for e in g:
	       aux.append(goIds[nzeros[e]])
	     aux.sort()
	     goids.append(aux)
	elif entropy ==2:
           (mixtureGO, s, p) = GQLCompomentClustering.inverse(mixtureGO,mixtureCluster)
	   for g in p:
  	     aux = []
	     for e in g:
	       aux.append(goIds[nzeros[e]])
	     aux.sort()
	     goids.append(aux)
	#elif entropy ==3:
	#    GQLEntropyAnalysis.maxFlowGraph(mixtureGO,mixtureCluster)
	#    return


	 # calculating the tp, fn ...
	(tp, fn, fp, tn) = computeProbabilisticErrors2(mixtureGO,mixtureCluster)
        (tpm, fnm, fpm, tnm)= computeProbabilisticErrors2(maxPosteriorMixture(mixtureGO),mixtureGO)

	if contigency:
	    table = contigencyTableFromMixture(mixtureGO,mixtureCluster)
            return table, gonodes, len(selected), goids, tp, fn, fp, tn, tpm, fnm, fpm, tnm
        else:
            return gonodes, len(selected), goids, tp, fn, fp, tn, tpm, fnm, fpm, tnm

def mixtureFromGoMapping(mapping,goIds,genes):
     ''' Obtain a posterior distributions given a mapping of genes to
     go terms. It return the posterior mixture, a list of genes not annotated in
     go and a count matrix.
     '''
     mixture = Numeric.zeros([len(genes),len(goIds)],Numeric.Float)
     mixtureAux = Numeric.zeros([len(genes),len(goIds)],Numeric.Float)
     notInMapping = Numeric.ones(len(genes),Numeric.Float)
     goIdsInv = {}
     for i,g in enumerate(genes):
         try:
             nodes = mapping[g]
             size = len(nodes)
             for n in nodes:
                 mixture[i,n] = 1.0/size
		 countMatrix[i,n] = 1.0
	 except KeyError:
             notInMapping[i] = 0.0

     #print sum(mixtureAux)
     #mixture = mixture/resize(sum(mixture,1),(len(genes),len(goIds) ))

     return mixture,notInMapping, countMatrix


def tableToString(table):
	saux = "------"
	s= '     |'
	for j in range(len(table[0])):
		s += "%12d" % (j+1)
		saux+="------------"
	s += " \n"+saux

	for i in range(len(table)):
		if sum(table[i]) != 0:
		   s += "\n%5d |" % (i+1)
   		   for j in range(len(table[i])):
			s += ("%.5f" % (table[i][j])).rjust(12)

	return s


################################################################################
# Old ....
# Old ....
# Old ....
# Old ....
################################################################################

def clusterEvaluationSpecific(cluster,genes,go,level,ontologies,contigency=0):

        atLevel = go.nodesAtLevel(level)
        atLevel = go.filterNodesOntology(atLevel,ontologies)
        atLevel = go.filterNotMeaningfull(atLevel)
	[mapping,goIds] = go.genes2TermsMapping(atLevel)


	(tp, fn, fp, tn, ct) = computeGOErrors(cluster,mapping,genes)

        # do rand ... and whatever ...
        table = []
	if contigency:
            #(mixtureGO,mask) = mixtureFromGoMapping(mapping,goIds,genes)
	    #table = contigencyTableFromMixture(mixtureGO,mixtureCluster)
	    table = [[]]

        return correctedRand(tp,fn,fp,tn), correctedSen(tp,fn,fp,tn), correctedSpe(tp,fn,fp,tn), table, ct

def clusterEvaluationMixtureSpecific(mixtureCluster,genes,go,level,ontologies,contigency=0,genesFilter=[]):

        atLevel = go.nodesAtLevel(level)
        atLevel = go.filterNodesOntology(atLevel,ontologies)
        atLevel = go.filterNotMeaningfull(atLevel)
	[mapping,goIds] = go.genes2TermsMapping(atLevel)

	(tp, fn, fp, tn) = computeGOErrorsMixture(mixtureCluster,mapping,genes,filter=genesFilter)

        # do rand ... and whatever ...
        table = []
	if contigency:
            (mixtureGO,mask) = mixtureFromGoMapping(mapping,goIds,genes)
	    table = contigencyTableFromMixture(mixtureGO,mixtureCluster)

        return correctedRand(tp,fn,fp,tn), correctedSen(tp,fn,fp,tn), correctedSpe(tp,fn,fp,tn), table, tp, fn, fp, tn
def clusterEvaluationPartition(matrixCluster,genes,go,level,ontologies,rep):

        atLevel = go.nodesAtLevel(level)
        atLevel = go.filterNodesOntology(atLevel,ontologies)
        atLevel = go.filterNotMeaningfull(atLevel)
        (partGo,goIds,genesUsed) = go.partition(atLevel,genes)

	#print genesUsed

	matrixGO = matrixFromPartitionWithOverlap(partGo,len(genes))

        #print matrixGO
	#print matrixCluster


	mGo = SubMatrix(matrixGO,genesUsed)
	mCl = SubMatrix(matrixCluster,genesUsed)

	(tp, fn, fp, tn) = computeErrosFromMatrices(mGo,mCl)
        (tpd, fnd, fpd, tnd) = randomLabelNullHypothesis(mGo,mCl,rep)

	print tp, fn, fp, tn, (tp+fn+fp+tn)

	stRand = indexStatistics(tp,fn,fp,tn,tpd,fnd,fpd,tnd,randF)
	print stRand
        stRand = indexStatistics(tp,fn,fp,tn,tpd,fnd,fpd,tnd,correctedRand)
        print stRand

	#stSen = indexStatistics(tp,fn,fp,tn,tpd,fnd,fpd,tnd,sensitivity)
	#stSpe = indexStatistics(tp,fn,fp,tn,tpd,fnd,fpd,tnd,specificity)

        #return [[stRand],[stSen],[stSpe]]
        #file.write("%i\t%f\t%f\t%f\t%f\n"%(i+1,crand,sen,spe,rand))
        return correctedRand(tp,fn,fp,tn)

def clusterEvaluationClass(partitionCluster,genes,go,level,ontologies,rep,file):
        atLevel = go.nodesAtLevel(level)
        atLevel = go.filterNodesOntology(atLevel,ontologies)
        atLevel = go.filterNotMeaningfull(atLevel)
        (partGo,goIds,genesUsed) = go.partition(atLevel,genes)

	#print genesUsed

	matrixGO = matrixFromPartitionWithOverlap(partGo,len(genes))

        #print matrixGO
	#print matrixCluster

	nRand = []
	cRand = []
	sen = []
	spe= []

	for i in range(len(partitionCluster)):
		p = partitionJoin(partitionCluster,i)
		mCl = matrixFromPartitionWithOverlap(p,len(genes))

		mGo = SubMatrix(matrixGO,genesUsed)
		mCl = SubMatrix(mCl,genesUsed)

		(tp, fn, fp, tn) = computeErrosFromMatrices(mGo,mCl)
		(tpd, fnd, fpd, tnd) = randomLabelNullHypothesis(mGo,mCl,rep)

		print tp, fn, fp, tn, (tp+fn+fp+tn)
		stRand = indexStatistics(tp,fn,fp,tn,tpd,fnd,fpd,tnd,randF)
		#print stRand
		nRand.append(stRand[0])
		cRand.append(correctedRand(tp,fn,fp,tn))
		sen.append(correctedSen(tp,fn,fp,tn))
		spe.append(correctedSpe(tp,fn,fp,tn))
                file.write("%f\t%f\t%f\t%f\n"%(nRand[-1],cRand[-1],sen[-1],spe[-1]))

		#stRand = indexStatistics(tp,fn,fp,tn,tpd,fnd,fpd,tnd,correctedRand)
		#print stRand

	#stSen = indexStatistics(tp,fn,fp,tn,tpd,fnd,fpd,tnd,sensitivity)
	#stSpe = indexStatistics(tp,fn,fp,tn,tpd,fnd,fpd,tnd,specificity)

        #return [[stRand],[stSen],[stSpe]]
        #file.write("%i\t%f\t%f\t%f\t%f\n"%(i+1,crand,sen,spe,rand))
        return nRand,cRand,sen,spe

def clusterEvaluationOld(cluster,genes,go,level,ontologies):

        atLevel = go.nodesAtLevel(level)
        atLevel = go.filterNodesOntology(atLevel,ontologies)
        atLevel = go.filterNotMeaningfull(atLevel)
        [mapping,goIds] = go.genes2TermsMapping(atLevel)

        (ct,crand, sen, spe,rand) = compareClustering2Annotation(cluster,
                                                                              genes,
                                                                              mapping)

        print ct
        return (crand,sen,spe,mapping,ct)
        #file.write("%i\t%f\t%f\t%f\t%f\n"%(i+1,crand,sen,spe,rand))


def compareClustering2Annotation(clustering, genes, annotation):
    ''' compares a clustering to a annotation, where genes can be
    annotated at more then one class, but clustered only in one'''

    newClustering = []
    newClasses = []

    print len(clustering),len(genes)

    for (i,g) in enumerate(genes):
        try:
            ants = annotation[g]
            for a in ants:
                newClustering.append(clustering[i])
                newClasses.append(a)
        except KeyError:
           #print "gene not found",g
            x=0
    return computeExternalIndices(newClasses,newClustering,max(newClasses)+1,max(newClustering)+1)
