/*
swebs - a simple web server
Copyright (C) 2022 Nate Choe
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
*/
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
int initTLS() {
assert(gnutls_global_init() >= 0);
return 0;
}
Listener *createListener(SocketType type, uint16_t port, int backlog, ...) {
Listener *ret = malloc(sizeof(Listener));
if (ret == NULL)
return NULL;
ret->type = type;
ret->fd = socket(AF_INET, SOCK_STREAM, 0);
if (ret->fd < 0) {
free(ret);
return NULL;
}
int opt = 1;
if (setsockopt(ret->fd, SOL_SOCKET,
SO_REUSEPORT,
&opt, sizeof(opt)) < 0) {
goto error;
}
ret->addr.sin_family = AF_INET;
ret->addr.sin_addr.s_addr = INADDR_ANY;
ret->addr.sin_port = htons(port);
ret->addrlen = sizeof(ret->addr);
if (bind(ret->fd, (struct sockaddr *) &ret->addr, ret->addrlen) < 0)
goto error;
if (listen(ret->fd, backlog) < 0)
goto error;
va_list ap;
va_start(ap, backlog);
switch (type) {
case TCP: default:
break;
case TLS: {
char *keyfile = va_arg(ap, char *);
char *certfile = va_arg(ap, char *);
if (gnutls_certificate_allocate_credentials(&ret->creds)
< 0)
goto error;
if (gnutls_certificate_set_x509_key_file(ret->creds,
certfile, keyfile,
GNUTLS_X509_FMT_PEM) < 0)
goto error;
if (gnutls_priority_init(&ret->priority,
NULL, NULL) < 0)
goto error;
#if GNUTLS_VERSION_NUMBER >= 0x030506
gnutls_certificate_set_known_dh_params(ret->creds,
GNUTLS_SEC_PARAM_MEDIUM);
#endif
break;
}
}
va_end(ap);
return ret;
error:
close(ret->fd);
free(ret);
return NULL;
}
Stream *acceptStream(Listener *listener) {
Stream *ret = malloc(sizeof(Stream));
if (ret == NULL)
return NULL;
ret->type = listener->type;
ret->fd = accept(listener->fd, (struct sockaddr *) &listener->addr,
&listener->addrlen);
if (ret->fd < 0) {
free(ret);
return NULL;
}
switch (listener->type) {
case TCP: default:
break;
case TLS:
if (gnutls_init(&ret->session, GNUTLS_SERVER) < 0)
goto error;
if (gnutls_priority_set(ret->session,
listener->priority) < 0)
goto error;
if (gnutls_credentials_set(ret->session,
GNUTLS_CRD_CERTIFICATE,
listener->creds) < 0)
goto error;
gnutls_certificate_server_set_request(ret->session,
GNUTLS_CERT_IGNORE);
gnutls_handshake_set_timeout(ret->session,
GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
gnutls_transport_set_int(ret->session, ret->fd);
if (gnutls_handshake(ret->session) < 0)
goto error;
break;
}
return ret;
error:
close(ret->fd);
free(ret);
return NULL;
}
void freeListener(Listener *listener) {
if (listener->type == TLS) {
gnutls_certificate_free_credentials(listener->creds);
gnutls_priority_deinit(listener->priority);
}
close(listener->fd);
free(listener);
}
void freeStream(Stream *stream) {
if (stream->type == TLS) {
gnutls_bye(stream->session, GNUTLS_SHUT_RDWR);
gnutls_deinit(stream->session);
}
close(stream->fd);
free(stream);
}
ssize_t sendStream(Stream *stream, void *data, size_t len) {
switch (stream->type) {
case TCP:
return write(stream->fd, data, len);
case TLS:
return gnutls_record_send(stream->session, data, len);
default:
return -1;
}
}
ssize_t recvStream(Stream *stream, void *data, size_t len) {
switch (stream->type) {
case TCP:
return read(stream->fd, data, len);
case TLS:
return gnutls_record_recv(stream->session, data, len);
default:
return -1;
}
}