Reworked sockets library

This commit is contained in:
Nate Choe
2022-04-03 12:04:32 -05:00
parent 88d52ca830
commit 12281a030d
5 changed files with 75 additions and 64 deletions

View File

@@ -1,7 +1,7 @@
SRC = $(wildcard src/*.c)
OBJ = $(subst .c,.o,$(subst src,work,$(SRC)))
LIBS = -pthread -pie -lrt -ldl $(shell pkg-config --libs gnutls)
CFLAGS := -O2 -pipe -Wall -Wpedantic -Wshadow -ansi -ggdb
LIBS = -pie -lrt -ldl $(shell pkg-config --libs gnutls)
CFLAGS := -O2 -pipe -Wall -Wpedantic -Wshadow -ansi
CFLAGS += -Isrc/ -fpie -D_POSIX_C_SOURCE=200809L $(shell pkg-config --cflags gnutls)
INSTALLDIR := /usr/sbin
HEADERDIR := /usr/include/

View File

@@ -37,11 +37,24 @@ void runServer(int connfd, Sitefile *site, Listener *listener,
Connection *connections = malloc(sizeof(Connection) * allocConns);
int connCount = 1;
/* connections are 1 indexed because fds[0] is the notify fd. */
Context *context;
assert(fds != NULL);
assert(connections != NULL);
fds[0].fd = connfd;
fds[0].events = POLLIN;
switch (site->type) {
case TCP:
context = createContext(TCP);
break;
case TLS:
context = createContext(TLS, site->key, site->cert);
break;
default:
createLog("Socket type is somehow invalid");
return;
}
for (;;) {
int i;
poll(fds, connCount, -1);
@@ -75,9 +88,10 @@ remove:
createLog("Message received that included an invalid fd");
continue;
}
newstream = createStream(listener, O_NONBLOCK, newfd);
createLog("Obtained file descriptor from child");
newstream = createStream(context, O_NONBLOCK, newfd);
if (newstream == NULL) {
createLog("Stream couldn't be created from file descriptor");
close(newfd);

View File

@@ -139,16 +139,7 @@ NULL
exit(EXIT_FAILURE);
}
switch ((*site)->type) {
case TCP: default:
*listener = createListener(TCP, (*site)->port, backlog);
break;
case TLS:
initTLS();
*listener = createListener(TLS, (*site)->port, backlog,
(*site)->key, (*site)->cert);
break;
}
*listener = createListener((*site)->port, backlog);
if (listener == NULL) {
fprintf(stderr, "Failed to create socket\n");
exit(EXIT_FAILURE);

View File

@@ -33,22 +33,11 @@ int initTLS() {
return 0;
}
Listener *createListener(SocketType type, unsigned short port,
int backlog, ...) {
int shmid;
Listener *createListener(unsigned short port, int backlog) {
Listener *ret = malloc(sizeof(Listener));
va_list ap;
shmid = smalloc(sizeof(Listener));
if (shmid < 0)
if (ret == NULL)
return NULL;
ret = saddr(shmid);
if (ret == NULL) {
sdestroy(shmid);
return NULL;
}
ret->type = type;
ret->fd = socket(AF_INET, SOCK_STREAM, 0);
ret->shmid = shmid;
if (ret->fd < 0) {
free(ret);
return NULL;
@@ -69,23 +58,38 @@ Listener *createListener(SocketType type, unsigned short port,
goto error;
if (listen(ret->fd, backlog) < 0)
goto error;
return ret;
error:
close(ret->fd);
free(ret);
return NULL;
}
va_start(ap, backlog);
Context *createContext(SocketType type, ...) {
Context *ret;
va_list ap;
ret = malloc(sizeof(Context));
if (ret == NULL)
return NULL;
va_start(ap, type);
ret->type = type;
switch (type) {
case TCP: default:
case TCP:
break;
case TLS: {
char *keyfile = va_arg(ap, char *);
char *certfile = va_arg(ap, char *);
char *keyfile, *certfile;
keyfile = va_arg(ap, char *);
certfile = va_arg(ap, char *);
if (gnutls_certificate_allocate_credentials(&ret->creds)
< 0)
< 0)
goto error;
if (gnutls_certificate_set_x509_key_file(ret->creds,
certfile, keyfile,
GNUTLS_X509_FMT_PEM) < 0)
goto error;
if (gnutls_priority_init(&ret->priority,
NULL, NULL) < 0)
if (gnutls_priority_init(&ret->priority, NULL, NULL)
< 0)
goto error;
#if GNUTLS_VERSION_NUMBER >= 0x030506
gnutls_certificate_set_known_dh_params(ret->creds,
@@ -97,17 +101,21 @@ Listener *createListener(SocketType type, unsigned short port,
va_end(ap);
return ret;
error:
close(ret->fd);
sfree(ret);
sdestroy(shmid);
free(ret);
return NULL;
}
Stream *createStream(Listener *listener, int flags, int fd) {
int acceptConnection(Listener *listener) {
return accept(listener->fd,
(struct sockaddr *) &listener->addr,
&listener->addrlen);
}
Stream *createStream(Context *context, int flags, int fd) {
Stream *ret = malloc(sizeof(Stream));
if (ret == NULL)
return NULL;
ret->type = listener->type;
ret->type = context->type;
ret->fd = fd;
{
@@ -115,18 +123,18 @@ Stream *createStream(Listener *listener, int flags, int fd) {
fcntl(ret->fd, F_SETFL, oldflags | flags);
}
switch (listener->type) {
switch (context->type) {
case TCP: default:
break;
case TLS:
if (gnutls_init(&ret->session, GNUTLS_SERVER) < 0)
goto error;
if (gnutls_priority_set(ret->session,
listener->priority) < 0)
context->priority) < 0)
goto error;
if (gnutls_credentials_set(ret->session,
GNUTLS_CRD_CERTIFICATE,
listener->creds) < 0)
context->creds) < 0)
goto error;
gnutls_certificate_server_set_request(ret->session,
GNUTLS_CERT_IGNORE);
@@ -144,23 +152,15 @@ error:
return NULL;
}
Stream *acceptStream(Listener *listener, int flags) {
int fd;
fd = accept(listener->fd, (struct sockaddr *) &listener->addr,
&listener->addrlen);
return createStream(listener, flags, fd);
void freeListener(Listener *listener) {
close(listener->fd);
free(listener);
}
void freeListener(Listener *listener) {
int shmid;
if (listener->type == TLS) {
gnutls_certificate_free_credentials(listener->creds);
gnutls_priority_deinit(listener->priority);
}
close(listener->fd);
shmid = listener->shmid;
sfree(listener);
sdestroy(shmid);
void freeContext(Context *context) {
gnutls_certificate_free_credentials(context->creds);
gnutls_priority_deinit(context->priority);
free(context);
}
void freeStream(Stream *stream) {

View File

@@ -25,14 +25,17 @@
#include <swebs/types.h>
typedef struct {
SocketType type;
int fd;
int shmid;
struct sockaddr_in addr;
socklen_t addrlen;
} Listener;
typedef struct {
SocketType type;
gnutls_certificate_credentials_t creds;
gnutls_priority_t priority;
} Listener;
/* creds and priority are only used in TLS structs. */
} Context;
typedef struct {
SocketType type;
@@ -41,17 +44,20 @@ typedef struct {
} Stream;
int initTLS();
Listener *createListener(SocketType type, uint16_t port, int backlog, ...);
Listener *createListener(uint16_t port, int backlog);
Context *createContext(SocketType type, ...);
/*
* extra arguments depend on type (similar to fcntl):
* tcp: (void)
* tls: (char *keyfile, char *certfile, char *ocspfile)
* tls: (char *keyfile, char *certfile)
* */
Stream *createStream(Listener *listener, int flags, int fd);
Stream *acceptStream(Listener *listener, int flags);
/* returns 1 on error, accepts fcntl flags */
int acceptConnection(Listener *listener);
/* Returns a file descriptor from the listener */
Stream *createStream(Context *context, int flags, int fd);
/* flags are fcntl flags */
void freeListener(Listener *listener);
void freeContext(Context *context);
void freeStream(Stream *stream);
ssize_t sendStream(Stream *stream, const void *data, size_t len);