/*
wrapper.c

Steve S

This is intended to be an inetd wrapper, so you can allow/deny connections
from certain IP addresses or domains.
Use it by prefixing the line in /etc/inetd.conf, thus:

OLD:
telnet	stream	tcp	nowait	NOLUID	/etc/telnetd	telnetd
NEW:
telnet  stream  tcp     nowait  NOLUID  /usr/local/etc/wrapper wrapper /etc/telnetd telnetd

see, we added the path and service name in there.

The called program will have two new environment variables --
SRC_HOST and SRC_IP
Note that telnet seems to rewrite its environment anyway, and so these are lost

The program will look in DEFAULT_PATH (currently /usr/local/etc/rules/) for
a file called by its SECOND argument (argv[0] in the wrapped program).  In 
the above example, this would be 'telnetd'.  This file holds lines of the
format:

+full.host.name
+*.domain.name
+1.1.1.1   
+1.1
+1.1.0.0
-full.host.name
-*.domain.name
-1.1.1.1   
-1.1
-1.1.0.0

Lines prefixed with a + are 'allow' lines (if none are found, the default is
to allow ALL) and lines with a - are 'deny' lines. In the case of overlap,
the LAST definition takes precedence.  If you specify a host by IP address,
you can specify a subnet by using trailing 0's or by omitting the 0's 
altogether.
*/

#define SCO
#include <stdio.h>
#ifdef AIX
#define _BSD 43
#include <sys/time.h>
#endif
#include <sys/select.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <string.h>
#include <ctype.h>
#include <signal.h>
#include <netdb.h>
#include <errno.h>
#include <sys/param.h>
#include <sys/resource.h>

#define NONE (fd_set *) NULL
#define NEVER (struct timeval *) NULL
#define IGNORE (struct sockaddr *) NULL

#define DEFAULT_PATH "/usr/local/etc/rules/"

struct rule {
	int r_type;
	#define T_ACCEPT 0
	#define T_REJECT 1
	#define T_IP     0
	#define T_HOST   2
	#define IP_ACCEPT   0
	#define IP_REJECT   1
	#define HOST_ACCEPT 2
	#define HOST_REJECT 3
	union { 
		char ip[4];
		long s_addr;
		char hostname[32];
	} r_site;
	#define r_ip r_site.ip
	#define r_addr r_site.s_addr
	#define r_hostname r_site.hostname
	struct rule * r_next;
};

int debug = 0;
struct rule *r_head = (struct rule *)0;
struct sockaddr_in sin;
struct hostent *shes;

read_file( char *f )
{
	FILE *fp;
	char buf[128];
	struct rule *srs;
	int a,b,c,d,n;
	int ruletype;

	if(debug) printf("Reading file %s\n",f);

	fp = fopen(f,"r");
	if(!fp) {	
		printf("Unable to read configuration file.  Aborting.\n");
		exit(1);
	}

	while(fgets(buf,sizeof(buf),fp)) {
		if((buf[0]=='\n')||(buf[0]=='#')) continue;
		srs = (struct rule *)malloc(sizeof(struct rule));
		if(buf[0]=='+') ruletype = T_ACCEPT;
		if(buf[0]=='-') ruletype = T_REJECT;
		if(isdigit((int)buf[1])) {
			ruletype |= T_IP;
			n=sscanf(buf+1,"%d.%d.%d.%d",&a,&b,&c,&d);
			if(n<4) d=0;
			if(n<3) c=0;
			if(n<2) b=0;
			srs->r_ip[0]=(unsigned char)a;
			srs->r_ip[1]=(unsigned char)b;
			srs->r_ip[2]=(unsigned char)c;
			srs->r_ip[3]=(unsigned char)d;
		} else {
			ruletype |= T_HOST;
			strncpy(srs->r_hostname,buf+1,sizeof(srs->r_hostname)-1);
		}
		srs->r_type = ruletype;
		srs->r_next = r_head;
		r_head = srs;
	}
}

int ip_compare( unsigned char addr[4] )
/* return 1 for match */
{
	int i;
	union {
		u_long l;
		unsigned char c[4];
	} u;

	u.l = sin.sin_addr.s_addr;

	if(debug) {
		printf("comparing ");
		for(i=0;i<4;i++)printf("[%d]",(int)addr[i]);
		printf(" to ");
		for(i=0;i<4;i++)printf("[%d]",(int)u.c[i]);
		printf("\n");
	}
	for(i=0;i<4;i++) {
		if(addr[i]==0) break;
		if(addr[i] != u.c[i] ) return 0;
	}
	return 1;
}

int host_compare( char *host )
/* return 1 for match */
{
	char *p, *q;
	int l,m;

	if(debug) {
		printf( "comparing [%s] to [%s]\n", host,shes->h_name);
	}

	for(l = strlen(host);l && ((host[l-1]==' ')||(host[l-1]=='\n'));l--);
	if(!l) return 1;
	p = host+l-1;
	for(m = strlen(shes->h_name);m && ((shes->h_name[m-1]==' ')||
		(shes->h_name[m-1]=='\n'));m--);
	if(!m) return 1;
	q = shes->h_name+m-1;
	
	while(l && m) {
		if(*p == '*') break;
		if(*p != *q ) return 0;
		p--; q--;
		l--; m--;
	}
	return 1;
}

int check_rules()
{
	int accept_conn = 0;
	struct rule *srs;
	int matches;

	/* we reject anyone with a matching REJECT record. we set accept_con
	   to 1 if we find an ACCEPT record that doesnt match (and it is currently
	   set to 0) and to 2 if we find one that does match. */

	if(debug) printf("Checking rules...\n");

	for(srs=r_head;srs;srs=srs->r_next) {
		/* check to see if we match */
		if(debug) printf("Rule type [%d]\n",srs->r_type);
		matches = 0;
		switch(srs->r_type) {
			case IP_ACCEPT:
			case IP_REJECT:
				matches = ip_compare(srs->r_ip);
				break;
			case HOST_ACCEPT:
			case HOST_REJECT:
				matches = host_compare(srs->r_hostname);
				break;
		}
		if(debug) printf("Rule check yields %d\n",matches);
		switch(srs->r_type) {
			case IP_ACCEPT:
			case HOST_ACCEPT:
				if(!accept_conn) accept_conn = 1;
				if(matches) accept_conn = 2;
				break;
			case IP_REJECT:
			case HOST_REJECT:
				if(matches) return 1;
				break;
		}
	}

	if(accept_conn == 1) return 1; /* ie, there were some ACCEPT records,
	                                  but none of them matched us        */
	return 0;
}

main(int argc, char * argv[], char * envp[])
{
	char buffer[128];
	char src_host[128], src_ip[128];
	int c;
	extern char *optarg;
	extern int optind;
	char *rulefile = (char *)0;
	int len;
	extern char ** environ;

	while((c = getopt(argc,argv,"df:")) != -1) 
		switch(c) {
			case 'd':	debug++; break;
			case 'f':   rulefile = optarg; break;
		}
	
	if(optind>=argc) {
		printf("Bad format: need minimum of 2 arguents.\n");
		exit(-1);
	}

	r_head = (struct rule *)0;
	if(rulefile)  {
		read_file(rulefile);
	} else {
		sprintf(buffer,"%s%s",DEFAULT_PATH,argv[optind+1]);
		read_file(buffer);
	}
	
	/* now, work out where we're coming from */
	len = sizeof(sin);
	if( getpeername(0,&sin,&len) < 0 ) {
		printf("ERROR: unable to identify calling socket.\n");
		perror("getpeername");
		exit(1);
	}
	shes = gethostbyaddr(&sin.sin_addr,sizeof(sin.sin_addr),AF_INET);

	/* now check for accept/reject */
	if(check_rules()) {
		printf("Sorry, your site is not allowed to connect to this service.\n");
		printf("Connection REFUSED.\n");
		exit(1);
	}

	/* now set the environment */
	sprintf(src_host,"SRC_HOST=%s", shes?shes->h_name:"" );
	putenv(src_host);
	if(debug) puts(src_host);
	sprintf(src_ip,"SRC_IP=%d.%d.%d.%d", 
		(unsigned char)((char *)&(sin.sin_addr.s_addr))[0],
		(unsigned char)((char *)&(sin.sin_addr.s_addr))[1],
		(unsigned char)((char *)&(sin.sin_addr.s_addr))[2],
		(unsigned char)((char *)&(sin.sin_addr.s_addr))[3]
	);
	putenv(src_ip);
	if(debug) puts(src_ip);

	if(debug) {
		printf("Calling [%s] %s...\n",argv[optind],argv[optind+1]);
	}
	if(debug) fflush(stdout);

	/* now exec the necessary program */
	execvp(argv[optind],argv+optind+1,environ);

	/* catch errors */
	printf("Oops, we shouldnt get this far.\n");
	exit(0); /* should never be reached */
}
