From 12281a030dd0831be004e57b36f022e9096fcaf1 Mon Sep 17 00:00:00 2001 From: Nate Choe Date: Sun, 3 Apr 2022 12:04:32 -0500 Subject: [PATCH] Reworked sockets library --- Makefile | 4 +-- src/runner.c | 16 ++++++++- src/setup.c | 11 +----- src/sockets.c | 86 ++++++++++++++++++++++----------------------- src/swebs/sockets.h | 22 +++++++----- 5 files changed, 75 insertions(+), 64 deletions(-) diff --git a/Makefile b/Makefile index f2d8ae0..5c4cc2e 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ SRC = $(wildcard src/*.c) OBJ = $(subst .c,.o,$(subst src,work,$(SRC))) -LIBS = -pthread -pie -lrt -ldl $(shell pkg-config --libs gnutls) -CFLAGS := -O2 -pipe -Wall -Wpedantic -Wshadow -ansi -ggdb +LIBS = -pie -lrt -ldl $(shell pkg-config --libs gnutls) +CFLAGS := -O2 -pipe -Wall -Wpedantic -Wshadow -ansi CFLAGS += -Isrc/ -fpie -D_POSIX_C_SOURCE=200809L $(shell pkg-config --cflags gnutls) INSTALLDIR := /usr/sbin HEADERDIR := /usr/include/ diff --git a/src/runner.c b/src/runner.c index b5f5959..929d181 100644 --- a/src/runner.c +++ b/src/runner.c @@ -37,11 +37,24 @@ void runServer(int connfd, Sitefile *site, Listener *listener, Connection *connections = malloc(sizeof(Connection) * allocConns); int connCount = 1; /* connections are 1 indexed because fds[0] is the notify fd. */ + Context *context; assert(fds != NULL); assert(connections != NULL); fds[0].fd = connfd; fds[0].events = POLLIN; + switch (site->type) { + case TCP: + context = createContext(TCP); + break; + case TLS: + context = createContext(TLS, site->key, site->cert); + break; + default: + createLog("Socket type is somehow invalid"); + return; + } + for (;;) { int i; poll(fds, connCount, -1); @@ -75,9 +88,10 @@ remove: createLog("Message received that included an invalid fd"); continue; } - newstream = createStream(listener, O_NONBLOCK, newfd); createLog("Obtained file descriptor from child"); + + newstream = createStream(context, O_NONBLOCK, newfd); if (newstream == NULL) { createLog("Stream couldn't be created from file descriptor"); close(newfd); diff --git a/src/setup.c b/src/setup.c index 6a8aac7..76a9daf 100644 --- a/src/setup.c +++ b/src/setup.c @@ -139,16 +139,7 @@ NULL exit(EXIT_FAILURE); } - switch ((*site)->type) { - case TCP: default: - *listener = createListener(TCP, (*site)->port, backlog); - break; - case TLS: - initTLS(); - *listener = createListener(TLS, (*site)->port, backlog, - (*site)->key, (*site)->cert); - break; - } + *listener = createListener((*site)->port, backlog); if (listener == NULL) { fprintf(stderr, "Failed to create socket\n"); exit(EXIT_FAILURE); diff --git a/src/sockets.c b/src/sockets.c index 620702c..28afe1d 100644 --- a/src/sockets.c +++ b/src/sockets.c @@ -33,22 +33,11 @@ int initTLS() { return 0; } -Listener *createListener(SocketType type, unsigned short port, - int backlog, ...) { - int shmid; +Listener *createListener(unsigned short port, int backlog) { Listener *ret = malloc(sizeof(Listener)); - va_list ap; - shmid = smalloc(sizeof(Listener)); - if (shmid < 0) + if (ret == NULL) return NULL; - ret = saddr(shmid); - if (ret == NULL) { - sdestroy(shmid); - return NULL; - } - ret->type = type; ret->fd = socket(AF_INET, SOCK_STREAM, 0); - ret->shmid = shmid; if (ret->fd < 0) { free(ret); return NULL; @@ -69,23 +58,38 @@ Listener *createListener(SocketType type, unsigned short port, goto error; if (listen(ret->fd, backlog) < 0) goto error; + return ret; +error: + close(ret->fd); + free(ret); + return NULL; +} - va_start(ap, backlog); +Context *createContext(SocketType type, ...) { + Context *ret; + va_list ap; + ret = malloc(sizeof(Context)); + if (ret == NULL) + return NULL; + va_start(ap, type); + ret->type = type; switch (type) { - case TCP: default: + case TCP: break; case TLS: { - char *keyfile = va_arg(ap, char *); - char *certfile = va_arg(ap, char *); + char *keyfile, *certfile; + keyfile = va_arg(ap, char *); + certfile = va_arg(ap, char *); + if (gnutls_certificate_allocate_credentials(&ret->creds) - < 0) + < 0) goto error; if (gnutls_certificate_set_x509_key_file(ret->creds, certfile, keyfile, GNUTLS_X509_FMT_PEM) < 0) goto error; - if (gnutls_priority_init(&ret->priority, - NULL, NULL) < 0) + if (gnutls_priority_init(&ret->priority, NULL, NULL) + < 0) goto error; #if GNUTLS_VERSION_NUMBER >= 0x030506 gnutls_certificate_set_known_dh_params(ret->creds, @@ -97,17 +101,21 @@ Listener *createListener(SocketType type, unsigned short port, va_end(ap); return ret; error: - close(ret->fd); - sfree(ret); - sdestroy(shmid); + free(ret); return NULL; } -Stream *createStream(Listener *listener, int flags, int fd) { +int acceptConnection(Listener *listener) { + return accept(listener->fd, + (struct sockaddr *) &listener->addr, + &listener->addrlen); +} + +Stream *createStream(Context *context, int flags, int fd) { Stream *ret = malloc(sizeof(Stream)); if (ret == NULL) return NULL; - ret->type = listener->type; + ret->type = context->type; ret->fd = fd; { @@ -115,18 +123,18 @@ Stream *createStream(Listener *listener, int flags, int fd) { fcntl(ret->fd, F_SETFL, oldflags | flags); } - switch (listener->type) { + switch (context->type) { case TCP: default: break; case TLS: if (gnutls_init(&ret->session, GNUTLS_SERVER) < 0) goto error; if (gnutls_priority_set(ret->session, - listener->priority) < 0) + context->priority) < 0) goto error; if (gnutls_credentials_set(ret->session, GNUTLS_CRD_CERTIFICATE, - listener->creds) < 0) + context->creds) < 0) goto error; gnutls_certificate_server_set_request(ret->session, GNUTLS_CERT_IGNORE); @@ -144,23 +152,15 @@ error: return NULL; } -Stream *acceptStream(Listener *listener, int flags) { - int fd; - fd = accept(listener->fd, (struct sockaddr *) &listener->addr, - &listener->addrlen); - return createStream(listener, flags, fd); +void freeListener(Listener *listener) { + close(listener->fd); + free(listener); } -void freeListener(Listener *listener) { - int shmid; - if (listener->type == TLS) { - gnutls_certificate_free_credentials(listener->creds); - gnutls_priority_deinit(listener->priority); - } - close(listener->fd); - shmid = listener->shmid; - sfree(listener); - sdestroy(shmid); +void freeContext(Context *context) { + gnutls_certificate_free_credentials(context->creds); + gnutls_priority_deinit(context->priority); + free(context); } void freeStream(Stream *stream) { diff --git a/src/swebs/sockets.h b/src/swebs/sockets.h index afdf484..8d09952 100644 --- a/src/swebs/sockets.h +++ b/src/swebs/sockets.h @@ -25,14 +25,17 @@ #include typedef struct { - SocketType type; int fd; - int shmid; struct sockaddr_in addr; socklen_t addrlen; +} Listener; + +typedef struct { + SocketType type; gnutls_certificate_credentials_t creds; gnutls_priority_t priority; -} Listener; + /* creds and priority are only used in TLS structs. */ +} Context; typedef struct { SocketType type; @@ -41,17 +44,20 @@ typedef struct { } Stream; int initTLS(); -Listener *createListener(SocketType type, uint16_t port, int backlog, ...); +Listener *createListener(uint16_t port, int backlog); +Context *createContext(SocketType type, ...); /* * extra arguments depend on type (similar to fcntl): * tcp: (void) - * tls: (char *keyfile, char *certfile, char *ocspfile) + * tls: (char *keyfile, char *certfile) * */ -Stream *createStream(Listener *listener, int flags, int fd); -Stream *acceptStream(Listener *listener, int flags); -/* returns 1 on error, accepts fcntl flags */ +int acceptConnection(Listener *listener); +/* Returns a file descriptor from the listener */ +Stream *createStream(Context *context, int flags, int fd); +/* flags are fcntl flags */ void freeListener(Listener *listener); +void freeContext(Context *context); void freeStream(Stream *stream); ssize_t sendStream(Stream *stream, const void *data, size_t len);