diff --git a/async.c b/async.c index 6555114..3dad137 100644 --- a/async.c +++ b/async.c @@ -148,6 +148,7 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) { ac->sub.replies.tail = NULL; ac->sub.channels = channels; ac->sub.patterns = patterns; + ac->sub.pending_unsubs = 0; return ac; oom: @@ -411,11 +412,11 @@ void redisAsyncDisconnect(redisAsyncContext *ac) { static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, redisCallback *dstcb) { redisContext *c = &(ac->c); dict *callbacks; - redisCallback *cb; + redisCallback *cb = NULL; dictEntry *de; int pvariant; char *stype; - sds sname; + sds sname = NULL; /* Match reply with the expected format of a pushed message. * The type and number of elements (3 to 4) are specified at: @@ -432,42 +433,43 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, callbacks = ac->sub.channels; /* Locate the right callback */ - assert(reply->element[1]->type == REDIS_REPLY_STRING); - sname = sdsnewlen(reply->element[1]->str,reply->element[1]->len); - if (sname == NULL) - goto oom; + if (reply->element[1]->type == REDIS_REPLY_STRING) { + sname = sdsnewlen(reply->element[1]->str,reply->element[1]->len); + if (sname == NULL) goto oom; - de = dictFind(callbacks,sname); - if (de != NULL) { - cb = dictGetEntryVal(de); - - /* If this is an subscribe reply decrease pending counter. */ - if (strcasecmp(stype+pvariant,"subscribe") == 0) { - cb->pending_subs -= 1; + if ((de = dictFind(callbacks,sname)) != NULL) { + cb = dictGetEntryVal(de); + memcpy(dstcb,cb,sizeof(*dstcb)); } + } - memcpy(dstcb,cb,sizeof(*dstcb)); + /* If this is an subscribe reply decrease pending counter. */ + if (strcasecmp(stype+pvariant,"subscribe") == 0) { + assert(cb != NULL); + cb->pending_subs -= 1; - /* If this is an unsubscribe message, remove it. */ - if (strcasecmp(stype+pvariant,"unsubscribe") == 0) { - if (cb->pending_subs == 0) - dictDelete(callbacks,sname); + } else if (strcasecmp(stype+pvariant,"unsubscribe") == 0) { + if (cb == NULL) + ac->sub.pending_unsubs -= 1; + else if (cb->pending_subs == 0) + dictDelete(callbacks,sname); - /* If this was the last unsubscribe message, revert to - * non-subscribe mode. */ - assert(reply->element[2]->type == REDIS_REPLY_INTEGER); + /* If this was the last unsubscribe message, revert to + * non-subscribe mode. */ + assert(reply->element[2]->type == REDIS_REPLY_INTEGER); - /* Unset subscribed flag only when no pipelined pending subscribe. */ - if (reply->element[2]->integer == 0 - && dictSize(ac->sub.channels) == 0 - && dictSize(ac->sub.patterns) == 0) { - c->flags &= ~REDIS_SUBSCRIBED; + /* Unset subscribed flag only when no pipelined pending subscribe + * or pending unsubscribe replies. */ + if (reply->element[2]->integer == 0 + && dictSize(ac->sub.channels) == 0 + && dictSize(ac->sub.patterns) == 0 + && ac->sub.pending_unsubs == 0) { + c->flags &= ~REDIS_SUBSCRIBED; - /* Move ongoing regular command callbacks. */ - redisCallback cb; - while (__redisShiftCallback(&ac->sub.replies,&cb) == REDIS_OK) { - __redisPushCallback(&ac->replies,&cb); - } + /* Move ongoing regular command callbacks. */ + redisCallback cb; + while (__redisShiftCallback(&ac->sub.replies,&cb) == REDIS_OK) { + __redisPushCallback(&ac->replies,&cb); } } } @@ -540,7 +542,7 @@ void redisProcessCallbacks(redisAsyncContext *ac) { /* Even if the context is subscribed, pending regular * callbacks will get a reply before pub/sub messages arrive. */ - redisCallback cb = {NULL, NULL, 0, NULL}; + redisCallback cb = {NULL, NULL, 0, 0, NULL}; if (__redisShiftCallback(&ac->replies,&cb) != REDIS_OK) { /* * A spontaneous reply in a not-subscribed context can be the error @@ -757,6 +759,7 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void redisContext *c = &(ac->c); redisCallback cb; struct dict *cbdict; + dictIterator it; dictEntry *de; redisCallback *existcb; int pvariant, hasnext; @@ -773,6 +776,7 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void cb.fn = fn; cb.privdata = privdata; cb.pending_subs = 1; + cb.unsubscribe_sent = 0; /* Find out which command will be appended. */ p = nextArgument(cmd,&cstr,&clen); @@ -812,6 +816,51 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void * subscribed to one or more channels or patterns. */ if (!(c->flags & REDIS_SUBSCRIBED)) return REDIS_ERR; + if (pvariant) + cbdict = ac->sub.patterns; + else + cbdict = ac->sub.channels; + + if (hasnext) { + /* Send an unsubscribe with specific channels/patterns. + * Bookkeeping the number of expected replies */ + while ((p = nextArgument(p,&astr,&alen)) != NULL) { + sname = sdsnewlen(astr,alen); + if (sname == NULL) + goto oom; + + de = dictFind(cbdict,sname); + if (de != NULL) { + existcb = dictGetEntryVal(de); + if (existcb->unsubscribe_sent == 0) + existcb->unsubscribe_sent = 1; + else + /* Already sent, reply to be ignored */ + ac->sub.pending_unsubs += 1; + } else { + /* Not subscribed to, reply to be ignored */ + ac->sub.pending_unsubs += 1; + } + sdsfree(sname); + } + } else { + /* Send an unsubscribe without specific channels/patterns. + * Bookkeeping the number of expected replies */ + int no_subs = 1; + dictInitIterator(&it,cbdict); + while ((de = dictNext(&it)) != NULL) { + existcb = dictGetEntryVal(de); + if (existcb->unsubscribe_sent == 0) { + existcb->unsubscribe_sent = 1; + no_subs = 0; + } + } + /* Unsubscribing to all channels/patterns, where none is + * subscribed to, results in a single reply to be ignored. */ + if (no_subs == 1) + ac->sub.pending_unsubs += 1; + } + /* (P)UNSUBSCRIBE does not have its own response: every channel or * pattern that is unsubscribed will receive a message. This means we * should not append a callback function for this command. */ diff --git a/async.h b/async.h index 4c65203..41951d4 100644 --- a/async.h +++ b/async.h @@ -46,6 +46,7 @@ typedef struct redisCallback { struct redisCallback *next; /* simple singly linked list */ redisCallbackFn *fn; int pending_subs; + int unsubscribe_sent; void *privdata; } redisCallback; @@ -105,6 +106,7 @@ typedef struct redisAsyncContext { redisCallbackList replies; struct dict *channels; struct dict *patterns; + int pending_unsubs; } sub; /* Any configured RESP3 PUSH handler */ diff --git a/test.c b/test.c index f991ef1..f43bc24 100644 --- a/test.c +++ b/test.c @@ -1729,10 +1729,14 @@ void subscribe_channel_a_cb(redisAsyncContext *ac, void *r, void *privdata) { strcmp(reply->element[2]->str,"Hello!") == 0); state->checkpoint++; - /* Unsubscribe to channels, including a channel X which we don't subscribe to */ + /* Unsubscribe to channels, including channel X & Z which we don't subscribe to */ redisAsyncCommand(ac,unexpected_cb, (void*)"unsubscribe should not call unexpected_cb()", - "unsubscribe B X A"); + "unsubscribe B X A A Z"); + /* Unsubscribe to patterns, none which we subscribe to */ + redisAsyncCommand(ac,unexpected_cb, + (void*)"punsubscribe should not call unexpected_cb()", + "punsubscribe"); /* Send a regular command after unsubscribing, then disconnect */ state->disconnect = 1; redisAsyncCommand(ac,integer_cb,state,"LPUSH mylist foo"); @@ -1767,8 +1771,10 @@ void subscribe_channel_b_cb(redisAsyncContext *ac, void *r, void *privdata) { /* Test handling of multiple channels * - subscribe to channel A and B - * - a published message on A triggers an unsubscribe of channel B, X and A - * where channel X is not subscribed to. + * - a published message on A triggers an unsubscribe of channel B, X, A and Z + * where channel X and Z are not subscribed to. + * - the published message also triggers an unsubscribe to patterns. Since no + * pattern is subscribed to the responded pattern element type is NIL. * - a command sent after unsubscribe triggers a disconnect */ static void test_pubsub_multiple_channels(struct config config) { test("Subscribe to multiple channels: ");