#include "socket.hpp" extern "C" { #include #include #include #include #include // for gethostbyname() #include // for close() } namespace { int setup_socket(const dsl::location_t& location, struct sockaddr_in& socket_address) { struct hostent *hp; int descriptor = socket(AF_INET, SOCK_STREAM, 0); if (descriptor < 0) throw dsl::socket_error("failed in socket() system call"); hp = gethostbyname(location.host_name.c_str()); if (hp == 0) { std::string error_msg = "could not find host " + location.host_name; throw dsl::socket_error(error_msg); } socket_address.sin_family = AF_INET; socket_address.sin_port = htons(location.port); memcpy(&socket_address.sin_addr, hp->h_addr_list[0], hp->h_length); return descriptor; } } // local namespace namespace dsl { //============================================================================= // socket_acceptor socket_acceptor::socket_acceptor(const location_t& location) : port_number(location.port) { descriptor = setup_socket(location, socket_address); // bind to a port and then start listening int retcode; retcode = bind(descriptor, (struct sockaddr *)&socket_address, sizeof(socket_address)); if (retcode < 0) throw socket_error("failed in bind() system call"); const int max_pending_connects = 15; retcode = listen(descriptor, max_pending_connects); if (retcode) throw socket_error("failed in listen() system call"); } socket_acceptor::~socket_acceptor() { close(descriptor); } //============================================================================= // socket_connection std::size_t socket_connection::send(const void* buffer, std::size_t num_bytes) const { std::size_t num_written = write(descriptor, buffer, num_bytes); if (num_written < 0) throw socket_error("error in write() system call"); return num_written; } std::size_t socket_connection::receive(void* buffer, std::size_t num_bytes) const { std::size_t num_read = read(descriptor, buffer, num_bytes); if (num_read < 0) throw socket_error("error in read() system call"); return num_read; } void socket_connection::send(const std::string& str) const { // send message length std::size_t len = str.size(); std::size_t num_written = write(descriptor, &len, sizeof(len)); if (num_written < 0) throw socket_error("error in write() system call"); // send message num_written = write(descriptor, str.c_str(), str.size()); if (num_written < 0) throw socket_error("error in write() system call"); } void socket_connection::receive(std::string& str) const { // receive message length std::size_t len; std::size_t num_read = read(descriptor, &len, sizeof(len)); if (num_read < 0) throw socket_error("error in read() system call"); // receive message str.resize(len); num_read = read(descriptor, &str[0], len); if (num_read < 0) throw socket_error("error in read() system call"); } //============================================================================= // server_connection server_connection::server_connection(const socket_acceptor& s) { struct hostent *ihp; struct sockaddr_in incoming; socklen_t len = sizeof(struct sockaddr_in); descriptor = accept(s.descriptor, (struct sockaddr*)&incoming, &len); if (descriptor < 0) throw socket_error("failed in accept() system call"); ihp = gethostbyaddr((char *)&incoming.sin_addr, sizeof(struct in_addr), AF_INET); if (ihp == NULL) throw socket_error("failed to get host by address"); } server_connection::~server_connection() { close(descriptor); } //============================================================================= // client_connection client_connection::client_connection(const location_t& location) : port_number(location.port) { descriptor = setup_socket(location, socket_address); int retcode = -1; struct hostent *ihp; int numtries = 0; int connect_timeout = 5; // connect is not exactly reliable, give it five tries while (retcode < 0 && numtries <= connect_timeout) { retcode = connect(descriptor, (struct sockaddr *)&socket_address, sizeof(socket_address)); if (retcode < 0) { ++numtries; sleep(1); close(descriptor); descriptor = socket(AF_INET, SOCK_STREAM, 0); } } if (numtries > connect_timeout) throw socket_error("error in connect() system call"); ihp = gethostbyaddr((char *)&socket_address.sin_addr, sizeof(struct in_addr), AF_INET); if (ihp == NULL) throw socket_error("get host by address failed"); } client_connection::~client_connection() { close(descriptor); } } // namespace dsl