/*
 * Copyright 2006 Pascal Gloor <pascal.gloor@spale.com>
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <time.h>

#include <sys/param.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/stat.h>

#include <netinet/in.h>
#include <arpa/inet.h>

#include "flow5.h"
#include "config.h"
#include "socket.h"
#include "debug.h"

/* socket buffers length */
#define BUFLEN 65536

void      usage        (char *prg);
int       getpeerid    (NFS_conf *conf, struct in_addr *src, uint16_t srcif);

/* usage */
void usage(char *prg)
{
	fprintf(stderr,"Netflow Splitter v1.0 by Pascal Gloor.\n");
	fprintf(stderr,"usage: %s -c <config file>\n",prg);
	exit(EXIT_FAILURE);
}


/* start */
int main(int argc, char *argv[])
{
	NFS_conf *conf;
	char *prg = *argv;
	char *file = NULL;
	int sock;
	int outsock;
	int i;
	int loop = 1;
	char **pbuf;
	size_t *plen;
	#ifdef PROFILING
	int loops = 50;
	#endif

	/* parse cmd line arguments */
	while((i=getopt(argc,argv,"c:"))!=-1)
	{
		switch(i)
		{
			case 'c':
				file = optarg;
				break;
			default:
				usage(prg);
		}
	}

	/* tailing junk on cmd line */
	if ( argc != optind ) usage(prg);

	/* no config file */
	if ( file == NULL ) usage(prg);

	/* loading config file */
	debug(printf("loading config: '%s'",file));
	if ( ( conf = load_config(file) ) == NULL )
	{
		fprintf(stderr,"unable to load configuration file '%s', aborting.\n",file);
		exit(EXIT_FAILURE);
	}

	/* setup listener */
	debug(printf("setup listening socket"));
	if ( ( sock = socket_setup(&conf->sock) ) == -1 )
	{
		fprintf(stderr,"unable to setup listening socket, aborting.\n");
		exit(EXIT_FAILURE);
	}

	/* setup "sending" socket */
	debug(printf("setup sending socket"));
	if ( ( outsock = socket(PF_INET, SOCK_DGRAM, 0) ) == -1 )
	{
		perror("socket");
		exit(EXIT_FAILURE);
	}

	/* allocate buffers for each peer */
	debug(printf("allocate buffers"));
	if ( ( plen = malloc(sizeof(size_t) * conf->peer) ) == NULL )
	{
		perror("malloc");
		exit(EXIT_FAILURE);
	}
	memset(plen, 0, sizeof(size_t) * conf->peer);

	if ( ( pbuf = malloc(sizeof(char*) * conf->peer) ) == NULL )
	{
		perror("malloc");
		exit(EXIT_FAILURE);
	}

	for(i=0; i<conf->peer; i++)
	{
		if ( ( pbuf[i] = malloc(BUFLEN) ) == NULL )
		{
			perror("malloc");
			exit(EXIT_FAILURE);
		}
		memset(pbuf[i], 0, BUFLEN);
	}

	/* detach if not in debug mode */
	debug(printf("not detaching in DEBUG mode"));
	#if DEBUG == 0
	#if PROFILING == 0

	switch(fork())
	{
		case -1:
			perror("fork");
			exit(EXIT_FAILURE);
		case 0:
			break;
		default:
			exit(EXIT_SUCCESS);
	}

	/* close stdin, redirect stdout/stderr to /dev/null */
	fclose(stdin);
	freopen("/dev/null", "a", stdout);
	freopen("/dev/null", "a", stderr);

	/* set session id */
	setsid();

	#endif
	#endif

	#ifdef PROFILING
	printf("not detaching in PROFILING mode\n");
	loop = loops;
	#endif

	debug(printf("entering packet loop"));
	while(loop)
	{
		char buf[BUFLEN];
		size_t off = 0;
		ssize_t len;
		struct sockaddr_in from;
		socklen_t fromlen = sizeof(struct sockaddr_in);

		#ifdef PROFILING
		printf("processing packet %i of %i.\r",loops-loop,loops);
		fflush(stdout);
		loop--;
		#endif

		if ( ( len = recvfrom(sock, buf, sizeof(buf), 0, (struct sockaddr*)&from, &fromlen) ) == -1 )
		{
			perror("recv");
			exit(EXIT_FAILURE);
		}

		debug(printf("packet from %s:%i, length %i",inet_ntoa(from.sin_addr), ntohs(from.sin_port), len));


		while(off<len)
		{
			struct nf5_header *header = (struct nf5_header*)(buf+off);
			uint16_t rec,count;

			off += sizeof(struct nf5_header);

			if ( off>=len )
			{
				debug(printf("packet too small for nf5 header."));
				break;
			}

			count = ntohs(header->count);

			/* packet size mismatch */
			if ( len - off < count * (int)sizeof(struct nf5_record) )
			{
				debug(printf("packet too small for %i nf5 records.",count));
				break;
			}

			debug(printf("-> header with %i records.",count));

			for(rec = 0; rec < count; rec++)
			{
				int peerid;
				struct nf5_header *pheader;
				struct nf5_record *record = (struct nf5_record*) ( buf + off );

				off += sizeof(struct nf5_record);

				debug(printf("\t-> record from %i to %i",ntohs(record->input),ntohs(record->output)));

				if ( ( peerid = getpeerid(conf, &from.sin_addr,ntohs(record->input)) ) != -1 )
				{
					debug(printf("\t-> record added in peer %i buffer.",peerid));

					if ( plen[peerid] == 0 )
					{
						header->count = 0;
						memcpy(pbuf[peerid], header, sizeof(struct nf5_header));
						plen[peerid] = sizeof(struct nf5_header);
					}

					memcpy(pbuf[peerid] + plen[peerid], record, sizeof(struct nf5_record));
					plen[peerid] += sizeof(struct nf5_record);
					pheader  = (struct nf5_header*)pbuf[peerid];
					pheader->count++;
				}

				/* dont log a packet twice if input==output */
				if ( record->input == record->output )
					continue;

				if ( ( peerid = getpeerid(conf, &from.sin_addr,ntohs(record->output)) ) != -1 )
				{
					debug(printf("\t-> record added in peer %i buffer.",peerid));

					if ( plen[peerid] == 0 )
					{
						header->count = 0;
						memcpy(pbuf[peerid], header, sizeof(struct nf5_header));
						plen[peerid] = sizeof(struct nf5_header);
					}

					memcpy(pbuf[peerid] + plen[peerid], record, sizeof(struct nf5_record));
					plen[peerid] += sizeof(struct nf5_record);
					pheader  = (struct nf5_header*)pbuf[peerid];
					pheader->count++;
				}
			}

			for(i=0; i<conf->peer; i++)
			{
				if ( plen[i] == 0 ) continue;

				header = (struct nf5_header*)pbuf[i];

				debug(printf("sending packet for %s:%i with %i records.",inet_ntoa(conf->dst[i].sin_addr), ntohs(conf->dst[i].sin_port), header->count));

				header->count = htons(header->count);

				sendto(outsock, pbuf[i], plen[i], 0, (struct sockaddr*)&conf->dst[i], sizeof(struct sockaddr_in));
				plen[i] = 0;
			}
		}
	}

	#ifdef PROFILING
	printf("\n");
	printf("Profiling done.\n");
	printf("Run 'gprof %s %s.gmon' for the profiling results.\n",prg,prg);
	#endif

	return 0;
}

int getpeerid(NFS_conf *conf, struct in_addr *src, uint16_t srcif)
{
	int id;

	for(id = 0; id < conf->peer; id++)
	{
		if ( conf->src[id].sin_addr.s_addr == src->s_addr && conf->srcif[id] == srcif )
			return id;
	}
	return -1;
}