diff --git a/Makefile b/Makefile index ea96419..b723245 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,8 @@ # This file is released under the BSD license, see the COPYING file OBJ=net.o hiredis.o sds.o async.o read.o sslio.o -EXAMPLES=hiredis-example hiredis-example-libevent hiredis-example-libev hiredis-example-glib hiredis-example-ssl +EXAMPLES=hiredis-example hiredis-example-libevent hiredis-example-libev hiredis-example-glib \ + hiredis-example-ssl hiredis-example-libevent-ssl TESTS=hiredis-test LIBNAME=libhiredis PKGCONFNAME=hiredis.pc @@ -94,6 +95,9 @@ static: $(STLIBNAME) hiredis-example-libevent: examples/example-libevent.c adapters/libevent.h $(STLIBNAME) $(CC) -o examples/$@ $(REAL_CFLAGS) $(REAL_LDFLAGS) -I. $< -levent $(STLIBNAME) +hiredis-example-libevent-ssl: examples/example-libevent-ssl.c adapters/libevent.h $(STLIBNAME) + $(CC) -o examples/$@ $(REAL_CFLAGS) $(REAL_LDFLAGS) -I. $< -levent $(STLIBNAME) + hiredis-example-libev: examples/example-libev.c adapters/libev.h $(STLIBNAME) $(CC) -o examples/$@ $(REAL_CFLAGS) $(REAL_LDFLAGS) -I. $< -lev $(STLIBNAME) diff --git a/async.c b/async.c index 0cecd30..5a14d45 100644 --- a/async.c +++ b/async.c @@ -40,6 +40,7 @@ #include "net.h" #include "dict.c" #include "sds.h" +#include "sslio.h" #define _EL_ADD_READ(ctx) do { \ if ((ctx)->ev.addRead) (ctx)->ev.addRead((ctx)->ev.data); \ @@ -524,6 +525,87 @@ static int __redisAsyncHandleConnect(redisAsyncContext *ac) { } } +#ifndef HIREDIS_NOSSL +/** + * Handle SSL when socket becomes available for reading. This also handles + * read-while-write and write-while-read + */ +static void asyncSslRead(redisAsyncContext *ac) { + int rv; + redisSsl *ssl = ac->c.ssl; + redisContext *c = &ac->c; + + ssl->wantRead = 0; + + if (ssl->pendingWrite) { + int done; + + /* This is probably just a write event */ + ssl->pendingWrite = 0; + rv = redisBufferWrite(c, &done); + if (rv == REDIS_ERR) { + __redisAsyncDisconnect(ac); + return; + } else if (!done) { + _EL_ADD_WRITE(ac); + } + } + + rv = redisBufferRead(c); + if (rv == REDIS_ERR) { + __redisAsyncDisconnect(ac); + } else { + _EL_ADD_READ(ac); + redisProcessCallbacks(ac); + } +} + +/** + * Handle SSL when socket becomes available for writing + */ +static void asyncSslWrite(redisAsyncContext *ac) { + int rv, done = 0; + redisSsl *ssl = ac->c.ssl; + redisContext *c = &ac->c; + + ssl->pendingWrite = 0; + rv = redisBufferWrite(c, &done); + if (rv == REDIS_ERR) { + __redisAsyncDisconnect(ac); + return; + } + + if (!done) { + if (ssl->wantRead) { + /* Need to read-before-write */ + ssl->pendingWrite = 1; + _EL_DEL_WRITE(ac); + } else { + /* No extra reads needed, just need to write more */ + _EL_ADD_WRITE(ac); + } + } else { + /* Already done! */ + _EL_DEL_WRITE(ac); + } + + /* Always reschedule a read */ + _EL_ADD_READ(ac); +} +#else + +/* Just so we're able to compile */ +static void asyncSslRead(redisAsyncContext *ac) { + abort(); + (void)ac; +} +static void asyncSslWrite(redisAsyncContext *ac) { + abort(); + (void)ac; +} + +#endif + /* This function should be called when the socket is readable. * It processes all replies that can be read and executes their callbacks. */ @@ -539,6 +621,11 @@ void redisAsyncHandleRead(redisAsyncContext *ac) { return; } + if (c->flags & REDIS_SSL) { + asyncSslRead(ac); + return; + } + if (redisBufferRead(c) == REDIS_ERR) { __redisAsyncDisconnect(ac); } else { @@ -561,6 +648,11 @@ void redisAsyncHandleWrite(redisAsyncContext *ac) { return; } + if (c->flags & REDIS_SSL) { + asyncSslWrite(ac); + return; + } + if (redisBufferWrite(c,&done) == REDIS_ERR) { __redisAsyncDisconnect(ac); } else { diff --git a/examples/example-libevent-ssl.c b/examples/example-libevent-ssl.c new file mode 100644 index 0000000..f780e3e --- /dev/null +++ b/examples/example-libevent-ssl.c @@ -0,0 +1,72 @@ +#include +#include +#include +#include + +#include +#include +#include + +void getCallback(redisAsyncContext *c, void *r, void *privdata) { + redisReply *reply = r; + if (reply == NULL) return; + printf("argv[%s]: %s\n", (char*)privdata, reply->str); + + /* Disconnect after receiving the reply to GET */ + redisAsyncDisconnect(c); +} + +void connectCallback(const redisAsyncContext *c, int status) { + if (status != REDIS_OK) { + printf("Error: %s\n", c->errstr); + return; + } + printf("Connected...\n"); +} + +void disconnectCallback(const redisAsyncContext *c, int status) { + if (status != REDIS_OK) { + printf("Error: %s\n", c->errstr); + return; + } + printf("Disconnected...\n"); +} + +int main (int argc, char **argv) { + signal(SIGPIPE, SIG_IGN); + struct event_base *base = event_base_new(); + if (argc < 5) { + fprintf(stderr, + "Usage: %s [ca]\n", argv[0]); + exit(1); + } + + const char *value = argv[1]; + size_t nvalue = strlen(value); + + const char *hostname = argv[2]; + int port = atoi(argv[3]); + + const char *cert = argv[4]; + const char *certKey = argv[5]; + const char *caCert = argc > 5 ? argv[6] : NULL; + + redisAsyncContext *c = redisAsyncConnect(hostname, port); + if (c->err) { + /* Let *c leak for now... */ + printf("Error: %s\n", c->errstr); + return 1; + } + if (redisSecureConnection(&c->c, caCert, cert, certKey) != REDIS_OK) { + printf("SSL Error!\n"); + exit(1); + } + + redisLibeventAttach(c,base); + redisAsyncSetConnectCallback(c,connectCallback); + redisAsyncSetDisconnectCallback(c,disconnectCallback); + redisAsyncCommand(c, NULL, NULL, "SET key %b", value, nvalue); + redisAsyncCommand(c, getCallback, (char*)"end-1", "GET key"); + event_base_dispatch(base); + return 0; +} diff --git a/sslio.c b/sslio.c index 6958d37..3b08140 100644 --- a/sslio.c +++ b/sslio.c @@ -99,7 +99,7 @@ int redisSslCreate(redisContext *c, const char *capath, const char *certpath, redisSsl *s = c->ssl; s->ctx = SSL_CTX_new(SSLv23_client_method()); - /* SSL_CTX_set_info_callback(s->ctx, sslLogCallback); */ + SSL_CTX_set_info_callback(s->ctx, sslLogCallback); SSL_CTX_set_mode(s->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); SSL_CTX_set_options(s->ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); SSL_CTX_set_verify(s->ctx, SSL_VERIFY_PEER, NULL); @@ -153,6 +153,22 @@ int redisSslCreate(redisContext *c, const char *capath, const char *certpath, return REDIS_ERR; } +static int maybeCheckWant(redisSsl *rssl, int rv) { + /** + * If the error is WANT_READ or WANT_WRITE, the appropriate flags are set + * and true is returned. False is returned otherwise + */ + if (rv == SSL_ERROR_WANT_READ) { + rssl->wantRead = 1; + return 1; + } else if (rv == SSL_ERROR_WANT_WRITE) { + rssl->pendingWrite = 1; + return 1; + } else { + return 0; + } +} + int redisSslRead(redisContext *c, char *buf, size_t bufcap) { int nread = SSL_read(c->ssl->ssl, buf, bufcap); if (nread > 0) { @@ -162,7 +178,7 @@ int redisSslRead(redisContext *c, char *buf, size_t bufcap) { return -1; } else { int err = SSL_get_error(c->ssl->ssl, nread); - if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { + if (maybeCheckWant(c->ssl, err)) { return 0; } else { __redisSetError(c, REDIS_ERR_IO, NULL); @@ -181,7 +197,7 @@ int redisSslWrite(redisContext *c) { c->ssl->lastLen = len; int err = SSL_get_error(c->ssl->ssl, rv); - if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { + if (maybeCheckWant(c->ssl, err)) { return 0; } else { __redisSetError(c, REDIS_ERR_IO, NULL); diff --git a/sslio.h b/sslio.h index a410cb3..1f46b03 100644 --- a/sslio.h +++ b/sslio.h @@ -33,6 +33,15 @@ typedef struct redisSsl { * previously called with in the event of an SSL_read/SSL_write situation */ size_t lastLen; + + /** Whether the SSL layer requires read (possibly before a write) */ + int wantRead; + + /** + * Whether a write was requested prior to a read. If set, the write() + * should resume whenever a read takes place, if possible + */ + int pendingWrite; } redisSsl; struct redisContext;