package dr.app.tools;

import java.io.FileReader;

import dr.evolution.alignment.Alignment;
import dr.evolution.io.NexusImporter;

public class TestPairwiseDistance {
    public static void main(String[] args) throws Exception {
	    NexusImporter importer = new NexusImporter(new FileReader("C:\\Temp\\CRF02_AG.nex"));
        Alignment alignment = importer.importAlignment();
        
        double aFrequency = 0;
        double cFrequency = 0;
        double gFrequency = 0;
        double tFrequency = 0;
        //transitions
        double p = 0;
        // A-G transitions
        double p1 = 0;
        // C-T transitions
        double p2 = 0;
        //transversions
        double q = 0;
        
        // obtaining the nucleotide frequencies for all taxa
        double[] alignmentFrequencies = alignment.getStateFrequencies();
        aFrequency = alignmentFrequencies[0]; // Proportion of taxa which are a
        cFrequency = alignmentFrequencies[1];
        gFrequency = alignmentFrequencies[2];
        tFrequency = alignmentFrequencies[3];
        
        System.out.println();
        
        System.out.println("taxon1\ttaxon2\tpDistance\tjcDistance\tk2pDistance\tf84Distance\ttnDistance");
        int counter = 0;
 		for (int a = 0; a < (alignment.getSequenceCount()-1); a++) {
			String first = alignment.getAlignedSequenceString(a);
			counter ++;
			for (int b = counter; b < (alignment.getSequenceCount()); b++) {				
				String second = alignment.getAlignedSequenceString(b);
				
				p = 0;
				p1 = 0;
				p2 = 0;
				q = 0;
				
				int totalA = 0;
				int totalC = 0;
				int totalG = 0;
				int totalT = 0;
				
				if (a == b) {
					p = 0;
					q = 0;
					p1 = 0;
					p2 = 0;
				} else {
					for (int c = 0; c < alignment.getSiteCount(); c++) {
						char char1 = first.charAt(c);
						char char2 = second.charAt(c);
						
						//the next two if statements are used to count the nucleotides in the two sequences being compared
						if (char1 == 'A') {
							totalA++;
						} else if (char1 == 'C') {
							totalC++;
						} else if (char1 == 'G') {
							totalG++;
						} else if (char1 == 'T') {
							totalT++;
						}
						if (char2 == 'A') {
							totalA++;
						} else if (char2 == 'C') {
							totalC++;
						} else if (char2 == 'G') {
							totalG++;
						} else if (char2 == 'T') {
							totalT++;
						}
							
						//counting differences
						if (char1 != char2) {
							if (char1 == 'A') {
								if (char2 == 'C') {
									q++;
									totalC++;
								} else if (char2 == 'G') {
									p++;
									p1++;
								} else if (char2 == 'T') {
									q++;									
								}
							}
							if (char1 == 'C') {
								if (char2 == 'A') {
									q++;
								} else if (char2 == 'G') {
									q++;
								} else if (char2 == 'T') {
									p++;
									p2++;
								}
							}
							if (char1 == 'G') {
								if (char2 == 'A') {
									p++;
									p1++;
								} else if (char2 == 'C') {
									q++;
								} else if (char2 == 'T') {
									q++;
								}
							}
							if (char1 == 'T') {
								totalT++;
								if (char2 == 'A') {
									q++;
								} else if (char2 == 'C') {
									p++;
									p2++;
								} else if (char2 == 'G') {
									q++;
								}
							}
						}						
					}					
				}
				//from absolute numbers to proportions
				p = p/alignment.getSiteCount();
				p1 = p1/alignment.getSiteCount();
				p2 = p2/alignment.getSiteCount();
				q = q/alignment.getSiteCount();
				
				//the 4 lines below set the nucleotide frequencies to the frequencies in the two sequences being compared 
				//aFrequency = totalA/(totalA+totalC+totalG+totalT);
				//cFrequency = totalC/(totalA+totalC+totalG+totalT);
				//gFrequency = totalG/(totalA+totalC+totalG+totalT);
				//tFrequency = totalT/(totalA+totalC+totalG+totalT);

				double jcDistance = getJCdistance((p+q));
				double k2pDistance = getK2Pdistance(p,q);
				double f84Distance = getF84distance((p),(q),aFrequency, cFrequency, gFrequency, tFrequency);
				double tnDistance = getTNdistance(p1,p2,q,aFrequency, cFrequency, gFrequency, tFrequency);
				System.out.println(alignment.getTaxonId(a)+"\t"+alignment.getTaxonId(b)+"\t"+(p1+p2+q)+"\t"+jcDistance+"\t"+k2pDistance+"\t"+f84Distance+"\t"+tnDistance);
			}
 		}
    	
    }
    private static double getJCdistance(double p) {    	 
    	double distance = -(3.0/4.0)*Math.log(1 - (4.0/3.0)*p);
    	return distance;    	
    }

    private static double getK2Pdistance(double p, double q) {    	 
    	double distance = -0.5*Math.log(1 - 2*p - q) - 0.25*Math.log(1 - 2*q);
    	return distance;    	
    }

    //McGuire, Prentice and Wright    
    private static double getF84distance(double p, double q, double piA, double piC, double piG, double piT) {
 
    	double a = (piC*piT/(piC + piT)) + piA*piG/(piA + piG);
    	double b = piC*piT + piA*piG;
    	double c = (piA + piG)*(piC +piT);

    	return -2.0*a* Math.log(1.0-p/(2.0*a)-(a-b)*q/(2.0*a*c)) + 2.0*(a-b-c)*Math.log(1.0-q/(2.0*c));    	
    }
    
    private static double getTNdistance(double p1, double p2, double q, double piA, double piC, double piG, double piT) {
    	
    	double piR = piA + piG;
    	double piY = piC + piT;
    	
    	double firstTerm = -((2.0*piA*piG)/piR)*Math.log(1.0 - (piR*p1)/(2.0*piA*piG) - q/(2.0*piR));
    	double secondTerm = -((2.0*piT*piC)/piY)*Math.log(1.0 - (piY*p2)/(2.0*piT*piC) - q/(2.0*piY));
    	double thirdTerm = -2.0*(piR*piY - (piA*piG*piY)/piR - (piT*piC*piR)/piY)*Math.log(1.0 - q/(2*piR*piY));
    	
    	double distance = firstTerm + secondTerm + thirdTerm;
    	
    	return distance;
    }
 
/** from Rzhetsky  and Nei 1995, not sure if this implementation is correct   
    private static double getRzhetskyHKYdistance(double p1, double p2, double q, double piA, double piC, double piG, double piT) {
    	
    	double distance = 0.0;
    	
    	double piR = piA + piG;
    	double piY = piC + piT;
    	
    	double c = (1.0 - (q/(2.0*piR*piY)));
    	double e = (1.0 -q/(2.0*piR)-(piR*p1)/(2.0*piA*piG));
    	double f = (1.0 - (piY*p2)/(2.0*piC*piT) -q/(2.0*piY));
    	
    	double delta = Math.pow((2.0*Math.pow(piR, 2.0)*e), -1) - Math.pow((2.0*Math.pow(piR, 2)*c), -1);
    	double epsilon = Math.pow(2.0*piA*piG*e, -1);
    	double dzeta = Math.pow((2.0*Math.pow(piY, 2)*f), -1) - Math.pow((2.0*Math.pow(piY, 2)*c),-1);
    	double eta = Math.pow((2*piC*piT*f), -1);
       	double nu = Math.pow((2*piR*piY*c), -1);
       	
       	double va1 = ((Math.pow(delta, 2)*q + Math.pow(epsilon, 2)*p1) - Math.pow((delta*q+epsilon*p1), 2))/eta;
       	double va2 = ((Math.pow(dzeta, 2)*q + Math.pow(eta, 2)*p2) - Math.pow((dzeta*q+eta*p2), 2))/eta;
       	double cova1a2 = (delta*dzeta*q*(1.0-q) - delta*eta*q*p2 - epsilon*eta*p1*p2)/eta;
       	double cova1a3 = nu*q*(delta*(1.0 - q) - epsilon*p1)/eta;
       	double cova2a3 = nu*q*(dzeta*(1.0 - q) - eta*p2)/eta;
    	
       	double gamma = (va2 - cova1a2)/(va1+va2 - 2.0*cova1a2) + ((piR*piY)/(piA*piG+piC*piT))*((cova1a3 - cova2a3)/(va1 + va2 - 2.0*cova1a2));
    	    	
    	double a1 = (piY/piR)*Math.log(1.0 - q/(2.0*piR*piY)) - Math.log(1.0 - q/(2.0*piR) - (piR*p1)/(2.0*piA*piG))/piR;
    	double a2 = (piR/piY)*Math.log(1.0 - q/(2.0*piR*piY)) - Math.log(1.0 - q/(2.0*piY) - (piY*p2)/(2.0*piC*piT))/piY;
    	double a3 = -Math.log(1.0 - q/(2*piR*piY));
    	
    	distance = 2.0*(piA*piG + piC*piT)*(gamma*a1 + (1 - gamma)*a2) + 2*piR*piY*a3;
     	
    	return distance;
    } **/
       
    
}
