#!/usr/bin/env python2.3
################################################################################
#
#       This file is part of the GQL (Graphical Query Language) Toolkit
#
#       file:   GQL.py
#       author: Alexander Schliep (alexander@schliep.org)
#
#       Copyright (C) 2003-2004 Alexander Schliep
#
#       Contact: alexander@schliep.org
#
#       Information: http://ghmm.org/gql
#
#	GQL is free software; you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation; either version 2 of the License, or
#	(at your option) any later version.
#
#	GQL is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with GQL; if not, write to the Free Software
#	Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#
#
#       This file is version $Revision: 2197 $
#                       from $Date: 2009-12-07 14:27:46 -0300 (Mon, 07 Dec 2009) $
#             last change by $Author: filho $.
#
################################################################################
#
#
#
#----- Globals -----------------------------------------------------------------
gGQLVersion = 0.6
gGQLBuilddate = "1/15/2003"
gNCBIURL = "http://www.ncbi.nlm.nih.gov/entrez/query.fcgi?cmd=search&db=nucleotide&term=%s[accn]"
MISSING_DATA = -9999.99
END_DATA = 9999.99

#-------------------------------------------------------------------------------
import os
import string
import numpy as Numeric
import random
import math
from ghmm import SequenceSet, Float

# missing data value

def minWithoutMissing(array):
    result = []
    for x in array:
        if x != MISSING_DATA: # simbol for missing data
           result.append(x)
    return min(result)

class ProfileSet:
    """ Class to hold all information pertinent to a profile data set

        - profiles are referenced via an integer (from 0 to #profiles - 1)
        - all additional info is stored in vectors
        - id to index dictionary is used for cross-referencing

    """

    def __init__(self):
        self.new()

    def new(self):
        self.virgin = True

        self.size = 0

        self.profile = []
        self.info = []
        self.acc = []
        self.genename = []
        self.id2genename = {}
        self.outlink = []
        self.likelihood = []
	self.seq_classes  = [] # classes of the sequences (when they exists)
 	self.classes_no = 0 # number of classes
	self.max_class = 0 # class with maximun value
	self.classes = []

        self.xrange = [0,0]
        self.yrange = ['Inf',0]
        self.likelihoodrange = [0,0]

        self.ghmm_seqs = None # a SequenceSet object.
        self.id2Index = {}
        self.fileName = None
        self.missingRate = 0.0

    def beforeRead(self):
        if not self.virgin:
            del(self.profile)

            del(self.acc)
            del(self.genename)
            del(self.id2genename)
            del(self.outlink)
            del(self.info)
	    del(self.seq_classes)
	    del(self.classes_no)
	    del(self.classes)

            del(self.likelihood)
            del(self.id2Index)
            #XXX del(self.ghmm_seqs.free())
            self.new()

    def afterRead(self):
        for i in range(self.size):
            self.id2genename[self.genename[i]] = i


    def ReadDataFromCaged(self, fileName, end=1):
        self.fileName = fileName
        self.beforeRead()

        #--- create the SQD file
        # File format of caged file (and a Matlab compatible variant)
        # 1). line: Textual description "Gene name\tAccession" for cage
        # Genname Acc# value1 ... valueN (tab-delim)

        file = open(fileName,'r')
        lines = file.readlines()
        #print "ReadDataFromCaged. File has %d lines" % len(lines)
        file.close()

        if lines[0].find('dim') > -1:
            items = lines[0].split('\t')
            end = int(items[1])
            lines = lines[1:]
        
	classes = [] # keep track of the existing classes

        outfile = open(fileName + ".sqd",'w')
        outfile.write("SEQD = {\n      O = {\n")

        self.headerLine = lines[0]
        if string.find(self.headerLine, "Gene") >= 0:
            start_i = 1
            self.headerLine = lines[0]
        else:
            print "ReadDataFromCaged: No Header, assuming Matlab compat"
            start_i = 0

        def replaceMissing(x):
            # XXX Fix me !
            if x == '':
                return MISSING_DATA
#                 return 0.0
            elif x == '_M': # 'official' missing value indicator
                return MISSING_DATA
#                 return 0.0
            elif x == 'NA': # 'official' missing value indicator
                return MISSING_DATA
#                 return 0.0
            else:
                return x

        nrTimeSteps = len(lines[0].split('\t')) - 2

        i = 1 # We count the profiles from 1
        for l in lines[start_i:]:
            items = l.split('\t')
            items[-1] = string.strip(items[-1],'\r\n')

            # Fix too-short sequences: values missing at the end and no tabs
            if len(items[2:]) < nrTimeSteps:
                items += ([''] * (nrTimeSteps - len(items[2:])))

            values = map(replaceMissing, items[2:])

            #print l, items[2:], values  
            if MISSING_DATA != max(map(float,values)): # data contains not only missing data

                self.genename.append(string.strip(str(items[0])))
                self.acc.append(string.strip(str(items[0])))
                self.info.append(string.strip(str(items[1])))

                # try to read label information ...
	        try:
	    	    class_aux = int(string.strip(str(items[1])));
                    self.seq_classes.append(class_aux) # if it is only a number in the commentar, asume it is the class of the sequence
		    if class_aux not in self.classes:
			self.classes.append(class_aux)
	        except ValueError:
                    class_aux = 0;
                    self.seq_classes.append(-1) # no class information ...

                outfile.write('('+str(len(self.genename))+') |'+str(class_aux)+' | ')
                outfile.write(string.join(map(str,values), ', '))

                if end>0:
                  aux = ', '.join(['9999.99' for i in range(end)])
                  outfile.write(', '+aux+';\n')
                else:
                    outfile.write(';\n')
                
                self.outlink.append("www.somedb.org/%d.html" % i)
                i += 1


	self.classes_no = len(self.classes)
	if(self.classes == []):
		self.max_class = 0
	else:
		self.max_class = max(1,max(self.classes))

	outfile.write("};};")
        outfile.close()

        self.ghmm_seqs = SequenceSet(Float(), fileName + ".sqd")

        for i in xrange(len(self.ghmm_seqs)):
            p = list(self.ghmm_seqs[i])
            # NOTE: Sqd files are already tagged with 9999 value as end marker
            self.addProfile("#%d" % i, p[:(len(p)-end)], "info about %d" % i)
            #self.acc.append("acc%d" % i)
            #self.genename.append("genename%d" % i)
            #self.outlink.append("www.somedb.org/%d.html" % i)

        self.afterRead()
        self.virgin = False
        self.missingDataRate()
        return self.ghmm_seqs

    def ReadDataFromDSequences(self, ghmm_seqs,fileName):
        print "read data from ds"
        self.fileName = fileName
        self.beforeRead()
        self.ghmm_seqs = ghmm_seqs
        for i in xrange(len(ghmm_seqs)):
            p = list(ghmm_seqs[i])
            # NOTE: Sqd files are already tagged with 9999 value as end marker
            self.addProfile("#%d" % i, p[:-1], "info about %d" % i)
            self.acc.append("acc%d" % i)
            self.genename.append("genename%d" % i)
            self.info.append(ghmm_seqs.getSeqLabel(i))
            self.outlink.append("www.somedb.org/%d.html" % i)
	    try:
		class_aux = int(ghmm_seqs.getSeqLabel(i))
                print 'read class', class_aux
		self.seq_classes.append(class_aux)
		if class_aux not in self.classes:
			self.classes.append(class_aux)
            except ValueError:
               self.seq_classes.append(-1) # no class information ...

	if(self.classes == []):
		self.max_class = 0
	else:
		self.max_class = max(1,max(self.classes))
	self.classes_no = len(self.classes)
        self.missingDataRate()
        self.afterRead()
        self.virgin = False

    def missingDataRate(self):
        missingCount = 0.0
        totalCount = len(self.ghmm_seqs)*(len(self.ghmm_seqs[0])-1)
        for sequence in self.ghmm_seqs:
            for value in sequence:
                if value == MISSING_DATA:
                    missingCount += 1.0
        self.missingRate = missingCount/totalCount
        print "Missing data rate %f"%self.missingRate
        print len(self.ghmm_seqs[0]), len(self.ghmm_seqs)
        

    def addProfile(self, id, profile, info,classe=-1):
        # Determine min and max of the y coordinate
        self.yrange[0] = min(self.yrange[0], minWithoutMissing(profile))
        self.yrange[1] = max(self.yrange[1], max(profile))

        # Determine maximal x-value!
        self.xrange[1] = max(self.xrange[1], len(profile)-1)

        self.profile.append(profile)
        #self.info.append("This is profile id = %s" % id)
        try:
          self.id2genename[self.genename[self.size]] = self.size
        except IndexError:
          pass
            
        self.size += 1
        if (classe != -1):
	    self.seq_classes.append(classe)

    def __getitem__(self, i):
        return self.profile[i]

    def Info(self,i):
        return "Acc#%s, its genename is %s - %s" % (self.acc[i], self.genename[i], self.info[i],)

    def __len__(self):
        return self.size

    def GHMMProfileSet(self): # XXX Should be GHMMSequenceSet
        return self.ghmm_seqs

    def getclass(self, i):
        return self.seq_classes[i]

    def tabDelim(self,query,delim='\t'):
        """Return id acc genename timepoint1 ... timepointn in tab-delimeted format"""
        result = ""

        for i in query.Result():
            result += "%s%s%2.2f%s%s%s%s%s" % (str(i), delim,
                                            query.Likelihood(i), delim,
                                            self.acc[i], delim,
                                            self.genename[i], delim)
            result += string.join(map(str,self.profile[i][:-1]),delim) # DOnt want final 9999.9
            result += "\n"

        return result

    def getSubset(self,ids):
	subset = ProfileSet()
	subset.virgin = False
    	for i in ids:
            subset.genename.append(self.genename[i])
            subset.acc.append(self.acc[i])
            subset.info.append(self.info[i])
            subset.outlink.append(self.outlink[i])
	    subset.addProfile("#%d" % i, self.profile[i], "info about %d" % i)
            subset.seq_classes.append(self.seq_classes[i])

        subset.classes_no = self.classes_no
        subset.max_class = self.max_class
	subset.classes = self.classes
	subset.fileName = self.fileName
        #subset.missingRate = self.missingRate
        subset.ghmm_seqs = self.ghmm_seqs.getSubset(ids)
        subset.missingDataRate()
	return subset

    def getStandardizedData(self,collumns=0):
        res = self.getSubset(range(len(self)));
        res.standardize(collumns=collumns)

    def standardize(self,collumns=0):
        profileAux = self.profile
        if collumns==0:
            print 'aqui'
            profileAux =  Numeric.transpose(Numeric.array(profileAux))
            print len(profileAux),len(profileAux[0])
        else:
            profileAux = (Numeric.array(profileAux))
        
        for i,p in enumerate(profileAux):
            p = Numeric.array(p)
            paux = p[Numeric.nonzero(p!=MISSING_DATA)]
            missing = Numeric.nonzero(p==MISSING_DATA)
            m = Numeric.mean(paux)
            s = Numeric.std(paux)
            pstand = (p-m)/s
            pstand[missing] = MISSING_DATA
            profileAux[i,:] = pstand
            #print pstand
            #print p, pstand
            
            

        if collumns==0:
            print 'aqui'
            profileAux =  Numeric.transpose(Numeric.array(profileAux))
            print len(profileAux),len(profileAux[0])
            
        #profileAux.tolist()
        self.profile = profileAux.tolist()


    
    def getScaledData(self,collumns=0):
        res = self.getSubset(range(len(self)));
        res.scale(collumns=collumns)

    '''
	def scale(self,collumns=0):
        profileAux = self.profile
        if collumns==0:
            profileAux =  Numeric.transpose(Numeric.array(profileAux))
        else:
            profileAux = (Numeric.array(profileAux))
            
        for i,p in enumerate(profileAux):
            p = Numeric.array(p)
            paux = p[Numeric.nonzero(p!=MISSING_DATA)]
            missing = Numeric.nonzero(p==MISSING_DATA)
            maxpoint = Numeric.max(paux)
            minpoint = Numeric.min(paux)
            pstand = (paux - minpoint)/(maxpoint-minpoint)
            pstand[missing] = MISSING_DATA
            #self.profile[i] = pstand.tolist()
            profileAux[i,:] = pstand
            #print self.profile[i]            
            #print p, pstand
            
        if collumns==0:
            profileAux =  Numeric.transpose(Numeric.array(profileAux))

            

        self.profile = profileAux.tolist()
		'''
        
    def scale(self,collumns=0, min=0, max=1):
        profileAux = self.profile
        if collumns==0:
            profileAux =  Numeric.transpose(Numeric.array(profileAux))
        else:
            profileAux = (Numeric.array(profileAux))
            
        for i,p in enumerate(profileAux):
            p = Numeric.array(p)
            paux = p[Numeric.nonzero(p!=MISSING_DATA)]
            missing = Numeric.nonzero(p==MISSING_DATA)
            maxpoint = Numeric.max(paux)
            minpoint = Numeric.min(paux)
            pstand = ((paux - minpoint)/(maxpoint-minpoint)) * (max - min) + min
            pstand[missing] = MISSING_DATA
            #self.profile[i] = pstand.tolist()
            profileAux[i,:] = pstand
            #print self.profile[i]            
            #print p, pstand
            
        if collumns==0:
            profileAux =  Numeric.transpose(Numeric.array(profileAux))

            

        self.profile = profileAux.tolist()    

    def fold(self):
        profileAux = self.profile        
        for i,p in enumerate(profileAux):
            p = Numeric.array(p)
            #paux = p[Numeric.nonzero(p!=MISSING_DATA)]
            missing = Numeric.nonzero(p==MISSING_DATA)
            end = Numeric.nonzero(p==END_DATA)
            p[missing] = [1.0]*sum(missing)
            pstand = Numeric.log(p)            
            pstand[missing] = MISSING_DATA
            pstand[end] = END_DATA
            #self.profile[i] = pstand.tolist()
            #print pstand, len(pstand)
            
            profileAux[i] = pstand
            #print self.profile[i]            
            #print p, pstand            
        self.profile = profileAux  

    def meanValues(self,dimension=1):
        profileAux = self.profile
        p = Numeric.array(profileAux)
        end = Numeric.nonzero(p==END_DATA)
        hasEnd = (sum(sum(end)) > 1)
        if hasEnd:
          hasEnd = 1
        else:
          hasEnd = 0
        size = len(p)
        totalDim = len(p[0])
        #print hasEnd,dimension,size
        means = []
        for j in range(dimension):
              sel = range(j,totalDim-hasEnd*dimension,dimension)
              #print sel
              psel = p[:,sel]
              #print len(psel), len(psel[0]), Numeric.nonzero(psel!=MISSING_DATA)
              mean= Numeric.mean(psel[Numeric.nonzero(psel!=MISSING_DATA)])
              #print 'aqui', Numeric.mean(psel[Numeric.nonzero(psel!=MISSING_DATA)])
              means.append(mean)
        #print 'means', len(means), means
        return means

    def fillMissing(self,dimension=1):
        profileAux = self.profile        
        for i,p in enumerate(profileAux):
            p = Numeric.array(p)
            #paux = p[Numeric.nonzero(p!=MISSING_DATA)]
            missing = Numeric.nonzero(p==MISSING_DATA)
            end = Numeric.nonzero(p==END_DATA)
            if len(end[0]) > 0:
              paux = p[:end[0][0]]
            else:
              paux = p
            size = len(paux)
            points = size/dimension
            pstand = p.copy()
            for j in range(dimension):
              sel = range(j,size,dimension)
              psel = paux[sel]
              mean=Numeric.mean(psel[Numeric.nonzero(psel!=MISSING_DATA)])
              psel[psel==MISSING_DATA] = mean
              pstand[sel] = psel
              print 'mean', mean, sel
            if dimension > 1:
              for j in range(points):
                sel = range(j*dimension,(j+1)*dimension)
                print sel, sum(p[sel]==MISSING_DATA)
                if (sum(p[sel]==MISSING_DATA) == dimension):
                    pstand[sel] = MISSING_DATA                         
            profileAux[i] = pstand

        self.profile = profileAux  
    
    def getIndexGenes(self,genes):
        ids = []
        for g in genes:
            try:
               ids.append(self.id2genename[g])
            except KeyError:
               print "Gene not found", g
        return ids

    def getIndexGene(self,gene):
        return self.id2genename[gene]





if __name__ == '__main__':
    import GQLApp
    app = GQLApp.GQLApp()
    app.mainloop()


