/* 
	Cryptocat: a highly secure data transfer server based on AES 128-bits encryption.
	(a challenge for seclab@unive by r1x, 2012)
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/types.h> 
#include <sys/wait.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <openssl/evp.h>
#include <fcntl.h>
#include <sys/errno.h>

#define MAXBUF 2048
#define E_WRONG_DEC "Decryption failed!\n"
#define E_WRONG_ENC "Encryption under session key failed!\n"
#define E_WRONG_KEY_LENGTH "Wrong session key length!\n"
#define C_DIR  "show_file_names"
#define C_SEND "send_encrypted"
#define SEP '#'

char buffer[MAXBUF];
char processed[][10] = {"Decrypted","Encrypted"};

/* A few logging/error functions */
void vprintlog(char *s, va_list args) {
	printf("[%i] ",getpid() );
	vprintf(s,args);
}
void printlog(char *s, ...) {
  va_list ap;
     va_start(ap, s);
    vprintlog(s,ap);
    va_end(ap); 
}
void fatallog(char *s, ...) {
  va_list ap;
    va_start(ap, s);
    vprintlog(s,ap);
    va_end(ap); 
    exit(1);
}
void error(const char *msg)
{
    perror(msg);
    printlog("FATAL: %s : %s",msg,strerror(errno));
    exit(1);
}

/* AES encrypt. do_encrypt is 1 for encryption and 0 for decryption */
int AEScrypt(char *in, int inlen, char *out, int *outlen, char *key, int do_encrypt)
{
	int tmplen;
	EVP_CIPHER_CTX ctx;
	
	/* Initialize */
	EVP_CIPHER_CTX_init(&ctx);
	EVP_CipherInit_ex(&ctx, EVP_aes_128_ecb(), NULL, NULL, NULL, do_encrypt);
	EVP_CIPHER_CTX_set_key_length(&ctx, 16);
	/* We finished modifying parameters so now we can set key */
	EVP_CipherInit_ex(&ctx, NULL, NULL, key, NULL, do_encrypt);
	/* Perform encryption/decryption */
	EVP_CipherUpdate(&ctx, out, outlen, in, inlen);
	if (EVP_CipherFinal_ex(&ctx, out+*outlen, &tmplen) == 0) {
		printlog("AESCrypt failure");
		return 0; // failure
	}
	*outlen += tmplen;

	printlog("%s %i bytes into %i bytes\n",processed[do_encrypt],inlen,*outlen);
	EVP_CIPHER_CTX_cleanup(&ctx);
	return *outlen;
}

/* Extract message parts */
int extract_part(char *s, int maxlen)
{
	int i=0;
	while(s[++i]!=SEP)
		if (i==maxlen) return 0;
	printlog("Extract part len=%i i=%i\n",maxlen,i);

	return i;
}

/* Serve the client */
void newclient(int fd)
{
	int i,n,outlen,fk,f;
	char out[MAXBUF+EVP_MAX_BLOCK_LENGTH];
	char key[16];
	bzero(buffer,MAXBUF);

	printlog("=== New client ===\n");
	i=0;
	while(i<MAXBUF && read(fd,&buffer[i],1)) i++;

	fk = open("key.bin",O_RDONLY);
	if (read(fk,key,16) != 16)
		error("Wrong key length");

	/* Decrypt under long-term key */
	if (AEScrypt(buffer,i,out,&outlen,key,0)==0) {
		write(fd,E_WRONG_DEC,strlen(E_WRONG_DEC)+1); 
		printlog(E_WRONG_DEC);
		exit(1);
	}

	int ep;
	if (!(ep = extract_part(out,outlen)))
		fatallog("Cannot extract message part");
		
	printlog("Command of length %i\n",ep);

	int pt=0;
	if (ep == strlen(C_DIR) && memcmp(out,C_DIR,strlen(C_DIR))==0) 
	{
		printlog("Executing ls\n");
		dup2(fd,1);
		execlp("ls","ls","-al",NULL);
		error("ERROR executing ls");
	} 
	else if (ep == strlen(C_SEND) && memcmp(out,C_SEND,strlen(C_SEND))==0) 
	{
		char skey[16],filename[MAXBUF];
		pt += strlen(C_SEND)+1;

		if (!(ep = extract_part(&out[pt],outlen-pt)))
			fatallog("Cannot find session key\n");
		if (ep != 16) {
			write(fd,E_WRONG_KEY_LENGTH,strlen(E_WRONG_KEY_LENGTH)+1);
			fatallog("Wrong session key length: %i\n",ep);
		}

		memcpy(skey,&out[pt],16);
		pt+=17;
		if (!(ep = extract_part(&out[pt],outlen-pt)))
			fatallog("Cannot find file name\n");

		memcpy(filename,&out[pt],ep);
		filename[ep]='\0';
		f = open(filename,O_RDONLY); 
		i=0;
		while(i<MAXBUF && read(f,&buffer[i],1)) i++;
		buffer[i]=0;
		printlog("Read %i bytes from file %s\n",i,filename);
		
		/* Encrypt under session key */
		if (AEScrypt(buffer,i,out,&outlen,skey,1)==0) { 
			write(fd,E_WRONG_ENC,strlen(E_WRONG_ENC)+1); 
			fatallog(E_WRONG_ENC);
		}

		write(fd,out,outlen);
		printlog("Sent encrypted file\n");

	}
	if (n < 0) error("ERROR writing to socket");
}

int main(int argc, char *argv[])
{
	int sockfd, newsockfd, portno,stat;
	socklen_t clilen;
	pid_t pid;

	struct sockaddr_in serv_addr, cli_addr;
	int n;
	if (argc < 2) {
		fprintf(stderr,"ERROR, no port provided\n");
		exit(1);
	}
	sockfd = socket(AF_INET, SOCK_STREAM, 0);
	if (sockfd < 0) 
	error("ERROR opening socket");
	bzero((char *) &serv_addr, sizeof(serv_addr));
	portno = atoi(argv[1]);
	serv_addr.sin_family = AF_INET;
	serv_addr.sin_addr.s_addr = INADDR_ANY;
	serv_addr.sin_port = htons(portno);
	if (bind(sockfd, (struct sockaddr *) &serv_addr,
		sizeof(serv_addr)) < 0) 
		error("ERROR on binding");
	listen(sockfd,5);
	clilen = sizeof(cli_addr);

     while (1) {
     	waitpid(-1,&stat,WNOHANG); // get rid of zombies
		newsockfd = accept(sockfd, 
		         (struct sockaddr *) &cli_addr, 
		         &clilen);
		if (newsockfd < 0) 
		  error("ERROR on accept");
		if ((pid=fork())<0) {
			error("ERROR on fork");
		} else if (pid == 0) {
			newclient(newsockfd);
			exit(0);
	 	} else 
	 		close(newsockfd);
 	}
}