jmulticastsocket.cpp 7.27 KB
/***************************************************************************
 *   Copyright (C) 2005 by Jeff Ferr                                       *
 *   root@sat                                                              *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/
#include "jmulticastsocket.h"
#include "jsocketexception.h"
#include "jioexception.h"
#include <string.h>
#include <errno.h>

namespace jsocket {

uint16_t MulticastSocket::_used_port = 1024;

MulticastSocket::MulticastSocket(std::string host_, uint16_t port_)
{
    _buffer = NULL;
    _sock_opt = NULL;
    _is = NULL;
    _os = NULL;

	CreateSocket();
	ConnectSocket(InetAddress::GetByName(host_), port_);
	BindSocket(InetAddress::GetByName(host_), port_);
	Join(InetAddress::GetByName(host_));
	InitStream();

	_sent_bytes = 0;
	_receive_bytes = 0;
}

MulticastSocket::~MulticastSocket()
{
	try {
		Close();
	} catch (...) {
	}

    if (_buffer != NULL) {
        delete _buffer;
    }

    if (_buffer != NULL) {
        delete _sock_opt;
    }

    if (_is != NULL) {
        delete _is;
    }

    if (_os != NULL) {
        delete _os;
    }
}

/** Private */

void MulticastSocket::CreateSocket()
{
	_fds = ::socket(PF_INET, SOCK_DGRAM, IPPROTO_IP); // IPPROTO_MTP
	
	if (_fds < 0) {
		throw SocketException("Create multicast socket error !");
	}

	_fdr = ::socket(PF_INET, SOCK_DGRAM, IPPROTO_IP); // IPPROTO_MTP
	
	if (_fdr < 0) {
		throw SocketException("Create multicast socket error !");
	}
}

void MulticastSocket::BindSocket(InetAddress *addr_, uint16_t local_port_)
{
	if (bind(_fdr, (SA *)&_sock_r, sizeof(_sock_r)) < 0) {
		throw SocketException("Bind multicast socket error !");
	}
}

void MulticastSocket::ConnectSocket(InetAddress *addr_, uint16_t port_)
{
	// Receive
	bzero(&_sock_r, sizeof(_sock_r));

	_sock_r.sin_family = AF_INET;
	_sock_r.sin_port = htons(port_);

#ifdef SOLARIS
	_sock_r.sin_addr.s_addr = htonl(INADDR_ANY);
#else
	if (addr_ == NULL) {
		_sock_r.sin_addr.s_addr = inet_addr(INADDR_ANY);
	} else {
		_sock_r.sin_addr.s_addr = inet_addr(addr_->GetHostAddress().c_str());
	}
#endif

	// Send
	bzero(&_sock_s, sizeof(_sock_s));

	_sock_s.sin_family = AF_INET;
	_sock_s.sin_port = htons(port_);
	_sock_s.sin_addr.s_addr = htonl(INADDR_ANY);
}

void MulticastSocket::InitStream()
{
	_buffer = new MulticastSocketBuffer(_fds, _fdr, _sock_s, _sock_r);
    
	_is = new std::istream(_buffer); // dsb
	_os = new std::ostream(_buffer);
}

/** End */

std::istream & MulticastSocket::GetInputStream()
{
	return *_is;
}

std::ostream & MulticastSocket::GetOutputStream()
{
	return *_os;
}

int MulticastSocket::Receive(char *data, int size)
{
	int n, length = sizeof(_sock_r);

	n = ::recvfrom(_fdr, data, size, 0, (SA *)&_sock_r, (socklen_t*)&length);
	
	if (n < 0 && errno == EAGAIN) {
		throw SocketException("Socket timeout exception !");
	} else if (n < 0) {
		throw IOException("Read socket error !");
	}

	_receive_bytes += n;

    return n;
}

int MulticastSocket::Send(char *data, int size)
{
	int n;
	
	n = ::sendto(_fds, data, size, 0, (SA *)&_sock_s, sizeof(_sock_s));
	
	if (n < 0) {
		throw IOException("Write udp data error !");
	}

	_sent_bytes += n;
	
	return n;
}

void MulticastSocket::Join(std::string group_)
{
	struct ip_mreq imr;

	imr.imr_multiaddr.s_addr = inet_addr(group_.c_str());
	imr.imr_interface.s_addr = htonl(INADDR_ANY);

	if (setsockopt(_fdr, IPPROTO_IP, IP_ADD_MEMBERSHIP, &imr, sizeof(struct ip_mreq)) < 0) {
		throw SocketException("Join group error !");
	}

	_groups.push_back(group_);
}

void MulticastSocket::Join(InetAddress *group_)
{
	struct ip_mreq imr;

	imr.imr_multiaddr.s_addr = inet_addr(group_->GetHostAddress().c_str());
	imr.imr_interface.s_addr = htonl(INADDR_ANY);

	if (setsockopt(_fdr, IPPROTO_IP, IP_ADD_MEMBERSHIP, &imr, sizeof(struct ip_mreq)) < 0) {
		throw SocketException("Join group error !");
	}

	_groups.push_back(group_->GetHostAddress());
}

void MulticastSocket::Leave(std::string group_)
{
	struct ip_mreq imr;

	for (std::vector<std::string>::iterator i=_groups.begin(); i!=_groups.end(); i++) {
		if (group_ == (*i)) {
			imr.imr_multiaddr.s_addr = inet_addr(group_.c_str());
			imr.imr_interface.s_addr = htonl(INADDR_ANY);

			if (setsockopt(_fdr, IPPROTO_IP, IP_DROP_MEMBERSHIP, &imr, sizeof(struct ip_mreq)) < 0) {
				throw SocketException("Leave group error !");
			}

			// _groups.remove(*i);

			break;
		}
	}
}

void MulticastSocket::Leave(InetAddress *group_)
{
	struct ip_mreq imr;
	std::string s = group_->GetHostAddress();

	for (std::vector<std::string>::iterator i=_groups.begin(); i!=_groups.end(); i++) {
		if (s == (*i)) {
			imr.imr_multiaddr.s_addr = inet_addr(s.c_str());
			imr.imr_interface.s_addr = htonl(INADDR_ANY);

			if (setsockopt(_fdr, IPPROTO_IP, IP_DROP_MEMBERSHIP, &imr, sizeof(struct ip_mreq)) < 0) {
				throw SocketException("Leave group error !");
			}

			// std::remove(i);

			break;
		}
	}
}

std::vector<std::string> & MulticastSocket::GetGroupList()
{
	return _groups;
}

void MulticastSocket::Close()
{
	if (close(_fdr) < 0) {
		throw IOException("Close multicast receiver descriptor error !");
	}
	
	if (close(_fds) < 0) {
		throw IOException("Close multicast sender descriptor error !");
	}
}

uint16_t MulticastSocket::GetLocalPort()
{
	return ntohs(_sock_r.sin_port);
}

uint64_t MulticastSocket::GetSentBytes()
{
	if (_sent_bytes != 0) {
		return _sent_bytes;
	} else {
		return _buffer->GetSentBytes();
	}
}

uint64_t MulticastSocket::GetReceiveBytes()
{
	if (_receive_bytes != 0) {
		return _receive_bytes;
	} else {
		return _buffer->GetReceiveBytes();
	}
}

void MulticastSocket::SetMulticastTTL(uint8_t ttl_)
{
	if (setsockopt(_fds, IPPROTO_IP, IP_MULTICAST_TTL, &ttl_, sizeof(uint8_t))) {
		throw SocketException("Seting multicast ttl error !");
	}
}

SocketOption * MulticastSocket::GetSocketOption()
{
	return new SocketOption(_fdr, SOCK_MCAST);
}

std::string MulticastSocket::what()
{
	return "Multicast";
}

};