#include
#include
/*
* this square root algorithm is based on "bit guessing, squaring and checking".
*
* the basic idea is starting with a 0 result and then trying to set bits from high to low
* then square it and check if it stays under the square term (input).
* a very simple and logical approach.
*
* below is a good example of a slow (but clear) implementation:
*
* int mborg_sqrt(int val) {
* int guess=0;
* int bit = 1 << 15;
* do {
* guess ^= bit;
* // check to see if we can set this bit without going over sqrt(val)...
* if (guess * guess > val )
* guess ^= bit; // it was too much, unset the bit...
* } while ((bit >>= 1) != 0);
*
* return guess;
* }
*
* we don't need to multiply (square) at all. it is certainly not an
* expensive operation on a CPU, nowadays. however, in hardware implementations
* or on old processors, multiplies are not very cheap.
*
* this is why have to eliminite the multiply in the following way by keeping in
* mind the loop invariant. for this purpose, we consider the mathematical
* property of "distribution":
*
* (a+b)^2 = a^2 + b^2 + 2ab (1)
*
* consider 'a' to be our last guess (and we already got it's square) and note
* that 'b' is a mask with just one bit. this has a very simple square.
*
* b^2[n] = (1<>1 (5)
*
* which means we can keep it stored as variable that is shifted each iteration!
*
* the invariant for b^2 term may be derived simply from the fact that:
*
* b^2[n-1] = (1<<(n-1))^2 = 1<<(2n-2) = [(1<>2 = (b^2[n])>>2 (6)
*
* so we can use a mask for this that is shifted right by 2 every iteration.
*
* this final algorithmic optimisation arises from the fact that we wish to
* compare the square with the current square approximation every iteration.
*
* if our input square is 'X' then the comparison will look as follows at
* iteration n:
*
* cmp[n] = X - a^2[n] (7)
*
* and:
*
* cmp[n-1] = X - a^2[n-1] = X - ( a^2[n] + b^2[n-1] + 2ab[n-1] ) (8)
*
* this means we subtract loads of the same stuff every iteration. we don't
* want this. if we realise the following it becomes very clear:
*
* cmp[n-i] = X - a^2[n-i]
* = X - ( sum(j=n-1; j>=n-i; b^2[j]+2ab[j]) + a^2[n] ) (9)
*
* this is why we shouldn't use X but cmp and we should subtract only subtract
* the b^2 and 2ab terms from this every frame.
*
* all the optimations listed above make the iteration very dense and limit
* it to 7 ALU operation worst-case and 5 ALU operations best-case.
*
* there is one way to gain alot of speed in most cases. this is to quickly
* move through the leading zeroes at the start. this is a speed-up in almost
* any case and in the worst case it stays just as fast.
*
* the implementations are listed below and are called earx_sqrt 1, 2, 3, 4 and 5
* in increasing order of optimisation (and obfuscation ;)).
*
*
* Pieter van der Meer, 2006
*
*
*/
void earx_sqrt1(void)
{
long x=25;
long res=0;
int n=-1;
int N=15;
long sqr=0, sqr_old=0;
/* walk through leading zeroes, after this we have a minimal square root approximation */
do {
n++;
sqr_old=1<<(2*(N-n));
} while ( (n<=N) && (sqr_old>x) );
res|=1<<(N-n);
for (n++; n<=N; n++) {
/* sqr[n+1]= sqr[n] + .. */
sqr = sqr_old + (1<<(2*(N-n))) + (res<<(N+1-n));
printf("n=%2d res=0x%08lX sqr=0x%08lX sqr_old=0x%08lX\n", n, res, sqr, sqr_old);
printf("sqr (%d) = sqr_old (%d) + 1<<2(N-n) (%d) + (res<<(N+1-n)) (%d) \n", sqr, sqr_old, 1<<(2*(N-n)), res<<(N+1-n));
if (sqr<=x) {
res|=1<<(N-n);
sqr_old=sqr;
}
}
printf("0x%08lX\n",res);
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void earx_sqrt2(void)
{
long x=0x500;
long res=0;
int ni=15;
long sqr, sqr_old;
long smask=1<<(2*ni);
/* walk through leading zeroes, after this we have a minimal square root approximation */
for (; smask>x; smask>>=2, ni--);
res=1<=0; --ni) {
smask>>=2;
/* sqr[n+1]= sqr[n] + .. */
/* todo: +1 can be dropped if place this at the loop end and write a head case */
sqr = sqr_old + smask + (res<<(ni+1));
printf("ni=%2d res=0x%08lX sqr=0x%08lX sqr_old=0x%08lX\n", ni, res, sqr, sqr_old);
printf("sqr (%d) = sqr_old (%d) + 1<<2(N-n) (%d) + (res<<(N+1-n)) (%d) \n", sqr, sqr_old, smask, res<<(ni+1));
if (sqr<=x) {
res|=1<x; smask>>=2, --ni);
sqr=smask;
for (; ni>=0; --ni) {
if (sqr<=x) {
/* we only use res in one place and hence we can preshift it! */
/* this will also kill the counter operation */
res|=1<>=2;
/* sqr[n+1]= sqr[n] + .. */
sqr = sqr_old + smask + (res<x; smask>>=2);
sqr=smask;
res_shift=smask>>1; /*1<<(2*ni)*/
while (smask) {
sqr = sqr_old + smask + res_shift;
res_shift>>=1;
if (sqr<=x) {
res_shift|=smask;
sqr_old=sqr;
}
smask>>=2;
}
printf("0x%08lX\n", res_shift /*res*/);
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
__inline__ long earx_sqrt5(long cmp /* X */)
{
int N=15;
long sqr;
long smask=1<<(2*N);
long res_shift=0;
/* walk through leading zeroes, after this we have a lower bound */
for (; smask>cmp; smask>>=2);
sqr=smask;
while (smask) {
/* calculate square difference caused by guessed bit */
/* sqr[n] = b^2[n]+2ab[n] */
sqr=smask+res_shift;
/* shift temporary squareroot result down by one */
/* 2ab[n-1] = a<>1 */
res_shift>>=1;
/* if the guessed bit doesn't exceed the square, add it to the result */
/* cmp[n-i] = X - ( sum(j=n-1; j>=n-i; b^2[j]+2ab[j]) + a^2[n] ) */
if (cmp>=sqr) {
res_shift|=smask;
cmp-=sqr;
}
/* shift square bit down by two */
/* b^2[n-1] = (1<<(n-1))^2 = 1<<(2n-2) = [(1<>2 = (b^2[n])>>2 */
smask>>=2;
}
return res_shift;
}
#define TEST_OUTPUT 0
#define DEBUG_OUTPUT 1
#define TEST_RANGE 0x07FFFFFF
int main ()
{
//earx_sqrt();
//earx_sqrt2();
//earx_sqrt3();
//earx_sqrt4();
long l;
int err=0;
for (l=0; l<=TEST_RANGE; l++) {
long earx_out, math_out;
earx_out = earx_sqrt5(l);
#if DEBUG_OUTPUT
if (!(l&0x3FFFF))
printf("0x%08lX: %9d\n", l, earx_out);
#endif
#if TEST_OUTPUT
math_out = (long) sqrt(l);
if (earx_out != math_out) {
err=1;
printf("0x%08lX: %9d %9d\n", l, earx_out, math_out);
}
#endif
}
return 0;
}