Commit 045db1fe authored by Scott Vokes's avatar Scott Vokes
Browse files

WIP: Eliminate sender thread.

Also, rework dataflow between the client and listener threads.

Now, the adding/removing socket and shutdown commands to the listener
are all blocking for the client thread, which eliminates a few race
conditions for freeing SSL resources. Previously, rapidly adding/removing
the same socket could potentially cause problems.
parent af6fb281
Loading
Loading
Loading
Loading
+119 −169
Original line number Diff line number Diff line
@@ -39,10 +39,8 @@

/* Function pointers for pthreads. */
void *listener_mainloop(void *arg);
void *sender_mainloop(void *arg);

static bool poll_on_completion(struct bus *b, int fd);
static int sender_id_of_socket(struct bus *b, int fd);
static int listener_id_of_socket(struct bus *b, int fd);
static void noop_log_cb(log_event_t event,
        int log_level, const char *msg, void *udata);
@@ -50,7 +48,6 @@ static void noop_error_cb(bus_unpack_cb_res_t result, void *socket_udata);
static bool attempt_to_increase_resource_limits(struct bus *b);

static void set_defaults(bus_config *cfg) {
    if (cfg->sender_count == 0) { cfg->sender_count = 1; }
    if (cfg->listener_count == 0) { cfg->listener_count = 1; }
}

@@ -82,7 +79,6 @@ bool bus_init(bus_config *config, struct bus_result *res) {
    res->status = BUS_INIT_ERROR_ALLOC_FAIL;

    bool log_lock_init = false;
    struct sender **ss = NULL;       /* senders */
    struct listener **ls = NULL;     /* listeners */
    struct threadpool *tp = NULL;
    bool *joined = NULL;
@@ -105,7 +101,7 @@ bool bus_init(bus_config *config, struct bus_result *res) {
        res->status = BUS_INIT_ERROR_MUTEX_INIT_FAIL;
        goto cleanup;
    }
    if (0 != pthread_mutex_init(&b->fd_set_lock, NULL)) {
    if (0 != pthread_rwlock_init(&b->fd_set_lock, NULL)) {
        res->status = BUS_INIT_ERROR_MUTEX_INIT_FAIL;
        goto cleanup;
    }
@@ -117,22 +113,6 @@ bool bus_init(bus_config *config, struct bus_result *res) {
    BUS_LOG_SNPRINTF(b, 3, LOG_INITIALIZATION, b->udata, 64,
        "Initialized bus at %p", (void*)b);

    ss = calloc(config->sender_count, sizeof(*ss));
    if (ss == NULL) {
        goto cleanup;
    }

    for (int i = 0; i < config->sender_count; i++) {
        ss[i] = sender_init(b, config);
        if (ss[i] == NULL) {
            res->status = BUS_INIT_ERROR_SENDER_INIT_FAIL;
            goto cleanup;
        } else {
            BUS_LOG_SNPRINTF(b, 3, LOG_INITIALIZATION, b->udata, 64,
                "Initialized sender %d at %p", i, (void*)ss[i]);
        }
    }

    ls = calloc(config->listener_count, sizeof(*ls));
    if (ls == NULL) {
        goto cleanup;
@@ -155,7 +135,7 @@ bool bus_init(bus_config *config, struct bus_result *res) {
        goto cleanup;
    }

    int thread_count = config->sender_count + config->listener_count;
    int thread_count = config->listener_count;
    joined = calloc(thread_count, sizeof(bool));
    threads = calloc(thread_count, sizeof(pthread_t));
    if (joined == NULL || threads == NULL) {
@@ -167,25 +147,14 @@ bool bus_init(bus_config *config, struct bus_result *res) {
        goto cleanup;
    }

    b->sender_count = config->sender_count;
    b->senders = ss;
    b->listener_count = config->listener_count;
    b->listeners = ls;
    b->threadpool = tp;
    b->joined = joined;
    b->threads = threads;

    for (int i = 0; i < b->sender_count; i++) {
        int pcres = pthread_create(&b->threads[i], NULL,
            sender_mainloop, (void *)b->senders[i]);
        if (pcres != 0) {
            res->status = BUS_INIT_ERROR_PTHREAD_INIT_FAIL;
            goto cleanup;
        }
    }

    for (int i = 0; i < b->listener_count; i++) {
        int pcres = pthread_create(&b->threads[i + b->sender_count], NULL,
        int pcres = pthread_create(&b->threads[i], NULL,
            listener_mainloop, (void *)b->listeners[i]);
        if (pcres != 0) {
            res->status = BUS_INIT_ERROR_PTHREAD_INIT_FAIL;
@@ -199,12 +168,6 @@ bool bus_init(bus_config *config, struct bus_result *res) {
    return true;

cleanup:
    if (ss) {
        for (int i = 0; i < config->sender_count; i++) {
            if (ss[i]) { sender_free(ss[i]); }
        }
        free(ss);
    }
    if (ls) {
        for (int i = 0; i < config->listener_count; i++) {
            if (ls[i]) { listener_free(ls[i]); }
@@ -215,7 +178,7 @@ cleanup:
    if (joined) { free(joined); }
    if (b) {
        if (log_lock_init) {
            pthread_mutex_destroy(&b->fd_set_lock);
            pthread_rwlock_destroy(&b->fd_set_lock);
            pthread_mutex_destroy(&b->log_lock);
        }
        free(b);
@@ -276,20 +239,22 @@ static boxed_msg *box_msg(struct bus *b, bus_user_msg *msg) {
    assert(msg->fd != 0);

    /* Lock hash table and check whether this FD uses SSL. */
    if (0 != pthread_mutex_lock(&b->fd_set_lock)) { assert(false); }
    if (0 != pthread_rwlock_rdlock(&b->fd_set_lock)) { assert(false); }
    void *value = NULL;
    SSL *ssl = NULL;
    connection_info *ci = NULL;
    if (yacht_get(b->fd_set, box->fd, &value)) {
        ssl = (SSL *)value;
        assert(ssl != NULL);
        box->ssl = ssl;
        ci = (connection_info *)value;
    }
    if (0 != pthread_mutex_unlock(&b->fd_set_lock)) { assert(false); }
    if (0 != pthread_rwlock_unlock(&b->fd_set_lock)) { assert(false); }

    if (ssl == NULL) {
    if (ci == NULL) {
        /* socket isn't registered, fail out */
        BUS_LOG_SNPRINTF(b, 3, LOG_MEMORY, b->udata, 64,
            "socket isn't registered, failing -- %p", (void*)box);
        free(box);
        return NULL;
    } else {
        box->ssl = ci->ssl;
    }

    box->timeout_sec = (time_t)msg->timeout_sec;
@@ -320,12 +285,9 @@ bool bus_send_request(struct bus *b, bus_user_msg *msg)
        return false;
    }

    int s_id = sender_id_of_socket(b, msg->fd);
    struct sender *s = b->senders[s_id];

    BUS_LOG_SNPRINTF(b, 3-0, LOG_SENDING_REQUEST, b->udata, 64,
        "Sending request <fd:%d, seq_id:%lld>", msg->fd, (long long)msg->seq_id);
    bool res = sender_send_request(s, box);
    bool res = sender_do_blocking_send(b, box);
    BUS_LOG_SNPRINTF(b, 3, LOG_SENDING_REQUEST, b->udata, 64,
        "...request sent, result %d", res);
    return res;
@@ -337,30 +299,23 @@ static bool poll_on_completion(struct bus *b, int fd) {
    fds[0].fd = fd;
    fds[0].events = POLLIN;

    /* TODO REFACTOR this should be reused between bus, sender, and listener,
     *     or be moved into the listener.
     *     The sender has its own blocking polling for commands. */

    /* FIXME: compare this to TCP timeouts -- try to prevent sender
     * succeeding but failing to notify client. */
    const int TIMEOUT_SECONDS = 10;
    const int ONE_SECOND = 1000;  // msec

    for (int i = 0; i < TIMEOUT_SECONDS; i++) {
    for (;;) {
        BUS_LOG(b, 5, LOG_SENDING_REQUEST, "Polling on completion...tick...", b->udata);
        int res = poll(fds, 1, ONE_SECOND);
        int res = poll(fds, 1, -1);
        if (res == -1) {
            if (util_is_resumable_io_error(errno)) {
                BUS_LOG(b, 3, LOG_SENDING_REQUEST, "Polling on completion...EAGAIN", b->udata);
                if (errno == EINTR && i > 0) { i--; }
                BUS_LOG_SNPRINTF(b, 3, LOG_SENDING_REQUEST, b->udata, 64,
                    "poll_on_completion, resumable IO error %d", errno);
                errno = 0;
                continue;
            } else {
                assert(false);
                break;
                BUS_LOG_SNPRINTF(b, 3, LOG_SENDING_REQUEST, b->udata, 64,
                    "poll_on_completion, non-resumable IO error %d", errno);
                return false;
            }
        } else if (res > 0) {
        } else if (res == 1) {
            uint16_t msec = 0;
            uint8_t read_buf[sizeof(msec)];
            uint8_t read_buf[sizeof(uint8_t) + sizeof(msec)];

            if (fds[0].revents & (POLLERR | POLLHUP | POLLNVAL)) {
                BUS_LOG(b, 1, LOG_SENDING_REQUEST, "failed (broken alert pipe)", b->udata);
@@ -372,38 +327,32 @@ static bool poll_on_completion(struct bus *b, int fd) {

            if (sz == sizeof(read_buf)) {
                /* Payload: little-endian uint16_t, msec of backpressure. */
                msec = (read_buf[0] << 0) + (read_buf[1] << 8);
                if (msec > 0) {
                    BUS_LOG_SNPRINTF(b, 5, LOG_SENDING_REQUEST, b->udata, 64,
                        " -- awakening client thread with backpressure of %d msec", msec);
                    (void)poll(NULL, 0, msec);
                }
                assert(read_buf[0] == LISTENER_MSG_TAG);

                BUS_LOG(b, 3, LOG_SENDING_REQUEST, "sent!", b->udata);
                msec = (read_buf[1] << 0) + (read_buf[2] << 8);
                bus_backpressure_delay(b, msec, 3);
                BUS_LOG(b, 4, LOG_SENDING_REQUEST, "sent!", b->udata);
                return true;
            } else if (sz == -1) {
                if (util_is_resumable_io_error(errno)) {
                    errno = 0;
                } else {
                    assert(false);
                    break;
                }
            }
                    BUS_LOG_SNPRINTF(b, 3, LOG_SENDING_REQUEST, b->udata, 64,
                        "poll_on_completion read, non-resumable IO error %d", errno);
                    errno = 0;
                    return false;
                }
            } else {
                BUS_LOG_SNPRINTF(b, 3, LOG_SENDING_REQUEST, b->udata, 64,
                    "poll_on_completion bad read size %zd", sz);
                return false;
            }
    BUS_LOG(b, 2, LOG_SENDING_REQUEST, "failed (timeout)", b->udata);

    #if 0
        } else {
            BUS_LOG_SNPRINTF(b, 3, LOG_SENDING_REQUEST, b->udata, 64,
                "poll_on_completion, blocking forever returned 0, errno %d", errno);
            assert(false);
    #endif

    return false;
        }

static int sender_id_of_socket(struct bus *b, int fd) {
    /* Just evenly divide sockets between senders by file descriptor. */
    /* could also use sequence ID */
    return fd % b->sender_count;
    }
}

static int listener_id_of_socket(struct bus *b, int fd) {
@@ -411,10 +360,6 @@ static int listener_id_of_socket(struct bus *b, int fd) {
    return fd % b->listener_count;
}

struct sender *bus_get_sender_for_socket(struct bus *b, int fd) {
    return b->senders[sender_id_of_socket(b, fd)];
}

struct listener *bus_get_listener_for_socket(struct bus *b, int fd) {
    return b->listeners[listener_id_of_socket(b, fd)];
}
@@ -437,14 +382,12 @@ const char *bus_log_event_str(log_event_t event) {

bool bus_register_socket(struct bus *b, bus_socket_t type, int fd, void *udata) {
    /* Register a socket internally with a sender and listener. */
    int s_id = sender_id_of_socket(b, fd);
    int l_id = listener_id_of_socket(b, fd);

    BUS_LOG_SNPRINTF(b, 2, LOG_SOCKET_REGISTERED, b->udata, 64,
        "registering socket %d", fd);

    /* Spread sockets throughout the different sender & listener processes. */
    struct sender *s = b->senders[s_id];
    struct listener *l = b->listeners[l_id];
    
    int pipes[2];
@@ -453,30 +396,31 @@ bool bus_register_socket(struct bus *b, bus_socket_t type, int fd, void *udata)
        return false;
    }

    int pipe_out = pipes[0];
    int pipe_in = pipes[1];

    /* Metadata about the connection. Note: This will be shared by the
     * client thread and the listener thread, but each will only modify
     * some of the fields. The client thread will free this. */
    connection_info *ci = malloc(sizeof(*ci));
    if (ci == NULL) { goto cleanup; }

    ci->type = type;
    ci->fd = fd;
    ci->to_read_size = 0;
    ci->udata = udata;
    ci->largest_seq_id_seen = 0;

    SSL *ssl = NULL;
    if (type == BUS_SOCKET_SSL) {
        if (!bus_ssl_connect(b, ci)) { goto cleanup; }
        ssl = bus_ssl_connect(b, fd);
        if (ssl == NULL) { goto cleanup; }
    } else {
        ci->ssl = BUS_NO_SSL;
        ssl = BUS_NO_SSL;
    }
    *ci = (connection_info){
        .fd = fd,
        .type = type,
        .ssl = ssl,
        .udata = udata,
    };

    void *old_value = NULL;

    /* Lock hash table and save whether this FD uses SSL. */
    if (0 != pthread_mutex_lock(&b->fd_set_lock)) { assert(false); }
    bool set_ok = yacht_set(b->fd_set, fd, (void *)ci->ssl, &old_value);
    if (0 != pthread_mutex_unlock(&b->fd_set_lock)) { assert(false); }
    if (0 != pthread_rwlock_wrlock(&b->fd_set_lock)) { assert(false); }
    bool set_ok = yacht_set(b->fd_set, fd, ci, &old_value);
    if (0 != pthread_rwlock_unlock(&b->fd_set_lock)) { assert(false); }

    if (set_ok) {
        assert(old_value == NULL);
@@ -485,66 +429,61 @@ bool bus_register_socket(struct bus *b, bus_socket_t type, int fd, void *udata)
    }

    bool res = false;
    res = sender_register_socket(s, fd, ci->ssl);
    int completion_pipe = -1;
    res = listener_add_socket(l, ci, &completion_pipe);
    if (!res) { goto cleanup; }

    res = listener_add_socket(l, ci, pipe_in);
    if (!res) { goto cleanup; }

    /* FIXME: Move this into listener_add_socket? */
    BUS_LOG(b, 2, LOG_SOCKET_REGISTERED, "polling on socket add...", b->udata);
    bool completed = poll_on_completion(b, pipe_out);
    bool completed = poll_on_completion(b, completion_pipe);
    if (!completed) { goto cleanup; }

    close(pipe_out);
    close(pipe_in);

    BUS_LOG(b, 2, LOG_SOCKET_REGISTERED, "successfully added socket", b->udata);
    return true;
cleanup:
    if (ci) {
        free(ci);
    }
    close(pipe_out);
    close(pipe_in);
    BUS_LOG(b, 2, LOG_SOCKET_REGISTERED, "failed to add socket", b->udata);
    return false;
}

/* Free metadata about a socket that has been disconnected. */
bool bus_release_socket(struct bus *b, int fd) {
    int s_id = sender_id_of_socket(b, fd);
    int l_id = listener_id_of_socket(b, fd);

    BUS_LOG_SNPRINTF(b, 2, LOG_SOCKET_REGISTERED, b->udata, 64,
        "forgetting socket %d", fd);

    struct sender *s = b->senders[s_id];
    struct listener *l = b->listeners[l_id];

    if (!sender_remove_socket(s, fd)) {
        return false;
    int completion_fd = -1;
    if (!listener_remove_socket(l, fd, &completion_fd)) {
        return false;           /* couldn't send msg to listener */
    }

    if (!listener_remove_socket(l, fd)) {
        return false;           /* couldn't send msg to listener */
    bool completed = poll_on_completion(b, completion_fd);
    if (!completed) {
        return false;
    }

    /* Lock hash table and forget whether this FD uses SSL. */
    void *old_value = NULL;
    if (0 != pthread_mutex_lock(&b->fd_set_lock)) { assert(false); }
    if (0 != pthread_rwlock_wrlock(&b->fd_set_lock)) { assert(false); }
    bool rm_ok = yacht_remove(b->fd_set, fd, &old_value);
    if (0 != pthread_mutex_unlock(&b->fd_set_lock)) { assert(false); }
    if (0 != pthread_rwlock_unlock(&b->fd_set_lock)) { assert(false); }
    assert(rm_ok);

    SSL *ssl = (SSL *)old_value;
    assert(ssl != NULL);
    connection_info *ci = (connection_info *)old_value;
    assert(ci != NULL);

    if (ssl == BUS_NO_SSL) {
    if (ci->ssl == BUS_NO_SSL) {
        return true;            /* nothing else to do */
    } else {
        return bus_ssl_disconnect(b, ssl);
        return bus_ssl_disconnect(b, ci->ssl);
    }

    /* TODO: return ci->udata? */
    free(ci);
}

bool bus_schedule_threadpool_task(struct bus *b, struct threadpool_task *task,
@@ -552,44 +491,52 @@ bool bus_schedule_threadpool_task(struct bus *b, struct threadpool_task *task,
    return threadpool_schedule(b->threadpool, task, backpressure);
}

bool bus_shutdown(bus *b) {
    BUS_LOG(b, 2, LOG_SHUTDOWN, "shutting down sender threads", b->udata);
    for (int i = 0; i < b->sender_count; i++) {
        int off = 0;
        if (!b->joined[i + off]) {
            BUS_LOG_SNPRINTF(b, 3, LOG_SHUTDOWN, b->udata, 128,
                "sender_shutdown -- %d", i);
            while (!sender_shutdown(b->senders[i])) {
                BUS_LOG_SNPRINTF(b, 3, LOG_SHUTDOWN, b->udata, 128,
                    "sender_shutdown -- retry %d", i);
                sleep(1);
static void free_connection_cb(void *value, void *udata) {
    struct bus *b = (struct bus *)udata;
    connection_info *ci = (connection_info *)value;

    int l_id = listener_id_of_socket(b, ci->fd);
    struct listener *l = b->listeners[l_id];

    int completion_fd = -1;
    if (!listener_remove_socket(l, ci->fd, &completion_fd)) {
        return;           /* couldn't send msg to listener */
    }
            void *unused = NULL;
            int res = pthread_join(b->threads[i + off], &unused);
            BUS_LOG_SNPRINTF(b, 3, LOG_SHUTDOWN, b->udata, 128,
                "sender_shutdown -- joined %d", i);
            assert(res == 0);
            b->joined[i + off] = true;

    bool completed = poll_on_completion(b, completion_fd);
    if (!completed) {
        return;
    }
}

bool bus_shutdown(bus *b) {
    /* TODO: thread safety for shutdown being called concurrently on several client threads?
     * Maybe use CAS-ing the fd_set to NULL as a flag.*/
    struct yacht *fd_set = b->fd_set;
    b->fd_set = NULL;

    if (fd_set) {
        BUS_LOG(b, 2, LOG_SHUTDOWN, "removing all connections", b->udata);
        yacht_free(fd_set, free_connection_cb, b);
    }

    BUS_LOG(b, 2, LOG_SHUTDOWN, "shutting down listener threads", b->udata);
    for (int i = 0; i < b->listener_count; i++) {
        int off = b->sender_count;
        if (!b->joined[i + off]) {
        if (!b->joined[i]) {
            BUS_LOG_SNPRINTF(b, 3, LOG_SHUTDOWN, b->udata, 128,
                "listener_shutdown -- %d", i);
            while (!listener_shutdown(b->listeners[i])) {
                sleep(1);
            }
            int completion_fd = -1;
            listener_shutdown(b->listeners[i], &completion_fd);
            poll_on_completion(b, completion_fd);

            BUS_LOG_SNPRINTF(b, 3, LOG_SHUTDOWN, b->udata, 128,
                "listener_shutdown -- joining %d", i);
            void *unused = NULL;
            int res = pthread_join(b->threads[i + off], &unused);
            int res = pthread_join(b->threads[i], &unused);
            BUS_LOG_SNPRINTF(b, 3, LOG_SHUTDOWN, b->udata, 128,
                "listener_shutdown -- joined %d", i);
            assert(res == 0);
            b->joined[i + off] = true;
            b->joined[i] = true;
        }
    }

@@ -597,6 +544,17 @@ bool bus_shutdown(bus *b) {
    return true;
}

void bus_backpressure_delay(struct bus *b, size_t backpressure, uint8_t shift) {
    /* Push back if message bus is too busy. */
    backpressure >>= shift;
    
    if (backpressure > 0) {
        BUS_LOG_SNPRINTF(b, 8, LOG_SENDER, b->udata, 64,
            "backpressure %zd", backpressure);
        poll(NULL, 0, backpressure);
    }
}

void bus_lock_log(struct bus *b) {
    pthread_mutex_lock(&b->log_lock);
}
@@ -646,13 +604,6 @@ void bus_free(bus *b) {
    if (b == NULL) { return; }
    bus_shutdown(b);

    for (int i = 0; i < b->sender_count; i++) {
        BUS_LOG_SNPRINTF(b, 3, LOG_SHUTDOWN, b->udata, 128,
            "sender_free -- %d", i);
        sender_free(b->senders[i]);
    }
    free(b->senders);

    for (int i = 0; i < b->listener_count; i++) {
        BUS_LOG_SNPRINTF(b, 3, LOG_SHUTDOWN, b->udata, 128,
            "listener_free -- %d", i);
@@ -678,7 +629,6 @@ void bus_free(bus *b) {

    free(b->joined);
    free(b->threads);
    yacht_free(b->fd_set, NULL, NULL);

    pthread_mutex_destroy(&b->log_lock);

+2 −0
Original line number Diff line number Diff line
@@ -77,4 +77,6 @@ void bus_free(struct bus *b);
bool bus_process_boxed_message(struct bus *b,
    struct boxed_msg *box, size_t *backpressure);

void bus_backpressure_delay(struct bus *b, size_t backpressure, uint8_t shift);

#endif
+18 −11
Original line number Diff line number Diff line
@@ -56,6 +56,7 @@ typedef struct boxed_msg {
    int64_t out_seq_id;
    uint8_t *out_msg;
    size_t out_msg_size;
    size_t out_sent_size;
} boxed_msg;

#define BUS_NO_SSL ((SSL *)-2)
@@ -72,9 +73,6 @@ typedef struct bus {
    bus_log_cb *log_cb;
    pthread_mutex_t log_lock;

    uint8_t sender_count;
    struct sender **senders;

    uint8_t listener_count;
    struct listener **listeners;

@@ -84,9 +82,9 @@ typedef struct bus {
    struct threadpool *threadpool;
    SSL_CTX *ssl_ctx;

    /* Locked hash table for fd -> (SSL * | BUS_NO_SSL) */
    /* Locked hash table for fd -> connection_info */
    struct yacht *fd_set;
    pthread_mutex_t fd_set_lock;
    pthread_rwlock_t fd_set_lock;
} bus;

/* Special timeout value indicating UNBOUND. */
@@ -104,15 +102,24 @@ typedef enum {

/* Per-socket connection context. (Owned by the listener.) */
typedef struct {
    int fd;
    rx_error_t error;
    size_t to_read_size;
    int64_t largest_seq_id_seen;
    /* Shared */
    const int fd;
    const bus_socket_t type;
    void *udata;                /* user connection data */

    /* Shared, cleaned up by sender */
    SSL *ssl;                   /* SSL handle. Must be valid or BUS_NO_SSL. */

    bus_socket_t type;
    void *udata;                /* user connection data */
    /* Set by client thread */
    int64_t largest_wr_seq_id_seen;

    /* Set by listener thread */
    rx_error_t error;
    size_t to_read_size;
    int64_t largest_rd_seq_id_seen;
} connection_info;

/* Arbitrary byte used to tag writes from the listener. */
#define LISTENER_MSG_TAG 0x15

#endif
+22 −24
Original line number Diff line number Diff line
@@ -29,7 +29,7 @@
static bool init_client_SSL_CTX(SSL_CTX **ctx_out);
static void disable_SSL_compression(void);
static void disable_known_bad_ciphers(SSL_CTX *ctx);
static bool do_blocking_connection(struct bus *b, connection_info *ci);
static bool do_blocking_connection(struct bus *b, SSL *ssl, int fd);

/* Initialize the SSL library internals for use by the messaging bus. */
bool bus_ssl_init(struct bus *b) {
@@ -47,26 +47,24 @@ bool bus_ssl_init(struct bus *b) {

/* Do an SSL / TLS shake for a connection. Blocking.
 * Returns whether the connection succeeded. */
bool bus_ssl_connect(struct bus *b, connection_info *ci) {
SSL *bus_ssl_connect(struct bus *b, int fd) {
    SSL *ssl = NULL;

    ssl = SSL_new(b->ssl_ctx);
    if (ssl == NULL) {
        ERR_print_errors_fp(stderr);
        return false;
        return NULL;
    }
    ci->ssl = ssl;

    if (!SSL_set_fd(ci->ssl, ci->fd)) {
        return false;
    if (!SSL_set_fd(ssl, fd)) {
        return NULL;
    }

    if (do_blocking_connection(b, ci)) {
        return true;
    if (do_blocking_connection(b, ssl, fd)) {
        return ssl;
    } else {
        SSL_free(ci->ssl);
        ci->ssl = NULL;
        return false;
        SSL_free(ssl);
        return NULL;
    }
}

@@ -128,12 +126,12 @@ static void disable_known_bad_ciphers(SSL_CTX *ctx) {
    assert(res == 1);
}

static bool do_blocking_connection(struct bus *b, connection_info *ci) {
static bool do_blocking_connection(struct bus *b, SSL *ssl, int fd) {
    BUS_LOG_SNPRINTF(b, 2, LOG_SOCKET_REGISTERED, b->udata, 128,
        "SSL_Connect handshake for socket %d", ci->fd);
        "SSL_Connect handshake for socket %d", fd);

    struct pollfd fds[1];
    fds[0].fd = ci->fd;
    fds[0].fd = fd;
    fds[0].events = POLLOUT;
    
    bool connected = false;
@@ -142,7 +140,7 @@ static bool do_blocking_connection(struct bus *b, connection_info *ci) {
    while (!connected) {
        int pres = poll(fds, 1, TIMEOUT_MSEC);
        BUS_LOG_SNPRINTF(b, 5, LOG_SOCKET_REGISTERED, b->udata, 128,
            "SSL_Connect handshake for socket %d, poll res %d", ci->fd, pres);
            "SSL_Connect handshake for socket %d, poll res %d", fd, pres);

        if (pres < 0) {
            if (util_is_resumable_io_error(errno)) {
@@ -153,16 +151,16 @@ static bool do_blocking_connection(struct bus *b, connection_info *ci) {
            }
        } else if (pres > 0) {
            if (fds[0].revents & (POLLOUT | POLLIN)) {
                int connect_res = SSL_connect(ci->ssl);
                int connect_res = SSL_connect(ssl);
                BUS_LOG_SNPRINTF(b, 5, LOG_SOCKET_REGISTERED, b->udata, 128,
                    "socket %d: connect_res %d", ci->fd, connect_res);
                    "socket %d: connect_res %d", fd, connect_res);

                if (connect_res == 1) {
                    BUS_LOG_SNPRINTF(b, 5, LOG_SOCKET_REGISTERED, b->udata, 128,
                        "socket %d: successfully connected", ci->fd);
                        "socket %d: successfully connected", fd);
                    connected = true;
                } else if (connect_res < 0) {
                    int reason = SSL_get_error(ci->ssl, connect_res);
                    int reason = SSL_get_error(ssl, connect_res);

                    switch (reason) {
                    case SSL_ERROR_WANT_WRITE:
@@ -184,7 +182,7 @@ static bool do_blocking_connection(struct bus *b, connection_info *ci) {
                            unsigned long errval = ERR_get_error();
                            char ebuf[256];
                            BUS_LOG_SNPRINTF(b, 5, LOG_SOCKET_REGISTERED, b->udata, 128,
                                "socket %d: ERROR -- %s", ci->fd, ERR_error_string(errval, ebuf));
                                "socket %d: ERROR -- %s", fd, ERR_error_string(errval, ebuf));
                        }
                    }
                    break;
@@ -193,7 +191,7 @@ static bool do_blocking_connection(struct bus *b, connection_info *ci) {
                        unsigned long errval = ERR_get_error();
                        char ebuf[256];
                        BUS_LOG_SNPRINTF(b, 5, LOG_SOCKET_REGISTERED, b->udata, 128,
                            "socket %d: ERROR %d -- %s", ci->fd, reason, ERR_error_string(errval, ebuf));
                            "socket %d: ERROR %d -- %s", fd, reason, ERR_error_string(errval, ebuf));
                        assert(false);
                    }
                    }
@@ -201,16 +199,16 @@ static bool do_blocking_connection(struct bus *b, connection_info *ci) {
                } else {
                    BUS_LOG_SNPRINTF(b, 5, LOG_SOCKET_REGISTERED, b->udata, 128,
                        "socket %d: unknown state, setting event bask to (POLLIN | POLLOUT)",
                        ci->fd);
                        fd);
                    fds[0].events = (POLLIN | POLLOUT);
                }
            } else if (fds[0].revents & POLLHUP) {
                BUS_LOG_SNPRINTF(b, 2, LOG_SOCKET_REGISTERED, b->udata, 128,
                    "SSL_Connect: HUP on %d", ci->fd);
                    "SSL_Connect: HUP on %d", fd);
                return false;
            } else if (fds[0].revents & POLLERR) {
                BUS_LOG_SNPRINTF(b, 2, LOG_SOCKET_REGISTERED, b->udata, 128,
                    "SSL_Connect: ERR on %d", ci->fd);
                    "SSL_Connect: ERR on %d", fd);
                return false;
            }
        } else {
+2 −3
Original line number Diff line number Diff line
@@ -31,9 +31,8 @@
/* Initialize the SSL library internals for use by the messaging bus. */
bool bus_ssl_init(struct bus *b);

/* Do an SSL / TLS shake for a connection. Blocking.
 * Returns whether the connection succeeded. */
bool bus_ssl_connect(struct bus *b, connection_info *ci);
/* Do an SSL / TLS shake for a connection. Blocking. */
SSL *bus_ssl_connect(struct bus *b, int fd);

/* Disconnect and free an individual SSL handle. */
bool bus_ssl_disconnect(struct bus *b, SSL *ssl);
Loading