How to Write a Linux Firewall in Less than 1000 Lines of Code–Put Everything Together

This is the final post of how to write a Linux firewall in less than 1000 lines of code. If you haven’t read previous posts, you may want to do so in order to understand this post.

Part 1: Overview

Part 2: Command Line Arguments Parsing in glibc

Part 3.1: Linux Kernel Module Basics and Hello World

Part 3.2: Linux Kernel Programming – Linked List

Part 3.3 Linux Kernel Programming – Memory Allocation

Part 4.1: How to Filter Network Packets using Netfilter – Part 1 Netfilter Hooks

Part 4.2 How to Filter Network Packets using Netfilter – Part 2 Implement the Hook Function

Part 5: Linux procfs Virtual File System

The Combined Code

The kernel module code is given below,

#include <linux/module.h>

#include <linux/kernel.h>

#include <linux/proc_fs.h>

#include <linux/list.h>

#include <asm/uaccess.h>

#include <linux/udp.h>

#include <linux/tcp.h>

#include <linux/skbuff.h>

#include <linux/ip.h>

#include <linux/netfilter.h>

#include <linux/netfilter_ipv4.h>

 

#define PROCF_MAX_SIZE 1024

 

#define PROCF_NAME "minifirewall"

 

MODULE_LICENSE("GPL");

MODULE_DESCRIPTION("Linux minifirewall");

MODULE_AUTHOR("Liu Feipeng/roman10");

 

//the structure used for procfs

static struct proc_dir_entry *mf_proc_file;

unsigned long procf_buffer_pos;

char *procf_buffer;

 

//the structure used to register the function

 

static struct nf_hook_ops nfho;

 

static struct nf_hook_ops nfho_out;

 

 

/*structure for firewall policies*/

struct mf_rule_desp {

    unsigned char in_out;

    char *src_ip;

    char *src_netmask;

    char *src_port;

    char *dest_ip;

    char *dest_netmask;

    char *dest_port;

    unsigned char proto;

    unsigned char action;

};

 

 

/*structure for firewall policies*/

struct mf_rule {

    unsigned char in_out;        //0: neither in nor out, 1: in, 2: out

    unsigned int src_ip;        //

    unsigned int src_netmask;        //

    unsigned int src_port;        //0~2^32

    unsigned int dest_ip;

    unsigned int dest_netmask;

    unsigned int dest_port;

    unsigned char proto;        //0: all, 1: tcp, 2: udp

    unsigned char action;        //0: for block, 1: for unblock

    struct list_head list;

};

 

static struct mf_rule policy_list;

 

unsigned int port_str_to_int(char *port_str) {

    unsigned int port = 0;    

    int i = 0;

    if (port_str==NULL) {

        return 0;

    } 

    while (port_str[i]!='') {

        port = port*10 + (port_str[i]-'0');

        ++i;

    }

    return port;

}

 

void port_int_to_str(unsigned int port, char *port_str) {

    sprintf(port_str, "%u", port);

}

 

unsigned int ip_str_to_hl(char *ip_str) {

    /*convert the string to byte array first, e.g.: from "131.132.162.25" to [131][132][162][25]*/

    unsigned char ip_array[4];

    int i = 0;

    unsigned int ip = 0;

    if (ip_str==NULL) {

        return 0; 

    }

    memset(ip_array, 0, 4);

    while (ip_str[i]!='.') {

        ip_array[0] = ip_array[0]*10 + (ip_str[i++]-'0');

    }

    ++i;

    while (ip_str[i]!='.') {

        ip_array[1] = ip_array[1]*10 + (ip_str[i++]-'0');

    }

    ++i;

    while (ip_str[i]!='.') {

        ip_array[2] = ip_array[2]*10 + (ip_str[i++]-'0');

    }

    ++i;

    while (ip_str[i]!='') {

        ip_array[3] = ip_array[3]*10 + (ip_str[i++]-'0');

    }

    /*convert from byte array to host long integer format*/

    ip = (ip_array[0] << 24);

    ip = (ip | (ip_array[1] << 16));

    ip = (ip | (ip_array[2] << 8));

    ip = (ip | ip_array[3]);

    //printk(KERN_INFO "ip_str_to_hl convert %s to %un", ip_str, ip);

    return ip;

}

 

void ip_hl_to_str(unsigned int ip, char *ip_str) {

    /*convert hl to byte array first*/

    unsigned char ip_array[4];

    memset(ip_array, 0, 4);

    ip_array[0] = (ip_array[0] | (ip >> 24));

    ip_array[1] = (ip_array[1] | (ip >> 16));

    ip_array[2] = (ip_array[2] | (ip >> 8));

    ip_array[3] = (ip_array[3] | ip);

    sprintf(ip_str, "%u.%u.%u.%u", ip_array[0], ip_array[1], ip_array[2], ip_array[3]);

}

 

/*check the two input IP addresses, see if they match, only the first few bits (masked bits) are compared*/

bool check_ip(unsigned int ip, unsigned int ip_rule, unsigned int mask) {

    unsigned int tmp = ntohl(ip);    //network to host long

    int cmp_len = 32;

    int i = 0, j = 0;

    printk(KERN_INFO "compare ip: %u <=> %un", tmp, ip_rule);

    if (mask != 0) {

       cmp_len = 0;

       for (i = 0; i < 32; ++i) {

         if (mask & (1 << (32-1-i)))

           cmp_len++;

      else

         break;

       }

    }

    /*compare the two IP addresses for the first cmp_len bits*/

    for (i = 31, j = 0; j < cmp_len; --i, ++j) {

        if ((tmp & (1 << i)) != (ip_rule & (1 << i))) {

            printk(KERN_INFO "ip compare: %d bit doesn't matchn", (32-i));

            return false;

        }

    }

    return true;

}

 

void add_a_rule(struct mf_rule_desp* a_rule_desp) {

    struct mf_rule* a_rule;

    a_rule = kmalloc(sizeof(*a_rule), GFP_KERNEL);

    if (a_rule == NULL) {

        printk(KERN_INFO "error: cannot allocate memory for a_new_rulen");

        return;

    }

    a_rule->in_out = a_rule_desp->in_out;

    if (strcmp(a_rule_desp->src_ip, "-") != 0) 

        a_rule->src_ip = ip_str_to_hl(a_rule_desp->src_ip);

    else

        a_rule->src_ip = NULL;

    if (strcmp(a_rule_desp->src_netmask, "-") != 0)

        a_rule->src_netmask = ip_str_to_hl(a_rule_desp->src_netmask);

    else

        a_rule->src_netmask = NULL;

    if (strcmp(a_rule_desp->src_port, "-") != 0)

        a_rule->src_port = port_str_to_int(a_rule_desp->src_port);

    else 

        a_rule->src_port = NULL;

    if (strcmp(a_rule_desp->dest_ip, "-") != 0)

        a_rule->dest_ip = ip_str_to_hl(a_rule_desp->dest_ip);

    else 

        a_rule->dest_ip = NULL;

    if (strcmp(a_rule_desp->dest_netmask, "-") != 0)

        a_rule->dest_netmask = ip_str_to_hl(a_rule_desp->dest_netmask);

    else 

        a_rule->dest_netmask = NULL;

    if (strcmp(a_rule_desp->dest_port, "-") != 0)

        a_rule->dest_port = port_str_to_int(a_rule_desp->dest_port);

    else 

        a_rule->dest_port = NULL;

    a_rule->proto = a_rule_desp->proto;

    a_rule->action = a_rule_desp->action;

    printk(KERN_INFO "add_a_rule: in_out=%u, src_ip=%u, src_netmask=%u, src_port=%u, dest_ip=%u, dest_netmask=%u, dest_port=%u, proto=%u, action=%un", a_rule->in_out, a_rule->src_ip, a_rule->src_netmask, a_rule->src_port, a_rule->dest_ip, a_rule->dest_netmask, a_rule->dest_port, a_rule->proto, a_rule->action);

    INIT_LIST_HEAD(&(a_rule->list));

    list_add_tail(&(a_rule->list), &(policy_list.list));

}

 

void init_mf_rule_desp(struct mf_rule_desp* a_rule_desp) {

    a_rule_desp->in_out = 0;

    a_rule_desp->src_ip = (char *)kmalloc(16, GFP_KERNEL);

    a_rule_desp->src_netmask = (char *)kmalloc(16, GFP_KERNEL);

    a_rule_desp->src_port = (char *)kmalloc(16, GFP_KERNEL);

    a_rule_desp->dest_ip = (char *)kmalloc(16, GFP_KERNEL);

    a_rule_desp->dest_netmask = (char *)kmalloc(16, GFP_KERNEL);

    a_rule_desp->dest_port = (char *)kmalloc(16, GFP_KERNEL);

    a_rule_desp->proto = 0;

    a_rule_desp->action = 0;

}

 

void delete_a_rule(int num) {

    int i = 0;

    struct list_head *p, *q;

    struct mf_rule *a_rule;

    printk(KERN_INFO "delete a rule: %dn", num);

    list_for_each_safe(p, q, &policy_list.list) {

        ++i;

        if (i == num) {

            a_rule = list_entry(p, struct mf_rule, list);

            list_del(p);

            kfree(a_rule);

            return;

        }

    }

}

 

int procf_read(char *buffer, char **buffer_location, off_t offset, int buffer_length, int *eof, void *data)

 

{

 

    int ret;

 

    struct mf_rule *a_rule;

 

    char token[20];

 

    printk(KERN_INFO "procf_read (/proc/%s) called n", PROCF_NAME);

 

    if (offset > 0) {

 

        printk(KERN_INFO "eof is 1, nothing to readn");

 

        *eof = 1;

 

        return 0;

 

    } else {

 

        procf_buffer_pos = 0;

 

        ret = 0;

 

        list_for_each_entry(a_rule, &policy_list.list, list) {

 

            //in or out

 

            if (a_rule->in_out==1) {

 

                strcpy(token, "in");

 

            } else if (a_rule->in_out==2) {

 

                strcpy(token, "out");

 

            }

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //src ip

 

            if (a_rule->src_ip == NULL) {

 

                strcpy(token, "-");

 

            } else {

 

                ip_hl_to_str(a_rule->src_ip, token);

 

            } 

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //src netmask

 

            if (a_rule->src_netmask==NULL) {

 

                strcpy(token, "-");

 

            } else {

 

                ip_hl_to_str(a_rule->src_netmask, token);

 

            } 

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //src port

 

            if (a_rule->src_port==0) {

 

                strcpy(token, "-");

 

            } else {

 

                port_int_to_str(a_rule->src_port, token);

 

            } 

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //dest ip

 

            if (a_rule->dest_ip==NULL) {

 

                strcpy(token, "-");

 

            } else {

 

                ip_hl_to_str(a_rule->dest_ip, token);

 

            } 

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //dest netmask

 

            if (a_rule->dest_netmask==NULL) {

 

                strcpy(token, "-");

 

            } else {

 

                ip_hl_to_str(a_rule->dest_netmask, token);

 

            } 

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //dest port

 

            if (a_rule->dest_port==0) {

 

                strcpy(token, "-");

 

            } else {

 

                port_int_to_str(a_rule->dest_port, token);

 

            } 

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //protocol

 

            if (a_rule->proto==0) {

 

                strcpy(token, "ALL");

 

            } else if (a_rule->proto==1) {

 

                strcpy(token, "TCP");

 

            }  else if (a_rule->proto==2) {

 

                strcpy(token, "UDP");

 

            }

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //action

 

            if (a_rule->action==0) {

 

                strcpy(token, "BLOCK");

 

            } else if (a_rule->action==1) {

 

                strcpy(token, "UNBLOCK");

 

            }

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, "n", 1);

 

            procf_buffer_pos++;

 

        }

 

        //copy from procf_buffer to buffer

 

        printk(KERN_INFO "procf_buffer_pos: %ldn", procf_buffer_pos);

 

        memcpy(buffer, procf_buffer, procf_buffer_pos);

 

        ret = procf_buffer_pos;

 

    }

 

    return ret;

 

}

 

 

 

int procf_write(struct file *file, const char *buffer, unsigned long count, void *data)

 

{

 

   int i, j;

   struct mf_rule_desp *rule_desp;

 

   printk(KERN_INFO "procf_write is called.n");

 

   /*read the write content into the storage buffer*/

 

   procf_buffer_pos = 0;

 

   printk(KERN_INFO "pos: %ld; count: %ldn", procf_buffer_pos, count);

 

   if (procf_buffer_pos + count > PROCF_MAX_SIZE) {

 

       count = PROCF_MAX_SIZE-procf_buffer_pos;

 

   } 

 

   if (copy_from_user(procf_buffer+procf_buffer_pos, buffer, count)) {

 

       return -EFAULT;

 

   }

 

   if (procf_buffer[procf_buffer_pos] == 'p') {

 

       //print command

 

       return 0;

 

   } else if (procf_buffer[procf_buffer_pos] == 'd') {

 

       //delete command

 

       i = procf_buffer_pos+1; j = 0;

 

       while ((procf_buffer[i]!=' ') && (procf_buffer[i]!='n') ) {

 

           printk(KERN_INFO "delete: %dn", procf_buffer[i]-'0');

 

           j = j*10 + (procf_buffer[i]-'0');

 

           ++i;

 

       }

 

       printk(KERN_INFO "delete a rule: %dn", j);

 

       delete_a_rule(j);

 

       return count;

 

   }

 

   /*add a new policy according to content int the storage buffer*/

   rule_desp = kmalloc(sizeof(*rule_desp), GFP_KERNEL);

   if (rule_desp == NULL) {

       printk(KERN_INFO "error: cannot allocate memory for rule_despn");

       return -ENOMEM;

   }

 

   init_mf_rule_desp(rule_desp);

 

   

 

   /**fill in the content of the new policy **/

 

   /***in_out***/

 

   i = procf_buffer_pos; j = 0;

 

   if (procf_buffer[i]!=' ') {

 

       rule_desp->in_out = (unsigned char)(procf_buffer[i++] - '0');

 

   }

 

   ++i;

 

   printk(KERN_INFO "in or out: %un", rule_desp->in_out);

 

   /***src ip***/

 

   j = 0;

 

   while (procf_buffer[i]!=' ') {

 

       rule_desp->src_ip[j++] = procf_buffer[i++];

 

   }

 

   ++i;

 

   rule_desp->src_ip[j] = '';

 

   printk(KERN_INFO "src ip: %sn", rule_desp->src_ip);

 

   /***src netmask***/

 

   j = 0;

 

   while (procf_buffer[i]!=' ') {

 

       rule_desp->src_netmask[j++] = procf_buffer[i++];

 

   }

 

   ++i;

 

   rule_desp->src_netmask[j] = '';

 

   printk(KERN_INFO "src netmask: %sn", rule_desp->src_netmask);

 

   /***src port number***/

 

   j = 0;

 

   while (procf_buffer[i]!=' ') {

 

       rule_desp->src_port[j++] = procf_buffer[i++];

 

   }

 

   ++i;

 

   rule_desp->src_port[j] = '';

 

   printk(KERN_INFO "src_port: %sn", rule_desp->src_port);

 

   /***dest ip***/

 

   j = 0;

 

   while (procf_buffer[i]!=' ') {

 

       rule_desp->dest_ip[j++] = procf_buffer[i++];

 

   }

 

   ++i;

 

   rule_desp->dest_ip[j] = '';

 

   printk(KERN_INFO "dest ip: %sn", rule_desp->dest_ip);

 

   /***dest netmask***/

 

   j = 0;

 

   while (procf_buffer[i]!=' ') {

 

       rule_desp->dest_netmask[j++] = procf_buffer[i++];

 

   }

 

   ++i;

 

   rule_desp->dest_netmask[j] = '';

 

   printk(KERN_INFO "dest netmask%sn", rule_desp->dest_netmask);

 

   /***dest port***/

 

   j = 0;

 

   while (procf_buffer[i]!=' ') {

 

       rule_desp->dest_port[j++] = procf_buffer[i++];

 

   }

 

   ++i;

 

   rule_desp->dest_port[j] = '';

 

   printk(KERN_INFO "dest port: %sn", rule_desp->dest_port);

 

   /***proto***/

 

   j = 0;

 

   if (procf_buffer[i]!=' ') {

       if (procf_buffer[i] != '-')

 

           rule_desp->proto = (unsigned char)(procf_buffer[i++]-'0');

       else

           ++i;

 

   }

 

   ++i;

 

   printk(KERN_INFO "proto: %dn", rule_desp->proto);

 

   /***action***/

 

   j = 0;

   if (procf_buffer[i]!=' ') {

       if (procf_buffer[i] != '-')

 

           rule_desp->action = (unsigned char)(procf_buffer[i++]-'0');

       else

           ++i;

 

   }

 

   ++i;

 

   printk(KERN_INFO "action: %dn", rule_desp->action);

   add_a_rule(rule_desp);

   kfree(rule_desp);

 

   printk(KERN_INFO "--------------------n");

 

   return count;

 

}

 

//the hook function itself: regsitered for filtering outgoing packets

 

unsigned int hook_func_out(unsigned int hooknum, struct sk_buff *skb, 

 

        const struct net_device *in, const struct net_device *out,

 

        int (*okfn)(struct sk_buff *)) {

 

   /*get src address, src netmask, src port, dest ip, dest netmask, dest port, protocol*/

 

   struct iphdr *ip_header = (struct iphdr *)skb_network_header(skb);

 

   struct udphdr *udp_header;

 

   struct tcphdr *tcp_header;

 

   struct list_head *p;

 

   struct mf_rule *a_rule;

   char src_ip_str[16], dest_ip_str[16];

 

   int i = 0;

 

   /**get src and dest ip addresses**/

 

   unsigned int src_ip = (unsigned int)ip_header->saddr;

 

   unsigned int dest_ip = (unsigned int)ip_header->daddr;

 

   unsigned int src_port = 0;

 

   unsigned int dest_port = 0;

 

   /***get src and dest port number***/

 

   if (ip_header->protocol==17) {

 

       udp_header = (struct udphdr *)skb_transport_header(skb);

 

       src_port = (unsigned int)ntohs(udp_header->source);

 

       dest_port = (unsigned int)ntohs(udp_header->dest);

 

   } else if (ip_header->protocol == 6) {

 

       tcp_header = (struct tcphdr *)skb_transport_header(skb);

 

       src_port = (unsigned int)ntohs(tcp_header->source);

 

       dest_port = (unsigned int)ntohs(tcp_header->dest);

 

   }

   ip_hl_to_str(ntohl(src_ip), src_ip_str);

   ip_hl_to_str(ntohl(dest_ip), dest_ip_str);

 

   printk(KERN_INFO "OUT packet info: src ip: %u = %s, src port: %u; dest ip: %u = %s, dest port: %u; proto: %un", src_ip, src_ip_str, src_port, dest_ip, dest_ip_str, dest_port, ip_header->protocol); 

 

   //go through the firewall list and check if there is a match

 

   //in case there are multiple matches, take the first one

 

   list_for_each(p, &policy_list.list) {

 

       i++;

 

       a_rule = list_entry(p, struct mf_rule, list);

       //printk(KERN_INFO "rule %d: a_rule->in_out = %u; a_rule->src_ip = %u; a_rule->src_netmask=%u; a_rule->src_port=%u; a_rule->dest_ip=%u; a_rule->dest_netmask=%u; a_rule->dest_port=%u; a_rule->proto=%u; a_rule->action=%un", i, a_rule->in_out, a_rule->src_ip, a_rule->src_netmask, a_rule->src_port, a_rule->dest_ip, a_rule->dest_netmask, a_rule->dest_port, a_rule->proto, a_rule->action);

 

       //if a rule doesn't specify as "out", skip it

 

       if (a_rule->in_out != 2) {

 

           printk(KERN_INFO "rule %d (a_rule->in_out: %u) not match: out packet, rule doesn't specify as outn", i, a_rule->in_out);

 

           continue;

 

       } else {

 

           //check the protocol

 

           if ((a_rule->proto==1) && (ip_header->protocol != 6)) {

 

               printk(KERN_INFO "rule %d not match: rule-TCP, packet->not TCPn", i);

 

               continue;

 

           } else if ((a_rule->proto==2) && (ip_header->protocol != 17)) {

 

               printk(KERN_INFO "rule %d not match: rule-UDP, packet->not UDPn", i);

 

               continue;

 

           }

 

           //check the ip address

 

           if (a_rule->src_ip==0) {

 

              //rule doesn't specify ip: match

 

           } else {

 

              if (!check_ip(src_ip, a_rule->src_ip, a_rule->src_netmask)) {

 

                  printk(KERN_INFO "rule %d not match: src ip mismatchn", i);

 

                  continue;

 

              }

 

           }

 

           if (a_rule->dest_ip == 0) {

 

               //rule doesn't specify ip: match

 

           } else {

 

               if (!check_ip(dest_ip, a_rule->dest_ip, a_rule->dest_netmask)) {

 

                   printk(KERN_INFO "rule %d not match: dest ip mismatchn", i);

 

                   continue;

 

               }

 

           }

 

           //check the port number

 

           if (a_rule->src_port==0) {

 

               //rule doesn't specify src port: match

 

           } else if (src_port!=a_rule->src_port) {

 

               printk(KERN_INFO "rule %d not match: src port dismatchn", i);

 

               continue;

 

           }

 

           if (a_rule->dest_port == 0) {

 

               //rule doens't specify dest port: match

 

           }

 

           else if (dest_port!=a_rule->dest_port) {

 

               printk(KERN_INFO "rule %d not match: dest port mismatchn", i);

 

               continue;

 

           }

 

           //a match is found: take action

 

           if (a_rule->action==0) {

 

               printk(KERN_INFO "a match is found: %d, drop the packetn", i);

 

              printk(KERN_INFO "---------------------------------------n");

 

               return NF_DROP;

 

           } else {

 

               printk(KERN_INFO "a match is found: %d, accept the packetn", i);

 

              printk(KERN_INFO "---------------------------------------n");

 

               return NF_ACCEPT;

 

           }

 

       }

 

   }

 

   printk(KERN_INFO "no matching is found, accept the packetn");

 

   printk(KERN_INFO "---------------------------------------n");

 

   return NF_ACCEPT;            

 

}

 

 

 

 

 

//the hook function itself: registered for filtering incoming packets

 

unsigned int hook_func_in(unsigned int hooknum, struct sk_buff *skb, 

 

        const struct net_device *in, const struct net_device *out,

 

        int (*okfn)(struct sk_buff *)) {

 

   /*get src address, src netmask, src port, dest ip, dest netmask, dest port, protocol*/

 

   struct iphdr *ip_header = (struct iphdr *)skb_network_header(skb);

 

   struct udphdr *udp_header;

 

   struct tcphdr *tcp_header;

 

   struct list_head *p;

 

   struct mf_rule *a_rule;

   char src_ip_str[16], dest_ip_str[16];

 

   int i = 0;

 

   /**get src and dest ip addresses**/

 

   unsigned int src_ip = (unsigned int)ip_header->saddr;

 

   unsigned int dest_ip = (unsigned int)ip_header->daddr;

 

   unsigned int src_port = 0;

 

   unsigned int dest_port = 0;

 

   /***get src and dest port number***/

 

   if (ip_header->protocol==17) {

 

       udp_header = (struct udphdr *)(skb_transport_header(skb)+20);

 

       src_port = (unsigned int)ntohs(udp_header->source);

 

       dest_port = (unsigned int)ntohs(udp_header->dest);

 

   } else if (ip_header->protocol == 6) {

 

       tcp_header = (struct tcphdr *)(skb_transport_header(skb)+20);

 

       src_port = (unsigned int)ntohs(tcp_header->source);

 

       dest_port = (unsigned int)ntohs(tcp_header->dest);

 

   }

   ip_hl_to_str(ntohl(src_ip), src_ip_str);

   ip_hl_to_str(ntohl(dest_ip), dest_ip_str);

 

   printk(KERN_INFO "IN packet info: src ip: %u = %s, src port: %u; dest ip: %u = %s, dest port: %u; proto: %un", src_ip, src_ip_str, src_port, dest_ip, dest_ip_str, dest_port, ip_header->protocol); 

 

   //go through the firewall list and check if there is a match

 

   //in case there are multiple matches, take the first one

 

   list_for_each(p, &policy_list.list) {

 

       i++;

 

       a_rule = list_entry(p, struct mf_rule, list);

//printk(KERN_INFO "rule %d: a_rule->in_out = %u; a_rule->src_ip = %u; a_rule->src_netmask=%u; a_rule->src_port=%u; a_rule->dest_ip=%u; a_rule->dest_netmask=%u; a_rule->dest_port=%u; a_rule->proto=%u; a_rule->action=%un", i, a_rule->in_out, a_rule->src_ip, a_rule->src_netmask, a_rule->src_port, a_rule->dest_ip, a_rule->dest_netmask, a_rule->dest_port, a_rule->proto, a_rule->action);

 

       //if a rule doesn't specify as "in", skip it

 

       if (a_rule->in_out != 1) {

 

           printk(KERN_INFO "rule %d (a_rule->in_out:%u) not match: in packet, rule doesn't specify as inn", i, a_rule->in_out);

 

           continue;

 

       } else {

 

           //check the protocol

 

           if ((a_rule->proto==1) && (ip_header->protocol != 6)) {

 

               printk(KERN_INFO "rule %d not match: rule-TCP, packet->not TCPn", i);

 

               continue;

 

           } else if ((a_rule->proto==2) && (ip_header->protocol != 17)) {

 

               printk(KERN_INFO "rule %d not match: rule-UDP, packet->not UDPn", i);

 

               continue;

 

           }

 

           //check the ip address

 

           if (a_rule->src_ip==0) {

 

              //

 

           } else {

 

              if (!check_ip(src_ip, a_rule->src_ip, a_rule->src_netmask)) {

 

                  printk(KERN_INFO "rule %d not match: src ip mismatchn", i);

 

                  continue;

 

              }

 

           }

 

           if (a_rule->dest_ip == 0) {

 

               //

 

           } else {

 

               if (!check_ip(dest_ip, a_rule->dest_ip, a_rule->dest_netmask)) {

 

                  printk(KERN_INFO "rule %d not match: dest ip mismatchn", i);                  

 

                  continue;

 

               }

 

           }

 

           //check the port number

 

           if (a_rule->src_port==0) {

 

               //rule doesn't specify src port: match

 

           } else if (src_port!=a_rule->src_port) {

 

               printk(KERN_INFO "rule %d not match: src port mismatchn", i);

 

               continue;

 

           }

 

           if (a_rule->dest_port == 0) {

 

               //rule doens't specify dest port: match

 

           }

 

           else if (dest_port!=a_rule->dest_port) {

 

               printk(KERN_INFO "rule %d not match: dest port mismatchn", i);

 

               continue;

 

           }

 

           //a match is found: take action

 

           if (a_rule->action==0) {

 

               printk(KERN_INFO "a match is found: %d, drop the packetn", i);

 

               printk(KERN_INFO "---------------------------------------n");

 

               return NF_DROP;

 

           } else {

 

               printk(KERN_INFO "a match is found: %d, accept the packetn", i);

 

               printk(KERN_INFO "---------------------------------------n");

 

               return NF_ACCEPT;

 

           }

 

       }

 

   }

 

   printk(KERN_INFO "no matching is found, accept the packetn");

 

   printk(KERN_INFO "---------------------------------------n");

 

   return NF_ACCEPT;                

 

}

 

/* Initialization routine */

int init_module() {

    printk(KERN_INFO "initialize kernel modulen");

    procf_buffer = (char *) vmalloc(PROCF_MAX_SIZE);

    INIT_LIST_HEAD(&(policy_list.list));

    mf_proc_file = create_proc_entry(PROCF_NAME, 0644, NULL);

 

    if (mf_proc_file==NULL) {

 

        printk(KERN_INFO "Error: could not initialize /proc/%sn", PROCF_NAME);

 

        return -ENOMEM; 

 

    } 

 

    mf_proc_file->read_proc = procf_read;

 

    mf_proc_file->write_proc = procf_write;

 

    printk(KERN_INFO "/proc/%s is createdn", PROCF_NAME);

    /* Fill in the hook structure for incoming packet hook*/

 

    nfho.hook = hook_func_in;

 

    nfho.hooknum = NF_INET_LOCAL_IN;

 

    nfho.pf = PF_INET;

 

    nfho.priority = NF_IP_PRI_FIRST;

 

    nf_register_hook(&nfho);         // Register the hook

 

    /* Fill in the hook structure for outgoing packet hook*/

 

    nfho_out.hook = hook_func_out;

 

    nfho_out.hooknum = NF_INET_LOCAL_OUT;

 

    nfho_out.pf = PF_INET;

 

    nfho_out.priority = NF_IP_PRI_FIRST;

 

    nf_register_hook(&nfho_out);    // Register the hook

    return 0;

}

 

/* Cleanup routine */

void cleanup_module() {

    struct list_head *p, *q;

    struct mf_rule *a_rule;

    nf_unregister_hook(&nfho);

 

    nf_unregister_hook(&nfho_out);

    printk(KERN_INFO "free policy listn");

    list_for_each_safe(p, q, &policy_list.list) {

        printk(KERN_INFO "free onen");

        a_rule = list_entry(p, struct mf_rule, list);

        list_del(p);

        kfree(a_rule);

    }

    remove_proc_entry(PROCF_NAME, NULL);

    printk(KERN_INFO "kernel module unloaded.n");

 

}

The background to understand the code above has been covered in previous blogs. If you have difficulties understand the code above, please refer to the relevant parts for reference.

Use the following Makefile to compile the kernel module,

obj-m += mfkm.o
all:
	make -C /lib/modules/$(shell uname -r)/build M=$(PWD) modules
clean:
	make -C /lib/modules/$(shell uname -r)/build M=$(PWD) clean

 

The user space program code is given below,

#include <stdio.h>

#include <stdlib.h>

#include <getopt.h>

 

#define print_value(x) (x==NULL?"-" : x)

 

static struct mf_rule_struct {

    int in_out;

    char *src_ip;

    char *src_netmask;

    char *src_port;            //default to -1 

    char *dest_ip;

    char *dest_netmask;

    char *dest_port;

    char *proto;

    char *action;

} mf_rule;

 

static struct mf_delete_struct {

    char *cmd;

    char *row;

} mf_delete;

 

void send_to_proc(char *str)

 

{

 

    FILE *pf;

 

    pf = fopen("/proc/minifirewall", "w");

 

    if (pf == NULL)  {

 

        printf("Cannot open /proc/minifirewall for writtingn");

 

        return;

 

    } else {

 

        fprintf(pf, "%s", str);

 

    }

 

    fclose(pf);

 

    return;

 

}

 

int get_proto(char* proto) {

    if (strcmp(proto, "ALL") == 0) {

        return 0;

    } else if (strcmp(proto, "TCP") == 0) {

        return 1;

    } else if (strcmp(proto, "UDP") == 0) {

        return 2;

    }

}

 

int get_action(char* action) {

    if (strcmp(action, "BLOCK") == 0) {

    return 0;

    } else if (strcmp(action, "UNBLOCK") == 0) {

    return 1;

    }

}

 

void send_rule_to_proc()

{

    //printf("send_rule_to_procn");

    char a_rule[200];

 

    sprintf(a_rule, "%u %s %s %s %s %s %s %u %un", mf_rule.in_out+1, print_value(mf_rule.src_ip), print_value(mf_rule.src_netmask), print_value(mf_rule.src_port), print_value(mf_rule.dest_ip), print_value(mf_rule.dest_netmask), print_value(mf_rule.dest_port), get_proto(mf_rule.proto), get_action(mf_rule.action));

 

    //printf("%sn", a_rule);

 

    send_to_proc(a_rule);

}

 

void send_delete_to_proc()

{

    //printf("send_delete_to_procn");

    char delete_cmd[20];

 

    sprintf(delete_cmd, "%s%sn", "d", print_value(mf_delete.row));

 

    send_to_proc(delete_cmd);

}

 

void print_rule()

{

    FILE *pf;

 

    char token[20];

 

    char ch;

 

    int i = 0;

 

    printf("in/out    src ip    src mask    src port    dest ip    dest mask     dest port    proto    actionn");

 

    pf = fopen("/proc/minifirewall", "r");

 

    if (pf == NULL)  {

 

        printf("Cannot open /proc/minifirewall for readingn");

 

        return;

 

    } else {

 

      while (1) {

 

        while (((ch=fgetc(pf))==' ') || (ch == 'n')) {

 

            //skip the empty space

 

        }

 

        if (ch == EOF) break;

 

        //in/out

 

        i = 0;

 

        token[i++] = ch;

 

        while (((ch=fgetc(pf))!=EOF) && (ch!=' ')) {

 

            token[i++] = ch;

 

        }

 

        token[i] = '';

 

        printf("  %s  ", token);

 

        if (ch==EOF) break;

 

        //src ip

 

        i = 0;

 

        while (((ch=fgetc(pf))!=EOF) && (ch!=' ')) {

 

            token[i++] = ch;

 

        }

 

        token[i] = '';

 

        if (strcmp(token, "-")==0) {

 

            printf("      %s     ", token);

 

        } else {

 

            printf(" %s ", token);

 

        }

 

        //src mask

 

        i = 0;

 

        while (((ch=fgetc(pf))!=EOF) && (ch!=' ')) {

 

            token[i++] = ch;

 

        }

 

        token[i] = '';

 

        if (strcmp(token, "-")==0) {

 

            printf("     %s         ", token);

 

        } else {

 

            printf(" %s ", token);

 

        }

 

        if (ch==EOF) break;

 

        //src port        

 

        i = 0;

 

        token[i++] = ' ';

 

        while (((ch=fgetc(pf))!=EOF) && (ch!=' ')) {

 

            token[i++] = ch;

 

        }

 

        token[i] = '';

 

        printf("%s     ", token);

 

        if (ch==EOF) break;

 

        //dest ip

 

        i = 0;

 

        while (((ch=fgetc(pf))!=EOF) && (ch!=' ')) {

 

            token[i++] = ch;

 

        }

 

        token[i] = '';

 

        if (strcmp(token, "-")==0) {

 

            printf("      %s     ", token);

 

        } else {

 

            printf(" %s ", token);

 

        }

 

        if (ch==EOF) break;

 

        //dest mask

 

        i = 0;

 

        while (((ch=fgetc(pf))!=EOF) && (ch!=' ')) {

 

            token[i++] = ch;

 

        }

 

        token[i] = '';

 

        if (strcmp(token, "-")==0) {

 

            printf("      %s             ", token);

 

        } else {

 

            printf(" %s ", token);

 

        }

 

        if (ch==EOF) break;

 

        //dest port

 

        i = 0;

 

        while (((ch=fgetc(pf))!=EOF) && (ch!=' ')) {

 

            token[i++] = ch;

 

        }

 

        token[i] = '';

 

        printf("%s      ", token);

 

        if (ch==EOF) break;

 

        //proto

 

        i = 0;

 

        while (((ch=fgetc(pf))!=EOF) && (ch!=' ')) {

 

            token[i++] = ch;

 

        }

 

        token[i] = '';

 

        printf("    %s    ", token);

 

        if (ch==EOF) break;

 

        //action

 

        i = 0;

 

        while (((ch=fgetc(pf))!=EOF) && (ch!=' ') && (ch!='n')) {

 

            token[i++] = ch;

 

        }

 

        token[i] = '';      

 

        printf(" %sn", token);

 

        if (ch==EOF) break;

 

      }

 

    }

 

    fclose(pf);

 

    return;

    return;

}

 

int main(int argc, char **argv)

{

    int c; int action = 1;    //1: new rule; 2: print; 3: delete

    mf_rule.in_out = -1; mf_rule.src_ip = NULL; mf_rule.src_netmask = NULL; mf_rule.src_port = NULL;

    mf_rule.dest_ip = NULL; mf_rule.dest_netmask = NULL; mf_rule.dest_port = NULL;mf_rule.proto = NULL;

    mf_rule.action = NULL;

    while (1) 

    {

        static struct option long_options[] = 

        {

        /*set a flag*/

            {"in", no_argument, &mf_rule.in_out, 0},

            {"out", no_argument, &mf_rule.in_out, 1},

        /*These options don't set a flag.

            We distinguish them by their indices.*/

            {"print", no_argument, 0, 'o'},

            {"delete", required_argument, 0, 'd'},

            {"srcip", required_argument, 0, 's'},

            {"srcnetmask", required_argument, 0, 'm'},

            {"srcport", required_argument, 0, 'p'},

            {"destip", required_argument, 0, 't'},

            {"destnetmask", required_argument, 0, 'n'},

            {"destport", required_argument, 0, 'q'},

            {"proto", required_argument, 0, 'c'},

            {"action", required_argument, 0, 'a'},

            {0, 0, 0, 0}

        };

        int option_index = 0;

        c = getopt_long(argc, argv, "od:s:m:p:t:n:q:c:a:", long_options, &option_index);

        /*Detect the end of the options. */

        if (c == -1)

            break;

        action = 1;

        switch (c)

        {

            case 0:

              //printf("flag option: %s, mf_rule.in_out = %dn", long_options[option_index].name, mf_rule.in_out);

              break;

            case 'o':

                action = 2;    //print

              break;

            case 'd':

              action = 3;       //delete

              mf_delete.cmd = (char *)long_options[option_index].name;

              mf_delete.row = optarg;

              break;

            case 's':

              mf_rule.src_ip = optarg;  //src ip

              break; 

            case 'm':

              mf_rule.src_netmask = optarg; //srcnetmask:

              break;

            case 'p':

              mf_rule.src_port = optarg;    //srcport:

              break;

            case 't':

              mf_rule.dest_ip = optarg;     //destip:

              break;

            case 'n':

              mf_rule.dest_netmask = optarg;    //destnetmask

              break;

            case 'q':

              mf_rule.dest_port = optarg;    //destport

              break;

            case 'c':

              mf_rule.proto = optarg; //proto

              break;

            case 'a':

              mf_rule.action = optarg;//action

              break;

            case '?':

              /* getopt_long printed an error message. */

              break;

            default:

              abort();

        }

    //if (c != 0)

    //    printf("%s = %sn",  long_options[option_index].name, optarg);

    }

    if (action == 1) {

        send_rule_to_proc();

    } else if (action == 2) {

        print_rule();

    } else if (action == 3) {

        send_delete_to_proc();

    }

    if (optind < argc)

    {

        //printf("non-option ARGV-elements: ");

        while (optind < argc)

        //printf("%s ", argv[optind++]);

        putchar('n');

    }

}

You can compile the code by,

gcc –o mf mf.c

Sample Testing

1. Block all incoming traffic, unblock all outgoing traffic

1.1 Enter the configuration commands below,

roman10@roman10-laptop:~/hello$ sudo ./mf --in --proto ALL --action BLOCK

roman10@roman10-laptop:~/hello$ sudo ./mf --out --proto ALL --action UNBLOCK

roman10@roman10-laptop:~/hello$ ./mf --print

in/out    src ip    src mask    src port    dest ip    dest mask     dest port    proto    action

  in        -          -          -           -           -             -          ALL     BLOCK

  out        -          -          -           -           -             -          ALL     UNBLOCK

 

1.2 ping 127.0.0.1 –c 1

1.3 Check the output of the minifirewall kernel module by,

tail –f /var/log/messages

Below is a sample output of /var/log/messages:

Jul 31 15:12:16 roman10-laptop kernel: [72959.708449] OUT packet info: src ip: 16777343 = 127.0.0.1, src port: 0; dest ip: 16777343 = 127.0.0.1, dest port: 0; proto: 1

Jul 31 15:12:16 roman10-laptop kernel: [72959.708488] rule 1 (a_rule->in_out: 1) not match: out packet, rule doesn't specify as out

Jul 31 15:12:16 roman10-laptop kernel: [72959.708514] a match is found: 2, accept the packet

Jul 31 15:12:16 roman10-laptop kernel: [72959.708532] ---------------------------------------

Jul 31 15:12:16 roman10-laptop kernel: [72959.711527] IN packet info: src ip: 16777343 = 127.0.0.1, src port: 0; dest ip: 16777343 = 127.0.0.1, dest port: 0; proto: 1

Jul 31 15:12:16 roman10-laptop kernel: [72959.711610] a match is found: 1, drop the packet

Jul 31 15:12:16 roman10-laptop kernel: [72959.711660] ---------------------------------------

2. Test IP address and Netmask

2.1 Replace 10.0.2.15 with your local IP address. Enter the commands below, and check if you can get similar result.

roman10@roman10-laptop:~/hello$ sudo ./mf --out --srcip 10.0.2.15 --proto UDP --action BLOCK 

roman10@roman10-laptop:~/hello$ ./mf --print

in/out    src ip    src mask    src port    dest ip    dest mask     dest port    proto    action

  out   10.0.2.15      -          -           -           -             -          UDP     BLOCK

roman10@roman10-laptop:~/hello$ ping google.com

ping: unknown host google.com

roman10@roman10-laptop:~/hello$ sudo ./mf --delete 1

roman10@roman10-laptop:~/hello$ sudo ./mf --out --srcip 10.0.2.16 --proto UDP --action BLOCK

roman10@roman10-laptop:~/hello$ ./mf --print

in/out    src ip    src mask    src port    dest ip    dest mask     dest port    proto    action

  out   10.0.2.16      -          -           -           -             -          UDP     BLOCK

roman10@roman10-laptop:~/hello$ ping google.com

PING google.com (74.125.235.20) 56(84) bytes of data.

64 bytes from 74.125.235.20: icmp_seq=1 ttl=52 time=82.2 ms

^C64 bytes from 74.125.235.20: icmp_seq=2 ttl=52 time=17.6 ms

 

--- google.com ping statistics ---

2 packets transmitted, 2 received, 0% packet loss, time 5173ms

rtt min/avg/max/mdev = 17.675/49.975/82.276/32.301 ms

roman10@roman10-laptop:~/hello$ sudo ./mf --delete 1

roman10@roman10-laptop:~/hello$ sudo ./mf --out --srcip 10.0.2.16 --srcnetmask 255.252.0.0 --proto UDP --action BLOCK

roman10@roman10-laptop:~/hello$ ./mf --print

in/out    src ip    src mask    src port    dest ip    dest mask     dest port    proto    action

  out   10.0.2.16  255.252.0.0  -           -           -             -          UDP     BLOCK

roman10@roman10-laptop:~/hello$ ping google.com

ping: unknown host google.com

roman10@roman10-laptop:~/hello$ 

2.2 The idea is first time the ping is blocked because IP address matches the BLOCK rule. Second time the ping can go through because IP address doesn’t match. The third ping is blocked again because the first 14 bits (according to the mask 255.252.0.0) matches.

One can also examine the /var/log/messages content for verification.

This completes the tutorial How to Write a Linux Firewall in Less than 1000 Lines.

Linux procfs Virtual File System

procfs is a software created virtual file system that mounted to /proc directory at boot time. It was originally designed to provide information about running process of the Linux system, but has gone far beyond its original purpose as Linux kernel development proceeds.

Basics of procfs

It can act as a bridge connecting the user space and the kernel space. User space program can use proc files to read the information exported by kernel. For example, /proc/modules contains the information of loaded kernel modules. The command

cat /proc/modules

gives similar information as lsmod. Most proc files are read-only and only expose kernel information to user space programs.

proc files can also be used to control and modify kernel behavior on the fly. The proc files needs to be writtable in this case.

For example, to enable IP forwarding of iptable, one can use the command below,

echo 1 > /proc/sys/net/ipv4/ip_forward

Programming Interface of proc File

1. Create and Remove a proc File

The procfs programming interface is defined in /lib/modules/$(uname -r)/build/include/linux/proc_fs.h.

To create a proc file, use the following function,

extern struct proc_dir_entry *create_proc_entry(const char *name, mode_t mode,

                                                struct proc_dir_entry *parent);

The function accepts three parameters, a file name, the file permissions, and a location where this file is to be created.  If NULL is passed as parent, the /proc directory will be set as parent.

Below is an example,

#define PROCF_NAME "minifirewall"

static struct proc_dir_entry *mf_proc_file;

mf_proc_file = create_proc_entry(PROCF_NAME, 0644, NULL);

To remove a proc file, the function below is defined,

extern void remove_proc_entry(const char *name, struct proc_dir_entry *parent);

Follow the example above, the code below remove the created proc file,

remove_proc_entry(PROCF_NAME, NULL);

2. Read and Write through proc File

The proc_fs.h allows one to register read and write callbacks for proc file.

typedef int (read_proc_t)(char *page, char **start, off_t off,

                          int count, int *eof, void *data);

typedef int (write_proc_t)(struct file *file, const char __user *buffer,

                           unsigned long count, void *data);

The registered read callback function is triggered when user space program read the proc file. page is the location the kernel code writes data for the user space program; count defines maximum number of characters can be written; start and off are used only when the kernel code needs to return more than one page of data; eof parameter is set to 1 when all data has been written; the last parameter data represents private data.

The buffer page is in kernel space, it can be used directly.

The write callback function is called when user space program writes to the proc file. file is a file structure represent the proc file; the buffer is actually a user space buffer, so it cannot be used directly as the page buffer in proc write function, instead kernel code must use copy_from_user function to get the data. (The copy_from_user function takes 3 input parameters, the destination memory address in kernel space, the source memory address in user space, and the number of bytes of the data to copy.) The count parameter is how much data has been written in buffer.

Suppose the callback functions are defined as procf_read and procf_write respectively. The code below will register the callback functions,

mf_proc_file->read_proc = procf_read;

mf_proc_file->write_proc = procf_write;

A Sample Program

Below is a simple program demonstrates the usage of procfs. It consists of two parts, the kernel module test_proc.ko running in the kernel space and the configuration tool mf running in the user space.

The test_proc.c program is as below,

#include <linux/module.h>

#include <linux/kernel.h>

#include <linux/proc_fs.h>

#include <linux/list.h>

#include <asm/uaccess.h>

#include <linux/udp.h>

#include <linux/tcp.h>

#include <linux/skbuff.h>

#include <linux/ip.h>

#include <linux/netfilter.h>

#include <linux/netfilter_ipv4.h>

 

#define PROCF_MAX_SIZE 1024

 

#define PROCF_NAME "minifirewall"

 

MODULE_LICENSE("GPL");

MODULE_DESCRIPTION("test_proc");

MODULE_AUTHOR("Liu Feipeng/roman10");

 

static struct proc_dir_entry *mf_proc_file;

unsigned long procf_buffer_pos;

char *procf_buffer;

 

 

/*structure for firewall policies*/

struct mf_rule_desp {

    unsigned char in_out;

    char *src_ip;

    char *src_netmask;

    char *src_port;

    char *dest_ip;

    char *dest_netmask;

    char *dest_port;

    unsigned char proto;

    unsigned char action;

};

 

 

/*structure for firewall policies*/

struct mf_rule {

    unsigned char in_out;        //0: neither in nor out, 1: in, 2: out

    unsigned int src_ip;        //

    unsigned int src_netmask;        //

    unsigned int src_port;        //0~2^32

    unsigned int dest_ip;

    unsigned int dest_netmask;

    unsigned int dest_port;

    unsigned char proto;        //0: all, 1: tcp, 2: udp

    unsigned char action;        //0: for block, 1: for unblock

    struct list_head list;

};

 

static struct mf_rule policy_list;

 

unsigned int port_str_to_int(char *port_str) {

    unsigned int port = 0;    

    int i = 0;

    if (port_str==NULL) {

        return 0;

    } 

    while (port_str[i]!='') {

        port = port*10 + (port_str[i]-'0');

        ++i;

    }

    return port;

}

 

void port_int_to_str(unsigned int port, char *port_str) {

    sprintf(port_str, "%u", port);

}

 

unsigned int ip_str_to_hl(char *ip_str) {

    /*convert the string to byte array first, e.g.: from "131.132.162.25" to [131][132][162][25]*/

    unsigned char ip_array[4];

    int i = 0;

    unsigned int ip = 0;

    if (ip_str==NULL) {

        return 0; 

    }

    memset(ip_array, 0, 4);

    while (ip_str[i]!='.') {

        ip_array[0] = ip_array[0]*10 + (ip_str[i++]-'0');

    }

    ++i;

    while (ip_str[i]!='.') {

        ip_array[1] = ip_array[1]*10 + (ip_str[i++]-'0');

    }

    ++i;

    while (ip_str[i]!='.') {

        ip_array[2] = ip_array[2]*10 + (ip_str[i++]-'0');

    }

    ++i;

    while (ip_str[i]!='') {

        ip_array[3] = ip_array[3]*10 + (ip_str[i++]-'0');

    }

    /*convert from byte array to host long integer format*/

    ip = (ip_array[0] << 24);

    ip = (ip | (ip_array[1] << 16));

    ip = (ip | (ip_array[2] << 8));

    ip = (ip | ip_array[3]);

    //printk(KERN_INFO "ip_str_to_hl convert %s to %un", ip_str, ip);

    return ip;

}

 

void ip_hl_to_str(unsigned int ip, char *ip_str) {

    /*convert hl to byte array first*/

    unsigned char ip_array[4];

    memset(ip_array, 0, 4);

    ip_array[0] = (ip_array[0] | (ip >> 24));

    ip_array[1] = (ip_array[1] | (ip >> 16));

    ip_array[2] = (ip_array[2] | (ip >> 8));

    ip_array[3] = (ip_array[3] | ip);

    sprintf(ip_str, "%u.%u.%u.%u", ip_array[0], ip_array[1], ip_array[2], ip_array[3]);

}

 

/*check the two input IP addresses, see if they match, only the first few bits (masked bits) are compared*/

bool check_ip(unsigned int ip, unsigned int ip_rule, unsigned int mask) {

    unsigned int tmp = ntohl(ip);    //network to host long

    int cmp_len = 32;

    int i = 0, j = 0;

    printk(KERN_INFO "compare ip: %u <=> %un", tmp, ip_rule);

    if (mask != 0) {

       cmp_len = 0;

       for (i = 0; i < 32; ++i) {

         if (mask & (1 << (32-1-i)))

           cmp_len++;

      else

         break;

       }

    }

    /*compare the two IP addresses for the first cmp_len bits*/

    for (i = 31, j = 0; j < cmp_len; --i, ++j) {

        if ((tmp & (1 << i)) != (ip_rule & (1 << i))) {

            printk(KERN_INFO "ip compare: %d bit doesn't matchn", (32-i));

            return false;

        }

    }

    return true;

}

 

void add_a_rule(struct mf_rule_desp* a_rule_desp) {

    struct mf_rule* a_rule;

    a_rule = kmalloc(sizeof(*a_rule), GFP_KERNEL);

    if (a_rule == NULL) {

        printk(KERN_INFO "error: cannot allocate memory for a_new_rulen");

        return;

    }

    a_rule->in_out = a_rule_desp->in_out;

    if (strcmp(a_rule_desp->src_ip, "-") != 0) 

        a_rule->src_ip = ip_str_to_hl(a_rule_desp->src_ip);

    else

        a_rule->src_ip = NULL;

    if (strcmp(a_rule_desp->src_netmask, "-") != 0)

        a_rule->src_netmask = ip_str_to_hl(a_rule_desp->src_netmask);

    else

        a_rule->src_netmask = NULL;

    if (strcmp(a_rule_desp->src_port, "-") != 0)

        a_rule->src_port = port_str_to_int(a_rule_desp->src_port);

    else 

        a_rule->src_port = NULL;

    if (strcmp(a_rule_desp->dest_ip, "-") != 0)

        a_rule->dest_ip = ip_str_to_hl(a_rule_desp->dest_ip);

    else 

        a_rule->dest_ip = NULL;

    if (strcmp(a_rule_desp->dest_netmask, "-") != 0)

        a_rule->dest_netmask = ip_str_to_hl(a_rule_desp->dest_netmask);

    else 

        a_rule->dest_netmask = NULL;

    if (strcmp(a_rule_desp->dest_port, "-") != 0)

        a_rule->dest_port = port_str_to_int(a_rule_desp->dest_port);

    else 

        a_rule->dest_port = NULL;

    a_rule->proto = a_rule_desp->proto;

    a_rule->action = a_rule_desp->action;

    printk(KERN_INFO "add_a_rule: in_out=%u, src_ip=%u, src_netmask=%u, src_port=%u, dest_ip=%u, dest_netmask=%u, dest_port=%u, proto=%u, action=%un", a_rule->in_out, a_rule->src_ip, a_rule->src_netmask, a_rule->src_port, a_rule->dest_ip, a_rule->dest_netmask, a_rule->dest_port, a_rule->proto, a_rule->action);

    INIT_LIST_HEAD(&(a_rule->list));

    list_add_tail(&(a_rule->list), &(policy_list.list));

}

 

void init_mf_rule_desp(struct mf_rule_desp* a_rule_desp) {

    a_rule_desp->in_out = 0;

    a_rule_desp->src_ip = (char *)kmalloc(16, GFP_KERNEL);

    a_rule_desp->src_netmask = (char *)kmalloc(16, GFP_KERNEL);

    a_rule_desp->src_port = (char *)kmalloc(16, GFP_KERNEL);

    a_rule_desp->dest_ip = (char *)kmalloc(16, GFP_KERNEL);

    a_rule_desp->dest_netmask = (char *)kmalloc(16, GFP_KERNEL);

    a_rule_desp->dest_port = (char *)kmalloc(16, GFP_KERNEL);

    a_rule_desp->proto = 0;

    a_rule_desp->action = 0;

}

 

void delete_a_rule(int num) {

    int i = 0;

    struct list_head *p, *q;

    struct mf_rule *a_rule;

    printk(KERN_INFO "delete a rule: %dn", num);

    list_for_each_safe(p, q, &policy_list.list) {

        ++i;

        if (i == num) {

            a_rule = list_entry(p, struct mf_rule, list);

            list_del(p);

            kfree(a_rule);

            return;

        }

    }

}

 

int procf_read(char *buffer, char **buffer_location, off_t offset, int buffer_length, int *eof, void *data)

 

{

 

    int ret;

 

    struct mf_rule *a_rule;

 

    char token[20];

 

    printk(KERN_INFO "procf_read (/proc/%s) called n", PROCF_NAME);

 

    if (offset > 0) {

 

        printk(KERN_INFO "eof is 1, nothing to readn");

 

        *eof = 1;

 

        return 0;

 

    } else {

 

        procf_buffer_pos = 0;

 

        ret = 0;

 

        list_for_each_entry(a_rule, &policy_list.list, list) {

 

            //in or out

 

            if (a_rule->in_out==1) {

 

                strcpy(token, "in");

 

            } else if (a_rule->in_out==2) {

 

                strcpy(token, "out");

 

            }

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //src ip

 

            if (a_rule->src_ip == NULL) {

 

                strcpy(token, "-");

 

            } else {

 

                ip_hl_to_str(a_rule->src_ip, token);

 

            } 

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //src netmask

 

            if (a_rule->src_netmask==NULL) {

 

                strcpy(token, "-");

 

            } else {

 

                ip_hl_to_str(a_rule->src_netmask, token);

 

            } 

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //src port

 

            if (a_rule->src_port==0) {

 

                strcpy(token, "-");

 

            } else {

 

                port_int_to_str(a_rule->src_port, token);

 

            } 

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //dest ip

 

            if (a_rule->dest_ip==NULL) {

 

                strcpy(token, "-");

 

            } else {

 

                ip_hl_to_str(a_rule->dest_ip, token);

 

            } 

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //dest netmask

 

            if (a_rule->dest_netmask==NULL) {

 

                strcpy(token, "-");

 

            } else {

 

                ip_hl_to_str(a_rule->dest_netmask, token);

 

            } 

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //dest port

 

            if (a_rule->dest_port==0) {

 

                strcpy(token, "-");

 

            } else {

 

                port_int_to_str(a_rule->dest_port, token);

 

            } 

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //protocol

 

            if (a_rule->proto==0) {

 

                strcpy(token, "ALL");

 

            } else if (a_rule->proto==1) {

 

                strcpy(token, "TCP");

 

            }  else if (a_rule->proto==2) {

 

                strcpy(token, "UDP");

 

            }

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, " ", 1);

 

            procf_buffer_pos++;

 

            //action

 

            if (a_rule->action==0) {

 

                strcpy(token, "BLOCK");

 

            } else if (a_rule->action==1) {

 

                strcpy(token, "UNBLOCK");

 

            }

 

            printk(KERN_INFO "token: %sn", token);

 

            memcpy(procf_buffer + procf_buffer_pos, token, strlen(token));

 

            procf_buffer_pos += strlen(token);

 

            memcpy(procf_buffer + procf_buffer_pos, "n", 1);

 

            procf_buffer_pos++;

 

        }

 

        //copy from procf_buffer to buffer

 

        printk(KERN_INFO "procf_buffer_pos: %ldn", procf_buffer_pos);

 

        memcpy(buffer, procf_buffer, procf_buffer_pos);

 

        ret = procf_buffer_pos;

 

    }

 

    return ret;

 

}

 

 

 

int procf_write(struct file *file, const char *buffer, unsigned long count, void *data)

 

{

 

   int i, j;

   struct mf_rule_desp *rule_desp;

 

   printk(KERN_INFO "procf_write is called.n");

 

   /*read the write content into the storage buffer*/

 

   procf_buffer_pos = 0;

 

   printk(KERN_INFO "pos: %ld; count: %ldn", procf_buffer_pos, count);

 

   if (procf_buffer_pos + count > PROCF_MAX_SIZE) {

 

       count = PROCF_MAX_SIZE-procf_buffer_pos;

 

   } 

 

   if (copy_from_user(procf_buffer+procf_buffer_pos, buffer, count)) {

 

       return -EFAULT;

 

   }

 

   if (procf_buffer[procf_buffer_pos] == 'p') {

 

       //print command

 

       return 0;

 

   } else if (procf_buffer[procf_buffer_pos] == 'd') {

 

       //delete command

 

       i = procf_buffer_pos+1; j = 0;

 

       while ((procf_buffer[i]!=' ') && (procf_buffer[i]!='n') ) {

 

           printk(KERN_INFO "delete: %dn", procf_buffer[i]-'0');

 

           j = j*10 + (procf_buffer[i]-'0');

 

           ++i;

 

       }

 

       printk(KERN_INFO "delete a rule: %dn", j);

 

       delete_a_rule(j);

 

       return count;

 

   }

 

   /*add a new policy according to content int the storage buffer*/

   rule_desp = kmalloc(sizeof(*rule_desp), GFP_KERNEL);

   if (rule_desp == NULL) {

       printk(KERN_INFO "error: cannot allocate memory for rule_despn");

       return -ENOMEM;

   }

 

   init_mf_rule_desp(rule_desp);

 

   

 

   /**fill in the content of the new policy **/

 

   /***in_out***/

 

   i = procf_buffer_pos; j = 0;

 

   if (procf_buffer[i]!=' ') {

 

       rule_desp->in_out = (unsigned char)(procf_buffer[i++] - '0');

 

   }

 

   ++i;

 

   printk(KERN_INFO "in or out: %un", rule_desp->in_out);

 

   /***src ip***/

 

   j = 0;

 

   while (procf_buffer[i]!=' ') {

 

       rule_desp->src_ip[j++] = procf_buffer[i++];

 

   }

 

   ++i;

 

   rule_desp->src_ip[j] = '';

 

   printk(KERN_INFO "src ip: %sn", rule_desp->src_ip);

 

   /***src netmask***/

 

   j = 0;

 

   while (procf_buffer[i]!=' ') {

 

       rule_desp->src_netmask[j++] = procf_buffer[i++];

 

   }

 

   ++i;

 

   rule_desp->src_netmask[j] = '';

 

   printk(KERN_INFO "src netmask: %sn", rule_desp->src_netmask);

 

   /***src port number***/

 

   j = 0;

 

   while (procf_buffer[i]!=' ') {

 

       rule_desp->src_port[j++] = procf_buffer[i++];

 

   }

 

   ++i;

 

   rule_desp->src_port[j] = '';

 

   printk(KERN_INFO "src_port: %sn", rule_desp->src_port);

 

   /***dest ip***/

 

   j = 0;

 

   while (procf_buffer[i]!=' ') {

 

       rule_desp->dest_ip[j++] = procf_buffer[i++];

 

   }