// This program is a UDP based tunneling of stdin/out Ethernet packets.
//
// A udptun program is a bi-directional networking plug, that handles
// incoming packets on a given port, and sends outgoing packets to a
// given (remote) host and port. 

// An incoming packet is decrypted with the receiver channel key
// An outgoing packet is encrypted with the receiver channel key.

#include <arpa/inet.h>
#include <errno.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>

#define BUFSIZE 65536 
static unsigned char buffer[ BUFSIZE ];

// Buffer for reading stdin in fragmented way.
static struct {
    unsigned char buffer[ BUFSIZE ];
    unsigned int end;
    unsigned int cur;
} input;

// Remote end.
static struct sockaddr_in remote;

// The actual name of this program (argv[0])
static unsigned char *progname;

static unsigned char *key;
static unsigned int key_length;
static unsigned int seed;

// Simple PSK encryption:
//
// First, xor each byte with a key byte that is picked from the key
// by means of an index that includes the prior encoding. Also,
// compute the sum of encrypted bytes into a "magic" that is added the
// "seed" for seeding the random number generator. Secondly reorder
// the bytes using successive rand number picks from the seeded
// generator.
//

static void encrypt(unsigned char *buf,unsigned int n) {
    unsigned int k;
    unsigned int r;
    unsigned char b;
    unsigned int magic;
    for ( k = 0, r = 0, magic = 0; k < n; k++ ) {
	r = ( r + magic + k ) % key_length;
	buf[k] ^= key[ r ];
	magic += buf[k];
    }
    srand( seed + magic );
    for ( k = 0; k < n; k++ ) {
	r = rand() % n;
	b = buf[k];
	buf[k] = buf[r];
	buf[r] = b;
    }
}

static void decrypt(unsigned char *buf,unsigned int n) {
    static unsigned int randoms[ BUFSIZE ];
    unsigned int k;
    unsigned int r;
    unsigned char b;
    unsigned int magic = 0;

    for ( k = 0; k < n; k++ ) {
	magic += buf[k];
    }
    srand( seed + magic );
    for ( k = 0; k < n; k++ ) {
	randoms[k] = rand() % n;
    }
    for ( k = n; k > 0; ) {
	r = randoms[ --k ];
	b = buf[k];
	buf[k] = buf[r];
	buf[r] = b;
    }
    for ( k = 0, r = 0, magic = 0; k < n; k++ ) {
	r = ( r + magic + k ) % key_length;
	magic += buf[k];
	buf[k] ^= key[r];
    }
}

static void loadkey(unsigned char *filename) {
    int e;
    unsigned char *p;
    int n;
    struct stat filestat;
    int fd = open( (char*) filename, O_RDONLY );
    if ( fd < 0 ) {
	perror( "open key file" );
	return;
    }
    if ( fstat( fd, &filestat ) ) {
	perror( "stat of key file" );
	return;
    }
    key_length = filestat.st_size;
    if ( key_length < 256 ) {
	fprintf( stderr, "Too small key file: %d\n", key_length );
	return;
    }
    key = malloc( key_length );
    if ( key == 0 ) {
	fprintf( stderr, "Cannot allocate %d bytes\n", key_length );
	return;
    }
    e = key_length;
    p = key;
    while ( ( n = read( fd, p, e ) ) > 0 ) {
	e -= n;
	p += n;
    }
    if ( e != 0 ) {
	fprintf( stderr, "Failed loading key\n" );
	return;
    }
    for ( e = 0; (unsigned) e < key_length; e++ ) {
	seed += key[ e ];
    }
    if ( seed == 0 ) {
	fprintf( stderr, "Bad key; adds up to 0\n" );
    }
}

#define TRACEMAC 0
#if TRACEMAC
static unsigned char *mactoa(unsigned char *mac) {
    static unsigned char buffer[20];
    unsigned char *p = buffer;
    int i = 0;
    for ( ; i < 6; i++, p += 3 ) {
	sprintf( (char*) p, "%02x:", (unsigned int) *(mac++) );
    }
    buffer[17] = 0;
    return buffer;
}
#endif

// Read a UDP packet into the given buffer
static int doreadUDP(int fd, unsigned char *buf) {
    ssize_t len;
    struct sockaddr_in src;
    socklen_t addrlen = sizeof( src );
    memset( &src, 0, sizeof( src ) );
    //fprintf( stderr, "read UDP on %d\n", fd );
    if ( ( len = recvfrom( fd, buf, BUFSIZE, 0,
			   (struct sockaddr*)&src, &addrlen ) ) == -1) {
	perror( "Receiving UDP" );
	exit( 1 );
    }
    if ( memcmp( &src, &remote, 8 ) ) {
	//fprintf( stderr, "Bad UDP bytes from %s:%d\n",
	//	 inet_ntoa( src.sin_addr ), ntohs( src.sin_port ) );
	return 0;
    }
    //fprintf( stderr, "UDP from %s:%d\n",
    //     inet_ntoa( src.sin_addr ), ntohs( src.sin_port ) );
    if ( seed ) {
	decrypt( buf, len );
    }
#if TRACEMAC
    if ( len > 12 ) {
	fprintf( stderr, "FROM %s\n", mactoa( (unsigned char *) buf ) );
    }
#endif
    //fprintf( stderr, "received %d UDP bytes", (int) len );
    return len;
}

// Read up to n bytes from the given file descriptor into the buffer
static int doread(int fd, unsigned char *buf, int n) {
    ssize_t len;
    //fprintf( stderr, "read up to %d bytes from %d\n", n, fd );
    if ( ( len = read( fd, buf, n ) ) < 0 ) {
	perror( "Reading stdin" );
	exit( 1 );
    }
    //fprintf( stderr, "got %d bytes from %d\n", (int) len, fd );
    return len;
}

// Read n bytes from the given file descriptor into the buffer.
// If partial is allowed, then return amount read.
static int read_into(int fd, unsigned char *buf, int n,int partial) {
    int r, x = n;
    //fprintf( stderr, "read %d bytes from %d\n", n, fd );
    while( x > 0 ) {
	if ( (r = doread( fd, buf, x ) ) == 0 ) {
	    return 0 ;
	}
	x -= r;
	buf += r;
	if ( partial ) {
	    return n - x;
	}
    }
    return n;
}


// Write a UDP packet from the given buffer.
static int dowriteUDP(int fd, unsigned char *buf, int n) {
    int w;
#if TRACEMAC
    if ( n > 12 ) {
	fprintf( stderr, "TO %s\n", mactoa( (unsigned char *) buf+6 ) );
    }
#endif
    //fprintf( stderr, "write UDP of %d bytes to %d %s:%d\n", n, fd,
    //     inet_ntoa(remote.sin_addr), ntohs(remote.sin_port) );
    if ( seed ) {
	encrypt( buf, n );
    }
    if ( ( w = sendto( fd, buf, n, 0, (const struct sockaddr *)&remote,
		       sizeof( remote ) ) ) <= 0 ) {
	perror( "Writing socket" );
	w = -1;
    }
    return w;
}

static int dowrite(int fd, unsigned char *buf, int n) {
    int w;
    //fprintf( stderr, "write %d to %d\n", n, fd );
    if ( ( w = write( fd, buf, n ) ) < 0){
	perror( "Writing data" );
	w = -1;
    }
    return w;
}

static void usage(void) {
    fprintf( stderr, "Usage:\n" );
    fprintf( stderr, "%s localport remotehost remoteport [ keyfile ]\n",
	     progname );
    exit( 1 );
}

// Application main function
// $1 = UDP port to listen on
// $2 = peer configuration file
int main(int argc, char *argv[]) {
    int port, udp_fd, maxfd, n, fromUDP, flag = 0;
    uint16_t plength;
    
    progname = (unsigned char *) argv[0];
    if ( argc < 4 ) {
	usage();
    }
    
    if ( inet_aton( argv[2], &remote.sin_addr ) == 0 ) {
	fprintf( stderr, "Bad remote host IP" );
	usage();
    }
    if ( sscanf( argv[3], "%d", &port ) != 1 ) {
	fprintf( stderr, "Bad remote port" );
	usage();
    }
    remote.sin_family = AF_INET;
    remote.sin_port = htons( port );
    
    if ( sscanf( argv[1], "%d", &port ) != 1 ) {
	fprintf( stderr, "Bad local port" );
	usage();
    }
    
    if ( argc == 5 ) {
	loadkey( (unsigned char *) argv[4] );
	if ( seed == 0 ) {
	    fprintf( stderr, "Cannot load keyfile. Exiting.\n" );
	    exit( 1 );
	}
    }
    setbuf( stdout, NULL ); // No buffering on stdout.
    if ( ( udp_fd = socket( AF_INET, SOCK_DGRAM, 0 ) ) == 0 ) {
	perror( "creating socket");
	exit(1);
    }
    struct sockaddr_in udp_addr;
    memset( &udp_addr, 0, sizeof( udp_addr ) );
    udp_addr.sin_family = AF_INET;
    udp_addr.sin_port = htons( port );
    udp_addr.sin_addr.s_addr = htonl(INADDR_ANY);
    if ( bind( udp_fd, (struct sockaddr*) &udp_addr, sizeof( udp_addr ) ) ) {
	fprintf( stderr, "Error binding socket!\n");
	exit(1);
    }
    maxfd = udp_fd + 1;
    // Handle packets
    while( 1 ) {
	fd_set rd_set;
	FD_ZERO( &rd_set );
	FD_SET( udp_fd, &rd_set ); FD_SET( 0, &rd_set );
	n = select( maxfd, &rd_set, NULL, NULL, NULL);
	if ( n < 0 ) {
	    if ( errno == EINTR ) {
		continue;
	    }
	    perror("select");
	    exit(1);
	}
	fromUDP = FD_ISSET( udp_fd, &rd_set );
	//fprintf( stderr, "more to do %d\n", fromUDP );
	if ( fromUDP && flag && FD_ISSET( 0, &rd_set ) ) {
	    fromUDP = 0; // Prefer handling stdin
	}
	//fprintf( stderr, "preferred %d\n", fromUDP );
	flag = ! flag; // Alternate input preference
	if ( fromUDP ) { // Handle UDP data; read and deliver a full packet
	    n = doreadUDP( udp_fd, buffer );
	    if ( n == 0 ) {
		continue;
	    }
	    plength = htons( n );
	    if ( dowrite( 1, (unsigned char *) &plength,
			  sizeof( plength ) ) < 0 ) {
		break;
	    }
	    if ( dowrite( 1, buffer, n ) < 0 ) {
		break;
	    }
	    continue;
	}
	// Handle stdin data; read packet data and deliver when full
	if ( FD_ISSET( 0, &rd_set ) ) {
	    if ( input.end == 0 ) {
		n = read_into( 0, (unsigned char *) &plength,
			       sizeof( plength ), 0 );
		if ( n == 0 ) {
		    break;
		}
		input.end = ntohs( plength );
		input.cur = 0;
	    }
	    //fprintf( stderr, "Want %d\n", input.end - input.cur );
	    input.cur += read_into(
		0, input.buffer + input.cur, input.end - input.cur, 1 );
	    if ( input.end == input.cur ) {
		dowriteUDP( udp_fd, input.buffer, input.end );
		input.end = 0;
		continue;
	    }
	}
    }
}
