#include "defaults.hpp"
#include "packet.hpp"
#include "sched.hpp"
#include "classifier.hpp"
#include "config.hpp"
#include <new>
#include <queue>
#include <stdio.h>
#include <string.h>

extern "C" {
	#include <errno.h>
	#include <unistd.h>
	#include <netinet/in.h>
	#include <arpa/inet.h>
	#include <sys/stat.h>
	#include <sys/time.h>
	#include <sys/types.h>
	#include <sys/socket.h>
	#include <syslog.h>
	#include <signal.h>
	#include <linux/version.h>
	#include "log.h"

	#ifdef WITH_IPQ
	#include <linux/netfilter.h>
	#include <libipq.h>
	#endif //WITH_IPQ
}

using namespace std;

struct cmdline_args {
	char *cfg_file;
};

int daemon_proc = 0;
int kernel_sock = -1;
volatile int log_level = LL_MAX;
volatile int reload_cfg = 0;
volatile int shutdown_flag = 0;
void *generic_handle = 0;
generic_packet *(*get_packet)(void *arg) = 0;

#ifdef WITH_IPQ
struct ipq_stuff {
	struct ipq_handle *ipq_h;
	int ipq_init();
	int ipq_release();
} ipq;

int get_ipq_packet(shaper_config *cfg);
#endif //WITH_IPQ

#ifdef WITH_DIVERT
#define IPPROTO_DIVERT 254

int init_sock_raw(int port);
#endif //WITH_DIVERT

int do_shape(shaper_config *cfg);
int sleep(shaper_config *cfg);
int transmit(shaper_config *cfg, int class_idx);
int queue_packet(shaper_config *cfg, generic_packet *packet);
int daemon_init(const char *pidfile, const char *pname, int facility);
int parse_cmdline(int argc, char **argv, struct cmdline_args *cmdl_cfg);
int test_pidfile(char *fn);
void unlink_pidfile(const char *fn);
void sighup_handler(int i);
void sigint_handler(int i);
void sigterm_handler(int i);
void usage(char *pname);
void print_version_and_die(void);

int main(int argc, char **argv)
{
	int result;
	shaper_config *cfg;
	struct cmdline_args cmdl_cfg;

	result = parse_cmdline(argc, argv, &cmdl_cfg);
	if( result==-1 ) {
		usage(argv[0]);
		exit(1);
	}

	cfg = new(nothrow) shaper_config;
	if( cfg==0 ) {
		log_info(LL_ERR, "no memory for shaper_config!");
		goto __err_new;
	}

	result = cfg->read_config(cmdl_cfg.cfg_file);
	if( result!=0 ) {
		goto __err_read_config;
	}

	result = test_pidfile(cfg->get_pidfile());
	if( result!=0 ) {
		goto __err_test_pidfile;
	}

	#ifdef WITH_DIVERT
	if( cfg->packet_fw==OPT_FW_DIVERT ) {
		result = kernel_sock = init_sock_raw(cfg->get_divert_port());
		if( result==-1 ) {
			goto __err_fw_init;
		}
		get_packet = get_divert_packet;
		generic_handle = (void*)(kernel_sock);
	}
	#endif //WITH_DIVERT

	#ifdef WITH_IPQ
	if( cfg->packet_fw==OPT_FW_IPQ ) {
		result = ipq.ipq_init();
		if( result!=0 ) {
			goto __err_fw_init;
		}
		get_packet = get_ipq_packet;
		generic_handle = (void*)(ipq.ipq_h);
	}
	#endif //WITH_IPQ

	if( cfg->daemon ) {
		result = daemon_init(cfg->get_pidfile(), "shaperd", LOG_DAEMON);
		if( result==-1 ) goto __err_daemon_init;

		if( signal(SIGHUP, &sighup_handler)==SIG_ERR ) {
			log_info(LL_ALERT, "error installing SIGHUP handler "
				"{%s:%d}", __FILE__, __LINE__);
			goto __err_signal;
		}
	} else {
		if( signal(SIGINT, &sigint_handler)==SIG_ERR ) {
			log_info(LL_ALERT, "error installing SIGINT handler "
				"{%s:%d}", __FILE__, __LINE__);
			goto __err_signal;
		}

		if( cfg->catch_sighup==1 && 
		    signal(SIGHUP, &sighup_handler)==SIG_ERR ) {
			log_info(LL_ALERT, "error installing SIGHUP handler "
				"{%s:%d}", __FILE__, __LINE__);
			goto __err_signal;
		}
	}

	if( signal(SIGTERM, &sigterm_handler)==SIG_ERR ) {
		log_info(LL_ALERT, "error installing SIGTERM handler {%s:%d}",
			__FILE__, __LINE__);
		goto __err_signal;
	}

	do {
		if( reload_cfg ) {
			reload_cfg = 0;

			shaper_config *new_cfg = new(nothrow) shaper_config(cfg->packet_fw);
			if( new_cfg==0 ) {
				log_info(LL_ALERT, "no memory for loading "
					"new configuration {%s:%d}",
					__FILE__, __LINE__);
			}

			result = new_cfg->read_config(cmdl_cfg.cfg_file);
			if( result!=0 ) {
				log_info(LL_ERR, "error loading config file "
					"[%s] {%s:%d}", cmdl_cfg.cfg_file, 
					__FILE__, __LINE__);
				delete new_cfg;
			} else {
				log_info(LL_ERR, "config file [%s] seems to "
					"be ok {%s:%d}", cmdl_cfg.cfg_file, 
					__FILE__, __LINE__);
				log_info(LL_INFO, "reconfiguration OK {%s:%d}", 
					__FILE__, __LINE__);
				delete cfg;
				cfg = new_cfg;
			}
		}

		result = do_shape(cfg);
		if( result!=0 ) {
			log_info(LL_INFO, "do_shape() error. terminating ...");
			break;
		}

		if( shutdown_flag ) {
			log_info(LL_INFO, "cleaning up ... {%s:%d}",
				__FILE__, __LINE__);
			break;
		}
	} while( 1 );

	if( cfg->daemon ) 
		unlink_pidfile(cfg->get_pidfile());

	#ifdef WITH_DIVERT
	if( cfg->packet_fw==OPT_FW_DIVERT ) {
		log_info(LL_DEBUG1, "closing divert sock(%d) ... {%s:%d}", 
			kernel_sock, __FILE__, __LINE__);
		result = close(kernel_sock);
		if( result==-1 ) {
			log_info(LL_ERR, "error closing divert socket (%s) "
				"{%s:%d}", strerror(errno), 
				__FILE__, __LINE__);
			return -1;
		}
	}
	#endif //WITH_DIVERT

	if(1) {
		int fw = cfg->packet_fw;
		delete cfg;
	
		#ifdef WITH_IPQ
		if( fw==OPT_FW_IPQ )
			ipq.ipq_release();
		#endif //WITH_IPQ

		#ifdef WITH_DIVERT
		if( fw==OPT_FW_DIVERT ) {
			close(kernel_sock);
		}
		#endif //WITH_DIVERT
	}

	return 0;
	__err_signal:
		if( cfg->daemon ) 
			unlink_pidfile(cfg->get_pidfile());
	__err_daemon_init:
		if(1) {
			int fw = cfg->packet_fw;
			delete cfg;
			cfg = 0;

			#ifdef WITH_IPQ
			if( fw==OPT_FW_IPQ )
				ipq.ipq_release();
			#endif //WITH_IPQ

			#ifdef WITH_DIVERT
			if( fw==OPT_FW_DIVERT ) {
				close(kernel_sock);
			}
			#endif //WITH_DIVERT
		}
	__err_fw_init:
	__err_test_pidfile:
	__err_read_config:
		delete cfg;
	__err_new:
		return -1;
}

int do_shape(shaper_config *cfg)
{
	int result;

	while(1) {
		int class_idx;

		if( reload_cfg || shutdown_flag ) {
			log_info(LL_DEBUG1, "do_shape() exiting {%s:%d}",
				__FILE__, __LINE__);
			return 0;
		}

		result = class_idx = sleep(cfg);
		if( result>=0 ) {
			result = transmit(cfg, class_idx);
			if( result==-1 ) return -1;
		} else if( result==-2 ) {
			generic_packet *pckt;

			pckt = (*get_packet)(generic_handle);
			if( pckt==0 ) {
				// error
				continue;
			}

			result = queue_packet(cfg, pckt);
			if( result!=0 ) {
				// error
			}
		}
	}
	return -1;
}

int sleep(shaper_config *cfg)
{
	int i, dt, min_dt, result, whichclass;
	const int __infinito__ = 1 << (sizeof(int)*8-2);
	struct timeval tv, *tv_ptr;
	fd_set rfds;

	while(1) {
		whichclass = -1;
		min_dt = __infinito__;
		for( i = 0 ; i<cfg->n_classes ; i++ ) {
			classdef *cur_class = cfg->v_classes[i];

			if( cur_class->queue_empty() ) {
				continue;
			}

			result = dt = cur_class->my_sched.is_sleeping();
			if( result==-1 ) {
				log_info(LL_ERR, "sched::is_sleeping() error");
				exit(-1);
			} else if( result==0 ) {
				// not empty
				// not sleeping
				// -> ready to transmit
				return i;
			}

			if( dt < min_dt ) {
				min_dt = dt;
				whichclass = i;
			}
		}

		if( min_dt==__infinito__ ) {
			tv_ptr = 0;
		} else {
			tv_ptr = &tv;
			tv.tv_sec = 0;
			tv.tv_usec = min_dt;
		}
		// *** warning ***
		// this is a hack for netfilter's blocking api
		// but, it's o.k. for divert sockets :)
		FD_ZERO(&rfds);
		FD_SET(kernel_sock, &rfds);

		result = select(kernel_sock+1, &rfds, 0, 0, tv_ptr);
		if( result==-1 && errno!=EINTR ) {
			log_info(LL_ERR, "select() error (%s) {%s:%d}", 
				strerror(errno), __FILE__, __LINE__);
			exit(-1);
		} else if( result==-1 && errno==EINTR ) {
			return -3;
		} else if( result>0 ) {
			return -2;
		} else if( result==0 ) {
			return whichclass;
		}   
		log_info(LL_ALERT, "inconsistency. terminating {%s:%d}",
			__FILE__, __LINE__);
		exit(-1);
	}
}

int transmit(shaper_config *cfg, int class_idx)
{
	int result;
	classdef *my_class=0;
	generic_packet *packet=0;

	my_class = cfg->v_classes[class_idx];
	if( my_class->my_sched.is_sleeping() || my_class->queue_empty() ) {
		log_info(LL_ERR, "internal logic error: class [%s] is "
			"sleeping/empty in transmit()! {%s:%d}", 
			my_class->get_name(), __FILE__, __LINE__);
		exit(-1);
	}
	packet = my_class->dequeue_packet();

	if( my_class->my_sched.can_send()==0 ) {
		log_info(LL_ERR, "can't send packet, aborting program {%s:%d}",
			__FILE__, __LINE__);
		delete packet;
		exit(-1);
	}

	result = my_class->my_sched.send(packet);
	if( result==-1 ) {
		char packet_info[256];
		packet->log_packet(packet_info, sizeof(packet_info));
		if( log_level<LL_WARN ) {
			log_info(LL_ERR, "error reinjecting packet {%s:%d}",
				__FILE__, __LINE__);
		} else {
			log_info(LL_WARN, "error reinjecting packet -> %s "
				"{%s:%d}", packet_info, __FILE__, __LINE__);
		}
	}
	delete packet;
	return 0;
}

//  0 -> ok
// -1 -> error
int queue_packet(shaper_config *cfg, generic_packet *packet)
{
	static int seq=0;
	classdef *pc;
	int i, result, classidx, prio;

	seq++;
	if( log_level==LL_DEBUG2 ) {
		log_info(LL_DEBUG2, "%d packets received from kerneal", seq);
	} else if( seq%100==0 ) {
		log_info(LL_DEBUG1, "%d packets received from kerneal", seq);
	}

	classidx = -1;
	for( i = 0 ; i<cfg->n_classes ; i++ ) {
		classdef *pc = cfg->v_classes[i];
		result = pc->check_packet(packet, &prio);
		if( result==+1 ) {
			// match :)
			classidx = i;
			break;
		} 
	}

	if( classidx==-1 ) {
		char packet_info[256];
		packet->log_packet(packet_info, sizeof(packet_info));
		log_info(LL_WARN, "unmatched packet -> seq=%d %s", 
			seq, packet_info);
		delete packet;
		return 0;
	}

	pc = cfg->v_classes[classidx];
	#ifdef WITH_DIVERT
	if( cfg->packet_fw == OPT_FW_DIVERT ) {
		divert_packet *div_pckt = (divert_packet*)(packet);
		if( pc->divert_reinjection==1 ) {
			// outgoing packet
			div_pckt->from.sin_addr.s_addr = 0;
		}
	}
	#endif //WITH_DIVERT

	// todo: check 4 errors
	pc->queue_packet(packet, prio);

	return 0;
}

int parse_cmdline(int argc, char **argv, struct cmdline_args *cmdl_cfg)
{
	int opt;

	cmdl_cfg->cfg_file = DEFAULT_CFGFILE;
	while( (opt = getopt(argc, argv, "hvc:")) != EOF ) {
		switch( opt ) {
			case 'c':
				cmdl_cfg->cfg_file = optarg;
				break;
			case 'h':
				usage(argv[0]);
				exit(0);
				break;
			case 'v':
				print_version_and_die();
				break;
			default:
				return -1;
				break;
		}
	}

	if( optind<argc )
		return -1;
	if( cmdl_cfg->cfg_file==0 )
		return -1;

	return 0;
}

void unlink_pidfile(const char *fn)
{
	int result;

	result = unlink(fn);
	if( result!=0 ) {
		log_info(LL_ERR, "can't unlink() pidfile (%s) {%s:%d}",
			strerror(errno), __FILE__, __LINE__);
	}
}

void sighup_handler(int i)
{
	log_info(LL_DEBUG1, "catched SIGHUP signal (%d)", i);
	reload_cfg = 1;
}

void sigint_handler(int i)
{
	log_info(LL_DEBUG1, "catched SIGINT signal (%d)", i);
	shutdown_flag = 1;
}

void sigterm_handler(int i)
{
	log_info(LL_DEBUG1, "catched SIGTERM signal (%d)", i);
	shutdown_flag = 1;
}

int test_pidfile(char *fn)
{
	int result;
	struct stat buf;

	result = stat(fn, &buf);
	if( result==-1 && errno==ENOENT )
		return 0;
	if( result!=0 ) {
		log_message(LL_ALERT, 
			"can't stat() (%s) %s",
			strerror(errno), fn);
		return -1;
	}
	log_message(LL_ALERT, 
		"already running!!! (see %s)",
		fn);
	return -1;
}

int daemon_init(const char *pidfile, const char *pname, int facility)
{
	pid_t pid;
	int result;
	FILE *fp;
	char buf[1024];

	log_info(LL_DEBUG1, "daemonizing process ...");

	pid = fork();
	if( pid==-1 ) {
		log_info(LL_ALERT, "can't fork() (%s) {%s:%d}",
			strerror(errno), __FILE__, __LINE__);
		return -1;
	} else if( pid!=0 ) {
		exit(0);
	}

	result = setsid();
	if( result==-1 ) {
		log_info(LL_ALERT, "setsid() error (%s) {%s:%d}",
			strerror(errno), __FILE__, __LINE__);
		return -1;
	}

	if( signal(SIGHUP, SIG_IGN)==SIG_ERR ) {
		log_info(LL_ALERT, "signal() error (%s) {%s:%d}",
			strerror(errno), __FILE__, __LINE__);
		return -1;
	}

	pid = fork();
	if( pid==-1 ) {
		log_info(LL_ALERT, "can't fork() (%s) {%s:%d}",
			strerror(errno), __FILE__, __LINE__);
		return -1;
	} else if( pid!=0 ) {
		exit(0);
	}

	// the following paragraph has been taken from courier-mta
	result = snprintf(buf, sizeof(buf), "%ld\n", (long)getpid());
	if( result>=1024 || (fp = fopen(pidfile, "w"))==NULL ||
	    fprintf(fp, "%s", buf)<0 || fflush(fp)<0 || fclose(fp) )
	{
		log_info(LL_ALERT, "error creating pidfile {%s:%d}",
			__FILE__, __LINE__);
		return -1;
	}

	close(0); close(1); close(2);
	openlog(pname, LOG_PID, facility);
	daemon_proc = 1;
	return 0;
}

void print_version_and_die(void)
{
	char *str_ipq = "with_ipq=no";
	char *str_divert = "with_divert=no";

	#ifdef WITH_DIVERT
	str_divert = "with_divert=yes";
	#endif

	#ifdef WITH_IPQ
	str_ipq = "with_ipq=yes";
	#endif

	printf("version %s\n", VERSION_STR);
	printf("built %s, %s; %s %s\n",
		str_divert, str_ipq,  __DATE__, __TIME__);
	printf("Compiled against: %s\n", UTS_RELEASE);
	printf("GNU CC: %s\n", __VERSION__);

	exit(0);
}

void usage(char *pname)
{
	printf("usage: %s [-h] [-v] [-c file]\n",
		pname);
}

#ifdef WITH_IPQ
int ipq_stuff::ipq_init()
{
	int result;

	ipq_h = 0;

#ifdef IPQ_124
	ipq_h = ipq_create_handle(0);
#else
	ipq_h = ipq_create_handle(0, PF_INET);
#endif
	if( ipq_h==0 ) {
		log_info(LL_ERR, "ipq_create_handle() error (%s) {%s:%d}",
			ipq_errstr(), __FILE__, __LINE__);
		goto __err_ipq_creat;
	}

	result = ipq_set_mode(ipq_h, IPQ_COPY_PACKET, 60+60);
	if( result==-1 ) {
		log_info(LL_ERR, "ipq_set_mode() error (%s) {%s:%d}",
			ipq_errstr(), __FILE__, __LINE__);
		goto __err_ipq_mode;
	}

	kernel_sock = ipq_h->fd;
	log_info(LL_DEBUG1, "using netlink socket %d {%s:%d}",
		kernel_sock, __FILE__, __LINE__);

	return 0;
	__err_ipq_mode:
		ipq_destroy_handle(ipq_h);
		ipq_h = 0;
	__err_ipq_creat:
		return -1;
}

int ipq_stuff::ipq_release()
{
	int result, retval=0;

	kernel_sock = -1;
	if( ipq_h ) {
		result = ipq_destroy_handle(ipq_h);
		ipq_h = 0;
		if( result!=0 ) {
			log_info(LL_ERR, "ipq_destroy_handle() error (%s) "
				"{%s:%d}", ipq_errstr(), __FILE__, __LINE__);
		retval = -1;
		}
	}
	return retval;
}
#endif //WITH_IPQ

#ifdef WITH_DIVERT
int init_sock_raw(int port)
{
	int sock, result, retval;
	struct sockaddr_in sin;

	result = sock = socket(AF_INET, SOCK_RAW, IPPROTO_DIVERT);
	if( result==-1 ) {
		log_info(LL_ALERT, "socket() error (%s)", strerror(errno));
		retval = -1;
		goto __err_socket;
	}

	memset(&sin, 0, sizeof(struct sockaddr_in));
	sin.sin_family = AF_INET;
	sin.sin_port = htons(port);
	sin.sin_addr.s_addr = INADDR_ANY;

	result = bind(sock, (struct sockaddr *)(&sin), sizeof(sin));
	if( result==-1 ) {
	log_info(LL_ALERT, "bind() error (%s)", strerror(errno));
		retval = -1;
		goto __err_bind;
	}

	return sock;
	__err_bind:
		close(sock);
	__err_socket:
		return retval;
}
#endif //WITH_DIVERT

