From 25c0505bf5484fa241f9519af79375456bc7b16a Mon Sep 17 00:00:00 2001 From: Nate Choe Date: Fri, 1 Apr 2022 19:58:50 -0500 Subject: [PATCH] Got it working for http --- Makefile | 2 +- example/sitefile | 2 +- src/main.c | 165 +++++++++++++++++++++++++++++++------------- src/runner.c | 39 +++++++---- src/sockets.c | 36 +++++++--- src/swebs/runner.h | 21 +----- src/swebs/sockets.h | 2 + src/swebs/util.h | 3 + src/util.c | 93 ++++++++++++++++--------- 9 files changed, 238 insertions(+), 125 deletions(-) diff --git a/Makefile b/Makefile index b28a0cb..f2d8ae0 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ SRC = $(wildcard src/*.c) OBJ = $(subst .c,.o,$(subst src,work,$(SRC))) LIBS = -pthread -pie -lrt -ldl $(shell pkg-config --libs gnutls) -CFLAGS := -O2 -pipe -Wall -Wpedantic -ansi +CFLAGS := -O2 -pipe -Wall -Wpedantic -Wshadow -ansi -ggdb CFLAGS += -Isrc/ -fpie -D_POSIX_C_SOURCE=200809L $(shell pkg-config --cflags gnutls) INSTALLDIR := /usr/sbin HEADERDIR := /usr/include/ diff --git a/example/sitefile b/example/sitefile index d8d93b2..2380c2b 100644 --- a/example/sitefile +++ b/example/sitefile @@ -1,6 +1,6 @@ define port 8000 -#define transport TLS +define transport TLS define key domain.key define cert domain.crt define timeout 2000 diff --git a/src/main.c b/src/main.c index 4dcbafd..ab30176 100644 --- a/src/main.c +++ b/src/main.c @@ -20,10 +20,13 @@ #include #include +#include #include #include #include +#include #include +#include #include #include @@ -31,15 +34,91 @@ #include #include -int main(int argc, char **argv) { - Sitefile *site; - Listener *listener; - int processes; +typedef struct { + pid_t pid; + int fd; +} Runner; - int *pending, pendingid, (*notify)[2]; - pthread_t *threads; +static Runner *runners; +static int processes; +static int mainfd; +static int *pending; +static Listener *listener; +static Sitefile *site; +static struct sockaddr_un addr; +/* We want to be able to handle a signal at any time, so some global variables + * are needed. */ +static void createProcess(int id) { + pid_t pid; + int connfd; + socklen_t addrlen; + + createLog("Creating a new process"); + pending[id] = 0; + + pid = fork(); + switch (pid) { + case -1: + createLog("fork() failed"); + exit(EXIT_FAILURE); + case 0: + break; + default: + addrlen = sizeof(addr); + runners[id].pid = pid; + runners[id].fd = accept(mainfd, + (struct sockaddr *) &addr, &addrlen); + return; + } + + connfd = socket(AF_UNIX, SOCK_STREAM, 0); + if (connfd < 0) + exit(EXIT_FAILURE); + if (connect(connfd, (struct sockaddr *) &addr, sizeof(addr))) { + createLog("connect() failed, killing child"); + exit(EXIT_FAILURE); + } + close(mainfd); + runServer(connfd, site, listener, pending, id); + createLog("child runServer() finished"); + exit(EXIT_SUCCESS); +} + +static void remakeChild(int signal) { + pid_t pid; int i; + pid = wait(NULL); + createLog("A child has died, recreating"); + for (i = 0; i < processes - 1; i++) { + if (runners[i].pid == pid) { + close(runners[i].fd); + createProcess(i); + return; + } + } +} + +static void exitClean(int signal) { + freeListener(listener); + close(mainfd); + remove(addr.sun_path); + exit(EXIT_SUCCESS); +} + +static void setsignal(int signal, void (*handler)(int)) { + struct sigaction action; + sigset_t sigset; + sigemptyset(&sigset); + action.sa_handler = handler; + action.sa_mask = sigset; + action.sa_flags = SA_NODEFER; + sigaction(SIGCHLD, &action, NULL); +} + +int main(int argc, char **argv) { + int i; + int pendingid; setup(argc, argv, &site, &listener, &processes); @@ -48,63 +127,53 @@ int main(int argc, char **argv) { createLog("smalloc() failed"); exit(EXIT_FAILURE); } - pending = saddr(pendingid);; + pending = saddr(pendingid); if (pending == NULL) { createLog("saddr() failed"); exit(EXIT_FAILURE); } memset(pending, 0, sizeof(int) * (processes - 1)); - notify = malloc(sizeof(int[2]) * (processes - 1)); - if (notify == NULL) { - createLog("malloc() failed"); - exit(EXIT_FAILURE); - } + mainfd = socket(AF_UNIX, SOCK_STREAM, 0); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, tmpnam(NULL), sizeof(addr.sun_path) - 1); + /* I know that tmpname is deprecated, I think this usage is safe + * though. */ + addr.sun_path[sizeof(addr.sun_path) - 1] = '\0'; + bind(mainfd, (struct sockaddr *) &addr, sizeof(addr)); + listen(mainfd, processes); - threads = malloc(sizeof(pthread_t) * (processes - 1)); - if (threads == NULL) { - createLog("malloc() failed"); - exit(EXIT_FAILURE); - } + runners = malloc(sizeof(Runner) * (processes - 1)); + for (i = 0; i < processes - 1; i++) + createProcess(i); - for (i = 0; i < processes - 1; i++) { - RunnerArgs *args = malloc(sizeof(RunnerArgs)); - if (args == NULL) { - createLog("malloc() failed"); - exit(EXIT_FAILURE); - } - if (pipe(notify[i])) { - createLog("pipe() failed"); - exit(EXIT_FAILURE); - } - args->site = site; - args->pendingid = pendingid; - args->notify = notify[i][0]; - args->id = i; - pthread_create(threads + i, NULL, - (void*(*)(void*)) runServer, args); - } - - signal(SIGPIPE, SIG_IGN); + setsignal(SIGPIPE, SIG_IGN); + setsignal(SIGKILL, exitClean); + setsignal(SIGINT, exitClean); + setsignal(SIGCHLD, remakeChild); createLog("swebs started"); for (;;) { - Stream *stream = acceptStream(listener, O_NONBLOCK); - int lowestThread; - createLog("Accepted stream"); - if (stream == NULL) { - createLog("Accepting a stream failed"); + int fd; + int lowestProc; + + fd = accept(listener->fd, (struct sockaddr *) &listener->addr, + &listener->addrlen); + if (fd < 0) { + if (errno == ENOTSOCK || errno == EOPNOTSUPP || + errno == EINVAL) { + createLog("You've majorly screwed up"); + exit(EXIT_FAILURE); + } continue; } + createLog("Accepted stream"); - lowestThread = 0; + lowestProc = 0; for (i = 1; i < processes - 1; i++) - if (pending[i] < pending[lowestThread]) - lowestThread = i; - if (write(notify[lowestThread][1], &stream, sizeof(&stream)) - < sizeof(&stream)) - continue; + if (pending[i] < pending[lowestProc]) + lowestProc = i; + sendFd(fd, runners[lowestProc].fd); } - } diff --git a/src/runner.c b/src/runner.c index 0544e31..b5f5959 100644 --- a/src/runner.c +++ b/src/runner.c @@ -15,12 +15,14 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . */ +#include #include #include #include #include #include +#include #include #include @@ -28,12 +30,8 @@ #include #include -void *runServer(RunnerArgs *args) { - Sitefile *site = args->site; - int *pending = saddr(args->pendingid); - int notify = args->notify; - int id = args->id; - +void runServer(int connfd, Sitefile *site, Listener *listener, + int *pending, int id) { int allocConns = 100; struct pollfd *fds = malloc(sizeof(struct pollfd) * allocConns); Connection *connections = malloc(sizeof(Connection) * allocConns); @@ -41,7 +39,7 @@ void *runServer(RunnerArgs *args) { /* connections are 1 indexed because fds[0] is the notify fd. */ assert(fds != NULL); assert(connections != NULL); - fds[0].fd = notify; + fds[0].fd = connfd; fds[0].events = POLLIN; for (;;) { @@ -71,6 +69,21 @@ remove: if (fds[0].revents & POLLIN) { Stream *newstream; + int newfd; + newfd = recvFd(connfd); + if (newfd < 0) { + createLog("Message received that included an invalid fd"); + continue; + } + newstream = createStream(listener, O_NONBLOCK, newfd); + + createLog("Obtained file descriptor from child"); + if (newstream == NULL) { + createLog("Stream couldn't be created from file descriptor"); + close(newfd); + continue; + } + if (connCount >= allocConns) { struct pollfd *newfds; Connection *newconns; @@ -91,17 +104,15 @@ remove: } connections = newconns; } - if (read(notify, &newstream, sizeof(newstream)) - < sizeof(newstream)) - continue; - fds[connCount].fd = newstream->fd; - fds[connCount].events = POLLIN; - if (newConnection(newstream, connections + connCount)) + if (newConnection(newstream, connections + connCount)) { + createLog("Couldn't initialize connection from stream"); continue; + } + fds[connCount].fd = newfd; + fds[connCount].events = POLLIN; connCount++; pending[id]++; } } - return NULL; } diff --git a/src/sockets.c b/src/sockets.c index 33b0180..620702c 100644 --- a/src/sockets.c +++ b/src/sockets.c @@ -25,6 +25,7 @@ #include #include +#include #include int initTLS() { @@ -34,12 +35,20 @@ int initTLS() { Listener *createListener(SocketType type, unsigned short port, int backlog, ...) { + int shmid; Listener *ret = malloc(sizeof(Listener)); va_list ap; - if (ret == NULL) + shmid = smalloc(sizeof(Listener)); + if (shmid < 0) return NULL; + ret = saddr(shmid); + if (ret == NULL) { + sdestroy(shmid); + return NULL; + } ret->type = type; ret->fd = socket(AF_INET, SOCK_STREAM, 0); + ret->shmid = shmid; if (ret->fd < 0) { free(ret); return NULL; @@ -89,22 +98,17 @@ Listener *createListener(SocketType type, unsigned short port, return ret; error: close(ret->fd); - free(ret); + sfree(ret); + sdestroy(shmid); return NULL; } -Stream *acceptStream(Listener *listener, int flags) { +Stream *createStream(Listener *listener, int flags, int fd) { Stream *ret = malloc(sizeof(Stream)); if (ret == NULL) return NULL; ret->type = listener->type; - ret->fd = accept(listener->fd, (struct sockaddr *) &listener->addr, - &listener->addrlen); - - if (ret->fd < 0) { - free(ret); - return NULL; - } + ret->fd = fd; { int oldflags = fcntl(ret->fd, F_GETFL); @@ -140,13 +144,23 @@ error: return NULL; } +Stream *acceptStream(Listener *listener, int flags) { + int fd; + fd = accept(listener->fd, (struct sockaddr *) &listener->addr, + &listener->addrlen); + return createStream(listener, flags, fd); +} + void freeListener(Listener *listener) { + int shmid; if (listener->type == TLS) { gnutls_certificate_free_credentials(listener->creds); gnutls_priority_deinit(listener->priority); } close(listener->fd); - free(listener); + shmid = listener->shmid; + sfree(listener); + sdestroy(shmid); } void freeStream(Stream *stream) { diff --git a/src/swebs/runner.h b/src/swebs/runner.h index 0f6e5c7..77f21f4 100644 --- a/src/swebs/runner.h +++ b/src/swebs/runner.h @@ -19,25 +19,10 @@ #define HAVE_RUNNER #include +#include #include #include -typedef struct { - Sitefile *site; - int pendingid; - /* int *pending */ - /* - * pending[thread id] = the number of connections being handled by that - * thread - * */ - int notify; - /* - * When this runner should accept a connection, notify will contain an - * int ready to be read. notify is an fd - * */ - int id; -} RunnerArgs; -/* my least favourite anti pattern */ - -void *runServer(RunnerArgs *args); +void runServer(int connfd, Sitefile *site, Listener *listener, + int *pending, int id); #endif diff --git a/src/swebs/sockets.h b/src/swebs/sockets.h index 4c0468a..afdf484 100644 --- a/src/swebs/sockets.h +++ b/src/swebs/sockets.h @@ -27,6 +27,7 @@ typedef struct { SocketType type; int fd; + int shmid; struct sockaddr_in addr; socklen_t addrlen; gnutls_certificate_credentials_t creds; @@ -46,6 +47,7 @@ Listener *createListener(SocketType type, uint16_t port, int backlog, ...); * tcp: (void) * tls: (char *keyfile, char *certfile, char *ocspfile) * */ +Stream *createStream(Listener *listener, int flags, int fd); Stream *acceptStream(Listener *listener, int flags); /* returns 1 on error, accepts fcntl flags */ diff --git a/src/swebs/util.h b/src/swebs/util.h index e100b61..66435ff 100644 --- a/src/swebs/util.h +++ b/src/swebs/util.h @@ -32,4 +32,7 @@ int createLog(char *msg); int istrcmp(char *s1, char *s2); /* case insensitive strcmp */ RequestType getType(char *str); + +void sendFd(int fd, int dest); +int recvFd(int source); #endif diff --git a/src/util.c b/src/util.c index 6a04ec8..8818957 100644 --- a/src/util.c +++ b/src/util.c @@ -22,6 +22,7 @@ #include #include +#include #include #include @@ -39,7 +40,11 @@ int smalloc(size_t size) { } void *saddr(int id) { - return shmat(id, NULL, 0); + void *addr; + addr = shmat(id, NULL, 0); + if (addr == (void *) -1) + return NULL; + return addr; } void sfree(void *addr) { @@ -82,35 +87,59 @@ int istrcmp(char *s1, char *s2) { } RequestType getType(char *str) { - unsigned long type; - int i; - if (strlen(str) >= 8) - return INVALID; - type = 0; - for (i = 0; str[i]; i++) { - type <<= 8; - type |= str[i]; - } - switch (type) { - case 0x474554l: - return GET; - case 0x504f5354l: - return POST; - case 0x505554l: - return PUT; - case 0x48454144l: - return HEAD; - case 0x44454c455445l: - return DELETE; - case 0x5041544348l: - return PATCH; - case 0x4f5054494f4e53l: - return OPTIONS; - default: - return INVALID; - } - /* - * This would actually be far nicer in HolyC of all languages. I feel - * like the context immediately following each magic number is enough. - * */ + if (strcmp(str, "GET") == 0) + return GET; + if (strcmp(str, "POST") == 0) + return POST; + if (strcmp(str, "PUT") == 0) + return PUT; + if (strcmp(str, "HEAD") == 0) + return HEAD; + if (strcmp(str, "DELETE") == 0) + return DELETE; + if (strcmp(str, "PATCH") == 0) + return PATCH; + if (strcmp(str, "OPTIONS") == 0) + return OPTIONS; + return INVALID; +} + +void sendFd(int fd, int dest) { + struct msghdr msg; + struct cmsghdr *cmsg; + char iobuf[1]; + struct iovec io; + union { + char buf[CMSG_SPACE(sizeof(fd))]; + struct cmsghdr align; + } u; + memset(&msg, 0, sizeof(msg)); + io.iov_base = iobuf; + io.iov_len = sizeof(iobuf); + msg.msg_iov = &io; + msg.msg_iovlen = 1; + msg.msg_control = u.buf; + msg.msg_controllen = sizeof(u.buf); + cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + cmsg->cmsg_len = CMSG_LEN(sizeof(fd)); + memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd)); + sendmsg(dest, &msg, 0); +} + +int recvFd(int source) { + struct msghdr msg; + struct cmsghdr *cmsg; + char cmsgbuf[CMSG_SPACE(sizeof(int))]; + unsigned char *data; + int ret; + memset(&msg, 0, sizeof(msg)); + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + recvmsg(source, &msg, 0); + cmsg = CMSG_FIRSTHDR(&msg); + data = CMSG_DATA(cmsg); + memcpy(&ret, data, sizeof(ret)); + return ret; }