/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */

#include "aa_socket.h"

//#include <string.h>

#include <iostream>
using namespace std;

#include <unistd.h>
//#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/types.h>
#include <sys/socket.h>

#include <errno.h>

#include <netinet/in.h>
#if defined(linux)
#include <endian.h>
#else
#include <sys/endian.h>
#endif /*defined(linux)*/

// for gethostbyname
#include <netdb.h>

// These functions are wrappers, to preserve my nice method naming!
inline int _socket(int a,int b,int c){return socket(a,b,c);} 
inline int _connect(int a,const struct sockaddr *b,socklen_t c){return connect(a,b,c);}
inline int _listen(int a,int b){return listen(a,b);}
inline int _send(int a,char *b,unsigned int c, int d){return send(a,b,c,d);}


AASocket::AASocket()
{
}

AASocket::~AASocket()
{
  int err = close(socket);  // close server
	if(err == -1) throw Network_error("close", strerror(errno));
}

void AASocket::connect(char *host, unsigned short port)
{
  // create socket
  socket = _socket(PF_INET, SOCK_STREAM, IPPROTO_TCP); 
  // PF_INET: ipv4, PF_INET6: ipv6
  // tcp: IPPROTO_TCP
  // upd: IPPROTO_UDP

  if (socket == -1) throw Network_error("socket", strerror(errno));

  socketaddr.sin_family = AF_INET; // Use "internet protocol" IP
  socketaddr.sin_port = htons(port);  // connect to that port
  socketaddr.sin_addr.s_addr = INADDR_ANY;
  // INADDR_ANY puts your IP address automatically



	struct hostent *hp = gethostbyname(host);
	//	memcpy(&socketaddr.sin_addr.s_addr, *(hp->h_addr_list),sizeof(struct in_addr));
	memcpy(&(socketaddr.sin_addr),*(hp->h_addr_list),sizeof(struct in_addr));

	// FIXME: gethostbyname()
	//  socketaddr.sin_addr.s_addr = inet_addr(host); 
  //inet_aton (ip, &socketaddr.sin_addr);
  
  int err = _connect(socket, (struct sockaddr*)&socketaddr, sizeof(socketaddr));
	if(err == -1) throw Network_error("connect", strerror(errno));
}

void AASocket::listen(unsigned short port)
{
	int err;

	bind_socket = _socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
	if(bind_socket == -1) throw Network_error("tmp socket", strerror(errno));
	
	int optval = 1;
	err = setsockopt(bind_socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval));
	if(err == -1) throw Network_error("setsockopt", strerror(errno));

  socketaddr.sin_family = AF_INET; // Use "internet protocol" IP
  socketaddr.sin_port = htons(port);  // connect to that port
  socketaddr.sin_addr.s_addr = INADDR_ANY;
  // INADDR_ANY puts your IP address automatically
	
	// bind socket to address specified by "sa" parameter
	err = bind(bind_socket, (struct sockaddr*)&socketaddr, sizeof(socketaddr));
	if(err == -1) throw Network_error("bind", strerror(errno));
	
	err = _listen(bind_socket, 5);
	if(err == -1) throw Network_error("listen", strerror(errno));

  int csalen = sizeof(socketaddr);
  socket = accept(bind_socket, 
									(struct sockaddr*)&socketaddr, 
									(socklen_t*)&csalen);
	if(socket == -1) throw Network_error("accept", strerror(errno));

	err = close(bind_socket); // We don't need this anymore
  bind_socket = -1;
	if(err == -1) throw Network_error("tmp close", strerror(errno));
}


void AASocket::force_close()
{
  if(bind_socket != -1) close(bind_socket); // This should break the accept call
}


void AASocket::send(char* buf, unsigned int size)
{
	//unsigned int newsize = size + sizeof(unsigned int);
	//	char *newbuf = new char[newsize];

	unsigned int nsize = htonl(size);
	int n = _send(socket, (char*)&nsize, sizeof(unsigned int), MSG_WAITALL);
	if(n == -1) throw Network_error("send", strerror(errno));

	n = _send(socket, buf, size, MSG_WAITALL);
	if(n == -1) throw Network_error("send", strerror(errno));
}


int AASocket::receive(char* buf, unsigned int size)
{
	unsigned int insize;
	
	int n = recv(socket, &insize, sizeof(unsigned int), MSG_WAITALL);
	if(n == -1) throw Network_error("recv", strerror(errno));

	insize = ntohl(insize);
	if(insize > size) {
		char err_buf[256];
		sprintf(err_buf, "Buffer is too small. Should be %d is %d." , insize, size);
		throw Network_error("receive", err_buf);
	}
	
	n = recv(socket, buf, insize, MSG_WAITALL);
	if(n == -1) throw Network_error("recv", strerror(errno));

	return n;
}


void AASocket::send_string(string str)
{
	this->send((char*)str.c_str(), str.length());
}


string AASocket::receive_string()
{
	char buf[1024];
	memset(buf, 0, sizeof(buf));

	receive(buf, sizeof(buf));

  return string(buf);
}



#ifdef TEST_SOCKET

/**
 * Test application for AASocket
 * It should print the following to stdout:
 * A: Hello, how are you?
 * B: Fine thanks.
 * A: What about you?
 * B: I'm fine too.
 */

#include <sys/types.h>
#include <unistd.h>

#include <string>
#include <iostream>

int main()
{
	char buf[1024];
	memset(buf, 0, sizeof(buf));
  int f = fork();
  switch(f) {
  case -1: // Fork error
    perror("Fork failed!");
    return 1;

  case 0:  // Forked child
		{
			try {
				AASocket out;

				sleep(1); // Make sure the other end is listening

				// Test connect
				out.connect("127.0.0.1", 6666);

				// Test raw communication send
				sprintf(buf, "Hello how are you?");
				out.send(buf, sizeof(buf));

				// Test raw communication receive
				out.receive(buf, sizeof(buf));
				std::cout << "B: " << buf << std::endl;

				// Test string receive
				std::string q = out.receive_string();
				std::cout << "B: " << q << std::endl;

				// Test string send
				out.send_string(std::string("I'm fine too."));
				return 0;
			} catch(Network_error e) {
				std::cerr << "Out: " << e.error << std::endl;
			}
		}
  default: // Parent
		{
			try {
				AASocket in;
				
				// Test listen
				in.listen(6666);

				// Test raw communication receive
				in.receive(buf, sizeof(buf));
				std::cout << "A: " << buf << std::endl;

				// Test raw communication send
				sprintf(buf, "Fine thanks.");
				in.send(buf, sizeof(buf));

				// Test string send
				in.send_string(std::string("What about you?"));

				// Test string receive	
				std::string a = in.receive_string();
				std::cout << "A: " << a << std::endl;
				return 0;
			} catch(Network_error e) {
				std::cerr << "In: " << e.error << std::endl;
			}
		}
	}
	return 0;
}
#endif/*TEST_SOCKET*/