update
[rrq/nfblocker.git] / src / database.c
1 #include <fcntl.h>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <sys/stat.h>
6 #include <sys/types.h>
7 #include <unistd.h>
8
9 /**
10  * This file implements a "database" of "bad" domains, loaded from
11  * ".acl" files of a fairly strict format; each domain to block is
12  * written on a line starting with a period, immediately followed by
13  * the domain to block, then an optional comment.
14  *
15  * The database is populated by using the call sequence:
16  * 1. start_domain_database_loading();
17  * 2. load_domains( filename ); // repeated
18  * N. end_domain_database_loading();
19  *
20  * The final call triggers a reordering of domains so as to support
21  * binary search in reverse text order, for matching domain suffixes.
22  * See the function `tail_compare` for details.
23  */
24
25 /**
26  * This is the Entry type for the "database", which basically is an
27  * array of these. The domain pointer will point at a domain name in
28  * the loaded ".acl" file, and length is the domain name length.
29  */
30 typedef struct _Entry {
31     int length;
32     unsigned char *domain;
33 } Entry;
34
35 /**
36  * This is the domain name database root structure. It holds a pointer
37  * to the array of Entry records, the fill of that array, and the
38  * allocated size for that array (no lesser than the fill, of course).
39  */
40 static struct {
41     Entry *table;
42     int fill;
43     int size;
44 } database = { 0, 0, 0 };
45
46 /**
47  * This function compares strings backwars; the last k bytes of string
48  * (a,na) versus string (b,nb). It also holds '.' as the least of
49  * characters, so as to ensure that refined/extended domain names are
50  * comparatively greater that their base domain names.
51  */
52 static int tail_compare(unsigned char *a,unsigned char *b,int k) {
53     while ( k-- > 0 ) {
54         int c = *(--a) - *(--b);
55         if ( c != 0) {
56             if ( *a == '.' ) {
57                 return -1;
58             }
59             if ( *b == '.' ) {
60                 return 1;
61             }
62             return c;
63         }
64     }
65     return 0;
66 }
67
68 /**
69  * Extend the domain name table to allow additions.
70  */
71 #define STARTSIZE 100000
72 static void grow() {
73     if ( database.table ) {
74         Entry *old = database.table;
75         int s = database.size;
76         database.size += 100000;
77         database.table = (Entry*) calloc( database.size, sizeof( Entry ) );
78         memcpy( database.table, old, s * sizeof( Entry ) );
79         free( old );
80     } else {
81         database.table = (Entry*) calloc( STARTSIZE, sizeof( Entry ) );
82         database.size = STARTSIZE;
83     }
84 }
85
86 /**
87  * Determine the index for given domain. This matches computes a tail
88  * match between the given domain and the databse domains, returning
89  * the index for the matching database entry, or (-index-1) to
90  * indicate insertion point. In lookup mode, a database entry being a
91  * tail domain part of the given domain is also considered a match.
92  */
93 static int index_domain(unsigned char *domain,int n,int lookup) {
94     int lo = 0;
95     int hi = database.fill;
96     while ( lo < hi ) {
97         int m = ( lo + hi ) / 2;
98         Entry *p = &database.table[ m ];
99         int k = p->length;
100         if ( n < k ) {
101             k = n;
102         }
103         int q = tail_compare( p->domain + p->length, domain + n, k );
104 #if 0
105         fprintf( stderr, "%s %d %d %d\n", domain, k, m, q );
106 #endif
107         if ( q == 0 ) {
108             if ( p->length < n ) {
109                 // table entry shorter => new entry after, or match on lookup
110                 if ( lookup && *(domain+n-k-1) == '.' ) {
111                     return m;
112                 }
113                 lo = m + 1;
114             } else if ( p->length > n ) {
115                 // table entry longer  => new entry before
116                 hi = m;
117             } else {
118                 // equal
119                 return m;
120             }
121         } else if ( q < 0 ) {
122             // new entry after
123             lo = m + 1;
124         } else {
125             // new entry before
126             hi = m;
127         }
128     }
129     return -lo - 1;
130 }
131
132 /**
133  * Determine the length of a "word"
134  */
135 static int wordlen(unsigned char *p) {
136     unsigned char *q = p;
137     while ( *q > ' ' ) {
138         q++;
139     }
140     return q - p;
141 }
142
143 #if 0
144 static void add_domain(char *domain) {
145     if ( database.fill >= database.size ) {
146         grow();
147     }
148     int length = wordlen( domain );
149     int i = index_domain( domain, length, 0 );
150     if ( i < 0 ) {
151         i = -i-1;
152         int tail = database.fill - i;
153         if ( tail ) {
154             memmove( &database.table[ i+1 ],
155                      &database.table[i],
156                      tail * sizeof( Entry ) );
157         }
158         database.table[ i ].domain = domain;
159         database.table[ i ].length = length;
160         database.fill++;
161     } else {
162         char *p1 = strndup( domain, length );
163         char *p2 = strndup( database.table[i].domain,
164                             database.table[i].length );
165         fprintf( stderr, "fill = %d %d %s  == %s\n",
166                  i, database.fill, p1, p2 );
167         free( p1 );
168         free( p2 );
169     }
170 }
171 #endif 
172
173 static void fast_add_domain(unsigned char *domain,int length) {
174     int fill = database.fill;
175     if ( fill >= database.size ) {
176         grow();
177     }
178     database.table[ fill ].length = length;
179     database.table[ fill ].domain = domain;
180     database.fill++;
181 }
182
183 static int table_order(Entry *a,Entry *b) {
184     int k = ( a->length < b->length )? a->length : b->length;
185     int c = tail_compare( a->domain + a->length,
186                           b->domain + b->length, k );
187     if ( c != 0 ) {
188         return c;
189     }
190     return a->length - b->length;
191 }
192
193 /**
194  * External call to check a given domain.
195  */
196 unsigned int check_domain(unsigned char *domain) {
197     int i = index_domain( domain, wordlen( domain ), 1 );
198     return ( i < 0 )? 0 : ( i + 1 );
199 }
200
201 void start_domain_database_loading(void) {
202 }
203
204 #if 0
205 static void dump_table() {
206     fprintf( stderr, "Table fill=%d size=%d\n", database.fill, database.size );
207     int i = 0;
208     for ( ; i < database.fill; i++ ) {
209         char *p = strndup( database.table[i].domain,
210                            database.table[i].length );
211         fprintf( stderr, "[%d] %d %p %s\n",
212                  i, database.table[i].length, database.table[i].domain, p );
213         free( p );
214     }
215 }
216 #endif
217
218 void end_domain_database_loading(void) {
219     qsort( database.table, database.fill, sizeof( Entry ),
220            (__compar_fn_t) table_order );
221     //dump_table();
222 }
223
224 /**
225  * Load BAD domain names from file. The file is line based where data
226  * lines consist of domain name starting with period and ending with
227  * space or newline, and other lines ignored.
228  */
229 void load_domains(char *file) {
230     struct stat info;
231     unsigned char *data;
232     //fprintf( stderr, "state(\"%s\",&info)\n", file );
233     if ( stat( file, &info ) ) {
234         perror( file );
235         exit( 1 );
236     }
237     int n = info.st_size;
238     data = (unsigned char *) malloc( n );
239     //fprintf( stderr, "open(\"%s\",)\n", file );
240     int fd = open( file, O_RDONLY );
241     if ( fd < 0 ) {
242         perror( file );
243         exit( 1 );
244     }
245     //fprintf( stderr, "Loading %s\n", file );
246     unsigned char *end = data;
247     while ( n > 0 ) {
248         int k = read( fd, end, n );
249         if ( k == 0 ) {
250             fprintf( stderr, "Premature EOF for %s\n", file );
251             exit( 1 );
252         }
253         end += k;
254         n -= k;
255     }
256     //fprintf( stderr, "processing %s %p %p\n", file, data, end );
257     unsigned char *p = data;
258 #if 0
259     int count = 0;
260 #endif
261     while( p < end ) {
262 #if 0
263         if ( ( ++count % 10000 ) == 0 ) {
264             fprintf( stderr, "%d rules\n", count );
265         }
266 #endif
267         if ( *p == '.' ) {
268             unsigned char *domain = ++p;
269             while ( *p > ' ' ) {
270                 p++;
271             }
272             fast_add_domain( domain, p - domain );
273         }
274         while ( p < end && *p != '\n' ) {
275             p++;
276         }
277         p++;
278     }
279     close( fd );
280 }