Loading src/main/socket_wrapper.cc +92 −86 Original line number Diff line number Diff line Loading @@ -25,6 +25,7 @@ #include <string> #include <exception> #include <stdexcept> #include <openssl/err.h> #include "glog/logging.h" #include "socket_wrapper.h" Loading @@ -33,8 +34,10 @@ namespace { pthread_mutex_t* mutex_buffer = NULL; void pthread_mutex_funs(int mode, int index, const char* file, int line) { void pthread_mutex_funs(int mode, int index, const char* file, int line) { if (mode & CRYPTO_LOCK) { pthread_mutex_lock(&mutex_buffer[index]); } else { Loading @@ -42,13 +45,11 @@ namespace { } } unsigned long pthread_id_fun() { unsigned long pthread_id_fun() { return (unsigned long) pthread_self(); } void SSL_register_locks() { void SSL_register_locks() { const int num_locks = CRYPTO_num_locks(); mutex_buffer = (pthread_mutex_t*) malloc(num_locks * sizeof(pthread_mutex_t)); if (!mutex_buffer) { Loading @@ -62,8 +63,7 @@ namespace { CRYPTO_set_locking_callback(pthread_mutex_funs); } void SSL_free_locks() { void SSL_free_locks() { if (mutex_buffer) { CRYPTO_set_id_callback(NULL); CRYPTO_set_locking_callback(NULL); Loading @@ -74,42 +74,40 @@ namespace { mutex_buffer = NULL; } } } } // namespace #else namespace { void SSL_register_locks() { void SSL_register_locks() { LOG(INFO) << "No locks configured for OpenSSL. Do so yourself if you require thread-safety."; } void SSL_free_locks() {} } #endif namespace kinetic { using std::string; class OpenSSLInitializer { class OpenSSLInitializer { public: OpenSSLInitializer() { OpenSSLInitializer() { SSL_library_init(); SSL_register_locks(); SSL_load_error_strings(); OpenSSL_add_all_algorithms(); } ~OpenSSLInitializer() { ~OpenSSLInitializer() { SSL_free_locks(); } }; static OpenSSLInitializer init; SocketWrapper::SocketWrapper(const std::string& host, int port, bool use_ssl, bool nonblocking) SocketWrapper::SocketWrapper(const std::string& host, int port, bool use_ssl, bool nonblocking) : ctx_(NULL), ssl_(NULL), host_(host), port_(port), nonblocking_(nonblocking), fd_(-1) { if (use_ssl) { ctx_ = SSL_CTX_new(SSLv23_client_method()); Loading @@ -122,21 +120,20 @@ SocketWrapper::SocketWrapper(const std::string& host, int port, bool use_ssl, bo } SocketWrapper::~SocketWrapper() { if (fd_ == -1) { LOG(INFO) << "Not connected so no cleanup needed"; } else { LOG(INFO) << "Closing socket with fd " << fd_; if (fd_ != -1) { if (close(fd_)) { PLOG(ERROR) << "Error closing socket fd " << fd_; } } if(ssl_) SSL_free(ssl_); if(ctx_) SSL_CTX_free(ctx_); if (ssl_) { SSL_free(ssl_); } if (ctx_) { SSL_CTX_free(ctx_); } } bool SocketWrapper::Connect() { LOG(INFO) << "Connecting to " << host_ << ":" << port_; struct addrinfo hints; memset(&hints, 0, sizeof(struct addrinfo)); Loading @@ -151,8 +148,7 @@ bool SocketWrapper::Connect() { string port_str = std::to_string(static_cast<long long>(port_)); if (int res = getaddrinfo(host_.c_str(), port_str.c_str(), &hints, &result) != 0) { LOG(ERROR) << "Could not resolve host " << host_ << " port " << port_ << ": " << gai_strerror(res); LOG(ERROR) << "Could not resolve host " << host_ << " port " << port_ << ": " << gai_strerror(res); return false; } Loading @@ -161,12 +157,15 @@ bool SocketWrapper::Connect() { for (ai = result; ai != NULL; ai = ai->ai_next) { char host[NI_MAXHOST]; char service[NI_MAXSERV]; if (int res = getnameinfo(ai->ai_addr, ai->ai_addrlen, host, sizeof(host), service, sizeof(service), NI_NUMERICHOST | NI_NUMERICSERV) != 0) { if (int res = getnameinfo(ai->ai_addr, ai->ai_addrlen, host, sizeof(host), service, sizeof(service), NI_NUMERICHOST | NI_NUMERICSERV) != 0) { LOG(ERROR) << "Could not get name info: " << gai_strerror(res); continue; } else { LOG(INFO) << "Trying to connect to " << string(host) << " on " << string(service); } socket_fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); Loading Loading @@ -227,25 +226,32 @@ bool SocketWrapper::Connect() { } fd_ = socket_fd; if(ssl_) return ConnectSSL(); if (ssl_) { return ConnectSSL(); } return true; } #include <openssl/err.h> bool SocketWrapper::ConnectSSL() { bool SocketWrapper::ConnectSSL() { SSL_set_fd(ssl_, fd_); int rtn = SSL_connect(ssl_); if(rtn == 1) if (rtn == 1) { return true; } int err = SSL_get_error(ssl_, rtn); if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { fd_set read_fds, write_fds; FD_ZERO(&read_fds); FD_ZERO(&write_fds); if(err == SSL_ERROR_WANT_READ) FD_SET(fd_, &read_fds); if(err == SSL_ERROR_WANT_WRITE) FD_SET(fd_, &write_fds); FD_ZERO(&read_fds); FD_ZERO(&write_fds); if (err == SSL_ERROR_WANT_READ) { FD_SET(fd_, &read_fds); } if (err == SSL_ERROR_WANT_WRITE) { FD_SET(fd_, &write_fds); } struct timeval tv = {1, 1}; select(fd_ + 1, &read_fds, &write_fds, NULL, &tv); return ConnectSSL(); Loading Loading
src/main/socket_wrapper.cc +92 −86 Original line number Diff line number Diff line Loading @@ -25,6 +25,7 @@ #include <string> #include <exception> #include <stdexcept> #include <openssl/err.h> #include "glog/logging.h" #include "socket_wrapper.h" Loading @@ -33,8 +34,10 @@ namespace { pthread_mutex_t* mutex_buffer = NULL; void pthread_mutex_funs(int mode, int index, const char* file, int line) { void pthread_mutex_funs(int mode, int index, const char* file, int line) { if (mode & CRYPTO_LOCK) { pthread_mutex_lock(&mutex_buffer[index]); } else { Loading @@ -42,13 +45,11 @@ namespace { } } unsigned long pthread_id_fun() { unsigned long pthread_id_fun() { return (unsigned long) pthread_self(); } void SSL_register_locks() { void SSL_register_locks() { const int num_locks = CRYPTO_num_locks(); mutex_buffer = (pthread_mutex_t*) malloc(num_locks * sizeof(pthread_mutex_t)); if (!mutex_buffer) { Loading @@ -62,8 +63,7 @@ namespace { CRYPTO_set_locking_callback(pthread_mutex_funs); } void SSL_free_locks() { void SSL_free_locks() { if (mutex_buffer) { CRYPTO_set_id_callback(NULL); CRYPTO_set_locking_callback(NULL); Loading @@ -74,42 +74,40 @@ namespace { mutex_buffer = NULL; } } } } // namespace #else namespace { void SSL_register_locks() { void SSL_register_locks() { LOG(INFO) << "No locks configured for OpenSSL. Do so yourself if you require thread-safety."; } void SSL_free_locks() {} } #endif namespace kinetic { using std::string; class OpenSSLInitializer { class OpenSSLInitializer { public: OpenSSLInitializer() { OpenSSLInitializer() { SSL_library_init(); SSL_register_locks(); SSL_load_error_strings(); OpenSSL_add_all_algorithms(); } ~OpenSSLInitializer() { ~OpenSSLInitializer() { SSL_free_locks(); } }; static OpenSSLInitializer init; SocketWrapper::SocketWrapper(const std::string& host, int port, bool use_ssl, bool nonblocking) SocketWrapper::SocketWrapper(const std::string& host, int port, bool use_ssl, bool nonblocking) : ctx_(NULL), ssl_(NULL), host_(host), port_(port), nonblocking_(nonblocking), fd_(-1) { if (use_ssl) { ctx_ = SSL_CTX_new(SSLv23_client_method()); Loading @@ -122,21 +120,20 @@ SocketWrapper::SocketWrapper(const std::string& host, int port, bool use_ssl, bo } SocketWrapper::~SocketWrapper() { if (fd_ == -1) { LOG(INFO) << "Not connected so no cleanup needed"; } else { LOG(INFO) << "Closing socket with fd " << fd_; if (fd_ != -1) { if (close(fd_)) { PLOG(ERROR) << "Error closing socket fd " << fd_; } } if(ssl_) SSL_free(ssl_); if(ctx_) SSL_CTX_free(ctx_); if (ssl_) { SSL_free(ssl_); } if (ctx_) { SSL_CTX_free(ctx_); } } bool SocketWrapper::Connect() { LOG(INFO) << "Connecting to " << host_ << ":" << port_; struct addrinfo hints; memset(&hints, 0, sizeof(struct addrinfo)); Loading @@ -151,8 +148,7 @@ bool SocketWrapper::Connect() { string port_str = std::to_string(static_cast<long long>(port_)); if (int res = getaddrinfo(host_.c_str(), port_str.c_str(), &hints, &result) != 0) { LOG(ERROR) << "Could not resolve host " << host_ << " port " << port_ << ": " << gai_strerror(res); LOG(ERROR) << "Could not resolve host " << host_ << " port " << port_ << ": " << gai_strerror(res); return false; } Loading @@ -161,12 +157,15 @@ bool SocketWrapper::Connect() { for (ai = result; ai != NULL; ai = ai->ai_next) { char host[NI_MAXHOST]; char service[NI_MAXSERV]; if (int res = getnameinfo(ai->ai_addr, ai->ai_addrlen, host, sizeof(host), service, sizeof(service), NI_NUMERICHOST | NI_NUMERICSERV) != 0) { if (int res = getnameinfo(ai->ai_addr, ai->ai_addrlen, host, sizeof(host), service, sizeof(service), NI_NUMERICHOST | NI_NUMERICSERV) != 0) { LOG(ERROR) << "Could not get name info: " << gai_strerror(res); continue; } else { LOG(INFO) << "Trying to connect to " << string(host) << " on " << string(service); } socket_fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); Loading Loading @@ -227,25 +226,32 @@ bool SocketWrapper::Connect() { } fd_ = socket_fd; if(ssl_) return ConnectSSL(); if (ssl_) { return ConnectSSL(); } return true; } #include <openssl/err.h> bool SocketWrapper::ConnectSSL() { bool SocketWrapper::ConnectSSL() { SSL_set_fd(ssl_, fd_); int rtn = SSL_connect(ssl_); if(rtn == 1) if (rtn == 1) { return true; } int err = SSL_get_error(ssl_, rtn); if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { fd_set read_fds, write_fds; FD_ZERO(&read_fds); FD_ZERO(&write_fds); if(err == SSL_ERROR_WANT_READ) FD_SET(fd_, &read_fds); if(err == SSL_ERROR_WANT_WRITE) FD_SET(fd_, &write_fds); FD_ZERO(&read_fds); FD_ZERO(&write_fds); if (err == SSL_ERROR_WANT_READ) { FD_SET(fd_, &read_fds); } if (err == SSL_ERROR_WANT_WRITE) { FD_SET(fd_, &write_fds); } struct timeval tv = {1, 1}; select(fd_ + 1, &read_fds, &write_fds, NULL, &tv); return ConnectSSL(); Loading