import mixture
import mixtureLinearGaussian
import TabDataSet
import numpy
import random
import copy
import getopt
import os
import sys
import math
import matplotlib
#matplotlib.use("Agg")
import pylab
import markup
import shutil

def runMixture(minCluster,maxCluster,data,dist,repetitions,iterations,stopCriteria,bayesian,folds):
	"""
	Function that run the mixture model for a range of clusters

	@param minCluster: cluster minimal number
	@param maxCluster: cluster maximum number
	@param data: DataSet object
	@param dist: ProductDistribution of one component
	@param repetitions: number of repetitions
	@param iterations: EM iterations
	@param stopCriteria: EM stop criteria
        @param bayesian: perform MAP estimation
        @param folds: folds for crossvalidation

	@return: models, NEC, BIC, AIC, classifies

        Example of run for th1 data with up to 3 models and 2 EM replications
        python mixtureTools.py -m 3 -i 2 th1_mod.txt
	"""
	models = []
	classifies = []
	correlations = []
	stds = []
	errors = []
	estds = []
	       
	for k in range(minCluster,maxCluster+1):
		print "#################### MIXTURE CLUSTER %d ####################" % k
		# mixture model for k components with distribution dist
		train = mixtureComponents(k, dist, bayesian)
		# finding the best model
		correlation = 0
		std = 0
                error = 0
                estd = 0
                if folds > 1:
  		    [correlation,std, error, estd] = crossvalidationregression(train,data,repetitions,iterations,stopCriteria,bayesian,folds)
		bestmix = estimateWithReplication(train,data,repetitions,iterations,stopCriteria,bayesian)
		

		models.append(bestmix)
		# classifying data
		classify = bestmix.classify(data,silent=1)
		classifies.append(classify)
		correlations.append(correlation)
		stds.append(std)
		errors.append(error)
		estds.append(estd)


	
	NEC,BIC,AIC = mixture.modelSelection(data,models)
	return models, NEC, BIC, AIC, classifies, correlations, stds, errors, estds


def cv_samples(folds,n):
    foldSize = math.ceil(n/folds)
    rest = n%folds
    indices =range(n)
    random.shuffle(indices)
    results = []
    for i in range(folds):
      aux = []
      if i < rest:
        for j in range(foldSize+1):
           aux.append(indices.pop())
      else:
        for j in range(foldSize):
           aux.append(indices.pop())
      aux.sort()
      results.append(aux)
    return results
	
	    

def crossvalidationregression(mix,data,repetitions,iterations,stopCriteria,bayesian,folds):
    """
    The function performs the N-fold cross validation
   
    @param mixture: training mixture
    

    """

    cv = cv_samples(folds,data.N)
    yfinal = numpy.zeros((data.N,1))
    all = []
    error = []
    for sample in cv:

        traindata = copy.copy(data)
	sampleAux = []
	for s in sample:
	  sampleAux.append(data.sampleIDs[s])
        traindata.removeSamples(sampleAux)
        
        bextmix = None
        previous_posterior = None

	trainmix = copy.copy(mix)

        traindata.internalInit(trainmix)
        bestmix = estimateWithReplication(trainmix,traindata,repetitions,iterations,stopCriteria,bayesian)
        
        test = copy.copy(data)
        test.removeSamples(traindata.sampleIDs)
	test.internalInit(bestmix)	

        [r,p,ypred,y,m,s,e,genes]=mixtureLinearGaussian.evaluateRegression(bestmix,test)

	print "test", r,e, len(data), len(traindata), len(test)
	#print traindata.sampleIDs
	#print test.sampleIDs
	#print "correlation", r,p
	all.append(r)
	error.append(e)
    return [numpy.mean(all),numpy.std(all), numpy.mean(error), numpy.std(error)]



def mixtureComponents(n,dist,bayesian):
	"""
	Function that return a MixtureModel with n components of dist distribution
	
	@param n: number of components
	@param dist: ProductDistribution of one component
	
	@return: MixtureModel with n components of dist distribution
	"""

	if bayesian:
  	  pipr = mixture.DirichletPrior(n,[4.0]*n)
          sp1 =mixtureLinearGaussian.LinearGaussianPriorDistribution([0.5]*n)
          prior = mixture.MixtureModelPrior(1.0,1.0, pipr,[sp1])
	  pi = [1. / n] * n
  	  train = mixture.BayesMixtureModel(n,pi,[dist] * n,prior,identifiable=0)
	else:
	  pi = [1. / n] * n
  	  train = mixture.MixtureModel(n,pi,[dist] * n,identifiable=0)		
	return train

def estimateWithReplication(_mixture, data, repetitions, iterations, stopCriteria,bayesian):
	"""
	Function replicating em estimation and returning maximum likelihood replicate
	
	@param _mixture: MixtureModel object
	@param data: DataSet object
	@param repetitions: number of repetitions
	@param iterations: EM iterations
	@param stopCriteria: EM stop criteria

	@return: mixture best
	"""
	mixtureBest = []
	max = -float('inf');
	for j in range(repetitions):
		try:
			maux = copy.copy(_mixture)
			maux.modelInitialization(data)
			if bayesian:
  			  maux.mapEM(data,iterations,stopCriteria,silent=0)
			else:
			  maux.EM(data,iterations,stopCriteria,silent=0)
			[l,log_l] = maux.EStep(data)
			if log_l > max:
				max = log_l
				mixtureBest = maux
		except mixture.ConvergenceFailureEM:
			pass
		except mixtureLinearGaussian.EmptyComponent:
			print "Empty Component"
	if mixtureBest == []:
		raise mixture.ConvergenceFailureEM,"Convergence failed."
	return mixtureBest

def usage():
	"""
	Function that print usages informations for this program.
	"""
	print """
Usage: mixtureTools.py [options] <input_data_file.txt>
Options:
	-h, --help			Print this help message
	-m, --max-cluster		Cluster maximum number
	-r, --repetitions		Number of repetitions (default 15)
	-i, --iterations		EM iterations (default 10)
	-s, --stop-criteria		EM stop criteria (default 0.1)
	-b, --bayesian   		Bayesian estimates
        -f, --folds                     Folds for cross-validation (defaut 0)
	-o, --output			Path dir output (default res)
"""

def main():
	# parse command line options
	try:
		opts, args = getopt.getopt(sys.argv[1:], 'h:m:r:i:s:f:o:b', ['help', 'max-cluster', 'repetitions', 'iterations', 'stop-criteria', 'folds', 'output', 'bayesian'])
	except getopt.error, err:		
		usage()
		sys.exit(2)

	filename = None
	min_cluster = 1
	max_cluster = None
	repetitions = 15
	iterations = 10
	stop_criteria = 0.1
	bayesian = 0 
	data = None
        folds = 1
	output_path = 'res'

	# process options
	for o, a in opts:
		if o in ('-h', '--help'):
			usage()
			sys.exit(0)
		elif o in ('-m', '--max-cluster'):
			max_cluster = int(a)
		elif o in ('-r', '--repetitions'):
			repetitions = int(a)
		elif o in ('-i', '--iterations'):
			iterations = int(a)
		elif o in ('-s', '--stop-criteria'):
			stop_criteria = float(a)
		elif o in ('-f', '--folds'):
			folds = int(a)
		elif o in ('-o', '--output'):
			output_path = str(a)
		elif o in ('-b', '--bayesian'):
			bayesian = 1
		else:
			assert False, "unhandled option"

	if (max_cluster == None):
		usage()
		sys.exit(2)
	
	filename = sys.argv[-1]
	if (os.path.exists(filename)):
		data = TabDataSet.TabDataSet()
		data.fromFile(filename)
	else:
		print "Input data file does not exist."
		sys.exit(2)

	try:
		os.stat(output_path)
	except:
		os.mkdir(output_path)
	try:
		os.stat(output_path + '/pdf')
	except:
		os.mkdir(output_path + '/pdf')
	try:
		os.stat(output_path + '/genes')
	except:
		os.mkdir(output_path + '/genes')
	try:
		os.stat(output_path + '/m')
	except:
		os.mkdir(output_path + '/m')
	try:
		os.stat(output_path + '/png')
	except:
		os.mkdir(output_path + '/png')

	# creating a component
	p = data.p
	sigma = [1]
	beta = []
	for i in range(data.p):	
		beta.append(random.normalvariate(0,1))		
	dist = mixture.ProductDistribution([mixtureLinearGaussian.LinearGaussianDistribution(p, beta, sigma)])

	# run the main funcition
	models, NEC, BIC, AIC, classifies, correlations, stds, errors, estds = runMixture(min_cluster, max_cluster, data, dist, repetitions, iterations, stop_criteria, bayesian, folds)
	
	clusters = range(1, max_cluster+1)
	bestmodel = BIC.index(min(BIC))

	#NEC = numpy.array(NEC)
	BIC = numpy.array(BIC)
	AIC = numpy.array(AIC)
	
	#pylab.plot(NEC, label="NEC")
	pylab.semilogy(clusters, BIC, label="BIC")
	pylab.semilogy(clusters, AIC, label="AIC")
	pylab.xlabel('clusters')
	pylab.legend()
	pylab.savefig(output_path + "/pdf/" + filename[:-4] + "_" + str(bayesian) + "_BIC_AIC.pdf")
	pylab.savefig(output_path + "/png/" + filename[:-4] + "_" + str(bayesian) + "_BIC_AIC.png")
	pylab.hold(False)
	
	if bestmodel == 0:
		data.setNotes([0] * data.N)
	else: 
		data.setNotes(classifies[bestmodel-1])
	data.writeFile(output_path + "/" + filename[:-4] + "_"+str(bayesian)+"_classifies_" + str(clusters[bestmodel]) + ".res")

	headers = data.getHeaders()
	styles = ( 'layout.css', 'table.css' )

	for i,m in enumerate(models):
	  print m	  
	  [r,p,predy,y,means,stdsaux,error,genes] = mixtureLinearGaussian.evaluateRegression(m,data)

	  #_html_
	  page_cluster = markup.page()
	  page_cluster.init(title="Results " + filename[:-4] + '_' + str(i+1) + "_" + str(bayesian) + '_cluster', css=styles)
	  page_cluster.h1("Results " + filename[:-4] + '_' + str(i+1) + "_" + str(bayesian) + '_cluster')
	  #_res_
	  res = open(output_path + "/" + filename[:-4] + '_' + str(i+1) +"_"+str(bayesian)+ '_cluster' + '.res', 'w')
	  res.write('#comp\tb0')

	  if folds == 1:
	    correlations[i] = r
            stds[i] = 0.0
            errors[i] = error
	    estds[i] = 0.0
	    
	  #_html_
	  page_cluster.table.open(border=1)
	  page_cluster.thead.open()
	  page_cluster.th('Results')
	  page_cluster.thead.close()
	  page_cluster.tbody.open()
	  page_cluster.tr.open()
	  page_cluster.td("#comp")
	  page_cluster.td("b0")

	  for h in range(len(headers) - 1):
		#_html_
		page_cluster.td(str(headers[h+1]))
		#_res_
	  	res.write('\t' + str(headers[h+1]))
	  	
	  #_html_
	  page_cluster.td("sigma")
	  page_cluster.tr.close()
	  #_res_
	  res.write('\tSigma\n')

	  file = open(output_path + "/m/" + filename[:-4] + "_"+str(i+1)+"_"+str(bayesian)+".m",'w')
	  for j,c in enumerate(m.components):
		  beta = c.distList[0].beta.tolist()
		  file.write("x(:,"+str(j+1)+")="+str(beta)+'\n')
		  
		  #_html_
		  page_cluster.tr.open()
		  page_cluster.td(str(j+1))
		  #_res_
		  res.write(str(j+1))

		  for k in range(len(beta)):
		      #_html_
		      page_cluster.td(str(beta[k]))
		      #_res
		      res.write('\t' + str(beta[k]))

		  #_html_
		  page_cluster.td(str(c.distList[0].sigma[0]))
	  	  page_cluster.tr.close()
		  #_res_
		  res.write('\t' + str(c.distList[0].sigma[0]))
		  res.write('\n')

	  #_html_
	  page_cluster.tbody.close()
	  page_cluster.table.close()
	  page_cluster.img(src="png/" + filename[:-4] + "_"+str(i+1)+"_"+str(bayesian)+"_predicted.png")
	  if i != 0:
	  	page_cluster.img(src="tif/" + filename[:-4] + "_tf_"+str(i+1)+"_"+str(bayesian)+"_cluster.tif")
	  page_cluster.h3("Genes")
	  for j,gs in enumerate(genes):
	  	page_cluster.a("Cluster " + str(j+1), href="gene_cluster_" + str(i+1) + "_" + str(j+1) + ".html")
	  #_res_
	  res.write('\n')
	  res.close()
	  file.close()

	  #_html_file_
	  page_cluster_f = open(output_path + "/result_cluster_" + str(i+1) + ".html", "w")
	  page_cluster_f.write(str(page_cluster))
	  page_cluster_f.close()

 	  file = open(output_path + "/genes/" + filename[:-4] + "_"+str(i+1)+"_"+str(bayesian)+".genes",'w')
	  for j, gs in enumerate(genes):
 		  file.write(' '.join(gs)+'\n')
		  #_html_file2
		  page_cluster = markup.page()
		  page_cluster.init(title="Genes " + filename[:-4] + '_' + str(i+1) + "_" + str(bayesian) + '_cluster_' + str(j+1), css=styles)
		  page_cluster.h1("Genes " + filename[:-4] + '_' + str(i+1) + "_" + str(bayesian) + '_cluster_' + str(j+1))

		  #_html_
		  page_cluster.table.open(border=1)
		  page_cluster.thead.open()
		  page_cluster.th('Genes')
		  page_cluster.thead.close()
		  page_cluster.tbody.open()
		  		  
		  for g in gs:
			page_cluster.tr.open()
			page_cluster.td.open()	  
			page_cluster.a(str(g), href="http://www.ncbi.nlm.nih.gov/entrez/query.fcgi?db=gene&dopt=summary&term=mouse[organism]+AND+"+str(g))
 		  	page_cluster.td.close()
 		  	page_cluster.tr.close()

		  #_html_
		  page_cluster.tbody.close()
		  page_cluster.table.close()

		  #_html_file_
		  page_cluster_f = open(output_path + "/gene_cluster_" + str(i+1) + "_" + str(j+1) + ".html", "w")
		  page_cluster_f.write(str(page_cluster))
		  page_cluster_f.close()
 	  file.close()

	  print "regression", r, p
	  print "means", means
	  print "std", stdsaux
	  print "error", error

	  pylab.figure()
	  pylab.plot(y,predy,'.')
	  pylab.xlabel('predicted expression')
	  pylab.ylabel('expression')
	  pylab.title(filename[:-4] + " #Comps: "+str(i+1)+" "+str(bayesian)+" - corr "+str(r))
	  pylab.savefig(output_path + "/pdf/" + filename[:-4] + "_"+str(i+1)+"_"+str(bayesian)+"_predicted.pdf")
	  pylab.savefig(output_path + "/png/" + filename[:-4] + "_"+str(i+1)+"_"+str(bayesian)+"_predicted.png")
	  

        print "\n"
	print
	print "##################### Correlations ######################\n", correlations, "\n", stds, "\n", errors, "\n", estds	
	print "##################### NEC ######################\n", NEC, "\n"
	print "##################### BIC ######################\n", BIC, "\n"
	print "##################### AIC ######################\n", AIC, "\n"

        res = open(output_path + "/" + filename[:-4] +"_"+str(bayesian) + '.res', 'w')
        res.write('Correlations')
        
        for i in range(len(correlations)):
            res.write('\t' + str(correlations[i]))
	res.write('\n')
        
        res.write('Stds')
        for i in range(len(stds)):
            res.write('\t' + str(stds[i]))
	res.write('\n')
	
        res.write('Errors')
        for i in range(len(errors)):
            res.write('\t' + str(errors[i]))
	res.write('\n')

        res.write('Estds')
        for i in range(len(estds)):
            res.write('\t' + str(estds[i]))
	res.write('\n')

        res.write('NEC')
        for i in range(len(NEC)):
            res.write('\t' + str(NEC[i]))
	res.write('\n')

        res.write('BIC')
        for i in range(len(BIC)):
            res.write('\t' + str(BIC[i]))
	res.write('\n')

        res.write('AIC')
        for i in range(len(AIC)):
            res.write('\t' + str(AIC[i]))
	res.write('\n')

	res.close()
	
	page = markup.page()
	page.init(title="Results " + filename[:-4] + "_" + str(bayesian), css=styles)
	page.h1("Results " + filename[:-4] + "_" + str(bayesian))
	
	page.img(src="png/" + filename[:-4] + "_" + str(bayesian) + "_BIC_AIC.png")
	page.h3("Clusters Results")

	for i in clusters:
		page.a("Cluster " + str(i), href="result_cluster_" + str(i) + ".html")

	page.h2("Arguments")

	page.table.open(border=1)
	page.tbody.open()
	
	page.tr.open()
	page.td("max cluster")
	page.td(str(max_cluster), width="50", align="right")
	page.tr.close()
	
	page.tr.open()
	page.td("repetitions")
	page.td(str(repetitions), width="50", align="right")
	page.tr.close()

	page.tr.open()
	page.td("iterations")
	page.td(str(iterations), width="50", align="right")
	page.tr.close()

	page.tr.open()
	page.td("stop criteria")
	page.td(str(stop_criteria), width="50", align="right")
	page.tr.close()

	page.tr.open()
	page.td("bayesian")
	page.td(str(bayesian), width="50", align="right")
	page.tr.close()

	page.tr.open()
	page.td("folds")
	page.td(str(folds), width="50", align="right")
	page.tr.close()

	page.tr.open()
	page.td("input data file")
	page.td(filename, width="50", align="right")
	page.tr.close()

	page.tbody.close()
	page.table.close()	

	
	index_file = open(output_path + "/index.html", "w")
	index_file.write(str(page))
	index_file.close()

	shutil.copy("layout.css", output_path + "/layout.css")
	shutil.copy("table.css", output_path + "/table.css")

if __name__ == "__main__":
	main()

