package edu.upenn.hms;

import Jama.Matrix;

public class MatrixUtils {
	public static Matrix conjGrad(Matrix A, Matrix b, int max_iters, double tol) throws IllegalArgumentException
	{
		int n = A.getRowDimension();
		if(n != A.getColumnDimension())
		{
			throw new IllegalArgumentException("A must be square");
		}
		if(b.getColumnDimension() != 1)
		{
			throw new IllegalArgumentException("b must be a column vector");
		}
		if(b.getRowDimension() != n)
		{
			throw new IllegalArgumentException("b must have same height as A");
		}
		
		Matrix x = new Matrix(n,1);
		
		Matrix r = b.copy();
		Matrix p = r.copy();
		
		for(int niter=0; niter < max_iters; niter++)
		{
			Matrix temp = A.times(p);
			double rdotr = dot(r,r);
			double alpha = rdotr / dot(p, temp);
			x.plusEquals(p.times(alpha));
			r.minusEquals(temp.times(alpha));
			double resid = r.normF();
			if(resid < tol)
			{
				return x;
			}
			double beta = dot(r,r)/rdotr;
			p = r.plus(p.times(beta));
		}
		
		return x;
	}
	
	public static double dot(Matrix a, Matrix b) throws IllegalArgumentException
	{
		int n = a.getRowDimension();
		if(a.getColumnDimension() != 1)
		{
			throw new IllegalArgumentException("a must be a column vector");
		}
		if(b.getColumnDimension() != 1)
		{
			throw new IllegalArgumentException("b must be a column vector");
		}
		if(b.getRowDimension() != n)
		{
			throw new IllegalArgumentException("b and a must be of the same size");
		}
		
		Matrix mult = a.arrayTimes(b);

		double sum = 0;
		for(int i=0; i<n; i++)
		{
			sum += mult.get(i, 0);
		}
		return sum;
	}
}

