545c9f7019690413caded5f65221667a781d5a0d
[rrq/rrqmisc.git] / socket-sniff / socket-sniff.c
1 #include <arpa/inet.h>
2 #include <fcntl.h>
3 #include <linux/if_ether.h>
4 #include <linux/in.h>
5 #include <stddef.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <string.h>
9 #include <sys/socket.h>
10 #include <sys/stat.h>
11 #include <sys/time.h>
12 #include <sys/types.h>
13 #include <unistd.h>
14
15 #include <hashvector.h>
16
17 // Seconds between outputs
18 static int DELAY = 5;
19
20 // Byte count fade-out between displays
21 static int FADE = 10000;
22
23 // Number of top usage to report
24 static int WORST = 20;
25
26 // Drop-out age
27 static int OLD = 600;
28
29 // Number of characters for text format IP holdings
30 #define IPBUFMAX 40
31
32 // Count record for IP -> length mapping
33 typedef struct _Count {
34     struct _Count *next; // Next Count in time order
35     struct _Count *prev; // Previous Count in time order
36     struct timeval when; // Last update time for this Count record
37     int ignore; // Flag to leave out from reports
38     int last;   // The saved accumulation from the last report period
39     int accum;  // Current accumulation
40     int total;  // The total accumulation (reduced by fading)
41     char ip[ IPBUFMAX ]; // The IP concerned, in ascii
42 } Count;
43
44 // Print message and exit
45 static void die(char *m) {
46     fprintf( stderr, "%s\n", m );
47     exit( 1 );
48 }
49
50 // Return pointer to the key for an item
51 static void *Countp_itemkey(void *item) {
52     return ((Count*) item)->ip;
53 }
54
55 // Return 1 if the item has the key, or 0 otherwise.
56 static int Countp_haskey(void *item,void *key) {
57     return memcmp( key, Countp_itemkey( item ), IPBUFMAX ) == 0;
58 }
59
60 // Returns the hashcode for a key
61 static unsigned long Countp_hashcode(void *key) {
62     return hashvector_hashcode( key, IPBUFMAX );
63 }
64
65 // The hashvector of seen IP
66 static hashvector TBL = {
67     .table = { VECTOR_SLOTS, 0 },
68     .fill = 0,
69     .holes = 0,
70     .keyhashcode = Countp_hashcode,
71     .itemkey = Countp_itemkey,
72     .haskey = Countp_haskey
73 };
74
75 // The Count records in time order
76 static struct {
77     Count *head;
78     Count *tail;
79 } trail;
80
81 // Temporary buffer for IP addresses in ascii
82 static char buffer[ IPBUFMAX ];
83
84 /*============================================================
85  * Reading ignore lines.
86  */
87 // Return pointer to the key for an item
88 static void *charp_itemkey(void *item) {
89     return item;
90 }
91
92 // Return 1 if the item has the key, or 0 otherwise.
93 static int charp_haskey(void *item,void *key) {
94     return strcmp( key, item ) == 0;
95 }
96
97 // Returns the hashcode for a key
98 static unsigned long charp_hashcode(void *key) {
99     return hashvector_hashcode( key, strlen( (const char *) key ) );
100 }
101
102 static hashvector IGN = {
103     .table = { 256, 0 },
104     .fill = 0,
105     .holes = 0,
106     .keyhashcode = charp_hashcode,
107     .itemkey = charp_itemkey,
108     .haskey = charp_haskey
109 };
110
111 static void read_ignore_file(char *filename) {
112     #define RDBLKSZ 1000000
113     static char block[ RDBLKSZ ];
114     static char *cur = block;
115     static char *end = block;
116     int fd = open( filename, O_RDONLY );
117     if ( fd < 0 ) {
118         die( "Cannot load the ignare file." );
119     }
120     for ( ;; ) {
121         char *p = cur;
122         size_t n;
123         for ( ;; ) { // move p to end of line
124             while ( p < end && *p != '\n' ) {
125                 p++;
126             }
127             if ( p < end ) {
128                 break;
129             }
130             if ( cur != block && cur != end ) {
131                 memmove( cur, block, end - cur );
132                 end -= cur - block;
133                 cur = block;
134                 p = end;
135             }
136             n = RDBLKSZ - ( end - cur );
137             n = read( fd, end, n );
138             if ( n <= 0 ) {
139                 return; // All done
140             }
141             end += n;
142         }
143         // Line from cur to '\n' at p < end.
144         char *ip = calloc( 1, p - cur + 1 );
145         memcpy( ip, cur, p - cur );
146         cur = p + 1;
147         hashvector_add( &IGN, ip );
148     }
149 }
150
151 /*============================================================*/
152
153 static int Countp_compare(const void *ax, const void *bx) {
154     Count *a = (Count*) ax;
155     Count *b = (Count*) bx;
156     if ( b->ignore ) {
157         return 1;
158     }
159     if ( a->ignore ) {
160         return -1;
161     }
162     int x = a->total - b->total;
163     if ( x ) {
164         return x;
165     }
166     return a->last - b->last;
167 }
168
169 static int Countp_fade_and_print(unsigned long index,void *x,void *d) {
170     if ( x ) {
171         Count *item = (Count *) x;
172         item->last = item->accum;
173         item->total += item->last - FADE;
174         item->accum = 0;
175         if ( item->total <= 0 ) {
176             item->total = 0;
177         } else if ( index < WORST && item->ignore == 0 ) {
178             fprintf( stdout, "... %s %d %d\n",
179                      item->ip, item->total, item->last );
180         }
181     }
182     return 0;
183 }
184
185 static int Countp_reclaim(vector *pv,unsigned long ix,void *item,void *data) {
186     return 0;
187 }
188
189
190 // ip points to [ IPBUFMAX ] of ip address in text
191 static void add_show_table(char *ip,size_t length) {
192     static time_t show = 0;
193     Count *item;
194     int i = hashvector_find( &TBL, ip, (void**) &item );
195     struct timeval now;
196     if ( gettimeofday( &now, 0 ) ) {
197         perror( "gettimeofday" );
198         exit( 1 );
199     }
200     if ( i == 0 ) {
201         item = (Count *) calloc( 1, sizeof( Count ) );
202         memcpy( item->ip, ip, strlen( ip ) );
203         hashvector_add( &TBL, item );
204         item->ignore = hashvector_find( &IGN, ip, 0 );
205         for ( i = strlen( ip )-1; i > 1; i-- ) {
206             if ( ip[i] == '.' || ip[i] == ':' ) {
207                 item->ignore |= hashvector_find( &IGN, ip, 0 );
208             }
209             ip[i] = 0;
210         }
211         fprintf( stdout, "add %s\n", item->ip );
212     } else {
213     // Unlink item from the trail
214         if ( item->next ) {
215             item->next->prev = item->prev;
216         }
217         if ( item->prev ) {
218             item->prev->next = item->next;
219         }
220         if ( trail.head == item ) {
221             trail.head = item->next;
222         }
223         if ( trail.tail == item ) {
224             trail.tail = item->prev;
225         }
226         item->prev = item->next = 0;
227     }
228     item->accum += length;
229     item->when = now;
230     // Link in item to the trail end
231     if ( trail.head == 0 ) {
232         trail.head = item;
233     } else {
234         trail.tail->next = item;
235         item->prev = trail.tail;
236     }
237     trail.tail = item;
238     // Drop counters older than an hour
239     while ( trail.head->when.tv_sec + OLD < item->when.tv_sec ) {
240         Count *old = trail.head;
241         trail.head = old->next;
242         if ( trail.head ) {
243             trail.head->prev = 0;
244         }
245         fprintf( stdout, "drop %s\n", old->ip );
246         hashvector_delete( &TBL, old );
247         free( old );
248     }
249     if ( now.tv_sec < show ) {
250         return;
251     }
252     if ( now.tv_sec - show > DELAY ) {
253         show = now.tv_sec;
254     }
255     show += DELAY; // Time for next output
256     vector ordered = { 0, 0 };
257     hashvector_contents( &TBL, &ordered );
258     vector_qsort( &ordered, Countp_compare );
259     vector_iterate( &ordered, Countp_fade_and_print, 0 );
260     vector_resize( &ordered, 0, Countp_reclaim, 0 );
261     fprintf( stdout, "==%ld/%ld/%ld\n", TBL.fill, TBL.holes, TBL.table.size );
262 }
263
264 static char *ipv4_address(char *b) {
265     memset( buffer, 0, sizeof( buffer ) );
266     sprintf( buffer, "%hhu.%hhu.%hhu.%hhu", b[0], b[1], b[2], b[3] );
267     return buffer;
268 }
269
270 static char *ipv6_address(short *b) {
271     memset( buffer, 0, sizeof( buffer ) );
272     sprintf( buffer, "%x:%x:%x:%x:%x:%x:%x:%x",
273              ntohs(b[0]), ntohs(b[1]), ntohs(b[2]), ntohs(b[3]),
274              ntohs(b[4]), ntohs(b[5]), ntohs(b[6]), ntohs(b[7]) );
275     return buffer;
276 }
277
278 int main(int argc,char **argv) {
279     static char packet[ 2048 ];
280     int ARG = 1;
281     // Check for -fN to set FADE
282     if ( ARG < argc && strncmp( argv[ ARG ], "-d", 2 ) == 0 ) {
283         if ( sscanf( argv[ ARG ]+2, "%d", &DELAY ) != 1 ) {
284             die( "Missing/bad delay value" );
285         }
286         fprintf( stdout, "Delay is %d seconds between reports\n", DELAY );
287         ARG++;
288     }
289     if ( ARG < argc && strncmp( argv[ ARG ], "-f", 2 ) == 0 ) {
290         if ( sscanf( argv[ ARG ]+2, "%d", &FADE ) != 1 ) {
291             die( "Missing/bad fade value" );
292         }
293         fprintf( stdout, "Fading %d bytes before reports\n", FADE );
294         ARG++;
295     }
296     if ( ARG < argc && strncmp( argv[ ARG ], "-n", 2 ) == 0 ) {
297         if ( sscanf( argv[ ARG ]+2, "%d", &WORST ) != 1 ) {
298             die( "Missing/bad number to display" );
299         }
300         fprintf( stdout, "Displaying at most %d lines in reports\n", WORST );
301         ARG++;
302     }
303     if ( ARG < argc && strncmp( argv[ ARG ], "-a", 2 ) == 0 ) {
304         if ( sscanf( argv[ ARG ]+2, "%d", &OLD ) != 1 ) {
305             die( "Missing/bad drop-out age (seconds)" );
306         }
307         fprintf( stdout, "Displaying at most %d lines in reports\n", WORST );
308         ARG++;
309     }
310     if ( ARG < argc && strncmp( argv[ ARG ], "-i", 2 ) == 0 ) {
311         char *filename = argv[ ARG ] + 2;
312         if ( (*filename) == 0 ) {
313             die( "Missing/bad ignore filename" );
314         }
315         read_ignore_file( filename );
316         fprintf( stdout, "ignoring ip prefixes from %s\n", filename );
317         ARG++;
318     }
319     if ( ARG >= argc ) {
320         die( "Please tell which interface to sniff" );
321     }
322     setbuf( stdout, 0 );
323     int N;
324     char *iface = argv[ ARG ];
325     int fd = socket( AF_PACKET, SOCK_RAW, htons( ETH_P_ALL ) );
326     char flags[4] = { 1,0,0,0 };
327     if ( fd < 0 ) {
328         perror( "what?" );
329         die( "socket" );
330     }
331     if ( fd < 0 ) {
332         die( "socket" );
333     }
334     N = strlen(iface);
335     if ( setsockopt( fd, SOL_SOCKET, SO_BINDTODEVICE, iface, N ) ) {
336         die( "setsockopt bind to device" );
337     }
338     if ( setsockopt( fd, SOL_SOCKET, SO_BROADCAST, &flags, 4 ) ) {
339         die( "setsockopt broadcast" );
340     }
341     while ( ( N = read( fd, packet, 2048 ) ) > 0 ) {
342         if ( N < 54 ) {
343             continue;
344         }
345         int code = ntohs( *((short*)(packet+12)) );
346         if ( code == 0x0800 ) {
347             // 14+12=src  14+16=dst
348             char *p = ipv4_address( packet+30 );
349             if ( ( strncmp( p, "127.", 4 ) != 0 ) ) {
350                 add_show_table( p, N );
351             }
352         } else if ( code == 0x86dd ) {
353             // 14+8=src 14+24=dst
354             char *p = ipv6_address( (short*)(packet+38) );
355             if ( ( strncmp( p, "ff02:0:0:0:0:", 13 ) != 0 ) &&
356                  ( strncmp( p, "0:0:0:0:0:0:0:1", 15 ) != 0 ) ) {
357                 add_show_table( p, N );
358             }
359         } else if ( code == 0x8100 ) {
360             // ignore VLAN
361         } else {
362             // funny code
363         }
364     }
365     return 0;
366 }