/*
 * Copyright (c) 2015, 2017, Oracle and/or its affiliates. All rights reserved.
 *
 * This software is dual-licensed to you under the MIT License (MIT) and
 * the Universal Permissive License (UPL). See the LICENSE file in the root
 * directory for license terms. You may choose either license, or both.
 */

#include <stdio.h>
#include <unistd.h>
#include <netdb.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <openssl/ssl.h>
#include <netinet/in.h>
#include <errno.h>
#ifdef __CYGWIN__
#include <openssl/err.h>
#endif
#include "iotcs_port_ssl.h"
#include "iotcs_port_thread.h"
#include "iotcs_thread_private.h"
#include "trusted_assets_manager/iotcs_tam.h"
#include "iotcs_hostname_verifier.h"
#include "util/util_thread_private.h"
#include "util/util_memory.h"
#include <string.h>
#include "iotcs_port_system.h"
#include "util/util.h"

#include "log/log.h"
#define IOTCSP_MODULE_LOG_CHANNEL LOG_CHANNEL_PORT_SSL
#include "log/log_template.h"

#ifdef IOTCS_DEBUG
#define MIN_SIZE_OF_READ_DATA 6
#endif
#define PORT_BUFFER_SIZE 6
#ifndef PROXY_BUFFER_SIZE
#define PROXY_BUFFER_SIZE 512
#endif

/* We have one server and its certificates. So we can use one server context for
 * several connections.
 */
static SSL_CTX* ctx = NULL;

#ifdef IOTCS_MESSAGING_THREAD_SAFETY
static iotcs_port_mutex* ssl_mutex_pool = NULL;
#endif

/* Properties that will help create connection to server*/
static int ssl_port = 0;
static const char* user_host_addr = NULL;
struct addrinfo hints, *servinfo, *p;
static char proxy_buffer[PROXY_BUFFER_SIZE];

/* This connection will be used for pushing messages to server and short polling
 * implementation. It requires for its own socket. 
 */
static SSL* ssl = NULL;
static int sockfd = 0;

/* This connection will be used for long polling
 * implementation. It requires for its own socket. 
 */
#ifdef IOTCS_LONG_POLLING
static SSL* lp_ssl = NULL;
static int lp_socketfd = 0;
#endif

/* This connection will be used for SCS
 * implementation. It requires for its own setup. 
 */
#ifdef IOTCS_STORAGE_SUPPORT
static iotcs_bool is_ssl_storage = IOTCS_TRUE;
struct addrinfo scs_hints, *scs_servinfo, *scs_p;
static int storage_server_port = 0;
static char* storage_server_host = NULL;
static SSL* scs_ssl = NULL;
static int scs_socketfd = 0;
#endif
static char *proxy_host = NULL;
static char *proxy_port = NULL;

#ifdef IOTCS_MESSAGING_THREAD_SAFETY

static void locking_function(int mode, int n, const char *file, int line) {
    (void) file;
    (void) line;

    if (mode & CRYPTO_LOCK)
        iotcs_port_mutex_lock(ssl_mutex_pool[n]);
    else
        iotcs_port_mutex_unlock(ssl_mutex_pool[n]);
}

static unsigned long id_function(void) {
    return iotcs_get_thread_id();
}

static iotcs_result ssl_init_locks(void) {
    int i;

    CHECK_OOM(ssl_mutex_pool = (iotcs_port_mutex*) util_calloc(CRYPTO_num_locks(), sizeof (iotcs_port_mutex)));
    for (i = 0; i < CRYPTO_num_locks(); i++) {
        CHECK_OOM(ssl_mutex_pool[i] = iotcs_port_mutex_create());
    }

    CRYPTO_set_id_callback(id_function);
    CRYPTO_set_locking_callback(locking_function);
    return IOTCS_RESULT_OK;
}

static iotcs_result ssl_finit_locks(void) {
    if (ssl_mutex_pool) {
        int i;
        CRYPTO_set_id_callback(NULL);
        CRYPTO_set_locking_callback(NULL);

        for (i = 0; i < CRYPTO_num_locks(); i++) {
            iotcs_port_mutex_destroy(ssl_mutex_pool[i]);
        }

        util_free(ssl_mutex_pool);
        ssl_mutex_pool = NULL;
    }
    return IOTCS_RESULT_OK;
}
#endif

static iotcs_result do_write(int socketfd, char* request, size_t length) {
        int status = send(socketfd, request, length, 0);

        if (status < 0) {
            int errsv = errno;
            LOG_ERR("Socket write failed. %s.", strerror(errsv));
            return IOTCS_RESULT_FAIL;
        }

        return IOTCS_RESULT_OK;
}

static iotcs_result do_ssl_write(SSL* ssl, char* request, size_t length) {
    int ssl_status = SSL_write(ssl, request, length);

    if (ssl_status < 0) {
        LOG_ERR("SSL write failed. %d.", ssl_status);
        return IOTCS_RESULT_FAIL;
    }

    return IOTCS_RESULT_OK;
}

static iotcs_result do_read(int socketfd, char* buffer, int len, int *bytes_read, int is_long_polling, int use_while) {
    int read;
    int tot_read;

    tot_read = 0;
    while (len > 0) {
        if ((read = recv(socketfd, buffer, len, 0)) < 0) {
            int errsv = errno;
            LOG_ERR("Socket read failed. %s.", strerror(errsv));
            return IOTCS_RESULT_FAIL;
        }

        if (read == 0 || !use_while) { /* socket is closed */
            break;
        }

        len -= read;
        buffer += read;
        tot_read += read;
    }

    len -= read;
    buffer += read;
    tot_read += read;

    *bytes_read = tot_read;
    return IOTCS_RESULT_OK;
}

static iotcs_result do_ssl_read(SSL* ssl, char* buffer, int len, int *bytes_read, int is_long_polling) {
    int read;
    int tot_read;

    tot_read = 0;
    while (len > 0) {
        if ((read = SSL_read(ssl, buffer, len)) < 0) {
            int err = SSL_get_error(ssl, read);
            if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
                return IOTCS_RESULT_SSL_WANT_READ;
            }
            LOG_ERR("SSL read failed. %d.", err);
            return IOTCS_RESULT_FAIL;
        }

        if (read == 0) { /* socket is closed */
            break;
        }

        len -= read;
        buffer += read;
        tot_read += read;
    }

    *bytes_read = tot_read;
    return IOTCS_RESULT_OK;
}

iotcs_result iotcs_port_ssl_init(const char* addr, unsigned short port) {
    size_t certs;
    char port_str[PORT_BUFFER_SIZE];
    const char *proxy = iotcs_port_get_http_proxy();

    if (!addr) {
        return IOTCS_RESULT_INVALID_ARGUMENT;
    }

    ssl_port = port;
    user_host_addr = addr;

    SSL_library_init();

#ifdef IOTCS_MESSAGING_THREAD_SAFETY
    if (IOTCS_RESULT_OK != ssl_init_locks()) {
        LOG_ERRS("ssl_thread_setup failed");
        return IOTCS_RESULT_FAIL;
    }
#endif

    memset(&hints, 0, sizeof hints);
    hints.ai_family = AF_INET;
    hints.ai_socktype = SOCK_STREAM;
    util_safe_snprintf(port_str, PORT_BUFFER_SIZE, "%d", (int) port);

    if (proxy) {
        if (!proxy_host) {
            if (strncmp(proxy, "http://", strlen("http://")) == 0) {
                proxy_host = util_safe_strcpy(&proxy[strlen("http://") + 1]);
            } else {
                proxy_host = util_safe_strcpy(proxy);
            }
            proxy_port = strstr(proxy_host, ":");
            if (proxy_port) {
                *proxy_port = 0;
                proxy_port++;
            } else {
                proxy_port = util_safe_strcpy("8080");
            }
        }

        if (getaddrinfo(proxy_host, proxy_port, &hints, &servinfo) != 0) {
            LOG_ERR("Invalid host name: %s", proxy_host);
            return IOTCS_RESULT_FAIL;
        }
    } else {
        if (getaddrinfo(addr, port_str, &hints, &servinfo) != 0) {
            LOG_ERR("Invalid host name: %s", addr);
            return IOTCS_RESULT_FAIL;
        }
    }

    if (!ctx) {
        ctx = SSL_CTX_new(SSLv23_client_method());
    }

    SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2);
    SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv3);

    SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL);


    if (!ssl) {
        ssl = SSL_new(ctx);
    }
#ifdef IOTCS_LONG_POLLING
    if (!lp_ssl) {
        lp_ssl = SSL_new(ctx);
    }
#endif

    if (0 == (certs = tam_get_trust_anchors_count())) {
        if (SSL_CTX_set_default_verify_paths(ctx) == 0) {
            LOG_ERRS("SSL_CTX_set_default_verify_pathsv failed");
            return IOTCS_RESULT_FAIL;
        }
    } else {
        while (certs) {
            size_t len;
            const char *buf = tam_get_trust_anchor_certificate(--certs, &len);
            X509 *cert = d2i_X509(NULL, (const unsigned char**) &buf, len);
#ifdef IOTCS_DEBUG
            PEM_write_X509(stdout, cert);
#endif
            X509_STORE_add_cert(SSL_CTX_get_cert_store(ctx), cert);
            X509_free(cert);
        }
    }

    return IOTCS_RESULT_OK;
}

void iotcs_port_ssl_finalize() {
    iotcs_port_ssl_disconnect();
#ifdef IOTCS_LONG_POLLING
    iotcs_port_ssl_disconnect_lp();
#endif

    if (ssl) {
        SSL_free(ssl);
    }
#ifdef IOTCS_LONG_POLLING
    if (lp_ssl) {
        SSL_free(lp_ssl);
    }
#endif

    if (ctx) {
        SSL_CTX_free(ctx);
    }

    freeaddrinfo(servinfo);

#ifdef IOTCS_MESSAGING_THREAD_SAFETY
    ssl_finit_locks();
#endif

    //Bug #2561: Memory leak with SSL built-in compressions (https://rt.openssl.org/Ticket/Display.html?id=2561).
    //Fixed in 1.0.2 release.
#if OPENSSL_VERSION_NUMBER < 0x10002000L
    sk_SSL_COMP_free(SSL_COMP_get_compression_methods());
#else
    SSL_COMP_free_compression_methods();
#endif

#ifdef IOTCS_LONG_POLLING
    lp_ssl = NULL;
#endif
    ssl = NULL;
    ctx = NULL;
    user_host_addr = NULL;
    ssl_port = 0;
}

static iotcs_result disconnect_socket(int *socket, SSL* s) {
    if (*socket > 0) {
        close(*socket);
    }

    *socket = 0;

    if (s) {
        SSL_shutdown(s);
        SSL_clear(s);
    }

    return IOTCS_RESULT_OK;
}

static iotcs_result connect_socket(int *socket, SSL* ssl) {
    int length;
    if (*socket < 0) {
        return IOTCS_RESULT_FAIL;
    }

    /*loop through all the results and connect to the first we can*/
    for (p = servinfo; p != NULL; p = p->ai_next) {
        if (connect(*socket, p->ai_addr, p->ai_addrlen) == -1) {
            int errsv = errno;
            LOG_ERR("Connect to server failed. %s.", strerror(errsv));
            continue;
        }
        if (proxy_host) {
            LOG_INFO("Connected to proxy %s:%s.", proxy_host, proxy_port);
            length = util_safe_snprintf(proxy_buffer, PROXY_BUFFER_SIZE, "CONNECT %s:%d HTTP/1.1\r\nHost:%s:%d\r\nProxy-Connection: Keep-Alive\r\n\r\n", user_host_addr, ssl_port, user_host_addr, ssl_port);
            GOTO_ERR_MSG(do_write(*socket, proxy_buffer, length), "Socket write failed to proxy");
            GOTO_ERR_MSG(do_read(*socket, proxy_buffer, PROXY_BUFFER_SIZE, &length, IOTCS_FALSE, IOTCS_FALSE), "Socket read failed from proxy");
            if (strncmp(proxy_buffer, "HTTP/1.0 200", strlen("HTTP/1.0 200")) != 0 && strncmp(proxy_buffer, "HTTP/1.1 200", strlen("HTTP/1.1 200")) != 0) {
                GOTO_ERR_MSG(1, proxy_buffer);
            }
        }
        break; /*if we get here, we must have connected successfully*/
    }

    GOTO_ERR_MSG(p == NULL, "Socket connection failed.");
    GOTO_ERR_MSG(ssl == NULL, "Failed creating SSL");
    GOTO_ERR_MSG(SSL_set_fd(ssl, *socket) != 1, "Failed initializing SSL for the socket");
    if (proxy_host) {
        SSL_set_tlsext_host_name(ssl, user_host_addr);
    }

    if (SSL_connect(ssl) != 1) {
        long result_v = SSL_get_verify_result(ssl);
        LOG_ERR("Verify status=%ld", result_v);
        switch (result_v) {
            case X509_V_OK:
                GOTO_ERR_MSG(1, "The verification succeeded or no peer certificate was presented");
            case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
                GOTO_ERR_MSG(1, "Verification failed because the server's certificate is"
                        " self-signed and isn't part of the trust store.");
            case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
                GOTO_ERR_MSG(1, "Verification failed because the CA certificate up the "
                        "certificate-chain is not part of the trust store.");
            default:
                GOTO_ERR_MSG(1, "Unknown SSL error");
        }
    }
    // IOTCS_DISABLE_VERIFIER_CERT variable sets in makefile to disable certificate verification
#ifndef IOTCS_DISABLE_VERIFIER_CERT
    //custom verifier check
    if (!proxy_host) {
        GOTO_ERR_MSG(iotcs_hostname_verifier_check(ssl, user_host_addr) != IOTCS_RESULT_OK,
                "Host name verifier failed.");
    }
#endif

    return IOTCS_RESULT_OK;
error:
    disconnect_socket(socket, ssl);
    return IOTCS_RESULT_FAIL;
}

iotcs_result iotcs_port_ssl_connect(void) {
    sockfd = socket(AF_INET, SOCK_STREAM, 0);
    return connect_socket(&sockfd, ssl);
}

#ifdef IOTCS_LONG_POLLING

iotcs_result iotcs_port_ssl_connect_lp(int32_t timeout_ms) {
    iotcs_result rv;

    struct timeval t_timeout;

    t_timeout.tv_sec = timeout_ms / 1000;
    t_timeout.tv_usec = (timeout_ms % 1000) * 1000;

    lp_socketfd = socket(AF_INET, SOCK_STREAM, 0);

    if (IOTCS_RESULT_OK != (rv = connect_socket(&lp_socketfd, lp_ssl))) {
        LOG_ERRS("connect_socket method failed!");
        return IOTCS_RESULT_FAIL;
    }

    /*If timeout not positive then we do nothing.*/
    if (timeout_ms > 0) {
        /*Set timeout interval for receive.*/
        if (setsockopt(lp_socketfd, SOL_SOCKET, SO_RCVTIMEO, (char *) &t_timeout,
                sizeof (t_timeout)) < 0) {
            LOG_ERRS("setsockopt failed");
        }
    }

    return rv;
}
#endif

iotcs_result iotcs_port_ssl_disconnect(void) {
    return disconnect_socket(&sockfd, ssl);
}

#ifdef IOTCS_LONG_POLLING

iotcs_result iotcs_port_ssl_disconnect_lp(void) {
    return disconnect_socket(&lp_socketfd, lp_ssl);
}
#endif

iotcs_result iotcs_port_ssl_write(char* request, size_t length) {
    return do_ssl_write(ssl, request, length);
}

#ifdef IOTCS_LONG_POLLING

iotcs_result iotcs_port_ssl_write_lp(char* request, size_t length) {
    iotcs_result rv = do_ssl_write(lp_ssl, request, length);
    return rv;
}
#endif

iotcs_result iotcs_port_ssl_read(char* buffer, int len, int *bytes_read) {
    return do_ssl_read(ssl, buffer, len, bytes_read, IOTCS_FALSE);
}
#ifdef IOTCS_LONG_POLLING

iotcs_result iotcs_port_ssl_read_lp(char* buffer, int len, int *bytes_read) {
    return do_ssl_read(lp_ssl, buffer, len, bytes_read, IOTCS_TRUE);
}
#endif

int iotcs_posix_mqtt_ssl_get_socket(void) {
    return sockfd;
}

int iotcs_posix_mqtt_ssl_read(unsigned char* buf, int len) {
    int rc;
    LOG_DBG("read to 0x%p", buf);
    rc = SSL_read(ssl, buf, len);
    LOG_DBG("read %d", rc);

#ifdef IOTCS_DEBUG
    if (rc > MIN_SIZE_OF_READ_DATA) {
        int i = 0;
        LOG_DBGS("Next data was read:");
        for (i = 0; i < rc; i++) {
            printf("%c", buf[i]);
        }
        printf("\n");
    }
#endif

#ifdef __CYGWIN__
    if (rc == -1) {
        //        static int flag = 1;
        unsigned long error;
        int ssl_err = SSL_get_error(ssl, -1);
        // 0x2007507E is code that corresponds to SSL bug with write to read only bio
        // when SSL_read is done on NON_BLOCK socket
        if (ssl_err == SSL_ERROR_WANT_READ || 0x2007507E == (error = ERR_get_error())) {
            return 0;
        }
        //        if (flag) {
        //            SSL_load_error_strings();
        //            flag = 0;
        //        }
        do {
            LOG_ERR(" SSL_read error = %lu: %s", error, ERR_error_string(error, NULL));
        } while (0 != (error = ERR_get_error()));
        return 0;
    }
#endif
    return rc <= 0 ? -1 : rc;
}

int iotcs_posix_mqtt_ssl_write(const unsigned char* buf, int len) {
    LOG_DBG("write to 0x%p", buf);
    int rc = SSL_write(ssl, buf, len);
    LOG_DBG("written %d", rc);
    return rc;
}

#ifdef IOTCS_STORAGE_SUPPORT

iotcs_result iotcs_port_storage_ssl_init(const char* addr, unsigned short port, iotcs_bool is_ssl) {
    char port_str[PORT_BUFFER_SIZE];
    const char *proxy = iotcs_port_get_http_proxy();
    if (!addr) {
        return IOTCS_RESULT_INVALID_ARGUMENT;
    }

    is_ssl_storage = is_ssl;
    storage_server_port = port;
    util_free(storage_server_host);
    storage_server_host = util_safe_strcpy(addr);

    memset(&scs_hints, 0, sizeof scs_hints);
    scs_hints.ai_family = AF_INET;
    scs_hints.ai_socktype = SOCK_STREAM;
    util_safe_snprintf(port_str, PORT_BUFFER_SIZE, "%d", (int) port);

    if (proxy) {
        if (!proxy_host) {
            if (strncmp(proxy, "http://", strlen("http://")) == 0) {
                proxy_host = util_safe_strcpy(&proxy[strlen("http://") + 1]);
            } else {
                proxy_host = util_safe_strcpy(proxy);
            }
            proxy_port = strstr(proxy_host, ":");
            if (proxy_port) {
                *proxy_port = 0;
                proxy_port++;
            } else {
                proxy_port = util_safe_strcpy("8080");
            }
        }

        if (getaddrinfo(proxy_host, proxy_port, &scs_hints, &scs_servinfo) != 0) {
            LOG_ERR("Invalid host name: %s", proxy_host);
            return IOTCS_RESULT_FAIL;
        }
    } else {
        if (getaddrinfo(addr, port_str, &scs_hints, &scs_servinfo) != 0) {
            LOG_ERR("Invalid host name: %s", addr);
            return IOTCS_RESULT_FAIL;
        }
    }

    if (is_ssl_storage) {
        if (!scs_ssl) {
            scs_ssl = SSL_new(ctx);
        }
    }

    return IOTCS_RESULT_OK;
}

void iotcs_port_storage_ssl_finalize(void) {
    if (scs_ssl) {
        SSL_free(scs_ssl);
    }

    freeaddrinfo(scs_servinfo);

    scs_ssl = NULL;
    storage_server_host = NULL;
    storage_server_port = 0;
    scs_socketfd = 0;
}

iotcs_result iotcs_port_storage_ssl_connect(void) {
    int length;
    scs_socketfd = socket(AF_INET, SOCK_STREAM, 0);

    /*loop through all the results and connect to the first we can*/
    for (scs_p = scs_servinfo; scs_p != NULL; scs_p = scs_p->ai_next) {
        if (connect(scs_socketfd, scs_p->ai_addr, scs_p->ai_addrlen) == -1) {
            int errsv = errno;
            LOG_ERR("Connect to server failed. %s.", strerror(errsv));
            continue;
        }
        if (proxy_host) {
            LOG_INFO("Connected to proxy %s:%s.", proxy_host, proxy_port);
            length = util_safe_snprintf(proxy_buffer, PROXY_BUFFER_SIZE, "CONNECT %s:%d HTTP/1.1\r\nHost:%s:%d\r\nProxy-Connection: Keep-Alive\r\n\r\n", storage_server_host, storage_server_port, storage_server_host, storage_server_port);
            GOTO_ERR_MSG(do_write(scs_socketfd, proxy_buffer, length), "Socket write failed to proxy");
            GOTO_ERR_MSG(do_read(scs_socketfd, proxy_buffer, PROXY_BUFFER_SIZE, &length, IOTCS_FALSE, IOTCS_FALSE), "Socket read failed from proxy");
            if (strncmp(proxy_buffer, "HTTP/1.0 200", strlen("HTTP/1.0 200")) != 0 && strncmp(proxy_buffer, "HTTP/1.1 200", strlen("HTTP/1.1 200")) != 0) {
                GOTO_ERR_MSG(1, proxy_buffer);
            }
        }
        break; /*if we get here, we must have connected successfully*/
    }

    if (is_ssl_storage) {
        GOTO_ERR_MSG(scs_p == NULL, "Socket connection failed.");
        GOTO_ERR_MSG(scs_ssl == NULL, "Failed creating SSL");
        GOTO_ERR_MSG(SSL_set_fd(scs_ssl, scs_socketfd) != 1, "Failed initializing SSL for the socket");
        if (proxy_host) {
            SSL_set_tlsext_host_name(scs_ssl, storage_server_host);
        }

        if (SSL_connect(scs_ssl) != 1) {
            long result_v = SSL_get_verify_result(scs_ssl);
            LOG_ERR("Verify status=%ld", result_v);
            switch (result_v) {
                case X509_V_OK:
                    GOTO_ERR_MSG(1, "The verification succeeded or no peer certificate was presented");
                case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
                    GOTO_ERR_MSG(1, "Verification failed because the server's certificate is"
                            " self-signed and isn't part of the trust store.");
                case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
                    GOTO_ERR_MSG(1, "Verification failed because the CA certificate up the "
                            "certificate-chain is not part of the trust store.");
                default:
                    GOTO_ERR_MSG(1, "Unknown SSL error");
            }
        }
    }

    return IOTCS_RESULT_OK;
error:
    disconnect_socket(&scs_socketfd, scs_ssl);
    return IOTCS_RESULT_FAIL;
}

iotcs_result iotcs_port_storage_ssl_disconnect(void) {
    return disconnect_socket(&scs_socketfd, scs_ssl);
}

iotcs_result iotcs_port_storage_ssl_read(char* buffer, int len, int *bytes_read) {
    if (is_ssl_storage) {
        return do_ssl_read(scs_ssl, buffer, len, bytes_read, IOTCS_FALSE);
    } else {
        return do_read(scs_socketfd, buffer, len, bytes_read, IOTCS_FALSE, IOTCS_TRUE);
    }
}

iotcs_result iotcs_port_storage_ssl_write(char* request, size_t length) {
    if (is_ssl_storage) {
        return do_ssl_write(scs_ssl, request, length);
    } else {
        return do_write(scs_socketfd, request, length);
    }
}
#endif
