#!/usr/bin/env python2.3
################################################################################
#
#       This file is part of the GQL (Graphical Query Language) Toolkit
#
#       file:   GODag.py
#       author: Ivan Costa (filho@molgen.mpg.de)
#
#       Copyright (C) 2003-2004 Alexander Schliep and Ivan Costa
#
#       Contact: filho@molgen.mpg.de
#
#       Information: http://ghmm.org/gql
#
#   GQL is free software; you can redistribute it and/or modify
#   it under the terms of the GNU General Public License as published by
#   the Free Software Foundation; either version 2 of the License, or
#   (at your option) any later version.
#
#   GQL is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#   GNU General Public License for more details.
#
#   You should have received a copy of the GNU General Public License
#   along with GQL; if not, write to the Free Software
#   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#
#
#       This file is version $Revision: 1683 $
#                       from $Date: 2007-10-10 07:46:21 -0300 (Wed, 10 Oct 2007) $
#             last change by $Author: filho $.
#
################################################################################


from Term import *
from Gene import *
#import Graph
from Gato.Graph import Graph
from string import *
from Annotation import *
import copy

class GODag(Graph):
    """ Todo
    """

    directed = 0

    terms = {} # terms in this GO
    genes = {} # genes in annotated in this GO
    termsIdsToGoIds = {}
    roots = []

    def __init__(self):
        Graph.__init__(self)
        self.euclidian = 0
        self.terms= {}
        self.genes= {}
        self.termsIdsToGoIds = {}
        self.roots=[]

    def addTerm(self,goId,name,definition,namespace,subset=[],parents=[]):
        id = self.AddVertex()
        term = Term(id,goId,name,definition,namespace,subset,parents)
        self.terms[goId] = term
        self.termsIdsToGoIds[id] = goId

    def addRelation(self,parentId,childId,type):
        # type is currently not being kept ... do this latter
        self.AddEdge(self.terms[parentId].id,self.terms[childId].id)

    def addGene(self,id,acc,name,type,synonym=[]):
        self.genes[id] = Gene(id,acc,name,type,synonym=[])

    def getTermFromGoId(self,goId):
        return self.terms[goId]

    def getGeneFromGoId(self,goId):
        return self.genes[goId]

    def setRoots(self,roots):
        if len(roots) == 0:
          # if none is give, do this by hand       
          # this information is fixed and cannot be read from the GO files, so I fixed then 
          rootFunction = self.terms["GO:0003674"]
          rootProcess = self.terms["GO:0008150"]
          rootComponent = self.terms["GO:0005575"] 
          self.roots= [rootFunction,rootProcess,rootComponent]
        else:
          self.roots = roots

    def setLevels(self):
        ''' Set the levels of the go tree'''

        # this information is fixed and cannot be read from the GO files, so I fixed then
        #rootFunction = self.terms["GO:0003674"]
        #rootProcess = self.terms["GO:0008150"]
        #rootComponent = self.terms["GO:0005575"]
        #self.roots= [rootFunction,rootProcess,rootComponent]

        #clearing up

        for t in self.terms.values():
            t.maxLevel = 1
            t.minLevel = 9999

        for r in self.roots:
          r.minLevel = 1          

        #self.roots[Term.FUNCTION].minLevel = 1
        #self.roots[Term.PROCESS].minLevel = 1
        #self.roots[Term.COMPONENT].minLevel = 1

        queue = self.roots[:]
        
        maxAux = 1

        while (queue != []):
            e = queue.pop(0)
            siblings = self.OutNeighbors(e.id)

            for s in siblings:
                term = self.terms[self.termsIdsToGoIds[s]]
                term.maxLevel = max(term.maxLevel,e.maxLevel+1)
                term.minLevel = min(term.minLevel,e.minLevel+1)
                queue.append(term)
                maxAux = max(maxAux,term.maxLevel)
        
        print "max. level", maxAux

    def getLeaves(self):
        ''' return a list of nodes at a certain level'''
        nodes = []
        for t in self.terms.values():
            if( len(self.OutNeighbors(t.id))==0):
                aux = self.getParent(t.goId)
                if len(aux) > 0:
                  nodes.append(t)
                  print t.name, t.minLevel
                else:
                  print 'first level node'
        return nodes

    def nodesAtLevel(self,level,useMinLevel=0):
        ''' return a list of nodes at a certain level'''
        nodes = []
        for t in self.terms.values():
            if( t.maxLevel == level):
                nodes.append(t)
        return nodes

    def getParent(self,goId):
      term = self.terms[goId] 
      nodes = self.InNeighbors(term.id)
      res = []
      for n in nodes:
        res.append(self.termsIdsToGoIds[n])
      return res


    def allNodes(self,noroots=0):
        ''' return a list of nodes at a certain level'''
        nodes = []
        for t in self.terms.values():
            if (noroots==1):
              if t not in self.roots:
                nodes.append(t)
            else:
              nodes.append(t)
        return nodes

    def getGOSlim(self,name):
        ''' get only the nodes from the goslim "name"
        right now it is not carrying a real copy of the dag'''
        
        nodes = {}
        for t in self.terms.values():
            if( name in t.subset):
                nodes[t.goId] = t

        #print len(nodes), nodes

        goSlimDag = GODag()
        goSlimDag.terms = copy.copy(self.terms)
        goSlimDag.genes =  copy.copy(self.genes)
        goSlimDag.termsIdsToGoIds =  copy.copy(self.termsIdsToGoIds)
        goSlimDag.roots =  copy.copy(self.roots[:])
        
        graphCopy(self,goSlimDag)

        # first clearing the root
        for r in goSlimDag.roots:
            out = goSlimDag.OutNeighbors(r.id)
            auxOut = out[:]
            for o in auxOut:
                goSlimDag.DeleteEdge(r.id,o)
                
        #hack -- copy nodes not in the second level to the first level
        # and get rid of any parent association
        for n in nodes.values():
            if n.minLevel > 1:
                # this operation is being carried by the original graph ...
                # update this ...
                edges = goSlimDag.InNeighbors(n.id)
                auxEdges = edges[:]
                for e in auxEdges:
                    goSlimDag.DeleteEdge(e,n.id)
                # link it to the root
                goSlimDag.AddEdge(self.roots[n.namespace].id,n.id)                
        goSlimDag.setLevels()
        goSlimDag.setRoots([])
        return goSlimDag

    def getSubDagParents(self,termsNames):
        ''' get only the sub DAG containing the terms and all parents'''        
        queue = termsNames
        visited = []
        while len(queue) > 0:
          n = queue.pop()          
          if n not in visited:
            visited.append(n) 
            pars = self.getParent(n)
            for par in pars:
                if par not in visited:
                    if par not in queue:
                      queue.append(par)                    
        return self.getSubDag(visited)

    def toDotandSave(self,filename,highlight=[],extra=[]):
        import pydot
        edges = self.Edges()
        newEdges = []
        map = {}
        for (i,o) in edges:
            t = self.terms[self.termsIdsToGoIds[i]]
            inn = t.name.replace(' ','_')
            try:                
              inn = inn+'_'+extra[self.termsIdsToGoIds[i]]         
            except:
              pass                
            map[self.termsIdsToGoIds[i]] = inn
            t =  self.terms[self.termsIdsToGoIds[o]]
            out = t.name.replace(' ','_')
            try:                
              out = out+'_'+extra[self.termsIdsToGoIds[o]]         
            except:
              pass    
            map[self.termsIdsToGoIds[o]] = out
            newEdges.append((inn,out))
        g = pydot.graph_from_edges(newEdges,directed=True)        
        nodes = g.get_node_list()
        colors = ['gray86','gray57','blue','red','green']
        for i,list in enumerate(highlight):
          for name in list:
              t = g.get_node(map[name])
              t.set('fillcolor',colors[i])
              t.set('style','filled')
              #t.set('bottomlabel','test\ntest')
              #t.to_string()
            
        g.write_png(filename,prog='dot')
        #print g.to_string()

    def getSubDag(self,termsNames):
        ''' get only the sub DAG with the terms
            it assumes that the roots should be in the list'''
        
        nodes = {}
        notNodes = {}
        newtermsIdsToGoIds = {}
        for t in self.terms.values():
            if( t.goId in termsNames):
                nodes[t.goId] = t
                newtermsIdsToGoIds[t.id] = t.goId
            else:
                notNodes[t.goId] = t
                

        #print len(nodes), nodes

        subDag = GODag()
        #goSlimDag.terms = self.terms
        #goSlimDag.genes = self.genes
        subDag.termsIdsToGoIds = newtermsIdsToGoIds
        subDag.terms =  nodes
        subDag.roots = copy.copy(self.roots[:])
        #print nodes
        #print newtermsIdsToGoIds
        
        #subDag = copy.deepcopy(self)
        graphCopy(self,subDag)
        #print subDag.adjLists


        #print 'not nodes', notNodes

        # clearing the relation of terms not in the list
        for r in notNodes.values():
            ins = copy.copy(subDag.InNeighbors(r.id))
            #print 'inside', r.id, ins
            for e in ins:
                #print 'removing', e, r.id, ins
                subDag.DeleteEdge(e,r.id)
                #print 'removing', e, r.id, ins
            out = copy.copy(subDag.OutNeighbors(r.id))
            #print r.id, out
            for o in out:
                #print 'removing2', o, r.id, out
                subDag.DeleteEdge(r.id,o)
                #subDag.DeleteEdge(o,r.id)
                
        subDag.setLevels()
        return subDag


    def genesFromTerms(self,nodes):
        ''' given a set of nodes, a list of genes related to it
        with no gene repetition!
        think about latter how can repeated gene annotation
        information can be put here (distribution?
        are there enough annotations?)'''
        genes = {}
        for n in nodes:
            for a in n.annotations:
                genes[a.gene.id] = a.gene
        return genes.values()
        
    def genes2TermsMapping(self,nodes,max=9999999,min=0):
        mapping = {}
        id2GoId = {}
        countmin=0
        countmax=0
        endNodes = []
        i = 0
        for n in nodes:
            genes = self.genesFromSubDAG(n)
            #print genes
            if len(genes) < min:
                countmin +=1
            elif len(genes) > max:
                countmax +=1
            else:
              id2GoId[i] = n.goId
              #print n.goId, len(genes)
              #s = ""
              genes.sort()
              for g in genes:
              #    s += g.acc+" "
                try:
                  aux =  mapping[g.acc]
                  aux.append(i)
                except:
                  mapping[g.acc] = [i]
                for s in g.synonym:
                  try:
                    aux =  mapping[s]
                    aux.append(i)
                  except:
                    mapping[s] = [i]                  
              endNodes.append(n)
              i+=1

        print "removed terms", countmax,countmin
                
        return mapping, id2GoId, endNodes

    def partition(self,nodes,clustGenes):
        mapping = {}
        annotated = {}

        for i,g in enumerate(clustGenes):
            mapping[g] = i
        
        id2GoId = {}
        partition = []
        
        for i,n in enumerate(nodes):
            list = []
            genes = self.genesFromSubDAG(n)
            id2GoId[i] = n.goId
            for g in genes:
                if g.acc in mapping:
                    list.append(mapping[g.acc])
                    annotated[mapping[g.acc]]=mapping[g.acc]
                else:
                    for s in g.synonym:
                        if s in mapping:
                            list.append(mapping[s])
                            annotated[mapping[s]] = mapping[s]
                            break
            partition.append(list)
        #print annotated
        genesUsed = annotated.values()
        #print genesUsed
        genesUsed.sort()
        #print genesUsed
        return partition,id2GoId,genesUsed


    def filterNodesOntology(self,nodes,ontologies):
        auxNodes = []
        for n in nodes:
            if n.namespace in ontologies:
                auxNodes.append(n)
        print 'ontologies',ontologies, len(auxNodes)
        return auxNodes

    def filterNodesQuery(self,nodes,queries):
        auxNodes = []
        for n in nodes:
          #if n.namespace in ontologies:
            for s in queries:
              if s == n.name:                
                auxNodes.append(n)
        print len(auxNodes), auxNodes
        return auxNodes
    
    def filterNotMeaningfull(self,nodes):
        ''' filter nodes that are not interessant for gene
        labeling, such as not_known and ???'''

        filterNodes = ['GO:0000004','GO:0008370','GO:0005554']
        auxNodes = []
        for n in nodes:
            if n.goId not in filterNodes:
                auxNodes.append(n)
        return auxNodes

    
    def genesFromSubDAG(self,root):
        ''' starting from a node, look
        for all siblings from this node'''

        queue = [root]
        nodes = {}
        nodes[root.id] = root
        
        while (queue != []):
            e = queue.pop(0)
            siblings = self.OutNeighbors(e.id)

            for s in siblings:
                term = self.terms[self.termsIdsToGoIds[s]]
                queue.append(term)
                nodes[term.id] = term 

        #print len(nodes.values())
        return self.genesFromTerms(nodes.values())

    def loadTermsFromFile(self,fileName):
        ''' loads GO terms from a OBO File '''
        file = open(fileName,'r')
        lines = file.readlines()
        file.close()

        goId = None
        name = None
        definition = None
        namespace = None
        subset = []
        parents = []
        newTerm = 0
        for l in lines:
            l = l.strip('\n\c')
            items = l.split(':')
            if(items[0] == '' and newTerm): # reached end of record
                self.addTerm(goId,name,definition,namespace,subset,parents)
                goId = None
                name = None
                definition = None
                namespace = None
                subset = []
                parents = []
                newTerm = 0
            elif(items[0] == '[Term]'): #beginning of a record
                newTerm = 1
            elif(items[0] == 'id' and newTerm):
                goId = 'GO:'+items[2].lstrip()
            elif(items[0] == 'name' and newTerm):
                name = items[1].lstrip()
            elif(items[0] == 'def'):
                definition = items[1].lstrip()
            elif(items[0] == 'namespace'):
                aux = items[1].strip()
                if (aux == 'molecular_function'):
                    namespace = 0
                elif (aux == 'biological_process'):
                    namespace = 1
                elif (aux == 'cellular_component'):
                    namespace = 2
            elif(items[0] == 'subset'):
                subset.append(items[1].lstrip())
                #print subset
            elif(items[0] == 'is_a'):
                aux = items[2].split('!')
                parents.append('GO:'+aux[0].strip())
            elif(items[0] == 'relationship'):
                aux = items[2].split('!')
                parents.append('GO:'+aux[0].strip())
            elif(items[0] == 'is_obsolete'):
                goId = None
                name = None
                definition = None
                namespace = None
                subset = []
                parents = []
                newTerm = 0

        # putting the edges in the DAG

        for t in self.terms.values():
            for p in t.parents:
                self.addRelation(p,t.goId,"not in use")

        self.setRoots([])
        self.setLevels()


    def loadGenesFromFile(self,fileName):
        ''' loads GO gene annotation from a flat File '''

        if len(self.terms) == 0:
            print "You need to load GO ontologies first"
            return

        file = open(fileName,'r')
        lines = file.readlines()
        file.close()

        for l in lines:
            l = l.strip('\n\c')
            items = l.split('\t')

            if( len(items) > 1): 
    
                geneGoId = items[1]
                acc = items[2]
                qualifiers = items[3] # not used - find out what is this !!!

                annotatedGoTerm = items[4]
                annotationGoId = items[5] # can be two ???
                evidence = items[6]
                otherQualifier = items[7] # not used
                ontology = items[8]
                geneName = items[9]
                synonyms = items[10].split('|')
                synonyms.append(geneGoId)
                type = items[11]
                dataBase = items[14]

                if geneGoId not in self.genes:
                    self.genes[geneGoId] = Gene(geneGoId,acc,geneName,type,synonyms)
                 
                if annotatedGoTerm in self.terms:
                    gene = self.genes[geneGoId]
                    termAux = self.terms[annotatedGoTerm]
                    termAux.addAnnotation(Annotation(annotationGoId,evidence,dataBase,gene))
                else:
                    print "did not found ",annotatedGoTerm



def graphCopy(g1,g2):
        g2.simple           = copy.copy(g1.simple)
        g2.euclidian        = copy.copy(g1.euclidian) 
        g2.directed         = copy.copy(g1.directed)
        g2.vertices         = copy.deepcopy(g1.vertices[:])
        g2.adjLists         = copy.deepcopy(g1.adjLists)
        g2.invAdjLists      = copy.deepcopy(g1.invAdjLists)
        g2.highVertexID     = copy.deepcopy(g1.highVertexID)
        g2.embedding        = copy.deepcopy(g1.embedding)
        g2.labeling         = copy.deepcopy(g1.labeling)
        g2.edgeWeights      = copy.copy(g1.edgeWeights.copy())
        g2.edgeWeights[0]   = copy.copy(g1.edgeWeights[0])
        g2.vertexWeights    = copy.copy(g1.vertexWeights.copy())
        g2.size             = copy.deepcopy(g1.size)
        g2.edgeWidth        = copy.deepcopy(g1.edgeWidth )
        g2.vertexAnnotation = copy.deepcopy(g1.vertexAnnotation)
        g2.edgeAnnotation   = copy.deepcopy(g1.edgeAnnotation )
        g2.properties       = copy.copy(g1.properties.copy())
        return 0
        
if __name__ == '__main__':
    
    go = GODag()
    go.loadTermsFromFile('gene_ontology.obo')
    print "no of terms", len(go.terms)
    print "no of edges in the tree", go.Size()
    sizeOnt = go.Size()
    
    go.loadGenesFromFile('gene_association.sgd')
    
    print "no of genes", len(go.genes)
    sizeTotal = go.Size()
    print "last sizes", (sizeTotal - sizeOnt), sizeTotal
    
    counts = []    
    for i in go.terms.values():
        counts.append(len(i.annotations))
       
    print "gene - terms relations", sum(counts)
    
    import numpy
    
    counts = numpy.array(counts)
    
    print sum(counts == 0)
    print sum(counts < 3)
    print sum(counts > 500)

    go.setRoots([])
    go.setLevels()
