#!/usr/bin/env python2.3
################################################################################
#
#       This file is part of the GQL (Graphical Query Language) Toolkit
#
#       file:   GQLQuery.py
#       author: Ivan Costa (filho@molgen.mpg.de)
#
#       Copyright (C) 2003-2004 Alexander Schliep and Ivan Costa
#
#       Contact: filho@molgen.mpg.de
#
#       Information: http://ghmm.org/gql
#
#   GQL is free software; you can redistribute it and/or modify
#   it under the terms of the GNU General Public License as published by
#   the Free Software Foundation; either version 2 of the License, or
#   (at your option) any later version.
#
#   GQL is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#   GNU General Public License for more details.
#
#   You should have received a copy of the GNU General Public License
#   along with GQL; if not, write to the Free Software
#   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#
#
#       This file is version $Revision: 1871 $
#                       from $Date: 2008-11-03 15:06:11 -0300 (Mon, 03 Nov 2008) $
#             last change by $Author: filho $.
#
################################################################################

import GQLMixture
import GQLCluster
import GQLValidation
from numpy.oldnumeric import *
import ghmm
import math
import time
from GO.GODag import *
import GQLComponentClustering
import random
from PPI.PPI import *
import ConfigParser


################################################################################
# Methods for evaluating specific sources of biological data with the
# results of the clustering analysis (to be improved)
################################################################################

def clusterEvaluationMixture(mixtureCluster,mixtureAnnotation,
			     mask,entropy=0,contigency=0,cutoff=0,
			     entropies=[],go=None,terms=None,filter=None):
        '''
	'''

	# entropy cutoff / excluding genes ambigously assigned
	if cutoff != 0:

            mask = mask*(array(entropies) < cutoff)
	    # getting ride of genes not in go or no in the cutoff
  	selected = nonzero(mask)
  	mixtureAnnotation = take(mixtureAnnotation,selected,axis=0)
	#mixtureAnnotation = mixtureAnnotation[0]
        mixtureCluster = take(mixtureCluster,selected,axis=0)
	#mixtureCluster = mixtureCluster[0]
            


	#print len(mixtureAnnotation), len(mixtureAnnotation[0]),len(mixtureCluster),len(mixtureCluster[0])

	#print len(mixtureAnnotation),len(mixtureAnnotation[0])
	#print len(mixtureCluster),len(mixtureCluster[0])

	compAnnot = len(mixtureAnnotation[0])

	s = []
        componentsPartition = [[0]]
	
        entropyAux = []
	entropyMatrix = []
	
	#peforming the clustering of components
	if entropy == 1:
	   (mixtureAnnotation, l, componentsPartition) = GQLComponentClustering.greedJoining(mixtureAnnotation,
											 mixtureCluster)
	elif entropy ==2:
           (mixtureAnnotation, s, componentsPartition) = GQLComponentClustering.inverseOld(mixtureAnnotation,mixtureCluster)
	elif entropy ==3:
	   w = GQLComponentClustering.leastSquares(mixtureAnnotation,mixtureCluster,1,400)
	   (xf, wf, filtered,
	    (mixtureAnnotation, serror)) =  GQLComponentClustering.leastSquaresWithFilter(mixtureAnnotation,
											  mixtureCluster,w) 
  	   print 'error', serror
	elif entropy ==4:
	   w = GQLComponentClustering.leastSquares(mixtureAnnotation,mixtureCluster,2,1)
	   (xf, wf, filtered, (mixtureAnnotation, serror)) =  GQLComponentClustering.leastSquaresWithFilter(mixtureAnnotation,mixtureCluster,w)
  	   print 'error', serror
	elif entropy ==5:
           subGO = go.getSubDag(terms)	
	   (mixtureAnnotation, assign, entropyRes,
	    componentsPartition) = GQLComponentClustering.greedJoiningDag(mixtureAnnotation,
	    							       mixtureCluster,subGO,terms)
           
	   

	# calculating the tp, fn ...
	(tp, fn, fp, tn) = GQLValidation.computeProbabilisticErrors(mixtureAnnotation,mixtureCluster)
	#(tpm, fnm, fpm, tnm)= GQLValidation.computeProbabilisticErrors(mixtureCluster,
	#						GQLValidation.maxPosteriorMixture(mixtureCluster))
	#print tp, fn, fp, tn, tpm, fnm, fpm, tnm


	if contigency:
	    table = GQLValidation.contigencyTableFromMixture(mixtureAnnotation,mixtureCluster)
            return table, compAnnot, len(mixtureAnnotation), componentsPartition, tp, fn, fp, tn, assign, entropyRes
        else:
            return compAnnot , len(mixtureAnnotation), componentsPartition, tp, fn, fp, tn, assign, entropyRes

def mixtureFromGoMapping(go,genes,level,ontologies,all=0,max=999999,min=1,countM=0,query=[]):
     ''' Obtain a posterior distributions given a mapping of genes to
     go terms. It return the posterior mixture, a list of genes not annotated in
     go and a count matrix.
     '''

     if all:
       atLevel = go.allNodes()
       print "all nodes", len(atLevel)
     else:
       atLevel = go.nodesAtLevel(level)
       
     atLevel = go.filterNodesOntology(atLevel,ontologies)
     atLevel = go.filterNotMeaningfull(atLevel)

     if query != []:
        atLevel = go.filterNodesQuery(atLevel,query)	
     
     if all:
        [mapping,goIds,atLevel] = go.genes2TermsMapping(atLevel,max=max,min=min)
	print "end nodes", len(atLevel)
	print len(goIds)
	print len(atLevel)
     else:
        [mapping,goIds,atLevel] = go.genes2TermsMapping(atLevel)
     
     mixture = zeros([len(genes),len(goIds)],Float)
     countMatrix = zeros([len(genes),len(goIds)],Float)
     notInMapping = ones(len(genes),Float)
     goIdsInv = {}

     #for t in atLevel:
       #print t.name
     
     for i,g in enumerate(genes):
         try:
             nodes = mapping[g]
             size = len(nodes)
             for n in nodes:
                 mixture[i,n] = 1.0
		 countMatrix[i,n] = 1.0
	 except KeyError:
             notInMapping[i] = 0.0
	     #print g

     #print sum(mixtureAux)
     #mixture = mixture/resize(sum(mixture,1),(len(genes),len(goIds) ))

     # getting ride of go nodes with number of genes < min and > max
     nzeros = []
     nzerosTerms = []
     names = []
     for i in range(len(mixture[0])):
       if ((sum(countMatrix[:,i]) > min) and
	   (sum(countMatrix[:,i]) < max)):
	     nzeros.append(i)
	     nzerosTerms.append(goIds[i])
	     names.append(atLevel[i].name)

     mixture = take(mixture,nzeros,axis=1)
     countMatrix = take(countMatrix,nzeros,axis=1)

     for i,p in enumerate(countMatrix):
         counts = sum(p)
         if counts > 0:
	    mixture[i,:] =  p/counts
	    #print sum(mixture[i,:])

     if countM == 2: # put level information
       for i in range(len(nzerosTerms)):
         countMatrix[:,i] = multiply(countMatrix[:,i],(go.terms[nzerosTerms[i]]).minLevel)	    
	    
     print "no final porra", len(mixture[0]), len(mixture)
     if countM>0:
       return mixture,notInMapping, nzerosTerms, countMatrix
     else:
       return mixture,notInMapping, nzerosTerms

def mixtureFromPPI(ppibaits,genes,countMat=0):
        posterior = zeros([len(genes),len(ppibaits)],Float)
        countMatrix = zeros([len(genes),len(ppibaits)],Float)
	notInMapping = zeros(len(genes),Float)

	geneindex = {}
	for i,g in enumerate(genes):
	    geneindex[g] = i	
	
	for i,bait in enumerate(ppibaits):
	    targets = ppibaits[bait]
	    targets+=[bait]

	    for g in targets:
	        if g != '':
		  try:	
   	            countMatrix[geneindex[g],i] = 1
		    posterior[geneindex[g],i] = 1
		    notInMapping[geneindex[g]] = 1
		  except KeyError:
		    #print "Gene not found", g, bait
		    pass
		    
		
	annotated = []
	baits = ppibaits.keys()	
	nzerosGenes = []
	print len(countMatrix[0])
	for i in range(len(countMatrix[0])):
	    if sum(countMatrix[:,i]) > 0:
	        annotated+=[i]	    
	        nzerosGenes.append(baits[i])
		
	print "orig x annot ", len(posterior[0]), len(annotated)
	print "mean genes annotated", sum(countMatrix)

	posterior = take(posterior,annotated,axis=1)
	countMatrix = take(countMatrix,annotated,axis=1)

        for i,p in enumerate(countMatrix):
	    counts = sum(p)
	    if counts > 0:
	        posterior[i,:] =  posterior[i,:]/counts	
	
	if countMat:
	  return posterior, notInMapping, nzerosGenes, countMatrix
	else:
 	  return posterior, notInMapping, nzerosGenes


def tableToString(table):
	saux = "------"
	s= '     |'
	for j in range(len(table[0])):
		s += "%6d" % (j+1)
		saux+="------"
	s += " \n"+saux

	for i in range(len(table)):
		if sum(table[i]) != 0:
		   s += "\n%5d |" % (i+1)
   		   for j in range(len(table[i])):
			s += ("%.2f" % (table[i][j])).rjust(6)

	return s

def randomizeMixture(mixture,perc,mask,type=0):
    '''This function shuffles randonly perc/100 of the
    posteriors distributions '''
    perc = float(perc)/100.00
    n = len(mask)
    #print mask
    #num = nonzero(mask)
    num = []
    for i,e in enumerate(mask):
       if e:
           num.append(i)
    #print "test", take(mask,num)
    ro = random.sample(num,int(perc*len(num)))
    newMixture = mixture.copy()
    print n, ro
    if type == 1:
      while len(ro) > 1:
        e1 = ro[0]
	e2 = ro[1]
	ro = ro[2:]
	print e1, e2
        newMixture[e2] = mixture[e1]
	newMixture[e1] = mixture[e2]
    else:
      for o in ro:
         random.shuffle(newMixture[o:]) # this operation happens in place
    return newMixture

def loadMatrix(lines,initial,final):
   data = zeros((len(lines),final-initial),float)
   for i,l in enumerate(lines):
      l = l.strip('\n')
      laux = l.split('\t')
      for j,v in enumerate(laux[initial:final]):
          data[i][j] += float(v)
   return data


def std(x):
  mean = sum(x)/len(x)
  err = 0
  for i in x:
    err = err + (i-mean)*(i-mean)
  return sqrt(err/len(x))
  

def main(fileName, posteriorassig):

    #fileName = sys.argv[1]
    #if len(sys.argv) > 2:
    #  posteriorassig = sys.argv[2] # posterior assigments (if empty cluster assigments in filename will be used)
    #else:
    #  posteriorassig = ''
      
    conf = ConfigParser.ConfigParser()
    conf.read("config.ini")

    #reading variables from the configuration file    
    random = conf.getint('general','random')
    cutoffs =  (conf.getint('general','cutoffs'))
    componentClustering = (conf.getint('general','component_clustering'))
    data = conf.get('general','external_data')    
    data_path = conf.get('general','external_data_path')
    numberOfRepetitions  = conf.getint('general','random_repetitions')
  
    
    #reading atributes specific to GO
    numberOfLevels = conf.getint('GO','no_of_levels')
    GOTermFile = conf.get('GO','term_file')
    GOAnnotationFile = conf.get('GO','annotation_file')
    ontologies = eval(conf.get('GO','ontologies'))
    maxGenesPerNode =  conf.getint('GO','max_genes_node')
    minGenesPerNode =  conf.getint('GO','min_genes_node')
    filter =  conf.getint('GO','filter_entropy')    

    #reading atributes of others annotation types
    PPIFile = conf.get('PPI','ppi_file')    
    RRRFile = conf.get('RRR','rrr_file')
    
    profileSet = GQLCluster.ProfileSet()
    sequenceSet = profileSet.ReadDataFromCaged(fileName)
    profile = GQLCluster.ProfileClustering()
    profile.setProfileSet(profileSet)

    if posteriorassig != "":
       profile.readMixtureDistributions(posteriorassig)
       expt = profile.P
       no_components = len(expt[0])
       #print expt
    else:
       no_components = max(profileSet.seq_classes) + 1
       profileSet.cluster = profileSet.seq_classes
       expt = GQLValidation.mixtureFromPartition(profileSet.cluster,no_components)
       profile.P = expt
       cutoffs = 1

    if posteriorassig != '':
        fileName = posteriorassig

    if( "../" in fileName):
        fileName = fileName[3:]        

    fileNameAux = fileName+"_"+data+"_1_"+str(componentClustering)

    if( data == 'GO' or data == 'GOSlim' or data=='GOAll'):
        go = GODag()
        go.loadTermsFromFile(data_path+GOTermFile)
        go.loadGenesFromFile(data_path+GOAnnotationFile)
        if( data == 'GOSlim' ):
            go = go.getGOSlim('goslim_yeast')
    elif (data == 'PPI'):
        ppi = PPIGraph()
        ppi.loadPPIFromFile(data_path+PPIFile)
	(mixtureAno,mask, labelnames) = mixtureFromPPI(ppi.baitTargets,profileSet.genename)
    elif (data == 'RRR'):
        ppi = PPIGraph()
        ppi.loadPPIFromFile(data_path+RRRFile)
	(mixtureAno,mask, labelnames, countMat) = mixtureFromPPI(ppi.baitTargets,profileSet.genename,countMat=1)
        
    if( random == 1):
       dist = sum(expt)
       dist = dist/sum(dist)
       #print "distribuicao", dist
       fileRandom =  open(fileNameAux+"_rand.txt","a")
    elif( random > 0):
       fileRandom =  open(fileNameAux+"_rand.txt","a")


    file =  open(fileNameAux+".txt","a")
    #fileTable = open(fileNameAux+".tab","w")


    profile.calculateEntropies()

    # calculating the cutoffs steps
    if (cutoffs > 1):
        cutOffStep = len(profileSet)/cutoffs
    else:
        cutOffStep = 1

    for k in range(cutoffs):

      for i in range(numberOfLevels):

        # go related attribute, use all ontologies
        #ontologies = [[0,1,2]]
        if data == 'PPI' or data == 'RRR':
            ontologies = [[0]]

	file.write("%i"%i)

        if cutoffs == 1:
          currentCuttOff = 0
        else:
          currentCuttOff = profile.sortedEntropies[(cutoffs-k)*cutOffStep]

        for j in ontologies:

	    randSum = 0.0
	    cRandSum = 0.0
	    maxRandSum = 0.0

	    if( data == 'GO' or data == 'GOSlim'):
	        (mixtureAno,mask,labelnames,ctMat) = mixtureFromGoMapping(go,profileSet.genename,i+2,j,countM=1)
            elif (data == 'GOAll'):
                (mixtureAno,mask,labelnames,ctMat) = mixtureFromGoMapping(go,profileSet.genename,i+2,j,all=1,
                                                                    max=maxGenesPerNode,min=minGenesPerNode,countM=1)

  	    

	    if (random > 0):
	      ents = []
	      cRands = []
	      for k in range(numberOfRepetitions):
	        if random == 1:
		    #cluster = GQLValidation.computeRandomClusteringFixed(len(profileSet),dist)
		    print 'dist', dist
		    cluster = GQLValidation.computeRandomShuffle(profileSet.cluster)
                    (g,n,labelPart, tp, fp, fn, tn, ass, ent) = clusterEvaluationMixture(GQLValidation.mixtureFromPartition(
                                                                                                   cluster,no_components),
											     mixtureAno,mask,
											     entropy=componentClustering,
											     cutoff=currentCuttOff,
											     entropies=profile.entropies,
											     go = go,
											     terms=labelnames,
											     filter=filter)
												  
		else:
		    pertMixture = randomizeMixture(expt,random,mask,type=1)
                    (g,n,labelPart, tp, fp, fn, tn, ass, ent) = clusterEvaluationMixture(pertMixture,
											     mixtureAno,mask,
											     entropy=componentClustering,
											     cutoff=currentCuttOff,
											     entropies=profile.entropies,
											     go = go,
											     terms=labelnames,
											     filter=filter)												   

      		orand = GQLValidation.rand(tp, fp, fn, tn)
		cRand = GQLValidation.correctedRand(tp, fn, fp, tn, maxi=1)
                print 'rand random', cRand, tp, fn, fp, tn
	        fileRandom.write("%i\t%f\t%1.10f\t%f\t%f\t%f\t%f\t%f\t%f\n"%(k,orand,cRand,ent,tp,fp,fn,tn,tp+fp+fn+tn))

		randSum += orand
		cRandSum += cRand
		cRands.append(cRand)
		ents.append(ent)

              #fileRandom.write("\n")

	      mean = randSum/numberOfRepetitions
	      cmean = cRandSum/numberOfRepetitions
	    else:
	      try:
	        fileRandom =  open(fileNameAux+"_rand.txt","r")
		lines = fileRandom.readlines()
	        randomRes = loadMatrix(lines,0,4)
		randSum = sum(randomRes[:,1])
		cRandSum = sum(randomRes[:,2])
		cRands = randomRes[:,2]
		ents = randomRes[:,3]
		no = len(ents)
	        mean = randSum/no
	        cmean = cRandSum/no
                random = 1;
	      except:
  	        mean=0
	        cmean=0
	        ents=[0.0,0.01]
	        cRands=[0.0, 0.1]
	      
  	        print 'comp eval', componentClustering

            (ct,g,n,labPart,tp, fp, fn, tn, ass, ent) = clusterEvaluationMixture(expt,ctMat,mask,
											 entropy=componentClustering,
                                                                                         contigency=1,
											 cutoff=currentCuttOff,
                                                                                         entropies=profile.entropies,
											 go = go,
											 terms=labelnames,
											 filter=filter)
	      
            file.write("\t%f\t%f\t%f\t%f\t%f\t%f\t%f\t%f"%(GQLValidation.rand(tp,fn,fp,tn),
							  GQLValidation.correctedRand(tp, fn, fp, tn, maxi=1.0),
							   ent,tp,fp,fn,tn,tp+fp+fn+tn))
            if random > 0:
              zscoree = (ent - sum(ents)/len(ents))/std(ents)
              zscorecr = ( GQLValidation.correctedRand(tp, fn, fp, tn, maxi=1.0) - cmean)/std(cRands)
              file.write("\t%f\t%f\t%f"%(zscoree,zscorecr,GQLValidation.correctedIndex(GQLValidation.correctedRand(tp,fn,fp,tn),cmean)))

              if componentClustering == 5: # analysis of component partition
		      import GQLComponentClustering

		      termscluster = {}
		      anotCounts = dot(transpose(expt),ctMat)
		      counts = {}
		      mut = GQLComponentClustering.mutualInfoTerms(ctMat,expt)
		      for i,k in enumerate(labelnames):
			counts[k] = '_'.join([str(z) for z in anotCounts[:,i]])+('_'+str(mut[i]))
			#counts[k] = '_'.join([str(z) for z in anotCounts[:,i]])
		      for i,group in enumerate(labPart):
			#fileComponentClustering.write('Group '+str(ass[i])+'\n')
			goIds = []
			for e in group:
			  t = go.terms[labelnames[e]]		  
			  #fileComponentClustering.write(t.goId+'\t'+t.name+'\t'+ ass[i]+t.name+'\n')
			  goIds.append(t.goId)
			try:
			  termscluster[ass[i]] = termscluster[ass[i]]+goIds
			except:
			  termscluster[ass[i]] = goIds


		      for (i,terms) in termscluster.iteritems():
			fileComponentClustering = open(fileNameAux+'_'+str(i)+'_goterms.txt','w')      
			dagaux = go.getSubDagParents(copy.deepcopy(terms))
			# improve this
			headTerms = []
			for t in terms:
			  ps = dagaux.getParent(t)
			  term= go.terms[t]		  
			  hasIn = 0
			  for p in ps:
			    hasIn = hasIn | (p in terms)

			  if not hasIn:
			    headTerms.append(t)
			    fileComponentClustering.write(t+'\t'+term.name+'\t'+str(i)+'\t1\t'+replace(counts[t],'_','\t')+'\n')
			  else:
			    fileComponentClustering.write(t+'\t'+term.name+'\t'+str(i)+'\t0\t'+replace(counts[t],'_','\t')+'\n')
			fileComponentClustering.close()
			dagaux.toDotandSave(fileNameAux+'_'+str(i)+'_goterms.png',highlight=[terms,headTerms], extra=counts)

			dagaux = go.getSubDagParents(copy.deepcopy(headTerms))
			dagaux.toDotandSave(fileNameAux+'_'+str(i)+'_headgoterms.png',highlight=[headTerms], extra=counts)

		      fileComponentClustering.close()	



	    #fileTable.write("Level %i"%i)
	    #print >> fileTable,
	    #fileTable.write("\n")
	    #if data == 'RRR':
	    #  import GQLComponentClustering
	    #  fileTable.write(tableToString(GQLValidation.contigencyTableFromMixture(expt,
	    #  GQLComponentClustering.matrixFromPartition(countMat,labPart)))+"\n\n")
	    #else:
	    #  fileTable.write(tableToString(ct)+"\n\n")


        file.write("\n")

    file.close()
    #fileTable.close()

    if random:
        fileRandom.close()



usage_info="""
Generate a sequence set from given models.

Usage: GQLEvaluation <cluster_data_file.txt> <posterior_assignment_data_file.txt> (optional)

Example: GQLEvaluation yspo_hmm.txt

Cluster data file follows the formats used by GQL, where cluster
assignments are in the second column. Check the documentation of
http://www.ghmm.org/gql

You need a <config.ini> file in the directory you are executing the
software. Make sure you set the variable <external_data_path> to the
path where the files gene_association.sgd gene_ontology.obo are
located.

Make use of the other variables on your on risk!

Output:

<filename>_1_5_<cluster_number>_goterms.txt - list with all informative genes (5rd collumn indicates non-redundant terms)
<filename>_1_5_<cluster_number>_goterms.png - png plot with all informative terms per cluster (light grey indicate redundant/dark gray non-redundant
<filename>_1_5_<cluster_number>_goterms.txt - list with all informative genes (

For method details please refeer to

I. G. Costa, M. C. P. de Souto and A. Schliep Validating Gene
Clusterings by Selecting Informative Gene Ontology Terms with Mutual
Information Advances in Bioinformatics and Computational Biology,
Proceedings of Brazilian Symposium on Bioinformatics 2007, LNBI,
Springer, 81-92.
"""

if __name__ == '__main__':

    print  usage_info

    import sys

    
    import profile
    import pstats

    fileName = sys.argv[1]

    if len(sys.argv) > 2:
      posteriorassig = sys.argv[2] # posterior assigments (if empty cluster assigments in filename will be used)
    else:
      posteriorassig = ''

    main(fileName, posteriorassig)
    

