#include <stdio.h>
#include <math.h>

/*
 * this square root algorithm is based on "bit guessing, squaring and checking".
 *
 * the basic idea is starting with a 0 result and then trying to set bits from high to low 
 * then square it and check if it stays under the square term (input).
 * a very simple and logical approach. 
 *
 * below is a good example of a slow (but clear) implementation:
 *
 * int mborg_sqrt(int val) {
 *   int guess=0;
 *   int bit = 1 << 15;
 *   do {
 *     guess ^= bit;  
 *     // check to see if we can set this bit without going over sqrt(val)...
 *     if (guess * guess > val )
 *       guess ^= bit;  // it was too much, unset the bit...
 *   } while ((bit >>= 1) != 0);
 *   
 *   return guess;
 * }
 *
 * we don't need to multiply (square) at all. it is certainly not an
 * expensive operation on a CPU, nowadays. however, in hardware implementations
 * or on old processors, multiplies are not very cheap.
 *
 * this is why have to eliminite the multiply in the following way by keeping in
 * mind the loop invariant. for this purpose, we consider the mathematical
 * property of "distribution":
 *
 * (a+b)^2 = a^2 + b^2 + 2ab (1)
 *
 * consider 'a' to be our last guess (and we already got it's square) and note
 * that 'b' is a mask with just one bit. this has a very simple square.
 * 
 * b^2[n] = (1<<n)^2 = 1<<2n (2)
 *
 * in other words, we just need a shift, that's all
 *
 * the '2ab' term can be written as:
 *
 * 2ab[n] = 2a(1<<n) = a(1<<(n+1)) = a<<(n+1) (3)
 *
 * again, quite simple.
 * 
 * this is what we use for our implementation. the rest is the same algorithm as
 * listed above.
 * 
 * if we check equation 3 again we notice that if we evaluate the final
 * iteration (n goes down from 15 through 0), we end up with the following:
 *
 * 2ab[0] = a<<(0+1) = 2a (4)
 *
 * this means we got our square root term 'a'. we just need to shift it right by
 * one bit. why is this important? this way we can keep temporary results in the
 * same variable instead of using using separate a and 2ab terms.
 *
 * in any case, the invariant for the 2ab part is deduced as:
 *
 * 2ab[n-1] = a<<n = (2ab[n])>>1 (5)
 *
 * which means we can keep it stored as variable that is shifted each iteration!
 *
 * the invariant for b^2 term may be derived simply from the fact that:
 * 
 * b^2[n-1] = (1<<(n-1))^2 = 1<<(2n-2) = [(1<<n)^2]>>2 = (b^2[n])>>2 (6)
 * 
 * so we can use a mask for this that is shifted right by 2 every iteration.
 *
 * this final algorithmic optimisation arises from the fact that we wish to 
 * compare the square with the current square approximation every iteration.
 *
 * if our input square is 'X' then the comparison will look as follows at
 * iteration n:
 *
 * cmp[n] = X - a^2[n] (7)
 *
 * and:
 *
 * cmp[n-1] = X - a^2[n-1] = X - ( a^2[n] + b^2[n-1] + 2ab[n-1] ) (8)
 *
 * this means we subtract loads of the same stuff every iteration. we don't
 * want this. if we realise the following it becomes very clear:
 *
 * cmp[n-i] = X - a^2[n-i]
 *          = X - ( sum(j=n-1; j>=n-i; b^2[j]+2ab[j]) + a^2[n] ) (9)
 *
 * this is why we shouldn't use X but cmp and we should subtract only subtract
 * the b^2 and 2ab terms from this every frame.
 *
 * all the optimations listed above make the iteration very dense and limit
 * it to 7 ALU operation worst-case and 5 ALU operations best-case.
 *
 * there is one way to gain alot of speed in most cases. this is to quickly
 * move through the leading zeroes at the start. this is a speed-up in almost
 * any case and in the worst case it stays just as fast.
 * 
 * the implementations are listed below and are called earx_sqrt 1, 2, 3, 4 and 5
 * in increasing order of optimisation (and obfuscation ;)). 
 *
 *
 * Pieter van der Meer, 2006
 *
 *
 */


void earx_sqrt1(void)
{
  long x=25;
  long res=0;
  int n=-1;
  int N=15;
  long sqr=0, sqr_old=0;
  
  /* walk through leading zeroes, after this we have a minimal square root approximation */
  do {
    n++;
    sqr_old=1<<(2*(N-n));
  } while ( (n<=N) && (sqr_old>x) );
  
  res|=1<<(N-n);

  for (n++; n<=N; n++) {

    /* sqr[n+1]= sqr[n] + .. */
    sqr = sqr_old + (1<<(2*(N-n))) + (res<<(N+1-n));

    printf("n=%2d res=0x%08lX sqr=0x%08lX sqr_old=0x%08lX\n", n, res, sqr, sqr_old);
    printf("sqr (%d) = sqr_old (%d) + 1<<2(N-n) (%d) + (res<<(N+1-n)) (%d) \n", sqr, sqr_old, 1<<(2*(N-n)), res<<(N+1-n));

    if (sqr<=x) {
      res|=1<<(N-n);
      sqr_old=sqr;
    }

  }
  
  printf("0x%08lX\n",res);
}



/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/

void earx_sqrt2(void)
{
  long x=0x500;
  long res=0;
  int ni=15;
  long sqr, sqr_old;
  long smask=1<<(2*ni);
  
  /* walk through leading zeroes, after this we have a minimal square root approximation */
  for (; smask>x; smask>>=2, ni--);
  res=1<<ni;
  sqr_old=smask;

  for (--ni; ni>=0; --ni) {
    smask>>=2;

    /* sqr[n+1]= sqr[n] + .. */
    /* todo: +1 can be dropped if place this at the loop end and write a head case */
    sqr = sqr_old + smask + (res<<(ni+1));

    printf("ni=%2d res=0x%08lX sqr=0x%08lX sqr_old=0x%08lX\n", ni, res, sqr, sqr_old);
    printf("sqr (%d) = sqr_old (%d) + 1<<2(N-n) (%d) + (res<<(N+1-n)) (%d) \n", sqr, sqr_old, smask, res<<(ni+1));

    if (sqr<=x) {
      res|=1<<ni;
      sqr_old=sqr;
    }

  }
  
  printf("0x%08lX\n",res);
}

/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/

void earx_sqrt3(void)
{
  long x=25;
  long res=0;
  int ni=15;
  long sqr, sqr_old=0;
  long smask=1<<(2*ni);
  
  /* walk through leading zeroes, after this we have a minimal square root approximation */
  for (; smask>x; smask>>=2, --ni);
  sqr=smask;
  
  for (; ni>=0; --ni) {
    if (sqr<=x) {
      /* we only use res in one place and hence we can preshift it! */
      /* this will also kill the counter operation */
      res|=1<<ni;
      sqr_old=sqr;
    }
    smask>>=2;
    /* sqr[n+1]= sqr[n] + .. */
    sqr = sqr_old + smask + (res<<ni);
    //printf("ni=%2d res=0x%08lX sqr=0x%08lX sqr_old=0x%08lX\n", ni, res, sqr, sqr_old);
    //printf("sqr (%d) = sqr_old (%d) + 1<<2(N-n) (%d) + (res<<(N+1-n)) (%d) \n", sqr, sqr_old, smask, res<<(ni+1));
  }
  
  printf("0x%08lX\n",res);
}

/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/

void earx_sqrt4(void)
{
  long x=121;
  long res=0;
  int ni=15;
  long sqr, sqr_old=0;
  long smask=1<<(2*ni);
  long res_shift=0;
  
  /* walk through leading zeroes, after this we have a lower bound */
  for (; smask>x; smask>>=2);
  sqr=smask;
  res_shift=smask>>1; /*1<<(2*ni)*/

  while (smask) {
    sqr = sqr_old + smask + res_shift;
    res_shift>>=1;
    if (sqr<=x) {
      res_shift|=smask;
      sqr_old=sqr;
    }
    smask>>=2;
  }
  
  printf("0x%08lX\n", res_shift /*res*/);
}

/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/

__inline__ long earx_sqrt5(long cmp /* X */)
{
  int N=15;
  long sqr;
  long smask=1<<(2*N);
  long res_shift=0;
  
  /* walk through leading zeroes, after this we have a lower bound */
  for (; smask>cmp; smask>>=2);
  sqr=smask;

  while (smask) {
    /* calculate square difference caused by guessed bit */
    /* sqr[n] = b^2[n]+2ab[n] */
    sqr=smask+res_shift;
    /* shift temporary squareroot result down by one */
    /* 2ab[n-1] = a<<n = (2ab[n])>>1 */
    res_shift>>=1;
    /* if the guessed bit doesn't exceed the square, add it to the result */
    /* cmp[n-i] = X - ( sum(j=n-1; j>=n-i; b^2[j]+2ab[j]) + a^2[n] ) */
    if (cmp>=sqr) {
      res_shift|=smask;
      cmp-=sqr;
    }
    /* shift square bit down by two */
    /* b^2[n-1] = (1<<(n-1))^2 = 1<<(2n-2) = [(1<<n)^2]>>2 = (b^2[n])>>2 */
    smask>>=2;
  }

  return res_shift;  
}

#define TEST_OUTPUT 0
#define DEBUG_OUTPUT 1
#define TEST_RANGE 0x07FFFFFF

int main ()
{
  //earx_sqrt();
  //earx_sqrt2();
  //earx_sqrt3();
  //earx_sqrt4();
  long l;
  int err=0;
  for (l=0; l<=TEST_RANGE; l++) {
    long earx_out, math_out;
    earx_out = earx_sqrt5(l);
#if DEBUG_OUTPUT
    if (!(l&0x3FFFF))
      printf("0x%08lX: %9d\n", l, earx_out);
#endif
#if TEST_OUTPUT
    math_out = (long) sqrt(l);
    if (earx_out != math_out) {
      err=1;
      printf("0x%08lX: %9d %9d\n", l, earx_out, math_out);
    }
#endif
  }
  
  return 0;
} 
