diff --git a/sys/src/cmd/aan.c b/sys/src/cmd/aan.c index 7931dfe96..6994312e8 100644 --- a/sys/src/cmd/aan.c +++ b/sys/src/cmd/aan.c @@ -4,12 +4,10 @@ #include #include -#define NS(x) ((vlong)x) -#define US(x) (NS(x) * 1000LL) -#define MS(x) (US(x) * 1000LL) -#define S(x) (MS(x) * 1000LL) - -#define LOGNAME "aan" +#define NS(x) ((vlong)x) +#define US(x) (NS(x) * 1000LL) +#define MS(x) (US(x) * 1000LL) +#define S(x) (MS(x) * 1000LL) enum { Synctime = S(8), @@ -17,7 +15,7 @@ enum { K = 1024, Bufsize = 8 * K, Stacksize = 8 * K, - Timer = 0, // Alt channels. + Timer = 0, // Alt channels. Unsent = 1, Maxto = 24 * 3600, // A full day to reconnect. }; @@ -31,29 +29,28 @@ struct Endpoints { }; typedef struct { - ulong nb; // Number of data bytes in this message - ulong msg; // Message number - ulong acked; // Number of messages acked + uchar nb[4]; // Number of data bytes in this message + uchar msg[4]; // Message number + uchar acked[4]; // Number of messages acked } Hdr; typedef struct t_Buf { - Hdr hdr; - uchar buf[Bufsize]; + Hdr hdr; + uchar buf[Bufsize]; } Buf; -static char *progname; static Channel *unsent; static Channel *unacked; static Channel *empty; -static int netfd; -static int inmsg; +static int netfd; +static int inmsg; static char *devdir; -static int debug; -static int done; +static int debug; +static int done; static char *dialstring; -static int maxto = Maxto; -static char *Logname = LOGNAME; -static int client; +static int maxto = Maxto; +static char *Logname = "aan"; +static int client; static Alt a[] = { /* c v op */ @@ -64,7 +61,7 @@ static Alt a[] = { static void fromnet(void*); static void fromclient(void*); -static void reconnect(void); +static void reconnect(int); static void synchronize(void); static int sendcommand(ulong, ulong); static void showmsg(int, char *, Buf *); @@ -72,13 +69,13 @@ static int writen(int, uchar *, int); static int getport(char *); static void dmessage(int, char *, ...); static void timerproc(void *); -static Endpoints *getendpoints(char *); +static Endpoints* getendpoints(char *); static void freeendpoints(Endpoints *); static void usage(void) { - fprint(2, "Usage: %s [-cd] [-m maxto] dialstring|netdir\n", progname); + fprint(2, "Usage: %s [-cd] [-m maxto] dialstring|netdir\n", argv0); exits("usage"); } @@ -86,22 +83,46 @@ static int catch(void *, char *s) { if (!strcmp(s, "alarm")) { - syslog(0, Logname, "Timed out while waiting for client on %s, exiting...", - devdir); + syslog(0, Logname, "Timed out while waiting for reconnect, exiting..."); threadexitsall(nil); } return 0; } +static void* +emalloc(int n) +{ + ulong pc; + void *v; + + pc = getcallerpc(&n); + v = malloc(n); + if(v == nil) + sysfatal("Cannot allocate memory; pc=%lux", pc); + setmalloctag(v, pc); + return v; +} + +static char* +estrdup(char *s) +{ + char *v; + int n; + + n = strlen(s)+1; + v = emalloc(n); + memmove(v, s, n); + return v; +} + void threadmain(int argc, char **argv) { - int i, failed; - Buf *b; - Channel *timer; vlong synctime; + int i, n, failed; + Channel *timer; + Buf *b; - progname = argv[0]; ARGBEGIN { case 'c': client++; @@ -110,7 +131,7 @@ threadmain(int argc, char **argv) debug++; break; case 'm': - maxto = (int)strtol(EARGF(usage()), (char **)nil, 0); + maxto = (int)strtol(EARGF(usage()), nil, 0); break; default: usage(); @@ -138,29 +159,33 @@ threadmain(int argc, char **argv) atnotify(catch, 1); + /* + * Set up initial connection. use short timeout + * of 60 seconds so we wont hang arround for too + * long if there is some general connection problem + * (like NAT). + */ + netfd = -1; + reconnect(60); + unsent = chancreate(sizeof(Buf *), Nbuf); unacked = chancreate(sizeof(Buf *), Nbuf); empty = chancreate(sizeof(Buf *), Nbuf); timer = chancreate(sizeof(uchar *), 1); + if(unsent == nil || unacked == nil || empty == nil || timer == nil) + sysfatal("Cannot allocate channels"); - for (i = 0; i != Nbuf; i++) { - Buf *b = malloc(sizeof(Buf)); - sendp(empty, b); - } - - netfd = -1; + for (i = 0; i < Nbuf; i++) + sendp(empty, emalloc(sizeof(Buf))); if (proccreate(fromnet, nil, Stacksize) < 0) - sysfatal("%s; Cannot start fromnet; %r", progname); - - reconnect(); // Set up the initial connection. - synchronize(); + sysfatal("Cannot start fromnet; %r"); if (proccreate(fromclient, nil, Stacksize) < 0) - sysfatal("%s; Cannot start fromclient; %r", progname); + sysfatal("Cannot start fromclient; %r"); if (proccreate(timerproc, timer, Stacksize) < 0) - sysfatal("%s; Cannot start timerproc; %r", progname); + sysfatal("Cannot start timerproc; %r"); a[Timer].c = timer; a[Unsent].c = unsent; @@ -169,9 +194,6 @@ threadmain(int argc, char **argv) synctime = nsec() + Synctime; failed = 0; while (!done) { - vlong now; - int delta; - if (failed) { // Wait for the netreader to die. while (netfd >= 0) { @@ -180,20 +202,17 @@ threadmain(int argc, char **argv) } // the reader died; reestablish the world. - reconnect(); + reconnect(maxto); synchronize(); failed = 0; } - now = nsec(); - delta = (synctime - nsec()) / MS(1); - - if (delta <= 0) { + if (nsec() >= synctime) { Hdr hdr; - hdr.nb = 0; - hdr.acked = inmsg; - hdr.msg = -1; + PBIT32(hdr.nb, 0); + PBIT32(hdr.acked, inmsg); + PBIT32(hdr.msg, -1); if (writen(netfd, (uchar *)&hdr, sizeof(Hdr)) < 0) { dmessage(2, "main; writen failed; %r\n"); @@ -201,7 +220,6 @@ threadmain(int argc, char **argv) continue; } synctime = nsec() + Synctime; - assert(synctime > now); } switch (alt(a)) { @@ -211,19 +229,20 @@ threadmain(int argc, char **argv) case Unsent: sendp(unacked, b); - b->hdr.acked = inmsg; + PBIT32(b->hdr.acked, inmsg); if (writen(netfd, (uchar *)&b->hdr, sizeof(Hdr)) < 0) { dmessage(2, "main; writen failed; %r\n"); failed = 1; } - if (writen(netfd, b->buf, b->hdr.nb) < 0) { + n = GBIT32(b->hdr.nb); + if (writen(netfd, b->buf, n) < 0) { dmessage(2, "main; writen failed; %r\n"); failed = 1; } - if (b->hdr.nb == 0) + if (n == 0) done = 1; break; } @@ -237,48 +256,46 @@ static void fromclient(void*) { static int outmsg; + int n; + Buf *b; - for (;;) { - Buf *b; - - b = recvp(empty); - if ((int)(b->hdr.nb = read(0, b->buf, Bufsize)) <= 0) { - if ((int)b->hdr.nb < 0) + do { + b = recvp(empty); + n = read(0, b->buf, Bufsize); + if (n <= 0) { + if (n < 0) dmessage(2, "fromclient; Cannot read 9P message; %r\n"); else dmessage(2, "fromclient; Client terminated\n"); - b->hdr.nb = 0; + n = 0; } - b->hdr.msg = outmsg++; - + PBIT32(b->hdr.nb, n); + PBIT32(b->hdr.msg, outmsg); showmsg(1, "fromclient", b); sendp(unsent, b); - - if (b->hdr.nb == 0) - break; - } + outmsg++; + } while(n > 0); } static void fromnet(void*) { static int lastacked; + int n, m, len, acked; Buf *b; - b = (Buf *)malloc(sizeof(Buf)); - assert(b); - + b = emalloc(sizeof(Buf)); while (!done) { - int len, acked, i; - while (netfd < 0) { - dmessage(1, "fromnet; waiting for connection... (inmsg %d)\n", - inmsg); + if(done) + return; + dmessage(1, "fromnet; waiting for connection... (inmsg %d)\n", inmsg); sleep(1000); } // Read the header. - if ((len = readn(netfd, &b->hdr, sizeof(Hdr))) <= 0) { + len = readn(netfd, (uchar *)&b->hdr, sizeof(Hdr)); + if (len <= 0) { if (len < 0) dmessage(1, "fromnet; (hdr) network failure; %r\n"); else @@ -287,18 +304,27 @@ fromnet(void*) netfd = -1; continue; } - dmessage(2, "fromnet: Got message, size %d, nb %d, msg %d\n", len, - b->hdr.nb, b->hdr.msg); + n = GBIT32(b->hdr.nb); + m = GBIT32(b->hdr.msg); + acked = GBIT32(b->hdr.acked); + dmessage(2, "fromnet: Got message, size %d, nb %d, msg %d, acked %d, lastacked %d\n", + len, n, m, acked, lastacked); - if (b->hdr.nb == 0) { - if ((long)b->hdr.msg >= 0) { + if (n == 0) { + if (m >= 0) { dmessage(1, "fromnet; network closed\n"); break; } continue; } - - if ((len = readn(netfd, b->buf, b->hdr.nb)) <= 0 || len != b->hdr.nb) { + + if (n > Bufsize) { + dmessage(1, "fromnet; message too big %d > %d\n", n, Bufsize); + break; + } + + len = readn(netfd, b->buf, n); + if (len <= 0 || len != n) { if (len == 0) dmessage(1, "fromnet; network closed\n"); else @@ -308,28 +334,25 @@ fromnet(void*) continue; } - if (b->hdr.msg < inmsg) { - dmessage(1, "fromnet; skipping message %d, currently at %d\n", - b->hdr.msg, inmsg); + if (m < inmsg) { + dmessage(1, "fromnet; skipping message %d, currently at %d\n", m, inmsg); continue; } // Process the acked list. - acked = b->hdr.acked - lastacked; - for (i = 0; i != acked; i++) { + while(lastacked != acked) { Buf *rb; rb = recvp(unacked); - if (rb->hdr.msg != lastacked + i) { - dmessage(1, "rb %p, msg %d, lastacked %d, i %d\n", - rb, rb? rb->hdr.msg: -2, lastacked, i); - assert(0); + m = GBIT32(rb->hdr.msg); + if (m != lastacked) { + dmessage(1, "fromnet; rb %p, msg %d, lastacked %d\n", rb, m, lastacked); + sysfatal("fromnet; bug"); } - rb->hdr.msg = -1; + PBIT32(rb->hdr.msg, -1); sendp(empty, rb); + lastacked++; } - lastacked = b->hdr.acked; - inmsg++; showmsg(1, "fromnet", b); @@ -341,13 +364,14 @@ fromnet(void*) } static void -reconnect(void) +reconnect(int secs) { char ldir[40]; int lcfd, fd; if (dialstring) { syslog(0, Logname, "dialing %s", dialstring); + alarm(secs*1000); while ((fd = dial(dialstring, nil, nil, nil)) < 0) { char err[32]; @@ -360,16 +384,16 @@ reconnect(void) dmessage(1, "reconnect: dialed %s; %s\n", dialstring, err); sleep(1000); } + alarm(0); syslog(0, Logname, "reconnected to %s", dialstring); } else { Endpoints *ep; syslog(0, Logname, "waiting for connection on %s", devdir); - alarm(maxto * 1000); + alarm(secs*1000); if ((lcfd = listen(devdir, ldir)) < 0) sysfatal("reconnect; cannot listen; %r"); - if ((fd = accept(lcfd, ldir)) < 0) sysfatal("reconnect; cannot accept; %r"); alarm(0); @@ -381,7 +405,8 @@ reconnect(void) freeendpoints(ep); } - netfd = fd; // Wakes up the netreader. + // Wakes up the netreader. + netfd = fd; } static void @@ -389,6 +414,7 @@ synchronize(void) { Channel *tmp; Buf *b; + int n; // Ignore network errors here. If we fail during // synchronization, the next alarm will pick up @@ -396,7 +422,9 @@ synchronize(void) tmp = chancreate(sizeof(Buf *), Nbuf); while ((b = nbrecvp(unacked)) != nil) { - writen(netfd, (uchar *)b, sizeof(Hdr) + b->hdr.nb); + n = GBIT32(b->hdr.nb); + writen(netfd, (uchar *)&b->hdr, sizeof(Hdr)); + writen(netfd, b->buf, n); sendp(tmp, b); } chanfree(unacked); @@ -406,14 +434,14 @@ synchronize(void) static void showmsg(int level, char *s, Buf *b) { + int n; + if (b == nil) { dmessage(level, "%s; b == nil\n", s); return; } - - dmessage(level, - "%s; (len %d) %X %X %X %X %X %X %X %X %X (%p)\n", s, - b->hdr.nb, + n = GBIT32(b->hdr.nb); + dmessage(level, "%s; (len %d) %X %X %X %X %X %X %X %X %X (%p)\n", s, n, b->buf[0], b->buf[1], b->buf[2], b->buf[3], b->buf[4], b->buf[5], b->buf[6], b->buf[7], b->buf[8], b); @@ -483,16 +511,16 @@ getendpoint(char *dir, char *file, char **sysp, char **servp) serv = strchr(buf, '!'); if(serv){ *serv++ = 0; - serv = strdup(serv); + serv = estrdup(serv); } - sys = strdup(buf); + sys = estrdup(buf); } close(fd); } if(serv == 0) - serv = strdup("unknown"); + serv = estrdup("unknown"); if(sys == 0) - sys = strdup("unknown"); + sys = estrdup("unknown"); *servp = serv; *sysp = sys; } @@ -502,7 +530,7 @@ getendpoints(char *dir) { Endpoints *ep; - ep = malloc(sizeof(*ep)); + ep = emalloc(sizeof(*ep)); getendpoint(dir, "local", &ep->lsys, &ep->lserv); getendpoint(dir, "remote", &ep->rsys, &ep->rserv); return ep; @@ -517,4 +545,3 @@ freeendpoints(Endpoints *ep) free(ep->rserv); free(ep); } -