A Program for Taking Screenshot for Android Device from Command Line

Android DDMS allows one to take screenshot easily by clicking a button. However, sometimes a developer may want to take screenshot from the command line in order to automate it. This post discusses how we can do this and provides a java program at the end.

0. How DDMS Capture Screenshot

DDMS capture screenshots through Android Debug Bridge (adb). It creates adb client, which communicates with adb server. Both adb client and server runs on the developer’s machine. DDMS sends screenshot command through adb client to adb server, the adb server then communicates with the daemon process runs on Android phone or emulator to get a screenshot.

Note that one needs to enable USB debugging in order for DDMS to capture screenshot.

1. Capture Screenshot from Command Line

Android source code contains a program to capture screenshot, but one needs to compile by oneself.

The code uses similar way DDMS uses to capture screenshot. First, it initiates a Android Debug Bridge instance by calling the code below,

AndroidDebugBridge.init(false); //initialize the library

AndroidDebugBridge bridge = AndroidDebugBridge.createBridge();

The debug bridge will get a list of devices available. We can then query the list of devices by calling the code below,

IDevice[] devices = bridge.getDevices();

If there’s only one device available, we can proceed to screen capture by calling,

RawImage rawImage = device.getScreenshot();

If there’re multiple devices, one can select a device based on the serial number of the device. device.getSerialNumber() returns a string contains the serial number for the device.

2. Downloads

One can download the code from here or my github here. If you only want to executable, you can get it here. Note that you’ll need to rename the file from screenshot.jar.zip to screenshot.jar.

Below is a screenshot of a sample execution and the phone screen captured.

 
Figure 1. Execution of screenshot.jar


Figure 2. Screenshot Captured

References:
1. Android debug bridge: http://developer.android.com/guide/developing/tools/adb.html

Big Integer Arithmetic

Most of computer applications operates on small integers which can be represented by 32 bit integers. However, there are times we’ll need to deal with extremely big integers which can only be represented by hundreds of bits. One of such application is cryptography.

This post gives two simple implementations of the big integer arithmetic. Both implementations use char array to store the big integers. The first one stores the digits of the integer as char, and the second one stores the digits as single digit integer values.

String Representation
We define a big integer as the structure below,

typedef struct {

   int sign;

   int length;

   char dg[MAXDG];

} BN;

For example, given two input integers “111222333444555” and “12345”, we can store them into two BNs a and b, with a.length = 15, a.dg[0] = ‘1’, a.dg[1]=’1’ …, a.dg[14]=’5’, a.sign = 1; b.length=5, b.dg[0]=’1’, b.dg[1]=’2’…a.dg[4]=’5’, b.sign=1.

Addition: We add the two integers digit by digit from right to left. If the sum of the two digits is bigger than ‘9’. Then a carry of 1 is produced for the next position. And the carry is added to the sum of the digits at the next position. We store the digits of the result from left to right and then reverse it after it’s done.

Subtraction: Subtraction is similar to addition. We minus the digits from the two integers from right to left. If the resulted digit is less than 0, we borrow 1 from the next position.

Multiplication: A naive approach is to use addition to handle multiplication. But if both inputs are big, then the algorithm is very slow.

A better approach is to handle a single digit multiplication using addition, then shift the multiplicand by 1 (which is equivalent to x10) and handle the next digit multiplication.

For example, if the two inputs are 12345 and 57, we treat the multiplication as 12345 x 7 + 123450 x 5, then each of the two multiplications can be handled using additions. In this way, the number of additions needed are reduced to the sum of all digits of the multiplier.

Division: Again the naive approach is to use subtraction. A better alternative is trying to subtract the divisor from part of the dividend from right to left. For example, to compute 234 / 11, we start by computing 23 / 11 using subtraction. The quotient is 2 and the remainder is 1.  We then compute (1*10 + 4) / 11 using subtraction, the quotient will be 1. So the result is 21.

Below is a C implementation of the algorithms described above,

/**

array implementation of bignum

**/

#include <stdio.h>

#include <string.h>

 

#define MAXDG 500

#define PLUS 1

#define MINUS -1

 

typedef struct {

    int sign;

    int length;

    char dg[MAXDG];

} BN;

 

int cmp(BN a, BN b);

void sub(BN a, BN b, BN *c);

void add(BN a, BN b, BN *c);

void reverse(BN *c);

void rm_tzeros(BN *c);

void rm_lzeros(BN *c);

void init(BN *a);

 

int cmp(BN a, BN b) {

    if (a.sign == PLUS && b.sign == MINUS) {

        return 1;

    } else if (a.sign == MINUS && b.sign == PLUS) {

        return -1;

    } else {

        //a and b have same sign

        if (a.length > b.length) {

            return 1*a.sign;

        } else if (a.length < b.length) {

            return -1*a.sign;

        } else {

            //a and b have same sign, same length

            int i;

            for (i = 0; i < a.length; ++i) {

                if (a.dg[i] > b.dg[i]) {

                    return 1*a.sign;

                } else if (a.dg[i] < b.dg[i]) {

                    return -1*a.sign;

                } 

            }

            return 0; //equal

        }

    }

}

 

void reverse(BN *c) {

    char tmp;

    int i;

    for (i = 0; i < c->length/2; ++i) {

        tmp = c->dg[i];

        c->dg[i] = c->dg[c->length-i-1];

        c->dg[c->length-i-1] = tmp;

    }  

}

 

void rm_lzeros(BN *c) {

    int i, j;

    for (i = 0; i < c->length; ++i) {

        if (c->dg[i] != '0') {

            break;

        }

    }

    for (j = 0; i < c->length; ++i, ++j) {

        c->dg[j] = c->dg[i];

    }

    c->length = j == 0?1:j;

    c->dg[c->length] = '';

}

 

void rm_tzeros(BN *c) {

    int i, j = 0;

    for (i = c->length-1; i > 0; --i) {

        if (c->dg[i] == '0') {

            j++;

        } else {

            break;

        }

    }

    c->length -= j;

    c->dg[c->length] = '';

}

 

void sub(BN a, BN b, BN *c) {

   int i, j, k;

   int borrow = 0;

   if (a.sign == MINUS) {

       if (b.sign == MINUS) {

           sub(b, a, c);

           c->sign *= -1;

       } else {

           b.sign = MINUS;

           add(a, b, c);

       }

       return;

   } else {

       if (b.sign == MINUS) {

           b.sign = PLUS;

           add(a, b, c);

           return;

       } 

   }

   //a.sign == PLUS, b.sign == PLUS

   if (cmp(a, b) < 0) {

      sub(b, a, c);

      c->sign = MINUS;

      return; 

   } else if (cmp(a, b) == 0) {

      c->dg[0] = '0';

      c->dg[1] = '';

      c->length = 1;

      c->sign = PLUS;

      return;

   } 

   //a.sign == PLUS, b.sign == PLUS, a > b

   c->sign = PLUS;

   for (i = a.length-1, j = b.length-1, k = 0; i >= 0 && j >= 0; --i, --j, ++k) {

      c->dg[k] = a.dg[i] - borrow - b.dg[j];

      if (c->dg[k] < 0) {

          c->dg[k] += 10;

          borrow = 1;

      } else {

          borrow = 0;

      }

      c->dg[k] += '0';

   }

   for (; i >= 0; --i, ++k) {

      c->dg[k] = a.dg[i] - borrow;

      if (c->dg[k] - '0' < 0) {

          c->dg[k] += 10;

          borrow = 1; 

      } else {

          borrow = 0;

      } 

   }

   c->dg[k] = '';

   c->length = k;

   //printf("%sn", c->dg);

   rm_tzeros(c);

   //printf("%sn", c->dg);

   reverse(c);

}

 

void add(BN a, BN b, BN *c) {

    int i, j, k;

    int carry = 0;

//    printf("%d:%dn", a.length, b.length);

    if (a.sign == b.sign) {

        c->sign = a.sign;

    } else {

        if (a.sign == MINUS) {

            a.sign = PLUS;    //tmp

            sub(b, a, c);

        } else {

            b.sign = PLUS;    //tmp

            sub(a, b, c);

        }

        return;

    }

    for (i = a.length-1, j = b.length-1, k = 0; i >= 0 && j >= 0; --i, --j, ++k) {

        c->dg[k] = a.dg[i] -'0' + b.dg[j] - '0' + carry;

        //printf("%d %d %c %cn",k, c->dg[k], a.dg[i], b.dg[j]);

        carry = c->dg[k]/10;

        c->dg[k] = c->dg[k]%10 + '0';

//        printf("%d %cn", k, (*c->dg[k]);

    }

    for (; i >= 0; --i, ++k){

        if (carry != 0) {

            c->dg[k] = carry + a.dg[i] - '0';

            carry = c->dg[k]/10;

            c->dg[k] = c->dg[k]%10 + '0';

        } else {

            c->dg[k] = a.dg[i];

        }

    }

    for (; j >= 0; --j, ++k){

        if (carry != 0) {

            c->dg[k] = carry + b.dg[j] - '0';

            carry = c->dg[k]/10;

            c->dg[k] = c->dg[k]%10 + '0';

        } else {

            c->dg[k] = b.dg[j];

        }

    }

    if (carry != 0) {

        c->dg[k++] = carry + '0';

    }

    c->dg[k] = '';

    c->length = k;

    //printf("%d %sn", c->length, c->dg);

    reverse(c);

    //printf("%d %sn", c->length, c->dg);

}

 

void lshift(BN *a, int num) {

    int i;

    if ((a->length == 1) && (a->dg[0] == '0')) {

        return;

    }

    for (i = 0; i < num; ++i) {

        a->dg[i+a->length] = '0';

    }

    a->length += num;

    a->dg[a->length] = '';

}

 

void mul(BN a, BN b, BN *c) {

   int i, j;

   

   init(c);

   for (i = a.length - 1; i >= 0; --i) {

       for (j = 0; j < a.dg[i]-'0'; ++j) {

          //printf("add %s to %sn", b.dg, c->dg);

          add(*c, b, c); 

       }

       lshift(&b, 1);

       //printf("%s %sn", c->dg, b.dg);

   }

   c->sign = a.sign*b.sign; 

}

 

void div(BN a, BN b, BN *c) {

    int i;

    BN tmp;

    init(c);

    init(&tmp);

    for (i = 0; i < a.length; ++i) {

        //printf("bf: %sn", tmp.dg);

        if (i > 0) {

           lshift(&tmp, 1);

        } else {

           tmp.length = 1;

        }

        //printf("af: %sn", tmp.dg);

        tmp.dg[tmp.length-1] = a.dg[i];

        c->dg[i] = '0';

        //printf("%d b: %d:%d:%s %d:%d:%sn", i, tmp.sign, tmp.length, tmp.dg, b.sign, b.length, b.dg);

        rm_lzeros(&tmp);

        while (cmp(tmp, b) >= 0) {

            //printf("sub: %s %sn", tmp.dg, b.dg);

            sub(tmp, b, &tmp);

            c->dg[i]++;

        } 

    }

    c->dg[i] = '';

    c->length = a.length;

    //printf("%d %sn", i, c->dg);

    rm_lzeros(c);

    //printf("%sn", c->dg);

}

 

void init(BN *a) {

    a->sign = PLUS;

    a->length = 1;

    a->dg[0] = '0';

    a->dg[1] = '';

}

 

int main(int argc, char** argv) {

    BN m, n, rv;

    if (argv[1][0] == '-') {

        m.sign = MINUS;

        sprintf(m.dg, "%s", &(argv[1][1]));

    } else {

        m.sign = PLUS;

        sprintf(m.dg, "%s", argv[1]);

    }

    m.length = strlen(m.dg);

    if (argv[2][0] == '-') {

        n.sign = MINUS;

        sprintf(n.dg, "%s", &(argv[2][1]));

    } else {

        n.sign = PLUS;

        sprintf(n.dg, "%s", argv[2]);

    }

    n.length = strlen(n.dg);

    printf("m:%c%sn", m.sign>0?' ':'-', m.dg);

    printf("n:%c%sn", n.sign>0?' ':'-', n.dg);

    add(m, n, &rv);

    printf("m + n = %c%sn", rv.sign>0?' ':'-', rv.dg);

    sub(m, n, &rv);

    printf("m - n = %c%sn", rv.sign>0?' ':'-', rv.dg);

    lshift(&rv, 1);

    printf("m << 1 = %c%sn", rv.sign>0?' ':'-', rv.dg);

    mul(m, n, &rv);

    printf("m * n = %c%sn", rv.sign>0?' ':'-', rv.dg);

    div(m, n, &rv);

    printf("m / n = %c%sn", rv.sign>0?' ':'-', rv.dg);

}

Save the code to bignumarray.c, and compile the code using the command below,

gcc -o bignumarray bignumarray.c

And here are some simple tests,

./bignumarray 34123432143214321 1342

m: 34123432143214321

n: 1342

m + n =  34123432143215663

m - n =  34123432143212979

m << 1 =  341234321432129790

m * n =  45793645936193618782

m / n =  25427296678997

 

./bignumarray 332 13424312432

m: 332

n: 13424312432

m + n =  13424312764

m - n = -13424312100

m << 1 = -134243121000

m * n =  4456871727424

m / n =  0

Digit Array Representation

The big integer is still defined as below,

typedef struct {

   int sign;

   int length;

   char dg[MAXDG];

} BN;

But now the dg holds the values for each digit. For example, given two inputs 12345 and 678, then the two big integers are a, b, with a.length = 5, a.sign = 1, and a.dg[0]=5, a.dg[1]=4, a.dg[2]=3, a.dg[3]=2, a.dg[4]=1; b.length=3, b.sign = 1, and b.dg[0]=8, b.dg[1]=7, b.dg[2]=6. Note that the digits are stored in a reverse manner because it is more convenient to manipulate the end of an array than the start of an array (e.g. truncate the array by 1).

The arithmetic operations are similar to the previous implementation. Below is the sample C code.

#include <stdio.h>

#include <string.h>

 

#define MAXDG 500

#define PLUS 1

#define MINUS -1

 

typedef struct {

    int sign;

    int length;

    char dg[MAXDG];

} BN;

 

 

void add(BN a, BN b, BN *c);

void mul(BN a, BN b, BN *c);

void div(BN a, BN b, BN *c);

void sub(BN a, BN b, BN *c);

void digit_shift(BN *a, int d);

int cmp(BN a, BN b);

void zero_justify(BN* a);

void init(BN *a);

void printbn(BN a);

 

void printbn(BN a) {

    int i;

    //printf("length: %d    ", a.length);

    if (a.sign == MINUS) {

        printf("-");

    }

    for (i = a.length-1; i >= 0; --i) {

        printf("%c", '0' + a.dg[i]);

    }

    printf("n");

}

 

void init(BN *a) {

    a->length = 1;

    memset(a->dg, 0x00, sizeof(char)*MAXDG);

    a->sign = PLUS;

}

 

void zero_justify(BN* a) {

    while ((a->length > 1) && (a->dg[a->length-1] == 0)) {

        (a->length)--;

    }

    //if -0, make it +0

    if ((a->length == 1) && (a->dg[0] == 0)) {

        a->sign = PLUS;

    }

}

 

int cmp(BN a, BN b) {

    if (a.sign == PLUS && b.sign == MINUS) {

        return 1;

    } else if (a.sign == MINUS && b.sign == PLUS) {

        return -1;

    } else {

        //a and b have same sign

        if (a.length > b.length) {

            return 1*a.sign;

        } else if (a.length < b.length) {

            return -1*a.sign;

        } else {

            //a and b have same sign, same length

            int i;

            for (i = a.length-1; i >= 0; --i) {

                if (a.dg[i] > b.dg[i]) {

                    return 1*a.sign;

                } else if (a.dg[i] < b.dg[i]) {

                    return -1*a.sign;

                } 

            }

            return 0; //equal

        }

    }

}

 

void digit_shift(BN *a, int d) {

    int i;

    if ((a->length == 1) && (a->dg[0] == 0)) {

        return;

    }

    for (i = a->length - 1; i >= 0; --i) {

        a->dg[i+d] = a->dg[i];

    }

    for (i = 0; i < d; ++i) {

        a->dg[i] = 0;

    }

    a->length = a->length + d;

}

 

void add(BN a, BN b, BN *c) {

    int i;

    int carry = 0;

    init(c);

    if (a.sign == b.sign) {

        c->sign = a.sign;

    } else {

        if (a.sign == MINUS) {

            a.sign = PLUS;

            sub(b, a, c);

        } else {

            b.sign = PLUS;

            sub(a, b, c);

        }

        return;

    }

    c->length = (a.length > b.length)?a.length+1:b.length+1;

    for (i = 0; i < c->length; ++i) {

        c->dg[i] = (char)(a.dg[i] + b.dg[i] + carry)%10;

        carry = (a.dg[i] + b.dg[i] + carry)/10;

    }

    zero_justify(c); 

}

 

void sub(BN a, BN b, BN *c) {

    int borrow;

    int v;

    int i;

    init(c);

    if ((a.sign == MINUS) || (b.sign == MINUS)) {

        b.sign = -1*b.sign;

        add(a, b, c);

        return;

    }

    if (cmp(a, b) < 0) {

        sub(b, a, c);

        c->sign = MINUS;

        return;

    }

    //a > b, and both a and b are +

    c->length = a.length > b.length?a.length:b.length;

    borrow = 0;

    for (i = 0; i < c->length; ++i) {

        c->dg[i] = (a.dg[i] - borrow - b.dg[i]);

        if (c->dg[i] < 0) {

            c->dg[i] += 10;

            borrow = 1;

        } else {

            borrow = 0;

        }

    }

    zero_justify(c);

}

 

void mul(BN a, BN b, BN *c) {

    BN row, tmp;

    int i,j;

    init(c);

    for (i = 0; i < b.length; ++i) {

        for (j = 0; j < b.dg[i]; ++j) {

            add(*c, a, c);

        }

        digit_shift(&a, 1);

    }

    zero_justify(c);

}

 

void div(BN a, BN b, BN *c) {

    int i;

    BN tmp;

    init(c);

    init(&tmp);

    c->sign = a.sign * b.sign;

    a.sign = PLUS;

    b.sign = PLUS;

    c->length = a.length;

    for (i = a.length - 1; i >= 0; --i) {

        digit_shift(&tmp, 1);

        //printf("%d:%dn", tmp.length, b.length);

        //printbn(tmp);

        tmp.dg[0] = a.dg[i];

        c->dg[i] = 0;

        while (cmp(tmp, b) >= 0) {

            (c->dg[i])++;

            //printbn(tmp);

            sub(tmp, b, &tmp);

            //printbn(tmp);

        }

        //printf("div: %dn", c->dg[i]);

    }

    zero_justify(c);

}

 

void readbn(char *s, BN *a) {

    int i, j;

    init(a);

//    printf("readbn: %dn", strlen(s));

    for (i = strlen(s)-1, j = 0; i >= 0; --i) {

        a->dg[j++] = s[i] - '0';

    }

    a->length = j;

}

 

int main(int argc, char** argv) {

    BN m, n, rv;

    if (argv[1][0] == '-') {

        m.sign = MINUS;

        readbn(&argv[1][1], &m);

    } else {

        m.sign = PLUS;

        readbn(argv[1], &m);

    }

    if (argv[2][0] == '-') {

        n.sign = MINUS;

        readbn(&argv[2][1], &n);

    } else {

        n.sign = PLUS;

        readbn(argv[2], &n);

    }

    printf("m:");

    printbn(m);

    printf("n:");

    printbn(n);

    add(m, n, &rv);

    printf("m + n = ");

    printbn(rv);

    sub(m, n, &rv);

    printf("m - n = ");

    printbn(rv);

    digit_shift(&rv, 1);

    printf("m << 1 = ");

    printbn(rv);

    mul(m, n, &rv);

    printf("m * n = ");

    printbn(rv);

    div(m, n, &rv);

    printf("m / n = ");

    printbn(rv);

}

Save the code to bignumarray2.c, and compile the code using the command below,

gcc -o bignumarray2 bignumarray2.c

And here are some simple tests,

./bignumarray2 34123432143214321 1342

m:34123432143214321

n:1342

m + n = 34123432143215663

m - n = 34123432143212979

m << 1 = 341234321432129790

m * n = 45793645936193618782

m / n = 25427296678997

 

./bignumarray2 332 13424312432

m:332

n:13424312432

m + n = 13424312764

m - n = -13424312100

m << 1 = -134243121000

m * n = 4456871727424

m / n = 0

Additional Notes

1. Dynamic array is necessary in order to support arbitrary length of big integers.

2. Using a single byte to represent a single digit is a huge waste. There’re better representations.

References:
1. Programming Challenges

Primality Testing — Part 4. Fermat’s Little Theorem and A Probabilistic Approach

Previous posts on primality testing are all trying to factor the input number. For a very large number with few hundreds of bits, it is difficult to compute in that manner.
This post introduce a probabilistic approach based on Fermat’s primality test.

Fermat’s Little Theorem

Fermat’s primality test is based on Fermat’s Little Theorem,

If p is prime, then for every 1 <= a < p,

[ a ^ ( p – 1) ] mod p = 1

Fermat’s Little Theorem suggests a method to test if a number is prime or not. For a big input number p, it is not desirable to test every a. Normally, we randomly pick up an a, and if it satisfy the equation above, we claim it is prime, otherwise, it is not.

However, Fermat’s little theorem is a necessary condition for a number to be prime, but not sufficient. In other words, every prime number satisfy fermat’s little theorem, but it is also possible to find an a for a composite number n where [ a ^ (p – 1) ] mod n = 1.

It can be proven that most values of a will fail the test for a composite number. And if we pick several random a and they all pass the test, the possibility of making a wrong decision is rare. This algorithm cannot guarantee that positive primality test is 100% correct, but it is fast the possibility of wrong is low enough for many practical use cases.

Implementation

Below is a simple C implementation based on the idea above,

#include <stdio.h>

#include <stdlib.h>

 

unsigned long modexp(int x, int y, int n) {

    unsigned long i;

    if (y == 0) {

        return 1%n;

    }

    if (y == 1) {

        return x%n;

    } 

    i = modexp(x, y/2, n);

    if (y&0x01) {    //odd

        //return (x*i*i)%n;

//        printf("%d:%lu:%lun", x, i, (x%n*i*i));

        return x%n*i%n*i%n;

    } else {         //even

        return i*i%n;

    }

}

 

int main(int argc, char **argv) {

    int n;

    int a;

    int i;

    n = atoi(argv[1]);

    srand(time(NULL));

    struct timeval stTime, edTime;

    gettimeofday(&stTime, NULL);

    for (i = 0; i < 20 && i < n; ++i) {

        a = rand()%100;          

        if (a == 0) {

            a = 1;

        }

        //printf("a=%dn", a);

    if (modexp(a, n-1, n) != 1) {

            printf("%d is not prime, %dn", n, a);

            break;

    } 

    }

    if (i == 20 || i == n) {

          printf("%d is primen", n);

    }

    gettimeofday(&edTime, NULL);

    printf("time: %u:%un", (unsigned int)(edTime.tv_sec - stTime.tv_sec), (unsigned int)(edTime.tv_usec - stTime.tv_usec)); 

 

}

Note that the code uses the modular exponentiation algorithm. Due to the integer overflow, the largest input the program can accept is 65536 = 2^16.

Save the code to prim5.c and compile it using the command below,

gcc -o prim5 prim5.c

Below are some simpe tests,

roman10@ra-ubuntu-1:~/Desktop/integer/prim$ ./prim5 65521

65521 is prime

time: 0:69

 

roman10@ra-ubuntu-1:~/Desktop/integer/prim$ ./prim5 2

2 is prime

time: 0:53

 

roman10@ra-ubuntu-1:~/Desktop/integer/prim$ ./prim5 65536

65536 is not prime, 8

time: 0:53

 

roman10@ra-ubuntu-1:~/Desktop/integer/prim$ ./prim5 63703

63703 is prime

time: 0:69

 

roman10@ra-ubuntu-1:~/Desktop/integer/prim$ ./prim5 61493

61493 is prime

time: 0:67

 

roman10@ra-ubuntu-1:~/Desktop/integer/prim$ ./prim5 13

13 is prime

time: 0:54

References:
1. Wikipedia Primality test: http://en.wikipedia.org/wiki/Primality_test
2. Algorithms

How to Calculate Modular Exponentiation

Modular exponentiation is a exponentiation performed over modulus. It is used in many cryptographic algorithms.

Modular exponentiation accepts three input integers, say x, y and n and computes according to the formula below,

(x ^ y) mod n

Native Approaches

A naive approach to compute modular exponentiation is to compute the exponentiation first and then the modulus.

This approach is OK for small x, y and n. But in cryptographic algorithms, the values of x, y and n are normally several hundreds of bits long and this approach takes too much memory and multiplication.

A slightly improved version is to compute x mod n repeatedly and multiply to the current result. That is,

rv = x mod n => rv = rv * (x mod n) => rv = rv *(x mod n) ….

This approach reduces the usage of memory. But it still perform poorly when y is very big because the number of the multiplications it needs to perform is very large. Suppose y = 2^100, then the number of multiplications it needs to perform is also 2^100, which is almost infeasible in practice.

A Better Approach

A better approach is to repeatedly square x mod n. That is,

x mod n => (x mod n)^2 => (x mod n)^4 => (x mod n)^8 =>…=>

(x mod n) ^ logy

In this way, the number of multiplications needed can be reduced to logy. For y = 2^100, the number of multiplications needed is just log(2^100) = 100.

In general, this approach can be expressed by the equation below,

(x ^ y) mod n = [ (x ^ ( y / 2 ) ) mod n] ^ 2                                  if y is even

(x ^ y) mod n = [ x * [( x ^ ( y / 2 ) ) mod n ] ^ 2 ] mod n              if y is odd

This approach can be easily implemented in a recursive way. Below is a C implementation for this algorithm,

/**

polymial-time computation for (x^y)%N

the idea is for mod N

(x^y) = [x^(y/2)]^2 if y is even

        = x*[x^(y/2)]^2 if y is odd

**/

#include <stdio.h>

#include <stdlib.h>

 

unsigned long modexp(int x, int y, int n) {

    unsigned long i;

    if (y == 0) {

        return 1%n;

    }

    if (y == 1) {

        return x%n;

    } 

    i = modexp(x, y/2, n);

    printf("%d:%ldn", y, i);

    if (y&0x01) {    //odd

        //return (x*i*i)%n;

        return (x%n*i*i)%n;

    } else {         //even

        return (i*i)%n;

    }

}

 

int main(int argc, char **argv) {

    int x, y, N;

    x = atoi(argv[1]);

    y = atoi(argv[2]);

    N = atoi(argv[3]);

    printf("(x^y)modN = %lun", modexp(x, y, N));

    return 0;

}

One thing to note in the code above is that when y is odd, we compute using (x%n*i*i)%n instead of (x*i*i)%n. This is to prevent overflow caused by multiplication of x*i*i when x is a very big number.

Also note that when both x and n are big, the modulo x%n is big, the x%n*i*i can still cause overflow. Solving this issue requires different representations of integers and it is not discussed in this post. For most commonly used integers (except cryptographic), the program above is good enough.

For a given input number y, the number of recursions is ~logy. Therefore the time complexity is O(logy). When y is measured in bits, say, N bits, then the number of recursions is ~N. Therefore the time complexity is O(N).

Save the code as modexp.c and compile it using the command,

gcc -o modexp modexp.c

Below are some sample tests,

roman10@ra-ubuntu-1:~/Desktop/integer$ ./modexp 4 13 497

3:4

6:64

13:120

(x^y)modN = 445

 

roman10@ra-ubuntu-1:~/Desktop/integer$ ./modexp 1234 5678 90

2:64

5:46

11:64

22:64

44:46

88:46

177:46

354:64

709:46

1419:64

2839:64

5678:64

(x^y)modN = 46

 

roman10@ra-ubuntu-1:~/Desktop/integer$ ./modexp 3 14 30

3:3

7:27

14:27

(x^y)modN = 9

 

roman10@ra-ubuntu-1:~/Desktop/integer$ ./modexp 2 16 123

2:2

4:4

8:16

16:10

(x^y)modN = 100

An Alternative Approach

An alternative approach can help us to get rid of the recursion in the program above. The idea is based on the binary representations of an integer in computer. For example, in the first test case above, y = 13 = 2^3 + 2^2 + 2^0. Therefore to compute (4^13) mod 497, we can compute a = (4^(2^3)) mod 497, b = (4^(2^2)) mod 497 and c = (4^(2^0)) mod 497, and then multiply the results together.

In general, given x and n, we can pre-compute [ x ^ (2 ^ N)  ] mod n, where N is the number of bits needed to represent the biggest value of y. If y is constrained to be less than 2^100, then N = 100.

In addition, [x ^ (2 ^ N) ] mod n = { [x ^ ( 2 ^ (N-1) ) ] mod n} ^ 2. Therefore, we can compute from N = 0, 1 to N = 100 easily.

r0 = [x ^ (2 ^ 0)] mod n = x mod n
r1 = [x ^ (2 ^ 1)] mod n = r0 ^ 2
r2 = [x ^ (2 ^ 2)] mod n = r1 ^ 2

For a given y, we then scan the bits that is set to 1 and read the results from the pre-computation and multiply the results together.

Below is a sample C implementation for this approach,

/**

polymial-time computation for (x^y)%N

non-recursive version

**/

#include <stdio.h>

#include <stdlib.h>

 

#define MAXP 1000

unsigned long m[MAXP];

 

void precompute(int x, int n) {

   int i;

   m[0] = x%n;

   for (i = 1; i < MAXP; ++i) {

       m[i] = (m[i-1]*m[i-1])%n;

//       printf("%d:%ldn", i, m[i]);

   }

}

 

unsigned long modexp(int x, int y, int n) {

    int i;

    unsigned long rv;

    int tmp;

    rv = -1;

    for (i = 0; tmp != 0; ++i) {

        tmp = y >> i;

        if (tmp & 1) {

            if (rv == -1) {

                rv = m[i]%n;

            } else {

                rv = (m[i]*rv)%n;

            }

            printf("%d:%ldn", i, rv);

        }

    }

    return (rv == -1 ? 1%n:rv);

}

 

int main(int argc, char **argv) {

    int x, y, N;

    x = atoi(argv[1]);

    y = atoi(argv[2]);

    N = atoi(argv[3]);

    precompute(x, N);

    printf("(x^y)modN = %lun", modexp(x, y, N));

    return 0;

}

Save the code to modexp2.c and compile it using the command below,

gcc -o modexp2 modexp2.c

Below are some simple tests,

roman10@ra-ubuntu-1:~/Desktop/integer$ ./modexp2 4 13 497

0:4

2:30

3:445

(x^y)modN = 445

 

roman10@ra-ubuntu-1:~/Desktop/integer$ ./modexp2 1234 5678 90

1:46

2:46

3:46

5:46

9:46

10:46

12:46

(x^y)modN = 46

 

roman10@ra-ubuntu-1:~/Desktop/integer$ ./modexp2 3 14 30

1:9

2:9

3:9

(x^y)modN = 9

 

roman10@ra-ubuntu-1:~/Desktop/integer$ ./modexp2 2 16 123

4:100

(x^y)modN = 100

References:
1. Algorithms. Dasgupta, C.H.Papadimitriou, and U.V. Vazirani
2. Wikipedia Modular Exponentiation: http://en.wikipedia.org/wiki/Modular_exponentiation
3. Numerical Recipes: http://numericalrecipes.blogspot.com/2009/03/modular-exponentiation.html