Commit 0e551543 authored by Paul Lensing's avatar Paul Lensing
Browse files

Initial SSL support.

parent a11bedc7
Loading
Loading
Loading
Loading
+7 −3
Original line number Diff line number Diff line
@@ -35,6 +35,8 @@
#include "mock_callbacks.h"
#include "matchers.h"

#include "glog/logging.h"

namespace kinetic {

using ::testing::StrictMock;
@@ -57,7 +59,9 @@ class IntegrationTest : public ::testing::Test {
    IntegrationTest() : use_external_(false),
                        pid_(0), done_(false),
                        nonblocking_connection_(nullptr),
                        blocking_connection_(nullptr) {}
                        blocking_connection_(nullptr) {
          //google::LogToStderr();
    }

    void SetUp() {
        const char *kinetic_path = getenv("KINETIC_PATH");
@@ -71,8 +75,8 @@ class IntegrationTest : public ::testing::Test {
        }
        ConnectionOptions options;
        options.host = "localhost";
        options.port = 8123;
        options.use_ssl = false;
        options.port = 8443;
        options.use_ssl = true;
        options.user_id = 1;
        options.hmac_key = "asdfasdf";

+21 −23
Original line number Diff line number Diff line
@@ -115,18 +115,14 @@ Status KineticConnectionFactory::NewThreadsafeBlockingConnection(
Status KineticConnectionFactory::doNewConnection(
        ConnectionOptions const& options,
        unique_ptr <NonblockingKineticConnection>& connection, bool threadsafe) {
    auto socket_wrapper = make_shared<SocketWrapper>(options.host, options.port, true);

    if (!socket_wrapper->Connect()) {
        return Status::makeInternalError("Connection error");
    }
    try{
        auto socket_wrapper = make_shared<SocketWrapper>(options.host, options.port, options.use_ssl, true);
        if (!socket_wrapper->Connect())
            throw std::runtime_error("Could not connect to socket.");

        shared_ptr<NonblockingReceiverInterface> receiver;
    try{
        receiver = shared_ptr<NonblockingReceiverInterface>(new NonblockingReceiver(socket_wrapper, hmac_provider_, options));
    }catch(std::exception& e){
        return Status::makeInternalError("Connection error:"+std::string(e.what()));
    }

        auto writer_factory =
            unique_ptr<NonblockingPacketWriterFactoryInterface>(new NonblockingPacketWriterFactory());
        auto sender = unique_ptr<NonblockingSenderInterface>(new NonblockingSender(socket_wrapper,
@@ -139,7 +135,9 @@ Status KineticConnectionFactory::doNewConnection(
        } else {
            connection.reset(new NonblockingKineticConnection(service));
        }

    } catch(std::exception& e){
           return Status::makeInternalError("Connection error: "+std::string(e.what()));
    }
    return Status::makeOk();
}
} // namespace kinetic
+19 −19
Original line number Diff line number Diff line
@@ -33,10 +33,10 @@ namespace kinetic {
using std::make_shared;
using std::string;

NonblockingPacketWriter::NonblockingPacketWriter(int fd, unique_ptr<const Message> message,
NonblockingPacketWriter::NonblockingPacketWriter(shared_ptr<SocketWrapperInterface> socket_wrapper, unique_ptr<const Message> message,
    const shared_ptr<const string> value)
    : fd_(fd), message_(move(message)), value_(value), state_(kMagic),
    writer_(new NonblockingStringWriter(fd, make_shared<string>("F"))) {}
    : socket_wrapper_(socket_wrapper), message_(move(message)), value_(value), state_(kMagic),
    writer_(new NonblockingStringWriter(socket_wrapper_, make_shared<string>("F"))) {}

NonblockingPacketWriter::~NonblockingPacketWriter() {
    if (writer_ != NULL) {
@@ -46,14 +46,14 @@ NonblockingPacketWriter::~NonblockingPacketWriter() {

NonblockingStringStatus NonblockingPacketWriter::Write() {
    struct stat statbuf;
    if (fstat(fd_, &statbuf)) {
    if (fstat(socket_wrapper_->fd(), &statbuf)) {
        PLOG(ERROR) << "Unable to fstat socket";
        return kFailed;
    }
#ifndef __APPLE__
    if (S_ISSOCK(statbuf.st_mode)) {
        int optval = 1;
        setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &optval, sizeof(optval));
        setsockopt(socket_wrapper_->fd(), IPPROTO_TCP, TCP_CORK, &optval, sizeof(optval));
    }
#endif

@@ -83,17 +83,17 @@ NonblockingStringStatus NonblockingPacketWriter::Write() {
                TransitionFromValue();
                break;
            case kFinished:
                if (fstat(fd_, &statbuf)) {
                if (fstat(socket_wrapper_->fd(), &statbuf)) {
                    PLOG(ERROR) << "Unable to fstat socket";
                    return kFailed;
                }
                if (S_ISSOCK(statbuf.st_mode)) {
                    int optval = 0;
#ifndef __APPLE__
                    setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &optval, sizeof(optval));
                    setsockopt(socket_wrapper_->fd(), IPPROTO_TCP, TCP_CORK, &optval, sizeof(optval));
#endif
                    optval = 1;
                    setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &optval, sizeof(optval));
                    setsockopt(socket_wrapper_->fd(), IPPROTO_TCP, TCP_NODELAY, &optval, sizeof(optval));
                }
                return kDone;
            default:
@@ -111,7 +111,7 @@ bool NonblockingPacketWriter::TransitionFromMagic() {
    uint32_t size = htonl(serialized_message_.size());
    delete writer_;
    std::string encoded_size(reinterpret_cast<char *>(&size), sizeof(size));
    writer_ = new NonblockingStringWriter(fd_, make_shared<string>(encoded_size));
    writer_ = new NonblockingStringWriter(socket_wrapper_, make_shared<string>(encoded_size));
    state_ = kMessageLength;
    return true;
}
@@ -121,21 +121,21 @@ void NonblockingPacketWriter::TransitionFromMessageLength() {
    uint32_t size = htonl(value_->size());
    delete writer_;
    std::string encoded_size(reinterpret_cast<char *>(&size), sizeof(size));
    writer_ = new NonblockingStringWriter(fd_, make_shared<string>(encoded_size));
    writer_ = new NonblockingStringWriter(socket_wrapper_, make_shared<string>(encoded_size));
    state_ = kValueLength;
}

void NonblockingPacketWriter::TransitionFromValueLength() {
    // Move on to writing the serialized message
    delete writer_;
    writer_ = new NonblockingStringWriter(fd_, make_shared<string>(serialized_message_));
    writer_ = new NonblockingStringWriter(socket_wrapper_, make_shared<string>(serialized_message_));
    state_ = kMessage;
}

void NonblockingPacketWriter::TransitionFromMessage() {
    // Move on to writing the value
    delete writer_;
    writer_ = new NonblockingStringWriter(fd_, value_);
    writer_ = new NonblockingStringWriter(socket_wrapper_, value_);
    state_ = kValue;
}

@@ -144,10 +144,10 @@ void NonblockingPacketWriter::TransitionFromValue() {
    state_ = kFinished;
}

NonblockingPacketReader::NonblockingPacketReader(int fd, Message* response,
NonblockingPacketReader::NonblockingPacketReader(shared_ptr<SocketWrapperInterface> socket_wrapper, Message* response,
        unique_ptr<const string> &value)
    : fd_(fd), response_(response), state_(kMagic), value_(value), magic_(),
    reader_(new NonblockingStringReader(fd, 1, magic_)) {
    : socket_wrapper_(socket_wrapper), response_(response), state_(kMagic), value_(value), magic_(),
    reader_(new NonblockingStringReader(socket_wrapper_, 1, magic_)) {
}

NonblockingPacketReader::~NonblockingPacketReader() {
@@ -199,7 +199,7 @@ bool NonblockingPacketReader::TransitionFromMagic() {
        return false;
    }
    delete reader_;
    reader_ = new NonblockingStringReader(fd_, 4, message_length_);
    reader_ = new NonblockingStringReader(socket_wrapper_, 4, message_length_);
    state_ = kMessageLength;
    return true;
}
@@ -207,7 +207,7 @@ bool NonblockingPacketReader::TransitionFromMagic() {
void NonblockingPacketReader::TransitionFromMessageLength() {
    // Move on to reading the value length
    delete reader_;
    reader_ = new NonblockingStringReader(fd_, 4, value_length_);
    reader_ = new NonblockingStringReader(socket_wrapper_, 4, value_length_);
    state_ = kValueLength;
}

@@ -216,7 +216,7 @@ void NonblockingPacketReader::TransitionFromValueLength() {
    delete reader_;
    CHECK_EQ(4u, message_length_->size());
    uint32_t length = ntohl(*reinterpret_cast<const uint32_t *>(message_length_->data()));
    reader_ = new NonblockingStringReader(fd_, length, message_);
    reader_ = new NonblockingStringReader(socket_wrapper_, length, message_);
    state_ = kMessage;
}

@@ -225,7 +225,7 @@ void NonblockingPacketReader::TransitionFromMessage() {
    delete reader_;
    CHECK_EQ(4u, value_length_->size());
    uint32_t length = ntohl(*reinterpret_cast<const uint32_t *>(value_length_->data()));
    reader_ = new NonblockingStringReader(fd_, length, value_);
    reader_ = new NonblockingStringReader(socket_wrapper_, length, value_);
    state_ = kValue;
}

+6 −6
Original line number Diff line number Diff line
@@ -53,7 +53,7 @@ class NonblockingPacketWriterInterface {

class NonblockingPacketWriter : public NonblockingPacketWriterInterface {
    public:
    NonblockingPacketWriter(int fd, unique_ptr<const Message> message,
    NonblockingPacketWriter(shared_ptr<SocketWrapperInterface> socket_wrapper, unique_ptr<const Message> message,
            const shared_ptr<const string> value);
    ~NonblockingPacketWriter();
    NonblockingStringStatus Write();
@@ -64,7 +64,7 @@ class NonblockingPacketWriter : public NonblockingPacketWriterInterface {
    void TransitionFromValueLength();
    void TransitionFromMessage();
    void TransitionFromValue();
    const int fd_;
    shared_ptr<SocketWrapperInterface> socket_wrapper_;
    unique_ptr<const Message> message_;
    const shared_ptr<const string> value_;
    State state_;
@@ -75,7 +75,7 @@ class NonblockingPacketWriter : public NonblockingPacketWriterInterface {

class NonblockingPacketReader {
    public:
    NonblockingPacketReader(int fd, Message* response, unique_ptr<const string>& value);
    NonblockingPacketReader(shared_ptr<SocketWrapperInterface> socket_wrapper, Message* response, unique_ptr<const string>& value);
    ~NonblockingPacketReader();
    NonblockingStringStatus Read();

@@ -85,7 +85,7 @@ class NonblockingPacketReader {
    void TransitionFromValueLength();
    void TransitionFromMessage();
    bool TransitionFromValue();
    const int fd_;
    shared_ptr<SocketWrapperInterface> socket_wrapper_;
    Message* const response_;
    State state_;
    unique_ptr<const string>& value_;
@@ -100,13 +100,13 @@ class NonblockingPacketReader {
class NonblockingPacketWriterFactoryInterface {
    public:
    virtual ~NonblockingPacketWriterFactoryInterface() {}
    virtual unique_ptr<NonblockingPacketWriterInterface> CreateWriter(int fd,
    virtual unique_ptr<NonblockingPacketWriterInterface> CreateWriter(shared_ptr<SocketWrapperInterface> socket_wrapper,
        unique_ptr<const Message> message, const shared_ptr<const string> value) = 0;
};

class NonblockingPacketWriterFactory : public NonblockingPacketWriterFactoryInterface {
    public:
    unique_ptr<NonblockingPacketWriterInterface> CreateWriter(int fd,
    unique_ptr<NonblockingPacketWriterInterface> CreateWriter(shared_ptr<SocketWrapperInterface> socket_wrapper,
        unique_ptr<const Message> message, const shared_ptr<const string> value);
};

+1 −1
Original line number Diff line number Diff line
@@ -155,7 +155,7 @@ NonblockingPacketServiceStatus NonblockingReceiver::Receive() {

            // Start working on the next thing in the request queue
            nonblocking_response_ = new NonblockingPacketReader(
                socket_wrapper_->fd(), &message_, value_);
                socket_wrapper_, &message_, value_);
        }

        NonblockingStringStatus status = nonblocking_response_->Read();
Loading