Commit c3d8a644 authored by James Hughes's avatar James Hughes
Browse files

Merge pull request #25 from thaimai/master

Merge complete. This seems to be a combination of SSL support, Error handling and cork/uncork. In the future, it would be nicer for multiple pull requests. Also, please look at the comment in nonblocking_packet.cc to see if you can improve the performance in the >5KB value size.
parents 9c37022b 294a9d5d
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -74,10 +74,12 @@ class MessageStreamFactory : public MessageStreamFactoryInterface {
    MessageStreamFactory(SSL_CTX *ssl_context, IncomingValueFactoryInterface &value_factory);
    bool NewMessageStream(int fd, bool use_ssl, SSL *ssl, uint32_t max_message_size_bytes,
        MessageStreamInterface **message_stream);
    virtual ~MessageStreamFactory() {}
    ~MessageStreamFactory();

    private:
    SSL_CTX *ssl_context_;
    SSL *ssl_;
    bool ssl_created_;
    IncomingValueFactoryInterface &value_factory_;
    DISALLOW_COPY_AND_ASSIGN(MessageStreamFactory);
};
+2 −2
Original line number Diff line number Diff line
@@ -45,9 +45,9 @@ class MockMessageStream : public MessageStreamInterface {
class MockMessageStreamFactory : public MessageStreamFactoryInterface {
    public:
    MockMessageStreamFactory() {}
    virtual ~MockMessageStreamFactory() {}
    ~MockMessageStreamFactory() {}
    MOCK_METHOD4(NewMessageStream,
        bool(int fd, bool use_ssl, uint32_t max_message_size_bytes,
        bool(int fd, bool use_ssl, SSL *ssl, uint32_t max_message_size_bytes,
            MessageStreamInterface **message_stream));
};

+20 −8
Original line number Diff line number Diff line
@@ -107,32 +107,44 @@ int MessageStream::WriteMessage(const ::google::protobuf::Message &message,

MessageStreamFactory::MessageStreamFactory(SSL_CTX *ssl_context,
        IncomingValueFactoryInterface &value_factory)
    : ssl_context_(ssl_context), value_factory_(value_factory) {}
    : ssl_context_(ssl_context), value_factory_(value_factory) {
    ssl_created_ = false;
    }


MessageStreamFactory::~MessageStreamFactory() {
    if (ssl_created_) {
        SSL_free(ssl_);
    }
}

bool MessageStreamFactory::NewMessageStream(int fd, bool use_ssl, SSL *ssl, uint32_t max_message_size_bytes,
        MessageStreamInterface **message_stream) {
    if (use_ssl) {
        if (ssl == NULL) {
            SSL *ssl = SSL_new(ssl_context_);
            ssl_ = SSL_new(ssl_context_);
            // We want to automatically retry reads and writes when a renegotiation
            // takes place. This way the only errors we have to handle are real,
            // permanent ones.
            SSL_set_mode(ssl, SSL_MODE_AUTO_RETRY);
            if (ssl == NULL) {

            if (ssl_ == NULL) {
                LOG(ERROR) << "Failed to create new SSL object";
                return false;
            }
            if (SSL_set_fd(ssl, fd) != 1) {
            SSL_set_mode(ssl_, SSL_MODE_AUTO_RETRY);
            if (SSL_set_fd(ssl_, fd) != 1) {
                LOG(ERROR) << "Failed to associate SSL object with file descriptor";
                SSL_free(ssl);
                SSL_free(ssl_);
                return false;
            }
            if (SSL_accept(ssl) != 1) {
            if (SSL_accept(ssl_) != 1) {
                LOG(ERROR) << "Failed to perform SSL handshake";
                LOG(ERROR) << "The client may have attempted to use an SSL/TLS version below TLSv1.1";
                SSL_free(ssl);
                SSL_free(ssl_);
                return false;
            }
            ssl_created_ = true;
            ssl = ssl_;
        }
        LOG(INFO) << "Successfully performed SSL handshake";
        *message_stream = new MessageStream(max_message_size_bytes, new SslByteStream(ssl));
+25 −1
Original line number Diff line number Diff line
@@ -19,8 +19,12 @@
 */

#include "nonblocking_packet.h"

#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <sys/stat.h>

#include "glog/logging.h"

@@ -41,6 +45,16 @@ NonblockingPacketWriter::~NonblockingPacketWriter() {
}

NonblockingStringStatus NonblockingPacketWriter::Write() {
    struct stat statbuf;
    if (fstat(fd_, &statbuf)) {
        PLOG(ERROR) << "Unable to fstat socket";
        return kFailed;
    }
    if (S_ISSOCK(statbuf.st_mode)) {
        int optval = 1;
        setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &optval, sizeof(optval));
    }

    while (true) {
        NonblockingStringStatus status = writer_->Write();
        if (status != kDone) {
@@ -67,6 +81,16 @@ NonblockingStringStatus NonblockingPacketWriter::Write() {
                TransitionFromValue();
                break;
            case kFinished:
                if (fstat(fd_, &statbuf)) {
                    PLOG(ERROR) << "Unable to fstat socket";
                    return kFailed;
                }
                if (S_ISSOCK(statbuf.st_mode)) {
                    int optval = 0;
                    setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &optval, sizeof(optval));
                    optval = 1;
                    setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &optval, sizeof(optval));
                }
                return kDone;
            default:
                CHECK(false);