Commit 0d0898b4 authored by Scott Vokes's avatar Scott Vokes
Browse files

Refactor listener to receive status messages connectly on SSL connections.

The poll(2) / read(2) code needed to be reworked to reflect that SSL is
doing its own buffering -- poll may indicate that a socket has no more
to read, but there may still be data in the SSL socket's BIO. If poll
indicates data is available, keep reeding it and sinking it until
WANT_READ is returned. With this change, test_system_bus is able to
get status messages through SSL.

Also, change the hash table from tracking a (file descriptor) set to
tracking a (file descriptor -> fd_info) map, in preparation for
tracking which connections are plain or SSL in the sender. The
sender is not fully updated to SSL yet.
parent aa30bfb7
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ PUB_INC = ./include
#-------------------------------------------------------------------------------

# FIXME: Currently OSX/homebrew specific, rework before integration.
#OPENSSL_PATH ?=	/usr/local/Cellar/openssl/1.0.1j_1
OPENSSL_PATH ?=	.

#===============================================================================
# Shared Build Variables
+2 −2
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ LIB_PATH= ../../../obj
OPT=		-O3
LIB_INC =	-I${SOCKET99_PATH} -I${THREADPOOL_PATH}
CFLAGS +=	-std=c99 ${OPT} -Wall -g ${LIB_INC}
LDFLAGS +=	-L. -lsocket99 -L${LIB_PATH} -lthreadpool
LDFLAGS +=	-L. -lsocket99 -L${LIB_PATH}

BUS_OBJS = \
	bus.o \
@@ -42,7 +42,7 @@ echosrv: ${ECHOSRV_OBJS}
	${CC} -o $@ $^ ${LDFLAGS}

bus_example: bus_example.o libbus.a
	${CC} -o $@ $^ ${LDFLAGS} -lbus
	${CC} -o $@ $^ ${LDFLAGS} -lbus -lthreadpool

clean:
	rm -f *.a *.o test_casq echosrv bus_example
+16 −11
Original line number Diff line number Diff line
@@ -29,8 +29,6 @@ bool bus_ssl_init(struct bus *b) {
/* Do an SSL / TLS shake for a connection. Blocking.
 * Returns whether the connection succeeded. */
bool bus_ssl_connect(struct bus *b, connection_info *ci) {
    //if (!init_client_SSL_CTX(b, ci)) { return false; }

    SSL *ssl = NULL;

    ssl = SSL_new(b->ssl_ctx);
@@ -38,9 +36,12 @@ bool bus_ssl_connect(struct bus *b, connection_info *ci) {
        ERR_print_errors_fp(stderr);
        return false;
    }

    ci->ssl = ssl;

    if (!SSL_set_fd(ci->ssl, ci->fd)) {
        return false;;
    }

    if (do_blocking_connection(b, ci)) {
        return true;
    } else {
@@ -62,12 +63,15 @@ static bool init_client_SSL_CTX(SSL_CTX **ctx_out) {
    assert(ctx_out);

    /* Create TLS context */
#if KINETIC_USE_TLS_1_2
    const SSL_METHOD *method = TLSv1_2_client_method();
#else
    const SSL_METHOD *method = TLSv1_1_client_method();
#endif
    const SSL_METHOD *method = NULL;

    if (KINETIC_USE_TLS_1_2) {
        method = TLSv1_2_client_method();
    } else {
        method = TLSv1_1_client_method();
    }

    assert(method);
    ctx = SSL_CTX_new(method);
    if (ctx == NULL) {
        ERR_print_errors_fp(stderr);
@@ -123,6 +127,8 @@ static bool do_blocking_connection(struct bus *b, connection_info *ci) {
        } else if (pres > 0) {
            if (fds[0].revents & (POLLOUT | POLLIN)) {
                int connect_res = SSL_connect(ci->ssl);
                BUS_LOG_SNPRINTF(b, 5, LOG_SOCKET_REGISTERED, b->udata, 128,
                    "socket %d: connect_res %d", ci->fd, connect_res);

                if (connect_res == 1) {
                    BUS_LOG_SNPRINTF(b, 5, LOG_SOCKET_REGISTERED, b->udata, 128,
@@ -157,12 +163,10 @@ static bool do_blocking_connection(struct bus *b, connection_info *ci) {
                    break;
                    default:
                    {
                        BUS_LOG_SNPRINTF(b, 5, LOG_SOCKET_REGISTERED, b->udata, 128,
                            "socket %d: ERROR -- reason %d", ci->fd, reason);
                        unsigned long errval = ERR_get_error();
                        char ebuf[256];
                        BUS_LOG_SNPRINTF(b, 5, LOG_SOCKET_REGISTERED, b->udata, 128,
                            "socket %d: ERROR -- %s", ci->fd, ERR_error_string(errval, ebuf));
                            "socket %d: ERROR %d -- %s", ci->fd, reason, ERR_error_string(errval, ebuf));
                        assert(false);
                    }
                    }
@@ -186,6 +190,7 @@ static bool do_blocking_connection(struct bus *b, connection_info *ci) {
            BUS_LOG(b, 4, LOG_SOCKET_REGISTERED, "poll timeout", b->udata);
            elapsed += TIMEOUT_MSEC;
            if (elapsed > MAX_TIMEOUT) {
                BUS_LOG(b, 2, LOG_SOCKET_REGISTERED, "timed out", b->udata);
                return false;
            }
        }
+135 −33
Original line number Diff line number Diff line
@@ -299,6 +299,21 @@ static void set_error_for_socket(listener *l, int id, int fd, rx_error_t err) {
    l->fds[id].events &= ~POLLIN;
}

static void print_SSL_error(struct bus *b, connection_info *ci, int lvl, const char *prefix) {
    unsigned long errval = ERR_get_error();
    char ebuf[256];
    BUS_LOG_SNPRINTF(b, lvl, LOG_LISTENER, b->udata, 64,
        "%s -- ERROR on fd %d -- %s",
        prefix, ci->fd, ERR_error_string(errval, ebuf));
}

static bool socket_read_plain(struct bus *b,
    listener *l, int pfd_i, connection_info *ci);
static bool socket_read_ssl(struct bus *b,
    listener *l, int pfd_i, connection_info *ci);
static bool sink_socket_read(struct bus *b,
    listener *l, connection_info *ci, ssize_t size);

static void attempt_recv(listener *l, int available) {
    /*   --> failure --> close socket, don't die */
    struct bus *b = l->bus;
@@ -325,21 +340,110 @@ static void attempt_recv(listener *l, int available) {
                ci->to_read_size, l->read_buf_size);
            assert(l->read_buf_size >= ci->to_read_size);
            
            ssize_t size = read(fd->fd, l->read_buf, ci->to_read_size);
            switch (ci->type) {
                
            case BUS_SOCKET_PLAIN:
                socket_read_plain(b, l, i, ci);
                break;
            case BUS_SOCKET_SSL:
                socket_read_ssl(b, l, i, ci);
                break;
            default:
                assert(false);
            }
        }
    }
}
    
static bool socket_read_plain(struct bus *b, listener *l, int pfd_i, connection_info *ci) {
    ssize_t size = read(ci->fd, l->read_buf, ci->to_read_size);
    if (size == -1) {
        if (util_is_resumable_io_error(errno)) {
            errno = 0;
        } else {
            BUS_LOG_SNPRINTF(b, 3, LOG_LISTENER, b->udata, 64,
                "read: socket error reading, %d", errno);
                    set_error_for_socket(l, i, ci->fd, RX_ERROR_READ_FAILURE);
            set_error_for_socket(l, pfd_i, ci->fd, RX_ERROR_READ_FAILURE);
            errno = 0;
        }
    }
    
    if (size > 0) {
        return sink_socket_read(b, l, ci, size);
    } else {
        return false;
    }
}

static bool socket_read_ssl(struct bus *b, listener *l, int pfd_i, connection_info *ci) {
    for (;;) {
        ssize_t pending = SSL_pending(ci->ssl);
        ssize_t size = (ssize_t)SSL_read(ci->ssl, l->read_buf, ci->to_read_size);
        fprintf(stderr, "=== PENDING: %zd, got %zd ===\n", pending, size);
        
        if (size == -1) {
            int reason = SSL_get_error(ci->ssl, size);
            switch (reason) {
            case SSL_ERROR_WANT_READ:
                BUS_LOG_SNPRINTF(b, 3, LOG_LISTENER, b->udata, 64,
                    "SSL_read fd %d: WANT_READ\n", ci->fd);
                return true;
                
            case SSL_ERROR_WANT_WRITE:
                assert(false);
                
            case SSL_ERROR_SYSCALL:
            {
                if (errno == 0) {
                    print_SSL_error(b, ci, 1, "SSL_ERROR_SYSCALL errno 0");
                    assert(false);
                } else if (util_is_resumable_io_error(errno)) {
                errno = 0;
                } else {
                    BUS_LOG_SNPRINTF(b, 3, LOG_LISTENER, b->udata, 64,
                        "SSL_read fd %d: errno %d\n", ci->fd, errno);
                    print_SSL_error(b, ci, 1, "SSL_ERROR_SYSCALL");
                    set_error_for_socket(l, pfd_i, ci->fd, RX_ERROR_READ_FAILURE);
                }
            }
            case SSL_ERROR_ZERO_RETURN:
            {
                BUS_LOG_SNPRINTF(b, 3, LOG_LISTENER, b->udata, 64,
                    "SSL_read fd %d: ZERO_RETURN (HUP)\n", ci->fd);
                set_error_for_socket(l, pfd_i, ci->fd, RX_ERROR_POLLHUP);
                break;
            }
            
            default:
                print_SSL_error(b, ci, 1, "SSL_ERROR UNKNOWN");
                set_error_for_socket(l, pfd_i, ci->fd, RX_ERROR_READ_FAILURE);
                assert(false);
            }
        } else if (size > 0) {
            sink_socket_read(b, l, ci, size);
        }
    }
    return true;
}

#define DUMP_READ 0

static bool sink_socket_read(struct bus *b,
        listener *l, connection_info *ci, ssize_t size) {
    BUS_LOG_SNPRINTF(b, 3, LOG_LISTENER, b->udata, 64,
        "read %zd bytes, calling sink CB", size);
    
#if DUMP_READ
    bus_lock_log(b);
    printf("\n");
    for (int i = 0; i < size; i++) {
        if (i > 0 && (i & 15) == 0) { printf("\n"); }
        printf("%02x ", l->read_buf[i]);
    }
    printf("\n\n");
    bus_unlock_log(b);
#endif
    
    bus_sink_cb_res_t sres = b->sink_cb(l->read_buf, size, ci->udata);
    if (sres.full_msg_buffer) {
        BUS_LOG(b, 3, LOG_LISTENER, "calling unpack CB", b->udata);
@@ -361,9 +465,7 @@ static void attempt_recv(listener *l, int available) {
            assert(false);
        }
    }
            }
        }
    }
    return true;
}

static rx_info_t *find_info_by_sequence_id(listener *l,
+33 −5
Original line number Diff line number Diff line
@@ -66,7 +66,7 @@ struct sender *sender_init(struct bus *b, struct bus_config *cfg) {
    if (res != 0) {
        fprintf(stderr, "pthread_mutex_init: %s\n", strerror(res));
        free(s);
        yacht_free(s->fd_hash_table);
        yacht_free(s->fd_hash_table, NULL, NULL);
        return NULL;
    }

@@ -163,9 +163,19 @@ static bool add_fd_to_watch_set(struct sender *s, int fd) {
        s->fds[idx].fd = fd;
        s->fds[idx].events |= POLLOUT;
        s->active_fds++;
        if (!yacht_add(s->fd_hash_table, fd)) {
        fd_info *info = malloc(sizeof(*info));
        if (info == NULL) {
            if (0 != pthread_mutex_unlock(&s->watch_set_mutex)) {
                assert(false);
            }
            return false;
        }
        info->ssl = NULL;       /* FIXME: determine dataflow */
        void *old = NULL;
        if (!yacht_set(s->fd_hash_table, fd, info, &old)) {
            assert(false);
        }
        assert(old == NULL);
    }

    if (0 != pthread_mutex_unlock(&s->watch_set_mutex)) {
@@ -196,9 +206,13 @@ static bool remove_fd_from_watch_set(struct sender *s, tx_info_t *info) {
                    s->fds[i].fd = s->fds[s->active_fds - 1].fd;
                }
                s->active_fds--;
                if (!yacht_remove(s->fd_hash_table, info->fd)) {
                void *old = NULL;
                if (!yacht_remove(s->fd_hash_table, info->fd, &old)) {
                    assert(false);
                }
                fd_info *info = (fd_info *)old;
                assert(info);
                free(info);
                break;
            }
        }
@@ -441,7 +455,14 @@ static void attempt_write(sender *s, int available) {
            } else {
                BUS_LOG_SNPRINTF(b, 6, LOG_SENDER, b->udata, 64,
                    "writing %zd", rem);
                ssize_t wrsz = write(pfd->fd, &msg[info->sent_size], rem);

                /* TODO: if is_ssl according to fd_info
                 * . int wr_sz = SSL_write(ci->ssl, buf, buf_sz)
                 * . ssize_t wr_sz = write(ci->fd, buf, buf_sz) */
                
                ssize_t wrsz = 0;

                wrsz = write(pfd->fd, &msg[info->sent_size], rem);
                if (wrsz == -1) {
                    if (util_is_resumable_io_error(errno)) {
                        errno = 0;
@@ -560,11 +581,18 @@ static void notify_message_failure(sender *s, tx_info_t *info, bus_send_status_t
    release_tx_info(s, info);
}

static void free_fd_info_cb(void *value, void *udata) {
    fd_info *info = (fd_info *)value;
    /* Note: info->ssl will be freed by the listener. */
    (void)udata;
    free(info);
}

void sender_free(struct sender *s) {
    if (s) {
        int res = pthread_mutex_destroy(&s->watch_set_mutex);
        /* Must call sender_shutdown and wait for phtread_join first. */
        yacht_free(s->fd_hash_table);
        yacht_free(s->fd_hash_table, free_fd_info_cb, NULL);
        assert(res == 0);
        free(s);
    }
Loading