update
[rrq/nfblocker.git] / src / nfblocker.c
1 #include <linux/types.h>
2 #include <netinet/in.h>
3 #include <netinet/ip.h>
4 #include <netinet/tcp.h>
5 #include <stdio.h>
6 #include <stdlib.h>
7 #include <string.h>
8 #include <unistd.h>
9 #include <linux/netfilter.h>            /* for NF_ACCEPT */
10
11 #include <libnetfilter_queue/libnetfilter_queue.h>
12
13 // Caching of verdicts
14 unsigned int lookup_cache(unsigned char *domain);
15 void add_cache(unsigned char *domain,unsigned int ix);
16 int hash_code(unsigned char *domain);
17
18 // BAD domains database
19 unsigned int check_domain(unsigned char *domain);
20 void load_domains(char *file);
21 void start_domain_database_loading(void);
22 void end_domain_database_loading(void);
23
24 /**
25  * Return packet id, or 0 on error.
26  */
27 static u_int32_t get_packet_id(struct nfq_data *tb) {
28     struct nfqnl_msg_packet_hdr *ph = nfq_get_msg_packet_hdr( tb );
29     return ( ph )? ntohl( ph->packet_id ) : 0;
30 }
31
32 // Payload headers
33 struct headers {
34     struct ip first;
35     struct tcphdr second;
36     //unsigned char pad[12]; // ??
37 };
38
39 ///////// Debugging
40
41 static unsigned char *tell_ip(u_int32_t ip) {
42     static unsigned char THEIP[20];
43     unsigned char *b = (unsigned char *)&ip;
44     sprintf( (char*) THEIP, "%d.%d.%d.%d%c", b[0], b[1], b[2], b[3], 0 );
45     return THEIP;
46 }
47
48 static u_int32_t get_dest_ip4(unsigned char *data) {
49     struct headers *p = (struct headers *) data;
50     return p->first.ip_dst.s_addr;
51 }
52
53 /**
54  * Review payload packet payload
55  */
56 static void view_payload(unsigned char *data,int length) {
57     u_int32_t ip4 = get_dest_ip4( data );
58     u_int16_t port = ntohs( ((struct headers *) data )->second.th_dport );
59     u_int8_t syn = sizeof( struct headers );
60     unsigned char *body = data ;//+ sizeof( struct headers );
61 #define END 400
62     unsigned char * end = body + ( ( length > END )? END : length );
63     fprintf( stderr, "%s %d %d %d ", tell_ip( ip4 ), syn, port, length );
64     while ( body < end ) {
65         unsigned char c = *body++;
66         if ( c < ' ' || c >= 127 || 1 ) {
67             fprintf( stderr, "%02x ", c );
68         } else {
69             fprintf( stderr, "%c", c );
70         }
71     }
72     fprintf( stderr, "\n" );
73 }
74
75 //////////////////
76 static unsigned char buffer[1000];
77
78 /**
79  * SSL traffic includes a data packet with a clear text host name.
80  * This is knwon as the SNI extension.
81  */
82 static unsigned char *ssl_host(unsigned char *data,int length) {
83     // Check that it's a "Client Hello" message
84     unsigned char *p = data + sizeof( struct headers ) + 12;
85     if ( p[0] != 0x16 || p[1] != 0x03 || p[5] != 0x01 || p[6] != 0x00 ) {
86         return 0;
87     }
88     //fprintf( stderr, "Client Hello\n" );
89     // Note minor version p[2] is not checked
90     // record_length = 256 * p[3] + p[4]
91     // handshake_message_length = 256 * p[7] + p[8]
92     if ( p[9] != 0x03 || p[10] != 0x03 ) { // TLS 1.2 (?ralph?)
93         return 0;
94     }
95     //fprintf( stderr, "TLS 1.2\n" );
96     unsigned int i = 46 + ( 256 * p[44] ) + p[45];
97     i += p[i] + 1;
98     unsigned int extensions_length = ( 256 * p[i] ) + p[i+1];
99     i += 2;
100     int k = 0;
101     //fprintf( stderr, "TLS 1.2 %d %d\n", i, extensions_length );
102     while ( k < extensions_length ) {
103         unsigned int type = ( 256 * p[i+k] ) + p[i+k+1];
104         k += 2;
105         unsigned int length = ( 256 * p[i+k] ) + p[i+k+1];
106         k += 2;
107         //fprintf( stderr, "Extension %d %d\n", k-4, type );
108         if ( type == 0 ) { // Server Name
109             if ( p[i+k+2] ) {
110                 break; // Name badness
111             }
112             unsigned int name_length = ( 256 * p[i+k+3] ) + p[i+k+4];
113             unsigned char *path = &p[i+k+5];
114             memcpy( buffer, path, name_length );
115             buffer[ name_length ] = '\0';
116             return buffer;
117         }
118         k += length;
119     }
120     // This point is only reached on "missing or bad SNI".
121     view_payload( data, length );
122     return 0;
123 }
124
125 /**
126  * HTTP traffic includes a data packet with the host name as a
127  * "Host:" attribute.
128  */
129 static unsigned char *http_host(unsigned char *data,int length) {
130     unsigned char *body = data + sizeof( struct headers );
131     if ( ( strncmp( (char*) body, "GET ", 4 ) != 0 ) &&
132          ( strncmp( (char*) body, "POST ", 5 ) != 0 ) ) {
133         return 0;
134     }
135     unsigned char *end = data + length - 6;
136     int check = 0;
137     for ( ; body < end; body++ ) {
138         if ( check ) {
139             if ( strncmp( (char*) body, "Host:", 5 ) == 0 ) {
140                 body += 5;
141                 for( ; body < end; body++ ) if ( *body != ' ' ) break;
142                 unsigned char *start = body;
143                 int n = 0;
144                 for( ; body < end; n++, body++ ) if ( *body <= ' ' ) break;
145                 if ( n < 5 ) {
146                     return 0;
147                 }
148                 memcpy( buffer, start, n );
149                 buffer[ n ] = '\0';
150                 return buffer;
151             }
152             if ( strncmp( (char*) body, "\r\n", 2 ) == 0 ) {
153                 return 0;
154             }
155             for( ; body < end; body++ ) if ( *body == '\n' ) break;
156             if ( body >= end ) {
157                 return 0;
158             }
159         }
160         check = ( *body == '\n' );
161     }
162     return 0;
163 }
164
165 /**
166  * Callback function to handle a packet.
167  */
168 static int cb(
169     struct nfq_q_handle *qh,
170     struct nfgenmsg *nfmsg,
171     struct nfq_data *nfa, void *code )
172 {
173     u_int32_t id = get_packet_id( nfa );
174     unsigned char *data;
175     int length = nfq_get_payload( nfa, &data);
176     int verdict = NF_ACCEPT;
177     u_int32_t ip4 = get_dest_ip4( data );
178 #if 0
179     fprintf( stderr, "PKT %s %d\n", tell_ip( ip4 ), length );
180 #endif
181     if ( length >= 100 ) {
182         unsigned char *host = http_host( data, length );
183 #if 1
184             fprintf( stderr, "HTTP HOST %s %s\n", tell_ip( ip4 ), host );
185 #endif
186         if ( host == 0 ) {
187             host = ssl_host( data, length );
188 #if 1
189             fprintf( stderr, "SSL HOST %s %s\n", tell_ip( ip4 ), host );
190 #endif
191         }
192         if ( host ) {
193             int i = lookup_cache( host );
194             if ( i < 0 ) {
195                 unsigned int ix = check_domain( host );
196                 add_cache( host, ix );
197 #if 1
198                 fprintf( stderr, "%s %d %s ** %d\n",
199                          tell_ip( ip4 ), hash_code( host ), host, ix );
200 #endif
201                 if ( ix > 0 ) {
202                     verdict = NF_DROP;
203                 }
204             } else if ( i > 0 ) {
205                 verdict = NF_DROP;
206             }
207         }
208     }
209     return nfq_set_verdict(qh, id, verdict, 0, NULL);
210 }
211
212 /**
213  * Program main function.
214  */
215 int main(int argc, char **argv) {
216     // Load the database
217     start_domain_database_loading();
218     int n = 1;
219     for ( ; n < argc; n++ ) {
220         fprintf( stderr, "Loading blacklist %s\n", argv[ n ] );
221         load_domains( argv[ n ] );
222     }
223     end_domain_database_loading();
224     
225     struct nfq_handle *h;
226     struct nfq_q_handle *qh;
227     //struct nfnl_handle *nh;
228     int fd;
229     int rv;
230     char buf[4096] __attribute__ ((aligned));
231     
232     fprintf( stderr, "opening library handle\n");
233     h = nfq_open();
234     if ( !h ) {
235         fprintf(stderr, "error during nfq_open()\n");
236         exit(1);
237     }
238     
239     fprintf( stderr, "unbinding any existing nf_queue handler\n" );
240     if ( nfq_unbind_pf(h, AF_INET) < 0 ) {
241         fprintf(stderr, "error during nfq_unbind_pf()\n");
242         exit(1);
243     }
244     
245     fprintf( stderr, "binding nfnetlink_queue as nf_queue handler\n" );
246     if ( nfq_bind_pf(h, AF_INET) < 0 ) {
247         fprintf(stderr, "error during nfq_bind_pf()\n");
248         exit(1);
249     }
250
251 #define THEQUEUE 99
252     fprintf( stderr, "binding this socket to queue '%d'\n", THEQUEUE );
253     qh = nfq_create_queue( h,  THEQUEUE, &cb, NULL );
254     if ( !qh ) {
255         fprintf(stderr, "error during nfq_create_queue()\n");
256         exit(1);
257     }
258     
259     fprintf( stderr, "setting copy_packet mode\n" );
260     if ( nfq_set_mode(qh, NFQNL_COPY_PACKET, 0xffff ) < 0) {
261         fprintf(stderr, "can't set packet_copy mode\n");
262         exit(1);
263     }
264     
265     fd = nfq_fd( h );
266     
267     while ( ( rv = recv(fd, buf, sizeof(buf), 0) ) && rv >= 0 ) {
268         //printf( "pkt received\n" );
269         nfq_handle_packet(h, buf, rv);
270     }
271     
272     fprintf( stderr, "unbinding from queue %d\n", THEQUEUE);
273     nfq_destroy_queue(qh);
274     
275 #ifdef INSANE
276     /* normally, applications SHOULD NOT issue this command, since it
277        detaches other programs/sockets from AF_INET, too ! */
278     fprintf( stderr, "unbinding from AF_INET\n");
279     nfq_unbind_pf(h, AF_INET);
280 #endif
281     
282     fprintf( stderr, "closing library handle\n");
283     nfq_close( h );
284     
285     exit( 0 );
286 }