## Support Vector Machine Concept: VC Dimension

[latexpage]

Given a set of n samples \$x_1, x_2, …, x_n\$, we want to label them as either -1 or 1. In total, there are \$2^n\$ possible label combinations.

A class of learning machines H can be used to label the samples. If for each label combination, we can always find a learning machine \$h in H\$ that labels it correctly, we then say that H shatters n points.

VC (Vapnik-Chervonekis) dimension is then defined as the maximum number of points that can be shattered by H, which measures the capacity of the hypothesis class H.

Note that VC(H) = 4 does not mean H can shatter any 4 points in the hyperplane, as long as there’s 4 points can be shattered by H, it’s good enough. And it also infers that for any > 4 (say 5) points in the hyperplane, it cannot be shattered by H.

Now consider a simple example to find the VC dimension of a class of hypothesis H. H is defined as straight line in a two dimensional hyperplane.

We can find the following 3 points a, b and c where all \$2^3\$ possible label combinations can be classified correctly by H.

As mentioned before, it’s possible to find 3 points that H cannot shatter. For example, we just place a, b and c in the same line. There will be combinations that a straight line cannot label them correctly. But this doesn’t affect the VC dimension as it’s defined as the maximum number of points can be shattered.

One can try separate 4 points using a straight line, but no matter where we place those 4 points, there will be combinations that cannot be separated.

## Naive Bayes

#### Bayes’ Theorem

Let’s start from Bayes’ theorem, also referred as Bayes’ law or Bayes’ rule.

P(A|B) = P(B, A) / P(B)

= P(B|A) * P(A) / P(B)

= P(B|A) * P(A) / (P(B|A) * P(A) + P(B|^A) * P(^A))

P(A): prior probability. It’s the probability event A happens.

P(^A): the probability that event A not happen.

P(B): evidence, or background. The probability of event B happens.

P(B|A), P(B|^A): conditional probability, or likelihood. The probability of event B happens given A happened or not happened respectively.

P(A|B): posterior probability. The probability of A happens taking into account B for and against A.

#### Naive Bayes

When used for classification, Bayes’ Theorem can be expressed as below,

P(C|F1, F2, … Fn) = P(C, F1, F2, …, Fn) / P(F1, F2, … , Fn)

= P(F1, F2, … Fn|C) * P(C) / P(F1, F2, …, Fn)

C is some class/label we can classify a sample into, and F1, F2, … Fn represents features of the sample data.

P(F1, F2, …, Fn) doesn’t depend on C and are normally given or can be calculated based on probability of each feature. It’s effectively a constant and can be ignored for classification purpose.

The numerator can be expressed as following,

P(C, F1, F2 … , Fn)

= P(C) * P(F1, F2, … , Fn|C)

= P(C) * P(F1 | C) * P(F2, F3, … Fn | F1, C)

= P(C) * P(F1 | C) * P(F2 | F1, C) * P(F3, … Fn | F1, F2, C)

= P(C) * P(F1 | C) * P(F2 | F1, C) * P(F3 | F1, F2, C) * …  * P(Fn | F1, F2, …, Fn-1, C)

In Naive Bayes, all features are assumed to be independent. Thus Fi is independent from every other feature Fj where j != i. Therefore we have

P(F2 | F1, C) = P(F2 | C)

P(F3 | F1, F2, C) = P(F3 | C)

P(Fn | F1, F2, … Fn-1, C) = P(Fn | C)

Thus,

P(C, F1, F2 … , Fn) = P(C) * P(F1 | C) * P(F2 | C) * P(F3 | C), …, P(Fn | C)

For example, two authors A and B like to use words “love”, “life” and “money”. The probability of these words appears in A’s article is 0.1, 0.1 and 0.8, and in B’s as 0.5, 0.3 and 0.2. Now we have the phrase “love life”, which one of the author is more likely to have written that?

Without any information, there’s 50% percent probability for either A or B. Assuming the words are independent features, we can use Naive Bayes.

P(A | love, life) = P(A) * P(love | A) * P(life | A) / P(love, life) = 0.5 * 0.1 * 0.1 / P(love, life)

P(B | love, life) = P(B) * P(love | B) * P(life | B) / P(love, life) = 0.5 * 0.5 * 0.3 / P(love, life)

Clearly, it’s more likely that the phrase “love life” is written by author B. Note that P(love, life) is independent from the authors and just a scaling factor.

References:

1. Bayes’ theorem: http://en.wikipedia.org/wiki/Bayes%27_theorem
2. Naive Bayes classifier: http://en.wikipedia.org/wiki/Naive_Bayes_classifier
3. Udacity, Intro to Machine Learning. Lesson 1. Naive Bayes.

## Performance Metrics for Binary Classification

Binary classification classifies samples as either 0 (negative) or 1(positive). Depending on which class/label the sample data belongs to and its predication result, each sample fits in a cell of the following table.

 predicted | actual positive negative positive true positive (TP) false positive (FP) negative false negative (FN) true negative (TN)

We can then calculates different metrics, which measures the classification results from different angles.

True Positive Rate (Sensitivity, hit rate, recall): Out of all the positive samples, what fraction is actually detected as positive.

TPR = TP / (TP + FN)

True Negative Rate (Specificity): Out of all the negative samples, what fraction is actually detected as negative.

TNR = TN / (TN + FP)

Positive Predictive Value (Precision): Out of all samples predicted as positive, what fraction is actually positive.

PPV=TP / (TP + FP)

Negative Predictive Value: Out of all samples predicted as negative, what fraction is actually negative.

NPV = TN / (TN + FN)

False Positive Rate (Fall out): Out of all negative samples, what fraction is detected as positive by mistake.

FPR = FP / (FP + TN) = 1 – TNR

False Discovery Rate: Out of all samples predicted as positive, what fraction is actually negative.

FDR = FP / (FP + TP) = 1 – PPV

Accuracy: Out of all samples, what fraction is predicted correctly. That is, positive samples are predicted as positive, negative samples are predicted as negative.

Accuracy = (TP + TN) / (P + N)

The table below gives a nice overall view of the metrics mentioned above,

 predicted | actual positive negative positive true positive (TP) false positive (FP) PPV = TP / (TP + FP) negative false negative (FN) true negative (TN) NPV = TN / (FN + TN) TPR = TP / (TP + FN) TNR = TN / (FP  + TN) Accuracy = (TP + TN) / (P + N)

F1 score: the harmonic mean of precision and recall.

F1 =2 * TPR*PPV / (TPR + PPV) = 2TP / (2TP + FP + FN)

Matthews correlation coefficient: It takes account true and false positives and negatives, and is regarded as balanced measure. It can be used for cases where the number of samples at different classes vary drastically.

MCC = (TP * TN – FP * FN) / √(TP + FP)(TP + FN)(TN + FP)(TN + FN)

MCC returns a value between -1 to 1, where -1 indicates total disagreement between prediction and actual facts, 0 means no better than random guess, and 1 indicates perfect prediction.

References:

1. Matthews correlation coefficient: http://en.wikipedia.org/wiki/Matthews_correlation_coefficient
2. Sensitivity and specificity: http://en.wikipedia.org/wiki/Sensitivity_and_specificity

## Logistic Regression

This is some notes taken when I summarize the things learned after taking Andrew Ng’s machine learning course at coursera.

Introduction

Linear regression predicts continuous values. At times, we need to categorize things. Logistic regression is a probabilistic statistical classification model does that.

We will examine how logistic regression classify things to two categories (either 0 or 1) first, and then how it is used for multiple categories.

The logistic regression model can be described by the following logistic/sigmoid function below,

h(x) an be interpreted as the estimated probability that y = 1 on input x.

If theta’X >= 0, h(x) >= 0.5, we predict output y = 1

If theta’X < 0, h(x) < 0.5, we predict output y = 0

theta’X essentially describes the decision boundary.  Note that we can use  other values instead of 0.5 as the cutoff point if it is more suitable.

Cost Function

The cost function for logistic regression is defined as below,

The cost is further defined as,

We can merge the functions, and the cost function eventually becomes

With regularization, the cost function becomes,

Note that j starts from 1 as a convention.

The gradient descent of logistic regression is identical to linear regression, except that h(x(i)) is different.

Multi-class Classification: One-vs-All

We can use one-vs-all technique to apply logistic regression to multi-class classification. The idea is to train a logistic regression classifier for each class i to predict the probability that y = i. Then we pick the category that has the maximum probability for an input.

## Linear Regression

This is some notes taken when I summarize the things learned after taking Andrew Ng’s machine learning course at coursera.

Introduction

Regression is a technique to model relationships among variables. Typically, there’s one dependent variable y and one or many independent variables. This relationship is usually expressed as a regression function.

Linear regression, as the name suggests, models the relationship using a linear regression function. Depending on how many independent variables we have, we have simple linear regression with one independent variable and multivariate linear regression with more than one independent variables.

The hypothesis of linear regression can be described by the following equation,

The X are called features, and theta are the parameters. Given a set of training samples, we’ll need to choose theta to fit the training examples.

To measure how well we fit the training examples, we define the cost function of linear regression as below,

m represents the number of training samples, h(x) is the predicted value and y is the sample output value. The cost function measures the average square error of all samples and then divide by 2.

This is essentially an optimization problem where we need to choose parameter theta such that the cost defined by the cost function is minimized.

Over-fitting and Regularization

Fitting the regression parameters minimize the error for training samples, however we can run into the problem of trying too hard such that the regression function doesn’t generalize well. i.e.: The hypothesis produce high error for input outside of the training set. This problem is known as overfitting.

Two commonly used techniques to address overfitting is reducing number of features and regularization.

Regularization adds an additional term to the cost function to penalize having large theta value, which tends to produce much more smooth curves.

Note that by convention, the regularization term exclude j=0 case, which is theta 0.

Given the hypothesis and its cost function, there’re many ways to fit the parameter theta (i.e., solve the optimization problem), including conjugate gradient, BFGS, L-BFGS etc. The most commonly used technique is Gradient Descent.

The idea of gradient descent is to start at some random values, evaluate the cost. And keep iterating on theta value based on the function below to reduce the cost until we reach a minimal.

The alpha is called the learning rate. It can be proven that if choose a sufficiently small alpha value, the cost will converge at some minimum. However, we don’t want alpha value to be too small in practice because it will take longer time. Typically, we try out a range of alpha values (0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1) and plot the cost to see how fast it converges.

For linear regression with regularization, the above equation is essentially the following,

The second term can easily be rewritten as,

Feature Scaling and Mean Normalization

When we do gradient descent, the values for different features normally differ in scale. For example, feature A may have value in the range of [1, 10], feature B varies from [-10000, 10000].

It’s good to have the feature values have similar scales and centered around 0 (i.e.: have approximately mean of 0).

The former can be achieved using feature scaling, just divide every value of that feature by a number such that the range is approximately [-1, 1]. The latter is accomplished using mean normalization (This doesn’t apply to X0). We can usually use (X – mean) to achieve this.

Numerical Analysis

Besides using optimization algorithms to fit theta iteratively, it turns out we can also compute the theta values numerically.

Without regularization, the numerical equation is as below,

While this method doesn’t need to choose learning rate and iterate, it is more computationally expensive as n get large because of the matrix multiplication and inverse. In addition, the inverse may not even exist. This is typically due to redundant features (some features are not linearly independent) or too many features too few samples.

With regularization, the numerical solution is the following,

Note that inverse part will exist even if the equation without regularization is not invertible.

## Bit Set

One specific usage of bits is to represent sets. Suppose we have a set of N elements {e1, e2, …, en}, and we want to select a subset of K (K <= N) elements. We can use a bit pattern to represent the selection, where 1 indicates the corresponding element is selected, 0 otherwise.

Suppose N=8 and K=4,  we have a bit pattern initialized as 0000 1111, and we want to find the next permutation in lexicographical sense.

The permutations will be,

0000 1111
0001 0111
0001 1011
0001 1101
0001 1110
0010 0111
0010 1011
0010 1101
0010 1110
0011 0011

By looking at the patterns, we can summarize the followings,
1. exclude the trailing zeros, find the least significant zero, change it to one
2. for all the bits after the position, initialized to zeros
3. construct a set of bits with all bits after the bit position found at step 1 to ones, and shift to right by the number of trailing zeros in the original number plus 1

This can be expressed by the code below,

`#include <stdio.h>`

`#include <stdlib.h>`

`#include <string.h>`

`void prbin(unsigned int x) {`

`   static char b[65];`

`   b[0] = '';`

`   unsigned int z;`

`   for (z = 0x80000000; z > 0; z >>= 1)`

`   {`

`       strcat(b, ((x & z) == z) ? "1" : "0");`

`   }`

`   printf("%d=%sn", x, b);`

`}`

`int main(int argc, char **argv) {`

`   int m, n;`

`   unsigned int x, y, z;`

`   m = atoi(argv[1]);`

`   n = atoi(argv[2]);`

`   printf("permute %d in %d element setn", m, n);`

`   x = (1 << m)-1;`

`   while (!(x & (1 << n))) {`

`       prbin(x);`

`       y = x | (x-1);      //set all trailing zeros to one`

`       x = (y+1) | (((~y & -~y) - 1) >> (__builtin_ctz(x) + 1));`

`   }`

`   return 0;`

`}`

An alternative method which doesn’t use __builtin_ctz(x) is as below, the comments explains the algorithms briefly.

`#include <stdio.h>`

` `

`#include <stdlib.h>`

` `

`#include <string.h>`

` `

`void prbin(unsigned int x) {`

` `

`   static char b[65];`

` `

`   b[0] = '';`

` `

`   unsigned int z;`

` `

`   for (z = 0x80000000; z > 0; z >>= 1)`

` `

`   {`

` `

`       strcat(b, ((x & z) == z) ? "1" : "0");`

` `

`   }`

` `

`   printf("%d=%sn", x, b);`

` `

`}`

` `

`int main(int argc, char **argv) {`

` `

`   int m, n;`

` `

`   unsigned int x, y, z;`

` `

`   m = atoi(argv[1]);`

` `

`   n = atoi(argv[2]);`

` `

`   printf("permute %d in %d element setn", m, n);`

` `

`   x = (1 << m)-1;`

` `

`   while (!(x & (1 << n))) {`

` `

`       prbin(x);`

` `

`       y = x&(~(x-1));   //get the least significant one bit`

` `

`       z = (~x)&(x+y);   //get the least significant zero bit above y`

` `

`       x = x|z;          //set the least significant zero bit above y to 1`

` `

`       x = x&(~(z-1));   //clear all bits after the least significant zero bit above y`

` `

`       x = x | (((z/y)>>1) - 1); //appending the ones`

` `

`   }`

` `

`   return 0;`

` `

`}`

` `

Save the code to bitsets.c and bitsets2.c. Compile the code using the commands below,

gcc -o test bitsets.c

gcc -o test2 bitsets2.c

Below are some sample executions,

./test 3 5

`permute 3 in 5 element set`

`7=00000000000000000000000000000111`

`11=00000000000000000000000000001011`

`13=00000000000000000000000000001101`

`14=00000000000000000000000000001110`

`19=00000000000000000000000000010011`

`21=00000000000000000000000000010101`

`22=00000000000000000000000000010110`

`25=00000000000000000000000000011001`

`26=00000000000000000000000000011010`

`28=00000000000000000000000000011100`

./test2 3 5

`permute 3 in 5 element set`

`7=00000000000000000000000000000111`

`11=00000000000000000000000000001011`

`13=00000000000000000000000000001101`

`14=00000000000000000000000000001110`

`19=00000000000000000000000000010011`

`21=00000000000000000000000000010101`

`22=00000000000000000000000000010110`

`25=00000000000000000000000000011001`

`26=00000000000000000000000000011010`

`28=00000000000000000000000000011100`

## Big Integer Arithmetic

Most of computer applications operates on small integers which can be represented by 32 bit integers. However, there are times we’ll need to deal with extremely big integers which can only be represented by hundreds of bits. One of such application is cryptography.

This post gives two simple implementations of the big integer arithmetic. Both implementations use char array to store the big integers. The first one stores the digits of the integer as char, and the second one stores the digits as single digit integer values.

String Representation
We define a big integer as the structure below,

`typedef struct {`

`   int sign;`

`   int length;`

`   char dg[MAXDG];`

`} BN;`

For example, given two input integers “111222333444555” and “12345”, we can store them into two BNs a and b, with a.length = 15, a.dg[0] = ‘1’, a.dg[1]=’1’ …, a.dg[14]=’5’, a.sign = 1; b.length=5, b.dg[0]=’1’, b.dg[1]=’2’…a.dg[4]=’5’, b.sign=1.

Addition: We add the two integers digit by digit from right to left. If the sum of the two digits is bigger than ‘9’. Then a carry of 1 is produced for the next position. And the carry is added to the sum of the digits at the next position. We store the digits of the result from left to right and then reverse it after it’s done.

Subtraction: Subtraction is similar to addition. We minus the digits from the two integers from right to left. If the resulted digit is less than 0, we borrow 1 from the next position.

Multiplication: A naive approach is to use addition to handle multiplication. But if both inputs are big, then the algorithm is very slow.

A better approach is to handle a single digit multiplication using addition, then shift the multiplicand by 1 (which is equivalent to x10) and handle the next digit multiplication.

For example, if the two inputs are 12345 and 57, we treat the multiplication as 12345 x 7 + 123450 x 5, then each of the two multiplications can be handled using additions. In this way, the number of additions needed are reduced to the sum of all digits of the multiplier.

Division: Again the naive approach is to use subtraction. A better alternative is trying to subtract the divisor from part of the dividend from right to left. For example, to compute 234 / 11, we start by computing 23 / 11 using subtraction. The quotient is 2 and the remainder is 1.  We then compute (1*10 + 4) / 11 using subtraction, the quotient will be 1. So the result is 21.

Below is a C implementation of the algorithms described above,

`/**`

`array implementation of bignum`

`**/`

`#include <stdio.h>`

`#include <string.h>`

` `

`#define MAXDG 500`

`#define PLUS 1`

`#define MINUS -1`

` `

`typedef struct {`

`    int sign;`

`    int length;`

`    char dg[MAXDG];`

`} BN;`

` `

`int cmp(BN a, BN b);`

`void sub(BN a, BN b, BN *c);`

`void add(BN a, BN b, BN *c);`

`void reverse(BN *c);`

`void rm_tzeros(BN *c);`

`void rm_lzeros(BN *c);`

`void init(BN *a);`

` `

`int cmp(BN a, BN b) {`

`    if (a.sign == PLUS && b.sign == MINUS) {`

`        return 1;`

`    } else if (a.sign == MINUS && b.sign == PLUS) {`

`        return -1;`

`    } else {`

`        //a and b have same sign`

`        if (a.length > b.length) {`

`            return 1*a.sign;`

`        } else if (a.length < b.length) {`

`            return -1*a.sign;`

`        } else {`

`            //a and b have same sign, same length`

`            int i;`

`            for (i = 0; i < a.length; ++i) {`

`                if (a.dg[i] > b.dg[i]) {`

`                    return 1*a.sign;`

`                } else if (a.dg[i] < b.dg[i]) {`

`                    return -1*a.sign;`

`                } `

`            }`

`            return 0; //equal`

`        }`

`    }`

`}`

` `

`void reverse(BN *c) {`

`    char tmp;`

`    int i;`

`    for (i = 0; i < c->length/2; ++i) {`

`        tmp = c->dg[i];`

`        c->dg[i] = c->dg[c->length-i-1];`

`        c->dg[c->length-i-1] = tmp;`

`    }  `

`}`

` `

`void rm_lzeros(BN *c) {`

`    int i, j;`

`    for (i = 0; i < c->length; ++i) {`

`        if (c->dg[i] != '0') {`

`            break;`

`        }`

`    }`

`    for (j = 0; i < c->length; ++i, ++j) {`

`        c->dg[j] = c->dg[i];`

`    }`

`    c->length = j == 0?1:j;`

`    c->dg[c->length] = '';`

`}`

` `

`void rm_tzeros(BN *c) {`

`    int i, j = 0;`

`    for (i = c->length-1; i > 0; --i) {`

`        if (c->dg[i] == '0') {`

`            j++;`

`        } else {`

`            break;`

`        }`

`    }`

`    c->length -= j;`

`    c->dg[c->length] = '';`

`}`

` `

`void sub(BN a, BN b, BN *c) {`

`   int i, j, k;`

`   int borrow = 0;`

`   if (a.sign == MINUS) {`

`       if (b.sign == MINUS) {`

`           sub(b, a, c);`

`           c->sign *= -1;`

`       } else {`

`           b.sign = MINUS;`

`           add(a, b, c);`

`       }`

`       return;`

`   } else {`

`       if (b.sign == MINUS) {`

`           b.sign = PLUS;`

`           add(a, b, c);`

`           return;`

`       } `

`   }`

`   //a.sign == PLUS, b.sign == PLUS`

`   if (cmp(a, b) < 0) {`

`      sub(b, a, c);`

`      c->sign = MINUS;`

`      return; `

`   } else if (cmp(a, b) == 0) {`

`      c->dg[0] = '0';`

`      c->dg[1] = '';`

`      c->length = 1;`

`      c->sign = PLUS;`

`      return;`

`   } `

`   //a.sign == PLUS, b.sign == PLUS, a > b`

`   c->sign = PLUS;`

`   for (i = a.length-1, j = b.length-1, k = 0; i >= 0 && j >= 0; --i, --j, ++k) {`

`      c->dg[k] = a.dg[i] - borrow - b.dg[j];`

`      if (c->dg[k] < 0) {`

`          c->dg[k] += 10;`

`          borrow = 1;`

`      } else {`

`          borrow = 0;`

`      }`

`      c->dg[k] += '0';`

`   }`

`   for (; i >= 0; --i, ++k) {`

`      c->dg[k] = a.dg[i] - borrow;`

`      if (c->dg[k] - '0' < 0) {`

`          c->dg[k] += 10;`

`          borrow = 1; `

`      } else {`

`          borrow = 0;`

`      } `

`   }`

`   c->dg[k] = '';`

`   c->length = k;`

`   //printf("%sn", c->dg);`

`   rm_tzeros(c);`

`   //printf("%sn", c->dg);`

`   reverse(c);`

`}`

` `

`void add(BN a, BN b, BN *c) {`

`    int i, j, k;`

`    int carry = 0;`

`//    printf("%d:%dn", a.length, b.length);`

`    if (a.sign == b.sign) {`

`        c->sign = a.sign;`

`    } else {`

`        if (a.sign == MINUS) {`

`            a.sign = PLUS;    //tmp`

`            sub(b, a, c);`

`        } else {`

`            b.sign = PLUS;    //tmp`

`            sub(a, b, c);`

`        }`

`        return;`

`    }`

`    for (i = a.length-1, j = b.length-1, k = 0; i >= 0 && j >= 0; --i, --j, ++k) {`

`        c->dg[k] = a.dg[i] -'0' + b.dg[j] - '0' + carry;`

`        //printf("%d %d %c %cn",k, c->dg[k], a.dg[i], b.dg[j]);`

`        carry = c->dg[k]/10;`

`        c->dg[k] = c->dg[k]%10 + '0';`

`//        printf("%d %cn", k, (*c->dg[k]);`

`    }`

`    for (; i >= 0; --i, ++k){`

`        if (carry != 0) {`

`            c->dg[k] = carry + a.dg[i] - '0';`

`            carry = c->dg[k]/10;`

`            c->dg[k] = c->dg[k]%10 + '0';`

`        } else {`

`            c->dg[k] = a.dg[i];`

`        }`

`    }`

`    for (; j >= 0; --j, ++k){`

`        if (carry != 0) {`

`            c->dg[k] = carry + b.dg[j] - '0';`

`            carry = c->dg[k]/10;`

`            c->dg[k] = c->dg[k]%10 + '0';`

`        } else {`

`            c->dg[k] = b.dg[j];`

`        }`

`    }`

`    if (carry != 0) {`

`        c->dg[k++] = carry + '0';`

`    }`

`    c->dg[k] = '';`

`    c->length = k;`

`    //printf("%d %sn", c->length, c->dg);`

`    reverse(c);`

`    //printf("%d %sn", c->length, c->dg);`

`}`

` `

`void lshift(BN *a, int num) {`

`    int i;`

`    if ((a->length == 1) && (a->dg[0] == '0')) {`

`        return;`

`    }`

`    for (i = 0; i < num; ++i) {`

`        a->dg[i+a->length] = '0';`

`    }`

`    a->length += num;`

`    a->dg[a->length] = '';`

`}`

` `

`void mul(BN a, BN b, BN *c) {`

`   int i, j;`

`   `

`   init(c);`

`   for (i = a.length - 1; i >= 0; --i) {`

`       for (j = 0; j < a.dg[i]-'0'; ++j) {`

`          //printf("add %s to %sn", b.dg, c->dg);`

`          add(*c, b, c); `

`       }`

`       lshift(&b, 1);`

`       //printf("%s %sn", c->dg, b.dg);`

`   }`

`   c->sign = a.sign*b.sign; `

`}`

` `

`void div(BN a, BN b, BN *c) {`

`    int i;`

`    BN tmp;`

`    init(c);`

`    init(&tmp);`

`    for (i = 0; i < a.length; ++i) {`

`        //printf("bf: %sn", tmp.dg);`

`        if (i > 0) {`

`           lshift(&tmp, 1);`

`        } else {`

`           tmp.length = 1;`

`        }`

`        //printf("af: %sn", tmp.dg);`

`        tmp.dg[tmp.length-1] = a.dg[i];`

`        c->dg[i] = '0';`

`        //printf("%d b: %d:%d:%s %d:%d:%sn", i, tmp.sign, tmp.length, tmp.dg, b.sign, b.length, b.dg);`

`        rm_lzeros(&tmp);`

`        while (cmp(tmp, b) >= 0) {`

`            //printf("sub: %s %sn", tmp.dg, b.dg);`

`            sub(tmp, b, &tmp);`

`            c->dg[i]++;`

`        } `

`    }`

`    c->dg[i] = '';`

`    c->length = a.length;`

`    //printf("%d %sn", i, c->dg);`

`    rm_lzeros(c);`

`    //printf("%sn", c->dg);`

`}`

` `

`void init(BN *a) {`

`    a->sign = PLUS;`

`    a->length = 1;`

`    a->dg[0] = '0';`

`    a->dg[1] = '';`

`}`

` `

`int main(int argc, char** argv) {`

`    BN m, n, rv;`

`    if (argv[1][0] == '-') {`

`        m.sign = MINUS;`

`        sprintf(m.dg, "%s", &(argv[1][1]));`

`    } else {`

`        m.sign = PLUS;`

`        sprintf(m.dg, "%s", argv[1]);`

`    }`

`    m.length = strlen(m.dg);`

`    if (argv[2][0] == '-') {`

`        n.sign = MINUS;`

`        sprintf(n.dg, "%s", &(argv[2][1]));`

`    } else {`

`        n.sign = PLUS;`

`        sprintf(n.dg, "%s", argv[2]);`

`    }`

`    n.length = strlen(n.dg);`

`    printf("m:%c%sn", m.sign>0?' ':'-', m.dg);`

`    printf("n:%c%sn", n.sign>0?' ':'-', n.dg);`

`    add(m, n, &rv);`

`    printf("m + n = %c%sn", rv.sign>0?' ':'-', rv.dg);`

`    sub(m, n, &rv);`

`    printf("m - n = %c%sn", rv.sign>0?' ':'-', rv.dg);`

`    lshift(&rv, 1);`

`    printf("m << 1 = %c%sn", rv.sign>0?' ':'-', rv.dg);`

`    mul(m, n, &rv);`

`    printf("m * n = %c%sn", rv.sign>0?' ':'-', rv.dg);`

`    div(m, n, &rv);`

`    printf("m / n = %c%sn", rv.sign>0?' ':'-', rv.dg);`

`}`

Save the code to bignumarray.c, and compile the code using the command below,

gcc -o bignumarray bignumarray.c

And here are some simple tests,

`./bignumarray 34123432143214321 1342`

`m: 34123432143214321`

`n: 1342`

`m + n =  34123432143215663`

`m - n =  34123432143212979`

`m << 1 =  341234321432129790`

`m * n =  45793645936193618782`

`m / n =  25427296678997`

` `

`./bignumarray 332 13424312432`

`m: 332`

`n: 13424312432`

`m + n =  13424312764`

`m - n = -13424312100`

`m << 1 = -134243121000`

`m * n =  4456871727424`

`m / n =  0`

Digit Array Representation

The big integer is still defined as below,

`typedef struct {`

`   int sign;`

`   int length;`

`   char dg[MAXDG];`

`} BN;`

But now the dg holds the values for each digit. For example, given two inputs 12345 and 678, then the two big integers are a, b, with a.length = 5, a.sign = 1, and a.dg[0]=5, a.dg[1]=4, a.dg[2]=3, a.dg[3]=2, a.dg[4]=1; b.length=3, b.sign = 1, and b.dg[0]=8, b.dg[1]=7, b.dg[2]=6. Note that the digits are stored in a reverse manner because it is more convenient to manipulate the end of an array than the start of an array (e.g. truncate the array by 1).

The arithmetic operations are similar to the previous implementation. Below is the sample C code.

`#include <stdio.h>`

`#include <string.h>`

` `

`#define MAXDG 500`

`#define PLUS 1`

`#define MINUS -1`

` `

`typedef struct {`

`    int sign;`

`    int length;`

`    char dg[MAXDG];`

`} BN;`

` `

` `

`void add(BN a, BN b, BN *c);`

`void mul(BN a, BN b, BN *c);`

`void div(BN a, BN b, BN *c);`

`void sub(BN a, BN b, BN *c);`

`void digit_shift(BN *a, int d);`

`int cmp(BN a, BN b);`

`void zero_justify(BN* a);`

`void init(BN *a);`

`void printbn(BN a);`

` `

`void printbn(BN a) {`

`    int i;`

`    //printf("length: %d    ", a.length);`

`    if (a.sign == MINUS) {`

`        printf("-");`

`    }`

`    for (i = a.length-1; i >= 0; --i) {`

`        printf("%c", '0' + a.dg[i]);`

`    }`

`    printf("n");`

`}`

` `

`void init(BN *a) {`

`    a->length = 1;`

`    memset(a->dg, 0x00, sizeof(char)*MAXDG);`

`    a->sign = PLUS;`

`}`

` `

`void zero_justify(BN* a) {`

`    while ((a->length > 1) && (a->dg[a->length-1] == 0)) {`

`        (a->length)--;`

`    }`

`    //if -0, make it +0`

`    if ((a->length == 1) && (a->dg[0] == 0)) {`

`        a->sign = PLUS;`

`    }`

`}`

` `

`int cmp(BN a, BN b) {`

`    if (a.sign == PLUS && b.sign == MINUS) {`

`        return 1;`

`    } else if (a.sign == MINUS && b.sign == PLUS) {`

`        return -1;`

`    } else {`

`        //a and b have same sign`

`        if (a.length > b.length) {`

`            return 1*a.sign;`

`        } else if (a.length < b.length) {`

`            return -1*a.sign;`

`        } else {`

`            //a and b have same sign, same length`

`            int i;`

`            for (i = a.length-1; i >= 0; --i) {`

`                if (a.dg[i] > b.dg[i]) {`

`                    return 1*a.sign;`

`                } else if (a.dg[i] < b.dg[i]) {`

`                    return -1*a.sign;`

`                } `

`            }`

`            return 0; //equal`

`        }`

`    }`

`}`

` `

`void digit_shift(BN *a, int d) {`

`    int i;`

`    if ((a->length == 1) && (a->dg[0] == 0)) {`

`        return;`

`    }`

`    for (i = a->length - 1; i >= 0; --i) {`

`        a->dg[i+d] = a->dg[i];`

`    }`

`    for (i = 0; i < d; ++i) {`

`        a->dg[i] = 0;`

`    }`

`    a->length = a->length + d;`

`}`

` `

`void add(BN a, BN b, BN *c) {`

`    int i;`

`    int carry = 0;`

`    init(c);`

`    if (a.sign == b.sign) {`

`        c->sign = a.sign;`

`    } else {`

`        if (a.sign == MINUS) {`

`            a.sign = PLUS;`

`            sub(b, a, c);`

`        } else {`

`            b.sign = PLUS;`

`            sub(a, b, c);`

`        }`

`        return;`

`    }`

`    c->length = (a.length > b.length)?a.length+1:b.length+1;`

`    for (i = 0; i < c->length; ++i) {`

`        c->dg[i] = (char)(a.dg[i] + b.dg[i] + carry)%10;`

`        carry = (a.dg[i] + b.dg[i] + carry)/10;`

`    }`

`    zero_justify(c); `

`}`

` `

`void sub(BN a, BN b, BN *c) {`

`    int borrow;`

`    int v;`

`    int i;`

`    init(c);`

`    if ((a.sign == MINUS) || (b.sign == MINUS)) {`

`        b.sign = -1*b.sign;`

`        add(a, b, c);`

`        return;`

`    }`

`    if (cmp(a, b) < 0) {`

`        sub(b, a, c);`

`        c->sign = MINUS;`

`        return;`

`    }`

`    //a > b, and both a and b are +`

`    c->length = a.length > b.length?a.length:b.length;`

`    borrow = 0;`

`    for (i = 0; i < c->length; ++i) {`

`        c->dg[i] = (a.dg[i] - borrow - b.dg[i]);`

`        if (c->dg[i] < 0) {`

`            c->dg[i] += 10;`

`            borrow = 1;`

`        } else {`

`            borrow = 0;`

`        }`

`    }`

`    zero_justify(c);`

`}`

` `

`void mul(BN a, BN b, BN *c) {`

`    BN row, tmp;`

`    int i,j;`

`    init(c);`

`    for (i = 0; i < b.length; ++i) {`

`        for (j = 0; j < b.dg[i]; ++j) {`

`            add(*c, a, c);`

`        }`

`        digit_shift(&a, 1);`

`    }`

`    zero_justify(c);`

`}`

` `

`void div(BN a, BN b, BN *c) {`

`    int i;`

`    BN tmp;`

`    init(c);`

`    init(&tmp);`

`    c->sign = a.sign * b.sign;`

`    a.sign = PLUS;`

`    b.sign = PLUS;`

`    c->length = a.length;`

`    for (i = a.length - 1; i >= 0; --i) {`

`        digit_shift(&tmp, 1);`

`        //printf("%d:%dn", tmp.length, b.length);`

`        //printbn(tmp);`

`        tmp.dg[0] = a.dg[i];`

`        c->dg[i] = 0;`

`        while (cmp(tmp, b) >= 0) {`

`            (c->dg[i])++;`

`            //printbn(tmp);`

`            sub(tmp, b, &tmp);`

`            //printbn(tmp);`

`        }`

`        //printf("div: %dn", c->dg[i]);`

`    }`

`    zero_justify(c);`

`}`

` `

`void readbn(char *s, BN *a) {`

`    int i, j;`

`    init(a);`

`//    printf("readbn: %dn", strlen(s));`

`    for (i = strlen(s)-1, j = 0; i >= 0; --i) {`

`        a->dg[j++] = s[i] - '0';`

`    }`

`    a->length = j;`

`}`

` `

`int main(int argc, char** argv) {`

`    BN m, n, rv;`

`    if (argv[1][0] == '-') {`

`        m.sign = MINUS;`

`        readbn(&argv[1][1], &m);`

`    } else {`

`        m.sign = PLUS;`

`        readbn(argv[1], &m);`

`    }`

`    if (argv[2][0] == '-') {`

`        n.sign = MINUS;`

`        readbn(&argv[2][1], &n);`

`    } else {`

`        n.sign = PLUS;`

`        readbn(argv[2], &n);`

`    }`

`    printf("m:");`

`    printbn(m);`

`    printf("n:");`

`    printbn(n);`

`    add(m, n, &rv);`

`    printf("m + n = ");`

`    printbn(rv);`

`    sub(m, n, &rv);`

`    printf("m - n = ");`

`    printbn(rv);`

`    digit_shift(&rv, 1);`

`    printf("m << 1 = ");`

`    printbn(rv);`

`    mul(m, n, &rv);`

`    printf("m * n = ");`

`    printbn(rv);`

`    div(m, n, &rv);`

`    printf("m / n = ");`

`    printbn(rv);`

`}`

Save the code to bignumarray2.c, and compile the code using the command below,

gcc -o bignumarray2 bignumarray2.c

And here are some simple tests,

`./bignumarray2 34123432143214321 1342`

`m:34123432143214321`

`n:1342`

`m + n = 34123432143215663`

`m - n = 34123432143212979`

`m << 1 = 341234321432129790`

`m * n = 45793645936193618782`

`m / n = 25427296678997`

` `

`./bignumarray2 332 13424312432`

`m:332`

`n:13424312432`

`m + n = 13424312764`

`m - n = -13424312100`

`m << 1 = -134243121000`

`m * n = 4456871727424`

`m / n = 0`

1. Dynamic array is necessary in order to support arbitrary length of big integers.

2. Using a single byte to represent a single digit is a huge waste. There’re better representations.

References:
1. Programming Challenges

## Primality Testing — Part 4. Fermat’s Little Theorem and A Probabilistic Approach

Previous posts on primality testing are all trying to factor the input number. For a very large number with few hundreds of bits, it is difficult to compute in that manner.
This post introduce a probabilistic approach based on Fermat’s primality test.

Fermat’s Little Theorem

Fermat’s primality test is based on Fermat’s Little Theorem,

If p is prime, then for every 1 <= a < p,

[ a ^ ( p – 1) ] mod p = 1

Fermat’s Little Theorem suggests a method to test if a number is prime or not. For a big input number p, it is not desirable to test every a. Normally, we randomly pick up an a, and if it satisfy the equation above, we claim it is prime, otherwise, it is not.

However, Fermat’s little theorem is a necessary condition for a number to be prime, but not sufficient. In other words, every prime number satisfy fermat’s little theorem, but it is also possible to find an a for a composite number n where [ a ^ (p – 1) ] mod n = 1.

It can be proven that most values of a will fail the test for a composite number. And if we pick several random a and they all pass the test, the possibility of making a wrong decision is rare. This algorithm cannot guarantee that positive primality test is 100% correct, but it is fast the possibility of wrong is low enough for many practical use cases.

Implementation

Below is a simple C implementation based on the idea above,

`#include <stdio.h>`

`#include <stdlib.h>`

` `

`unsigned long modexp(int x, int y, int n) {`

`    unsigned long i;`

`    if (y == 0) {`

`        return 1%n;`

`    }`

`    if (y == 1) {`

`        return x%n;`

`    } `

`    i = modexp(x, y/2, n);`

`    if (y&0x01) {    //odd`

`        //return (x*i*i)%n;`

`//        printf("%d:%lu:%lun", x, i, (x%n*i*i));`

`        return x%n*i%n*i%n;`

`    } else {         //even`

`        return i*i%n;`

`    }`

`}`

` `

`int main(int argc, char **argv) {`

`    int n;`

`    int a;`

`    int i;`

`    n = atoi(argv[1]);`

`    srand(time(NULL));`

`    struct timeval stTime, edTime;`

`    gettimeofday(&stTime, NULL);`

`    for (i = 0; i < 20 && i < n; ++i) {`

`        a = rand()%100;          `

`        if (a == 0) {`

`            a = 1;`

`        }`

`        //printf("a=%dn", a);`

`    if (modexp(a, n-1, n) != 1) {`

`            printf("%d is not prime, %dn", n, a);`

`            break;`

`    } `

`    }`

`    if (i == 20 || i == n) {`

`          printf("%d is primen", n);`

`    }`

`    gettimeofday(&edTime, NULL);`

`    printf("time: %u:%un", (unsigned int)(edTime.tv_sec - stTime.tv_sec), (unsigned int)(edTime.tv_usec - stTime.tv_usec)); `

` `

`}`

Note that the code uses the modular exponentiation algorithm. Due to the integer overflow, the largest input the program can accept is 65536 = 2^16.

Save the code to prim5.c and compile it using the command below,

gcc -o prim5 prim5.c

Below are some simpe tests,

`roman10@ra-ubuntu-1:~/Desktop/integer/prim\$ ./prim5 65521`

`65521 is prime`

`time: 0:69`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer/prim\$ ./prim5 2`

`2 is prime`

`time: 0:53`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer/prim\$ ./prim5 65536`

`65536 is not prime, 8`

`time: 0:53`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer/prim\$ ./prim5 63703`

`63703 is prime`

`time: 0:69`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer/prim\$ ./prim5 61493`

`61493 is prime`

`time: 0:67`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer/prim\$ ./prim5 13`

`13 is prime`

`time: 0:54`

References:
1. Wikipedia Primality test: http://en.wikipedia.org/wiki/Primality_test
2. Algorithms

## How to Calculate Modular Exponentiation

Modular exponentiation is a exponentiation performed over modulus. It is used in many cryptographic algorithms.

Modular exponentiation accepts three input integers, say x, y and n and computes according to the formula below,

(x ^ y) mod n

Native Approaches

A naive approach to compute modular exponentiation is to compute the exponentiation first and then the modulus.

This approach is OK for small x, y and n. But in cryptographic algorithms, the values of x, y and n are normally several hundreds of bits long and this approach takes too much memory and multiplication.

A slightly improved version is to compute x mod n repeatedly and multiply to the current result. That is,

rv = x mod n => rv = rv * (x mod n) => rv = rv *(x mod n) ….

This approach reduces the usage of memory. But it still perform poorly when y is very big because the number of the multiplications it needs to perform is very large. Suppose y = 2^100, then the number of multiplications it needs to perform is also 2^100, which is almost infeasible in practice.

A Better Approach

A better approach is to repeatedly square x mod n. That is,

x mod n => (x mod n)^2 => (x mod n)^4 => (x mod n)^8 =>…=>

(x mod n) ^ logy

In this way, the number of multiplications needed can be reduced to logy. For y = 2^100, the number of multiplications needed is just log(2^100) = 100.

In general, this approach can be expressed by the equation below,

(x ^ y) mod n = [ (x ^ ( y / 2 ) ) mod n] ^ 2                                  if y is even

(x ^ y) mod n = [ x * [( x ^ ( y / 2 ) ) mod n ] ^ 2 ] mod n              if y is odd

This approach can be easily implemented in a recursive way. Below is a C implementation for this algorithm,

`/**`

`polymial-time computation for (x^y)%N`

`the idea is for mod N`

`(x^y) = [x^(y/2)]^2 if y is even`

`        = x*[x^(y/2)]^2 if y is odd`

`**/`

`#include <stdio.h>`

`#include <stdlib.h>`

` `

`unsigned long modexp(int x, int y, int n) {`

`    unsigned long i;`

`    if (y == 0) {`

`        return 1%n;`

`    }`

`    if (y == 1) {`

`        return x%n;`

`    } `

`    i = modexp(x, y/2, n);`

`    printf("%d:%ldn", y, i);`

`    if (y&0x01) {    //odd`

`        //return (x*i*i)%n;`

`        return (x%n*i*i)%n;`

`    } else {         //even`

`        return (i*i)%n;`

`    }`

`}`

` `

`int main(int argc, char **argv) {`

`    int x, y, N;`

`    x = atoi(argv[1]);`

`    y = atoi(argv[2]);`

`    N = atoi(argv[3]);`

`    printf("(x^y)modN = %lun", modexp(x, y, N));`

`    return 0;`

`}`

One thing to note in the code above is that when y is odd, we compute using (x%n*i*i)%n instead of (x*i*i)%n. This is to prevent overflow caused by multiplication of x*i*i when x is a very big number.

Also note that when both x and n are big, the modulo x%n is big, the x%n*i*i can still cause overflow. Solving this issue requires different representations of integers and it is not discussed in this post. For most commonly used integers (except cryptographic), the program above is good enough.

For a given input number y, the number of recursions is ~logy. Therefore the time complexity is O(logy). When y is measured in bits, say, N bits, then the number of recursions is ~N. Therefore the time complexity is O(N).

Save the code as modexp.c and compile it using the command,

gcc -o modexp modexp.c

Below are some sample tests,

`roman10@ra-ubuntu-1:~/Desktop/integer\$ ./modexp 4 13 497`

`3:4`

`6:64`

`13:120`

`(x^y)modN = 445`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer\$ ./modexp 1234 5678 90`

`2:64`

`5:46`

`11:64`

`22:64`

`44:46`

`88:46`

`177:46`

`354:64`

`709:46`

`1419:64`

`2839:64`

`5678:64`

`(x^y)modN = 46`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer\$ ./modexp 3 14 30`

`3:3`

`7:27`

`14:27`

`(x^y)modN = 9`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer\$ ./modexp 2 16 123`

`2:2`

`4:4`

`8:16`

`16:10`

`(x^y)modN = 100`

An Alternative Approach

An alternative approach can help us to get rid of the recursion in the program above. The idea is based on the binary representations of an integer in computer. For example, in the first test case above, y = 13 = 2^3 + 2^2 + 2^0. Therefore to compute (4^13) mod 497, we can compute a = (4^(2^3)) mod 497, b = (4^(2^2)) mod 497 and c = (4^(2^0)) mod 497, and then multiply the results together.

In general, given x and n, we can pre-compute [ x ^ (2 ^ N)  ] mod n, where N is the number of bits needed to represent the biggest value of y. If y is constrained to be less than 2^100, then N = 100.

In addition, [x ^ (2 ^ N) ] mod n = { [x ^ ( 2 ^ (N-1) ) ] mod n} ^ 2. Therefore, we can compute from N = 0, 1 to N = 100 easily.

r0 = [x ^ (2 ^ 0)] mod n = x mod n
r1 = [x ^ (2 ^ 1)] mod n = r0 ^ 2
r2 = [x ^ (2 ^ 2)] mod n = r1 ^ 2

For a given y, we then scan the bits that is set to 1 and read the results from the pre-computation and multiply the results together.

Below is a sample C implementation for this approach,

`/**`

`polymial-time computation for (x^y)%N`

`non-recursive version`

`**/`

`#include <stdio.h>`

`#include <stdlib.h>`

` `

`#define MAXP 1000`

`unsigned long m[MAXP];`

` `

`void precompute(int x, int n) {`

`   int i;`

`   m[0] = x%n;`

`   for (i = 1; i < MAXP; ++i) {`

`       m[i] = (m[i-1]*m[i-1])%n;`

`//       printf("%d:%ldn", i, m[i]);`

`   }`

`}`

` `

`unsigned long modexp(int x, int y, int n) {`

`    int i;`

`    unsigned long rv;`

`    int tmp;`

`    rv = -1;`

`    for (i = 0; tmp != 0; ++i) {`

`        tmp = y >> i;`

`        if (tmp & 1) {`

`            if (rv == -1) {`

`                rv = m[i]%n;`

`            } else {`

`                rv = (m[i]*rv)%n;`

`            }`

`            printf("%d:%ldn", i, rv);`

`        }`

`    }`

`    return (rv == -1 ? 1%n:rv);`

`}`

` `

`int main(int argc, char **argv) {`

`    int x, y, N;`

`    x = atoi(argv[1]);`

`    y = atoi(argv[2]);`

`    N = atoi(argv[3]);`

`    precompute(x, N);`

`    printf("(x^y)modN = %lun", modexp(x, y, N));`

`    return 0;`

`}`

Save the code to modexp2.c and compile it using the command below,

gcc -o modexp2 modexp2.c

Below are some simple tests,

`roman10@ra-ubuntu-1:~/Desktop/integer\$ ./modexp2 4 13 497`

`0:4`

`2:30`

`3:445`

`(x^y)modN = 445`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer\$ ./modexp2 1234 5678 90`

`1:46`

`2:46`

`3:46`

`5:46`

`9:46`

`10:46`

`12:46`

`(x^y)modN = 46`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer\$ ./modexp2 3 14 30`

`1:9`

`2:9`

`3:9`

`(x^y)modN = 9`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer\$ ./modexp2 2 16 123`

`4:100`

`(x^y)modN = 100`

References:
1. Algorithms. Dasgupta, C.H.Papadimitriou, and U.V. Vazirani
2. Wikipedia Modular Exponentiation: http://en.wikipedia.org/wiki/Modular_exponentiation
3. Numerical Recipes: http://numericalrecipes.blogspot.com/2009/03/modular-exponentiation.html

## Primality Testing — Part 3. A Further Optimized Method

The previous two posts cover the naive approaches for primality testing and a method with precomputation. This post introduces a new method which could perform better than previous approaches and it can be used together with the precomputation.

The optimization is based on the observation that all prime numbers are in the form of 6k+1 or 6k+5, with 2, 3 and 5 being the only exceptions. This is because all integers can be expressed in 6k+i for some integer k, and i  is from the set {0, 1, 2, 3, 4, 5}. And 2 divides 6k+0, 6k+2 and 6k+4, 3 divides 6k+3.

Therefore, the primality test for a given integer n can be done by checking if n is divisible by 2 and 3 first, and then checking if n is divisible for all numbers of form 6k+1 or 6k+5 and less than sqrt(n).

The idea of this optimization is actually similar to the Sieve of Eratosthenes algorithm described in previous post. We eliminated testing divisibility for numbers that is divisible by the prime factors.

In general, all prime numbers are in the form gk + i, where i < k and it represents the numbers that are coprime to g. In the case above, g = 6 and the coprime of 6 are 1 and 5. If i and g are not coprime, then i must be divisible by some prime factors of g.

We go through another example, suppose g = 2*3*5 = 30, then all integers can be expressed in the form 30*k + i where i is in the set {0, 1, 2….29}. Now we mark the numbers that divisible by 2, 3 and 5 respectively. 30*k + {0, 2, 4, 6, 8, 10, 12, …28} are divisible by 2; 30*k + {3, 6, 9, 12, 15, 18, 21, 24, 27} are divisible by 3; and 30*k + {5, 10, 15, 20, 25} are divisible by 5. The rest of the numbers {1, 7, 11, 13, 17, 19, 23, 29} are not divisible by 2, 3 and 5 and are therefore coprimes of 30.

So to test if a number n is prime number, we first test if the number is divisible by 2, 3, and 5 first, and then check if the number is divisible by numbers in the form 30*k + {1, 7, 11, 13, 17, 19, 23, 29}  up to sqrt(n).

Below is an implementation of the optimization for primality testing when g is set to 6,

`/**`

`generate prime numbers`

`**/`

`#include <stdio.h>`

`#include <stdlib.h>`

`#include <math.h>`

` `

`int rdiv(int n, int m) {`

`    if (n%m ==0) {`

`        printf("%d is divisible by %dn", n, m);`

`        return 0;`

`    }`

`    return 1;`

`}`

` `

`int prime(int n) {`

`    int i, k, ed, m;`

`    int cpr[2] = {1, 5};`

`    //check if we can divide the prime factors first`

`    if (n == 2 || n == 3) {return 1;}`

`    if (rdiv(n, 2) == 0) {`

`        return 0;`

`    } else if (rdiv(n, 3) == 0) {`

`        return 0;`

`    }`

`    //printf("ifprime: %d", ifprime);`

`    //check up to sqrt(n) of the form 6k + {coprime of 6} `

`    ed = sqrt(n) + 1;`

`    for (k = 0; ;++k) {`

`        for (i = 0; i < 2; ++i) {`

`            m = 6*k + cpr[i];`

`            if (m < 3) {continue;}`

`            if (m > ed) {`

`                return 1;`

`            } `

`            if (n%m == 0) {`

`                printf("%d is divisible by %dn", n, m);`

`                return 0;`

`            }`

`        }`

`    }`

`    return 1;`

`}`

` `

`int main(int argc, char **argv) {`

`    int n = atoi(argv[1]);`

`    int i;`

`    struct timeval stTime, edTime;`

`    gettimeofday(&stTime, NULL);`

`    if (prime(n)) {`

`        printf("%d is prime numbern", n);`

`    } else {`

`        printf("%d is not prime numbern", n);`

`    }`

`    gettimeofday(&edTime, NULL);`

`    printf("time: %u:%un", (unsigned int)(edTime.tv_sec - stTime.tv_sec), (unsigned int)(edTime.tv_usec - stTime.tv_usec)); `

`    return 0;`

`}`

Save the code to prim4.c and compile the code with command below,

gcc -o prim4 prim4.c –lm

Below are some simple tests:

`roman10@ra-ubuntu-1:~/Desktop/integer/prim\$ ./prim4 2`

`2 is prime number`

`time: 0:49`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer/prim\$ ./prim4 3`

`3 is prime number`

`time: 0:62`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer/prim\$ ./prim4 5`

`5 is prime number`

`time: 0:52`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer/prim\$ ./prim4 2147483647`

`2147483647 is prime number`

`time: 0:267`

` `

`roman10@ra-ubuntu-1:~/Desktop/integer/prim\$ ./prim4 2147483641`

`2147483641 is divisible by 2699`

`2147483641 is not prime number`

`time: 0:72`

Note that the code can be further improved by adding the pre-computation we mentioned in part2.