/* 
 * Copyright (C) 2000-2001 Computer & Communications Research Laboratories,
 *			   Industrial Technology Research Institute
 */
/*
 * msgSock.c
 *
 * $Id: msgSock.c,v 1.37 2001/06/19 23:32:13 hcc Exp $
 */

#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#include <errno.h>
#ifdef UNIX
#include <sys/socket.h>
#include <sys/select.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <unistd.h>
#endif
#include "msgSock.h"
#include "osdep.h"

#define msgMAXBUFLEN  10000

#ifdef WIN32
static int	need_winsock_cleanup = 0;
#endif

TCR	sock_tracer = NULL;

struct msgSockAddrObj
{
	char			addr_[64];
	UINT16			port_;
	struct sockaddr_in	sockaddr_;
	msgSockAddr		next_;	
};

struct msgSockObj
{
	SOCKET			sockfd_;
	msgSockType		type_;
	msgSockAddr		laddr_;
	msgSockAddr		raddr_;
};

void msgLibInit()
{
#ifdef WIN32
	WSADATA	wsadata;
	SOCKET sock;

	if ((sock = socket(PF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) {
		if (WSAStartup(MAKEWORD(1, 1), &wsadata) == 0)
			need_winsock_cleanup = 1;
	} else
		closesocket(sock);
#endif
}
 
void msgLibClean()
{
#ifdef WIN32
	if (need_winsock_cleanup) WSACleanup();
#endif
}

/**************** msgSockAddr Functions *****************/
/* addr: IP address or domain name, auto-fill local IP addr if NULL */
msgSockAddr msgSockAddrNew(const char* addr, UINT16 port)
{
	struct hostent *he;
	msgSockAddr _this = (msgSockAddr)malloc(sizeof(struct msgSockAddrObj));

	if (_this == NULL) return NULL;

	if (addr != NULL)
		strcpy(_this->addr_, addr);
	else
		_this->addr_[0] = '\0';
	_this->port_ = port;

	memset((char *)&_this->sockaddr_, 0, sizeof(_this->sockaddr_));
	_this->sockaddr_.sin_family = AF_INET;
	_this->sockaddr_.sin_port = htons(port);

	if(addr == NULL) {
		_this->sockaddr_.sin_addr.s_addr = htonl(INADDR_ANY); /* auto-fill with local IP */		
	} else if (!isalpha((int)addr[0])) {
		_this->sockaddr_.sin_addr.s_addr = inet_addr(addr);	
	} else if ((he=gethostbyname(addr)) != NULL) {
		memcpy(&_this->sockaddr_.sin_addr, he->h_addr, he->h_length);
	} else {
		/* print error */
		msgSockAddrFree(_this);
		return NULL;
	}

	return _this;
}

void msgSockAddrFree(msgSockAddr _this)
{
	if (_this != NULL)
		free(_this);

	return;
}

msgSockAddr msgSockAddrDup(msgSockAddr _this)
{
	msgSockAddr _copy;

	if (_this == NULL) return NULL;

	_copy = (msgSockAddr)malloc(sizeof(struct msgSockAddrObj));
	if (_copy == NULL) return NULL;

	strcpy(_copy->addr_, _this->addr_);
	_copy->port_ = _this->port_;
	_copy->sockaddr_ = _this->sockaddr_;

	return _copy;
}
 
char* msgSockAddrGetAddr(msgSockAddr _this)
{
	return _this->addr_;
}
 
UINT16 msgSockAddrGetPort(msgSockAddr _this)
{
	return _this->port_;
}

/**************** msgSock Functions *****************/ 
/* bind address to socket if laddr != NULL
 * listen if a TCP server socket
 */
msgSock msgSockNew(msgSockType type, msgSockAddr laddr)
{
	msgSock _this;
	/*u_long nonblock;*/

	if (type == msgSockUnknown)
		return NULL;

	_this = (msgSock)malloc(sizeof *_this);
	if (_this == NULL) return NULL;

	_this->sockfd_ = INVALID_SOCKET;
	_this->type_ = type;

	if (laddr == NULL)
		_this->laddr_ = NULL;
	else
		_this->laddr_ = msgSockAddrDup(laddr);

	_this->raddr_ = NULL;

	if (type == msgSockStreamS)
		return _this;

	if (type == msgSockDgram)
		_this->sockfd_ = socket(PF_INET, SOCK_DGRAM, 0);
	else if ((type == msgSockServer) || (type == msgSockStreamC))
		_this->sockfd_ = socket(PF_INET, SOCK_STREAM, 0);

	if (_this->sockfd_ == INVALID_SOCKET) {
		TCRPrint(sock_tracer, 1, "<msgSockNew>: socket() error.\n");
		msgSockFree(_this);
		return NULL;
	} else
		TCRPrint(sock_tracer, 2, "<msgSockNew>: socket = %d\n", _this->sockfd_);

	/* set socket in non-blocking mode */
/*
#if defined(WIN32)
	nonblock = 1;
	ioctlsocket(_this->sockfd_, FIONBIO, &nonblock);
#else
	int flags;
	if ((flags = fcntl(_this->sockfd_, F_GETFL, 0)) < 0) return NULL;
	flags |= O_NONBLOCK;
	if (fcntl(_this->sockfd_, F_SETFL, flags) < 0) return NULL;
#endif 
*/
	if (type == msgSockDgram && laddr != NULL) {
		if (bind(_this->sockfd_, (struct sockaddr *)&laddr->sockaddr_, sizeof(laddr->sockaddr_)) 
			== -1) {
			/* print error */
			msgSockFree(_this);
			return NULL;
		}
	}

	if (type == msgSockServer) {
		if (laddr != NULL) {
			int yes = 1;
			if (setsockopt(_this->sockfd_, SOL_SOCKET, SO_REUSEADDR, (const void*)&yes, sizeof(int)) == SOCKET_ERROR) {
				/* print error */
				msgSockFree(_this);
				return NULL;
			}		
			if (bind(_this->sockfd_, (struct sockaddr *)&laddr->sockaddr_, sizeof(laddr->sockaddr_)) 
				== -1) {
				/* print error */
				msgSockFree(_this);
				return NULL;
			}
		}
		if (listen(_this->sockfd_, 5) == -1) {
			/* print error */
			msgSockFree(_this);
			return NULL;
		}
	}

	return _this;
}
 
void msgSockFree(msgSock _this)
{
	if (_this->sockfd_ != INVALID_SOCKET) {
		closesocket(_this->sockfd_);
	}

	if (_this->laddr_ != NULL) {
		msgSockAddrFree(_this->laddr_);
	}

	if (_this->raddr_ != NULL) {
		msgSockAddrFree(_this->raddr_);
	}

	free(_this);
}

SOCKET msgSockGetSock(msgSock _this)
{
	return _this->sockfd_;
}

void msgSockSetSock(msgSock _this, SOCKET s)
{
	_this->sockfd_ = s;
}

msgSockAddr msgSockGetRaddr(msgSock _this)
{
	return _this->raddr_;
}

void msgSockSetRaddr(msgSock _this, msgSockAddr raddr)
{
	if (_this->raddr_ != NULL)
		msgSockAddrFree(_this->raddr_);

	_this->raddr_ = msgSockAddrDup(raddr);
}

msgSockAddr msgSockGetLaddr(msgSock _this)
{
	return _this->laddr_;
}

void msgSockSetLaddr(msgSock _this, msgSockAddr laddr)
{
	if (_this->laddr_ != NULL)
		msgSockAddrFree(_this->laddr_);

	_this->laddr_ = msgSockAddrDup(laddr);
}

msgSock msgSockAccept(msgSock _this)
{
	msgSock stream;
	msgSockAddr raddr;
	fd_set rset;
	int n;
	/*u_long nonblock;*/
    	struct sockaddr_in  addr;
	UINT32 addrlen;
	SOCKET s;
	UINT16 port;
	char* ipaddr;

	if (_this->type_ != msgSockServer)
		return NULL;

	FD_ZERO(&rset);
	FD_SET(_this->sockfd_, &rset);

	n = select(_this->sockfd_ + 1, &rset, NULL, NULL, NULL/*blocking*/);

	addrlen = sizeof(addr);
	if (FD_ISSET(_this->sockfd_, &rset)) {
		s = accept(_this->sockfd_, (struct sockaddr *)&addr, &addrlen);
		if (s == INVALID_SOCKET) return NULL;
/*
#if defined(WIN32)
		nonblock = 1;
		ioctlsocket(s, FIONBIO, &nonblock);
#else
		int flags;
		if ((flags = fcntl(s, F_GETFL, 0)) < 0) return NULL;
		flags |= O_NONBLOCK;
		if (fcntl(s, F_SETFL, flags) < 0) return NULL;
#endif
*/
	}

	stream = msgSockNew(msgSockStreamS, NULL);
	msgSockSetSock(stream, s);

	port = ntohs((UINT16)addr.sin_port);
	ipaddr = inet_ntoa(addr.sin_addr);
	raddr = msgSockAddrNew(ipaddr, port);
	msgSockSetRaddr(stream, raddr);

	return stream;
}
				
int msgSockConnect(msgSock _this, msgSockAddr raddr)
{
	int n, result;
	int error;
	UINT32 len;
	fd_set rset, wset;

	if (_this->type_ != msgSockStreamC)
		return -1;

	result = connect(_this->sockfd_, (struct sockaddr *)&raddr->sockaddr_, sizeof(struct sockaddr));
	
	if (result == SOCKET_ERROR) {
		/* have to take care of reentrant conditions */
		if (getError() != EINPROGRESS)
			return -1;
		else {
			/* connect still in progress */
			FD_ZERO(&rset);
			FD_SET(_this->sockfd_, &rset);
			wset = rset;

			n = select(_this->sockfd_ + 1, &rset, &wset, NULL, NULL/*blocking*/);
			
			if (n == 0) {
				/* timeout */
				closesocket(_this->sockfd_);
				return -1;
			} 

			if (FD_ISSET(_this->sockfd_, &rset) || FD_ISSET(_this->sockfd_, &wset)) {
				len = sizeof(error);
				if (getsockopt(_this->sockfd_, SOL_SOCKET, SO_ERROR, (char*)&error, &len) < 0)
					/* Solaris pending error */
					return -1;
			} else
				/* other errors */
				return -1;
		}
	};

	/* connected */
	return 0;
}

/*  Send 'len' of bytes, return n bytes sent, -1 on error */
int msgSockSend(msgSock _this, const char* buf, int len)
{
	int n, nsent = 0;
	int error = 0;
	int result;
	fd_set wset;
	/*struct timeval timeout = {1, 0};*/

	if (_this->type_ != msgSockStreamC && _this->type_ != msgSockStreamS)
		return -1;

	if (_this->sockfd_ == INVALID_SOCKET)
		return -1;

	for (nsent = 0; nsent < len; nsent += n) {
		FD_ZERO(&wset);
		FD_SET(_this->sockfd_, &wset);

		result = select(_this->sockfd_ + 1, NULL, &wset, NULL, NULL/*blocking*/);

		if (result == SOCKET_ERROR) {
			/* Timed out; return bytes transferred. */
			/*
			if (errno == ETIME)
				break;
			*/
			/* Other errors. */
			error = 1;
			break;
		}

		n = send(_this->sockfd_, buf + nsent, len - nsent, 0/*no flags*/);
      
		/* note that errno cannot be EWOULDBLOCK since select()
                 * just told us that data can be written.
		 */
		if (n == SOCKET_ERROR || n == 0) {
			error = 1;
			break;
		}
	}

	if (error)
		return -1;
	else
		return nsent;
}

/* Receive 'len' of bytes into buf
 * Return n byets actually received, -1 on error
 */
int msgSockRecv(msgSock _this, char* buf, int len)
{
	int n, nrecv = 0;
	int error = 0;
	int result;
	fd_set rset;
	/*struct timeval timeout = {1, 0};*/

	if (_this->type_ != msgSockStreamC && _this->type_ != msgSockStreamS)
		return -1;

	if (_this->sockfd_ == INVALID_SOCKET)
		return -1;

	for (nrecv = 0; nrecv < len; nrecv += n) {
		FD_ZERO(&rset);
		FD_SET(_this->sockfd_, &rset);

		result = select(_this->sockfd_ + 1, &rset, NULL, NULL, NULL/*blocking*/);

		if (result == SOCKET_ERROR) {
			/* Timed out; return bytes transferred. */
			/*
			if (errno == ETIME)
				break;
			*/
			/* Other errors. */
			error = 1;
			break;
		}

		n = recv(_this->sockfd_, buf + nrecv, len - nrecv, 0/*no flags*/);
      
		/* note that errno cannot be EWOULDBLOCK since select()
                 * just told us that data can be written.
		 */
		/*if (n == -1 || n == 0) {*/
		if (n == SOCKET_ERROR) {
			error = 1;
			break;
		}

		if (n == 0) {
			error = 0;
			nrecv = 0;
			break;
		}
	}

	if (error)
		return -1;
	else
		return nrecv;
}

/*
 *  RFC: 2126 (ISO Transport Service on top of TCP (ITOT))
 * 
 *	A TPKT consists of two part:
 *	
 *  - a Packet Header
 *  - a TPDU.
 *
 *  The format of the Packet Header is constant regardless of the type of
 *  TPDU. The format of the Packet Header is as follows:
 *
 *  +--------+--------+----------------+-----------....---------------+
 *  |version |reserved| packet length  |             TPDU             |
 *  +----------------------------------------------....---------------+
 *  <8 bits> <8 bits> <   16 bits    > <       variable length       >
 *
 *  where:
 *
 *  - Protocol Version Number
 *    length: 8 bits
 *    Value:  3
 *
 *  - Reserved
 *    length: 8 bits
 *    Value:  0
 *
 *  - Packet Length
 *    length: 16 bits
 *    Value:  Length of the entire TPKT in octets, including Packet
 *            Header
 */

/* Return length of len + 4, if success. Otherwise, return -1. */
int msgSockTpktSend(msgSock _this, const char* buf, int len)
{
	int nSent, nTpktSent;
	unsigned char tpktHeader[4];

	tpktHeader[0] = 3;	/* Protocol Version Number, always 3 */
	tpktHeader[1] = 0;	/* Reserved, 0 */
	tpktHeader[2] = (len + 4) >> 8;	/* high byte of nLength */
	tpktHeader[3] = (len + 4);	/* low byte of nLength */

	if ((nTpktSent = msgSockSend(_this, (char *)tpktHeader, 4)) != -1) {
		if (nTpktSent == 4) {
			if ((nSent = msgSockSend(_this, buf, len)) != -1)
				nTpktSent = (nSent == len) ? (nTpktSent + nSent) : -1;
			else 
				nTpktSent = -1;
		} else 
			nTpktSent = -1;
	}

	return nTpktSent;
}

int msgSockTpktRecv(msgSock _this, char* buf, int len)
{
	int nRecv, nTPDU;
	unsigned char tpktHeader[4];

	if ((nRecv = msgSockRecv(_this, (char*)tpktHeader, 4)) != -1) {
		if (nRecv == 4) {
			nTPDU = (int) (( tpktHeader[2]) << 8)+tpktHeader[3];
			nTPDU -= 4;
			if ((nRecv = msgSockRecv(_this, buf, nTPDU)) != -1) {
				if (nRecv != nTPDU) nRecv = (nRecv > 0) ? -1 : 0;
			}
		} else 
			nRecv = (nRecv > 0) ? -1 : 0;
	} else 
		return 0;

	return nRecv;
}

int msgSockSendto(msgSock _this, const char* buf, int len, msgSockAddr raddr)
{
	int nsent;

	if (_this->type_ != msgSockDgram)
		return -1;
 
	nsent = sendto( _this->sockfd_, 
			buf, 
			len, 
			0, /* No flags */
			(struct sockaddr*)&(raddr->sockaddr_),
			sizeof(raddr->sockaddr_));

	if(nsent == SOCKET_ERROR) 
		return -1;

	return nsent;
}

int msgSockRecvfrom(msgSock _this, char* buf, int len)
{
	msgSockAddr raddr;
	int nrecv;
	UINT32 addrlen;
	struct sockaddr_in addr;
	UINT16 port;
	char* ipaddr;

	if (_this->type_ != msgSockDgram)
		return -1;

	addrlen = sizeof(addr);
	nrecv = recvfrom( _this->sockfd_, 
			buf, 
			len, 
			0, /* no flags */
			(struct sockaddr*)&addr, 
			&addrlen);

	if (nrecv == SOCKET_ERROR) 
		return -1;

	port = ntohs((UINT16)addr.sin_port);
	ipaddr = inet_ntoa(addr.sin_addr);
	raddr = msgSockAddrNew(ipaddr, port);
	msgSockSetRaddr(_this, raddr);

	return nrecv;
}

/*========================================================================
// msgSock Tracer */
void	msgSockSetTracer(TCR tracer)
{
	sock_tracer = tracer;
}

TCR	msgSockGetTracer(void)
{
	return sock_tracer;
}

