from ghmm import *
import re

################### Arquivos de Entrada e de Saida ###############################################
arquivos = ['contig3.fasta','contig1.fasta','contig2.fasta','contig12.fasta'] # Entradas
threshold = 3000

# Saidas para o experimento com treinamento baseado no arquivo contig1
saidas_1 = ['viterbi_1.txt','viterbi_bw_1.txt','viterbi_pp_1.txt']
saida_posterior_1 = ['posterior_viterbi_1.txt','posterior_viterbi_bw_1.txt']
saida_erro_1 = ['viterbi_erro_1.txt','viterbi_bw_erro_1.txt','viterbi_pp_erro_1.txt']

# Saidas para o experimento com treinamento baseado no arquivo contig2
saidas_2 = ['viterbi_2.txt','viterbi_bw_2.txt','viterbi_pp_2.txt']
saida_posterior_2 = ['posterior_viterbi_2.txt','posterior_viterbi_bw_2.txt']
saida_erro_2 = ['viterbi_erro_2.txt','viterbi_bw_erro_2.txt','viterbi_pp_erro_2.txt']

# Saidas para o experimento com treinamento baseado no arquivo contig12 (contig 1 + contig2)
saidas_3 = ['viterbi_3.txt','viterbi_bw_3.txt','viterbi_pp_3.txt']
saida_posterior_3 = ['posterior_viterbi_3.txt','posterior_viterbi_bw_3.txt']
saida_erro_3 = ['viterbi_erro_3.txt','viterbi_bw_erro_3.txt','viterbi_pp_erro_3.txt']

# Saidas para o experimento com treinamento baseado na criacao de um HMM aleatorio
saidas_4 = ['viterbi_bw_4.txt','viterbi_pp_4.txt']
saida_posterior_4 = ['posterior_viterbi_bw_4.txt']
saida_erro_4 = ['viterbi_bw_erro_4.txt','viterbi_pp_erro_4.txt']

# Colocando todas as saidas numa lista para iteracao
saidas = [saidas_1,saidas_2,saidas_3,saidas_4]
saidas_posterior = [saida_posterior_1,saida_posterior_2,saida_posterior_3,saida_posterior_4]
saidas_erro = [saida_erro_1,saida_erro_2,saida_erro_3,saida_erro_4]
#################################################################################################




################### Lendo os arquivos ###########################################################
# Os arquivos contig1, contig2 e contig12 (que representa contig1 + contig2) sao lidos e as strings
# correspondentes ao seu texto sao colocados no vetor treinamento. O arquivo contig3 eh lido e armazenado
# na variavel teste. A variavel teste_vector entao eh criada contendo um vetor com os caracteres da
# string teste.

treinamento = []
teste = ""

for i in range(len(arquivos)):

	infile = open(arquivos[i])

	line = infile.readline()

	if not line.startswith(">"):
	    raise TypeError("Not a FASTA file: %r" % line)

	title = line

	sequence_lines = []

	while 1:
	    line = infile.readline().rstrip()
	    if line == "":
	        break
	    sequence_lines.append(line)

	sequence = "".join(sequence_lines)
	
	if i == 0:
	    teste = sequence

	else:
	    treinamento.append(sequence)


teste_lower = []
treinamento_lower = []

teste_lower.append(teste.upper())

for i in range(len(treinamento)):
    treinamento_lower.append(treinamento[i].upper())

"""
for elementindex in range(len(teste)):
    teste_vector.append(teste[elementindex].lower())

for i in range(len(treinamento)):
    vector_trein = []
    for elementindex in range(len(treinamento[i])):
        vector_trein.append(treinamento[i][elementindex].lower())
    treinamento_vector.append(vector_trein)
"""

#################################################################################################

#print 'teste = ', teste, '\n'
#print 'teste_lower = ', teste_lower, '\n'
#print 'treino1 = ', treinamento[0], '\n'
#print 'treino1_lower = ', treinamento_lower[0], '\n'
#print 'treino2_lower = ', treinamento_lower[1], '\n'
#print 'treino12_lower = ', treinamento_lower[2], '\n'
#print '\n'
#print '\n'


################### Calculando as frequencias ###################################################
# As frequencias sao calculadas de acordo com o algoritmo demonstrado no relatorio. Guardamos todas as
# frequencias numa variavel chamada matriz_frequencias, onde as linhas representam os arquivos contendo
# os dados daquela frequencia (contig1, contig2 ou contig12). E as colunas representam a frequencia
# das emissoes e das transicoes.

matriz_frequencias = []

for i in range(len(treinamento)):

	sequence = treinamento[i]

        # Calculando as frequencias de Emissao
	a_uppers = re.findall("[A]", sequence)
	count_A = len(a_uppers)
	c_uppers = re.findall("[C]", sequence)
	count_C = len(c_uppers)
	g_uppers = re.findall("[G]", sequence)
	count_G = len(g_uppers)
	t_uppers = re.findall("[T]", sequence)
	count_T = len(t_uppers)
	all_uppers = re.findall("[A-Z]", sequence)
	count_all_uppers = len(all_uppers)

	a_lowers = re.findall("[a]", sequence)
	count_a = len(a_lowers)
	t_lowers = re.findall("[t]", sequence)
	count_t = len(t_lowers)
	c_lowers = re.findall("[c]", sequence)
	count_c = len(c_lowers)
	g_lowers = re.findall("[g]", sequence)
	count_g = len(g_lowers)
	all_lowers = re.findall("[a-z]", sequence)
	count_all_lowers = len(all_lowers)

	freq_a = float(count_a) / (count_all_lowers)
	freq_c = float(count_c) / (count_all_lowers)
	freq_g = float(count_g) / (count_all_lowers)
	freq_t = float(count_t) / (count_all_lowers)

	freq_A = float(count_A) / (count_all_uppers)
	freq_C = float(count_C) / (count_all_uppers)
	freq_G = float(count_G) / (count_all_uppers)
	freq_T = float(count_T) / (count_all_uppers)


        # Calculando as frequencias de transicao
	counter_same_lower = 0
	counter_same_upper = 0
	counter_changed_lu = 0
	counter_changed_ul = 0

	for elementindex in range(len(sequence)-1):
	        if sequence[elementindex].isalpha() and sequence[elementindex].islower() and sequence[elementindex+1].isalpha() and sequence[elementindex+1].islower():
	                counter_same_lower += 1

	        if sequence[elementindex].isalpha() and sequence[elementindex].isupper() and sequence[elementindex+1].isalpha() and sequence[elementindex+1].isupper():
	                counter_same_upper += 1

	        if sequence[elementindex].isalpha() and sequence[elementindex].islower() and sequence[elementindex+1].isalpha() and sequence[elementindex+1].isupper():
	                counter_changed_lu += 1

	        if sequence[elementindex].isalpha() and sequence[elementindex].isupper() and sequence[elementindex+1].isalpha() and sequence[elementindex+1].islower():
	                counter_changed_ul += 1

	freq_aa = float(counter_same_lower)/(counter_same_lower + counter_changed_lu)
	freq_AA = float(counter_same_upper)/(counter_same_upper + counter_changed_ul)
	freq_aA = float(counter_changed_lu)/(counter_changed_lu + counter_same_lower)
	freq_Aa = float(counter_changed_ul)/(counter_changed_ul + counter_same_upper)

	vetor_frequencias = [freq_a, freq_c, freq_g, freq_t, freq_A, freq_C, freq_G, freq_T, freq_aa, freq_aA, freq_Aa, freq_AA]

        matriz_frequencias.append(vetor_frequencias)

######################################################################################################

#print matriz_frequencias
#print '\n'
#print '\n'



################### Criando as hmm, realizando o Viterbi e o Baum-Welch #############################

for i in range(len(treinamento)):

        print "++++++++++++++++++++++++ Saida no."+str(i)+' ++++++++++++++++++++++++'+'\n'

        ################## Criando o HMM #############################################
	sigma = IntegerRange(1,5) # Representa a quantidade de emissoes

	A = [[matriz_frequencias[i][8],matriz_frequencias[i][9]],[matriz_frequencias[i][10],matriz_frequencias[i][11]]]
	#A=[[0.2,0.8],[0.1,0.9]] # Representa a matriz de transicao
	print "Transicoes: "
	print str(matriz_frequencias[i][8])+' '+str(matriz_frequencias[i][9])
	print str(matriz_frequencias[i][10])+' '+str(matriz_frequencias[i][11])+'\n'

        naoilha = [matriz_frequencias[i][0],matriz_frequencias[i][1],matriz_frequencias[i][2],matriz_frequencias[i][3]] # Emissoes da nao-ilha

	ilha = [matriz_frequencias[i][4],matriz_frequencias[i][5],matriz_frequencias[i][6],matriz_frequencias[i][7]] # Emissoes da ilha

	B = [naoilha, ilha] # Matriz contendo as emissoes

	pi = [0.99,0.01] # Representa as emissoes iniciais para ilha e nao ilha

        # Criando o HMM e imprimindo no log.txt
	m = HMMFromMatrices(sigma, DiscreteDistribution(sigma), A, B, pi)
	print "HMM inicial: "
        print m
        #############################################################################


        ################## Criando o SequenceSet e rodando o Viterbi ################
	alfabeto = Alphabet(['A','C','G','T']) # Representa o alfabeto que estamos utilizando

	sequence_set_teste = SequenceSet(alfabeto,teste_lower) # Sequence set utilizado pelos algoritmos

	ss_treino = []
	ss_treino.append(treinamento_lower[i])
	sequence_set_treino = SequenceSet(alfabeto,treinamento_lower)

	vi = m.viterbi(sequence_set_teste) # Rodando o viterbi e guardando-o na variavel vi

        # Escrevendo o resultado do viterbi no arq_vi
	arq_vi = open(saidas[i][0],'w')
	for j in range(len(vi[0])-1):
	        character_1 = str(vi[0][j])
	        if j%65==0 and j!=0:
	                arq_vi.write('\n')
	        arq_vi.write(character_1)
	arq_vi.close()


        # Calculando a matriz de confusao e o erro associado e imprimindo-os no arquivo
	tp = 0
	fp = 0
	tn = 0
	fn = 0
	for j in range(len(vi[0])-1):
		vit_char = vi[0][j]
		contig_char = teste[j]
		if vit_char==1 and contig_char.islower():
			fp=fp+1
		if vit_char==1 and contig_char.isupper():
			tp=tp+1
		if vit_char==0 and contig_char.islower():
			tn=tn+1
		if vit_char==0 and contig_char.isupper():
			fn=fn+1


	arq_erro = open(saidas_erro[i][0],'w')
        arq_erro.write("Taxa de Erro: "+str(float(fn+fp)/float(fp+tp+fn+tn))+"\n")
        arq_erro.write("Taxa de Acerto de ilhas: "+str(float(tp)/float(tp+fn))+"\n")
        arq_erro.write("FP: "+str(fp)+" TP: "+str(tp)+" FN: "+str(fn)+" TN: "+str(tn))
        arq_erro.close()


        # Realizando o calculo da probabilidade posterior e guardando em arq_vi_post (enorme)
        post = m.posterior(sequence_set_teste[0])
        arq_vi_post = open(saidas_posterior[i][0],'w')
        for j in range(100000):
            arq_vi_post.write(str(post[j][0])+'\n')
        arq_vi_post.close()

        #############################################################################



        ################## Rodando o Baum-Welch #####################################

	# Rodando o Baum-Welch e imprimindo no arquivo log.txt
        m.baumWelch(sequence_set_treino)
        print "HMM apos o Baum-Welch: "
        print "Transicoes: "
	print str(m.getTransition(0,0))+' '+str(m.getTransition(0,1))
	print str(m.getTransition(1,0))+' '+str(m.getTransition(1,1))+'\n'
        print m

        # Rodando o viterbi no novo HMM
	vi_bw = m.viterbi(sequence_set_teste)

        # Imprimindo o novo viterbi no arquivo arq_vi_bw
	arq_vi_bw = open(saidas[i][1],'w')
	for j in range(len(vi_bw[0])-1):
	        character_2 = str(vi_bw[0][j])
	        if j%65==0 and j!=0:
	                arq_vi_bw.write('\n')
	        arq_vi_bw.write(character_2)
	arq_vi_bw.close()


        # Calculando a matriz de confusao e o erro associado ao novo viterbi
	tp = 0
	fp = 0
	tn = 0
	fn = 0
	for j in range(len(vi_bw[0])-1):
		vit_char = vi_bw[0][j]
		contig_char = teste[j]
		if vit_char==1 and contig_char.islower():
			fp=fp+1
		if vit_char==1 and contig_char.isupper():
			tp=tp+1
		if vit_char==0 and contig_char.islower():
			tn=tn+1
		if vit_char==0 and contig_char.isupper():
			fn=fn+1

	arq_erro = open(saidas_erro[i][1],'w')
        arq_erro.write("Taxa de Erro: "+str(float(fn+fp)/float(fp+tp+fn+tn))+"\n")
        arq_erro.write("Taxa de Acerto de ilhas: "+str(float(tp)/float(tp+fn))+"\n")
        arq_erro.write("FP: "+str(fp)+" TP: "+str(tp)+" FN: "+str(fn)+" TN: "+str(tn)+"\n\n")
        arq_erro.close()

        # Realizando o calculo da probabilidade posterior e guardando em arq_vi_post (enorme)
        post = m.posterior(sequence_set_teste[0])
        arq_vi_post = open(saidas_posterior[i][1],'w')
        for j in range(100000):
            arq_vi_post.write(str(post[j][0])+'\n')
        arq_vi_post.close()


        ################ Realizando o pos-processamento #################################
        count_i = 0
        while count_i < len(vi_bw[0])-1:
              symbol = vi_bw[0][count_i]
              if symbol==1:
                 counter = 0
                 while symbol==1 and (count_i+counter) < len(vi_bw[0])-1:
                       counter = counter + 1
                       symbol = vi_bw[0][count_i+counter]
                 counter2 = counter-1
                 if counter <= threshold:
                    while counter2 >= 0:
                          vi_bw[0][count_i+counter2] = 0
                          counter2 = counter2 - 1
                 count_i+=counter
              else: count_i+=1
    
        # Imprimindo o viterbi apos o pos-processamento
        arq_vi_bw = open(saidas[i][2],'w')
        for j in range(len(vi_bw[0])-1):
	    character_2 = str(vi_bw[0][j])
            if j%65==0 and j!=0:
	       arq_vi_bw.write('\n')
            arq_vi_bw.write(character_2)
        arq_vi_bw.close()


        # Calculando a matriz de confusao e o erro e imprimindo-os, apos o pos-processamento
        tp = 0
        fp = 0
        tn = 0
        fn = 0
        for j in range(len(vi_bw[0])):
	    vit_char = vi_bw[0][j]
	    contig_char = teste[j]
	    #print str(vit_char)+' '+contig_char+'\n'
	    if vit_char==1 and contig_char.islower():
		fp=fp+1
	    if vit_char==1 and contig_char.isupper():
		tp=tp+1
	    if vit_char==0 and contig_char.islower():
		tn=tn+1
	    if vit_char==0 and contig_char.isupper():
		fn=fn+1

        arq_erro = open(saidas_erro[i][2],'w')
        arq_erro.write("Taxa de Erro: "+str(float(fn+fp)/float(fp+tp+fn+tn))+"\n")
        arq_erro.write("Taxa de Acerto de ilhas: "+str(float(tp)/float(tp+fn))+"\n")
        arq_erro.write("FP: "+str(fp)+" TP: "+str(tp)+" FN: "+str(fn)+" TN: "+str(tn))
        arq_erro.close()

        #############################################################################

## FIM DO LOOP QUE REPRESENTA OS EXPERIMENTOS COM HMM NAO ALEATORIA




################## Experimento com hmm aleatoria ####################################

print "++++++++++++++++++++ Experimento aleatorio: ++++++++++++++++++++++++\n"

# Criando o hmm aleatorio
sigma = IntegerRange(1,5)
A = [[0.9,0.1],[0.2,0.8]]
um_quarto = [0.25,0.25,0.25,0.25]
B = [um_quarto, um_quarto]
pi = [0.9,0.1]
m = HMMFromMatrices(sigma, DiscreteDistribution(sigma), A, B, pi)

print "HMM aleatorio: "
print m

# Criando o Sequence Set
alfabeto = Alphabet(['A','C','G','T'])
sequence_set_teste = SequenceSet(alfabeto,teste_lower)
sequence_set_treino = SequenceSet(alfabeto,treinamento_lower)

# Rodando o Baum-Welch e imprimindo o novo hmm
m.baumWelch(sequence_set_treino)
print "HMM apos o Baum-Welch: "
print "Transicoes: "
print str(m.getTransition(0,0))+' '+str(m.getTransition(0,1))
print str(m.getTransition(1,0))+' '+str(m.getTransition(1,1))+'\n'
print m

# Calculando o viterbi
vi_bw = m.viterbi(sequence_set_teste)

# Imprimindo o viterbi
arq_vi_bw = open(saidas[3][0],'w')
for j in range(len(vi_bw[0])-1):
	character_2 = str(vi_bw[0][j])
	if j%65==0 and j!=0:
		arq_vi_bw.write('\n')
	arq_vi_bw.write(character_2)
arq_vi_bw.close()


# Calculando a matriz de confusao e o erro e imprimindo-os no arquivo
tp = 0
fp = 0
tn = 0
fn = 0
for j in range(len(vi_bw[0])):
	vit_char = vi_bw[0][j]
	contig_char = teste[j]
	#print str(vit_char)+' '+contig_char+'\n'
	if vit_char==1 and contig_char.islower():
		fp=fp+1
	if vit_char==1 and contig_char.isupper():
		tp=tp+1
	if vit_char==0 and contig_char.islower():
		tn=tn+1
	if vit_char==0 and contig_char.isupper():
		fn=fn+1

arq_erro = open(saidas_erro[3][0],'w')
arq_erro.write("Taxa de Erro: "+str(float(fn+fp)/float(fp+tp+fn+tn))+"\n")
arq_erro.write("Taxa de Acerto de ilhas: "+str(float(tp)/float(tp+fn))+"\n")
arq_erro.write("FP: "+str(fp)+" TP: "+str(tp)+" FN: "+str(fn)+" TN: "+str(tn))
arq_erro.close()


# Realizando o calculo da probabilidade posterior e guardando em arq_vi_post (enorme)
#post = m.posterior(sequence_set_teste[0])
#arq_vi_post = open(saidas_posterior[3][0],'w')
#for i in range(100000):
#        arq_vi_post.write(str(post[i][0])+'\n')
#arq_vi_post.close()


# Realizando o pos-processamento
count_i = 0
while count_i < len(vi_bw[0])-1:
    symbol = vi_bw[0][count_i]
    if symbol==1:
        counter = 0
        while symbol==1 and (count_i+counter) < len(vi_bw[0])-1:
            counter = counter + 1
            symbol = vi_bw[0][count_i+counter]
        counter2 = counter-1
        if counter <= threshold:
            while counter2 >= 0:
                  vi_bw[0][count_i+counter2] = 0
                  counter2 = counter2 - 1
        count_i+=counter
    else: count_i+=1
    
# Imprimindo o viterbi apos o pos-processamento
arq_vi_bw = open(saidas[3][1],'w')
for j in range(len(vi_bw[0])-1):
	character_2 = str(vi_bw[0][j])
	if j%65==0 and j!=0:
		arq_vi_bw.write('\n')
	arq_vi_bw.write(character_2)
arq_vi_bw.close()


# Calculando a matriz de confusao e o erro e imprimindo-os, apos o pos-processamento
tp = 0
fp = 0
tn = 0
fn = 0
for j in range(len(vi_bw[0])):
	vit_char = vi_bw[0][j]
	contig_char = teste[j]
	#print str(vit_char)+' '+contig_char+'\n'
	if vit_char==1 and contig_char.islower():
		fp=fp+1
	if vit_char==1 and contig_char.isupper():
		tp=tp+1
	if vit_char==0 and contig_char.islower():
		tn=tn+1
	if vit_char==0 and contig_char.isupper():
		fn=fn+1

arq_erro = open(saidas_erro[3][1],'w')
arq_erro.write("Taxa de Erro: "+str(float(fn+fp)/float(fp+tp+fn+tn))+"\n")
arq_erro.write("Taxa de Acerto de ilhas: "+str(float(tp)/float(tp+fn))+"\n")
arq_erro.write("FP: "+str(fp)+" TP: "+str(tp)+" FN: "+str(fn)+" TN: "+str(tn))
arq_erro.close()

#####################################################################################################
#####################################################################################################

