/*
 * Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
 *
 * This software is dual-licensed to you under the MIT License (MIT) and
 * the Universal Permissive License (UPL). See the LICENSE file in the root
 * directory for license terms. You may choose either license, or both.
 */

#include <stdio.h>
#include <string.h>
#include <new>
#include "EthernetInterface.h"
#include "mbedtls/platform.h"
#include "Mutex.h"
#include "mbedtls/ssl.h"
#include "mbedtls/net.h"
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"
#include "mbedtls/certs.h"
#include "mbedtls/threading.h"
#include "mbedtls/threading_alt.h"
#include "iotcs_port_ssl.h"
#include "iotcs_port_mqtt.h"
#include "iotcs_port_system.h"
#include "iotcs_port_thread.h"
#include "iotcs_hostname_verifier.h"
#include "trusted_assets_manager/iotcs_tam.h"
#include "util/util.h"
#include "fsl_clock.h"
#include "Socket/TCPSocketConnection.h"
#include "util/util_memory.h"
#include "log/log.h"
#define IOTCSP_MODULE_LOG_CHANNEL LOG_CHANNEL_PORT_SSL
#include "log/log_template.h"

#define IOTCS_MBED_SOCKET_READ_TIMEOUT_MS 5000
#define IOTCS_MBED_SOCKET_CONNECT_TIMEOUT_MS 5000

extern "C" int mbedtls_hardware_poll(void *data,
        unsigned char *output, size_t len, size_t *olen) {

    size_t i;
    int ret;
    UNUSED_ARG(data);

    CLOCK_EnableClock( kCLOCK_Rnga0 );
    CLOCK_DisableClock( kCLOCK_Rnga0 );
    CLOCK_EnableClock( kCLOCK_Rnga0 );

    RNG->CR = RNG_CR_INTM_MASK | RNG_CR_HA_MASK | RNG_CR_GO_MASK;

    for (i = 0; i < len; ++i) {
        int bit;
        unsigned char byte = 0;
        for (bit = 0; bit < 8; ++bit) {
            while ((RNG->SR & RNG_SR_OREG_LVL_MASK) == 0);
            byte |= (RNG->OR & 1) << bit;
        }
        *output++ = byte;
    }

    if ((RNG->SR & RNG_SR_SECV_MASK) != 0) {
        ret = -1;
        goto error;
    }

    *olen = len;
    ret = 0;

error:
    CLOCK_DisableClock( kCLOCK_Rnga0 );

    return ( ret);
}

static char* g_server_host = NULL;
#ifdef IOTCS_MBED_HOST_NAME_OVERRIDE
#define IOTCSP_SERVER_HOST UTIL_STRINGIFY(IOTCS_MBED_HOST_NAME_OVERRIDE)
#else
#define IOTCSP_SERVER_HOST g_server_host
#endif
static unsigned short g_server_port;
static mbedtls_entropy_context entropy;
static mbedtls_ctr_drbg_context ctr_drbg;
static mbedtls_ssl_context ssl;
static mbedtls_ssl_config conf;
static mbedtls_x509_crt cacert;
static mbedtls_x509_crt_profile cert_profile;
TCPSocketConnection socket;

#ifdef IOTCS_LONG_POLLING
static mbedtls_ssl_context lp_ssl;
TCPSocketConnection lp_socket;
#endif

#ifdef IOTCS_STORAGE_SUPPORT
static iotcs_bool is_ssl_storage = IOTCS_TRUE;
static mbedtls_ssl_context scs_ssl;
static TCPSocketConnection scs_socket;
static char* storage_server_host = NULL;
static unsigned short storage_server_port;
#endif


/* mbed socket send function */
int mbedtls_net_send(void* mbedSockConPtr, const unsigned char* buf, size_t length) {
    TCPSocketConnection *socket = (TCPSocketConnection*)mbedSockConPtr;
    int rv = socket->send_all((char *) buf, length);
    LOG_DBG("mbedtls_net_send %d(%d) %s %s", rv, length, socket==&::socket ? "" : "lp", socket->is_connected() ? "" : "DISCONNECTED");
    if (rv < 0) {
        return MBEDTLS_ERR_NET_SEND_FAILED;
    } else if (rv == 0 && socket->is_connected()) {
        return MBEDTLS_ERR_SSL_WANT_WRITE; /* timeout has expired */
    } else {
        return rv;
    }
}

/* mbed socket receive function */
int mbedtls_net_recv(void* mbedSockConPtr, unsigned char* buf, size_t length) {
    TCPSocketConnection *socket = (TCPSocketConnection*)mbedSockConPtr;
    int rv = socket->receive_all((char *) buf, length);
    LOG_DBG("mbedtls_net_recv %d(%d) %s %s", rv, length, socket==&::socket ? "" : "lp", socket->is_connected() ? "" : "DISCONNECTED");
    if (rv < 0) {
        return MBEDTLS_ERR_NET_RECV_FAILED;
    } else if (rv == 0 && socket->is_connected()) {
        return MBEDTLS_ERR_SSL_WANT_READ; /* timeout has expired */
    } else {
        return rv;
    }
}

static void mutex_init(mbedtls_threading_mutex_t *mutex) {
#if defined(IOTCS_LONG_POLLING) || defined(IOTCS_STORAGE_SUPPORT)
    Mutex *m = new((void*)mutex) Mutex;
    UNUSED_ARG(m);
#else
    UNUSED_ARG(mutex);
#endif
}
static void mutex_free(mbedtls_threading_mutex_t *mutex) {
#if defined(IOTCS_LONG_POLLING) || defined(IOTCS_STORAGE_SUPPORT)
    Mutex *m = reinterpret_cast<Mutex*> (mutex);
    m->~Mutex();
#else
    UNUSED_ARG(mutex);
#endif
}

static int mutex_lock(mbedtls_threading_mutex_t *mutex) {
#if defined(IOTCS_LONG_POLLING) || defined(IOTCS_STORAGE_SUPPORT)
    (reinterpret_cast<Mutex*> (mutex))->lock();
#else
    UNUSED_ARG(mutex);
#endif
    return 0;
}

static int mutex_unlock(mbedtls_threading_mutex_t *mutex) {
#if defined(IOTCS_LONG_POLLING) || defined(IOTCS_STORAGE_SUPPORT)
    (reinterpret_cast<Mutex*> (mutex))->unlock();
#else
    UNUSED_ARG(mutex);
#endif
    return 0;
}

iotcs_result iotcs_port_ssl_init(const char* addr, unsigned short port) {
    size_t certs;
    const char *pers = "https seed";
    int ret;

    if (!addr) {
        return IOTCS_RESULT_INVALID_ARGUMENT;
    }

    g_server_port = port;

    g_server_host = util_safe_strcpy(addr);


    mbedtls_threading_set_alt(mutex_init, mutex_free, mutex_lock, mutex_unlock);

    mbedtls_ssl_config_init(&conf);
    mbedtls_x509_crt_init(&cacert);
    mbedtls_ctr_drbg_init(&ctr_drbg);
    mbedtls_entropy_init(&entropy);
    if ((ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy,
            (const unsigned char *) pers,
            strlen(pers))) != 0) {
        LOG_ERR("mbedtls_ctr_drbg_seed failed with code %d", ret);
        return IOTCS_RESULT_FAIL;
    }

    if ((ret = mbedtls_ssl_config_defaults(&conf,
            MBEDTLS_SSL_IS_CLIENT,
            MBEDTLS_SSL_TRANSPORT_STREAM,
            MBEDTLS_SSL_PRESET_DEFAULT)) != 0) {
        LOG_ERR("mbedtls_ssl_config_defaults failed with code %d", ret);
        return IOTCS_RESULT_FAIL;
    }

    mbedtls_ssl_conf_ca_chain(&conf, &cacert, NULL);
    mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ctr_drbg);

#if defined(UNSAFE)    
    mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_OPTIONAL); /* TODO: check if required */
#endif    

    // the default config rejects certificates with keys less than 2048 bits
    // in out case, we need 1024 as a minimum
    cert_profile = *conf.cert_profile;
    cert_profile.rsa_min_bitlen = 1024; /* TODO: check magic number */
    conf.cert_profile = &cert_profile;

    // retrieve the trust anchors from the TAM and inject them into the
    // TLS stack as trusted certificates
    certs = tam_get_trust_anchors_count();
    while (certs) {
        size_t len;
        const char *buf = tam_get_trust_anchor_certificate(--certs, &len);
        ret = mbedtls_x509_crt_parse_der(&cacert, reinterpret_cast<const unsigned char*> (buf), len);
        if (ret != 0) {
            LOG_ERR("mbedtls_x509_crt_parse_der failed with code %d", ret);
            return IOTCS_RESULT_FAIL;
        }
    }

    mbedtls_ssl_init(&ssl);
    if ((ret = mbedtls_ssl_setup(&ssl, &conf)) != 0) {
        if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED) {
            LOG_CRITS("mbedtls_ssl_setup failed with out of memory");
        } else {
            LOG_ERR("mbedtls_ssl_setup failed with code %d.", ret);
        }
        return IOTCS_RESULT_FAIL;
    }

    if ((ret = mbedtls_ssl_set_hostname(&ssl, IOTCSP_SERVER_HOST)) != 0) {
        LOG_ERR("mbedtls_ssl_set_hostname failed with code %d.", ret);
        return IOTCS_RESULT_FAIL;
    }

#ifdef IOTCS_LONG_POLLING
    mbedtls_ssl_init(&lp_ssl);
    if ((ret = mbedtls_ssl_setup(&lp_ssl, &conf)) != 0) {
        if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED) {
            LOG_CRITS("mbedtls_ssl_setup failed with out of memory");
        } else {
            LOG_ERR("mbedtls_ssl_setup failed with code %d.", ret);
        }
        return IOTCS_RESULT_FAIL;
    }

    if ((ret = mbedtls_ssl_set_hostname(&lp_ssl, IOTCSP_SERVER_HOST)) != 0) {
        LOG_ERR("mbedtls_ssl_set_hostname failed with code %d.", ret);
        return IOTCS_RESULT_FAIL;
    }
#endif

    return IOTCS_RESULT_OK;
}

void iotcs_port_ssl_finalize() {
    iotcs_port_ssl_disconnect();
#ifdef IOTCS_LONG_POLLING
    iotcs_port_ssl_disconnect_lp();
#endif
    mbedtls_x509_crt_free(&cacert);
    mbedtls_ssl_free(&ssl);
#ifdef IOTCS_LONG_POLLING
    mbedtls_ssl_free(&lp_ssl);
#endif
    mbedtls_ssl_config_free(&conf);
    mbedtls_ctr_drbg_free(&ctr_drbg);
    mbedtls_entropy_free(&entropy);
    util_free(g_server_host);
    g_server_host = NULL;
    g_server_port = 0;
}
static iotcs_result ssl_disconnect(mbedtls_ssl_context *ssl, TCPSocketConnection *socket) {
    //mbedtls_ssl_close_notify(&ssl);
    mbedtls_ssl_session_reset(ssl);
    LOG_DBGS("close socket");
    /* 'false' arg means no shutdown. We don't expect any data from server here */
    socket->close(false);

    return IOTCS_RESULT_OK;
}

#ifndef IOTCS_MBED_CONNECT_RETRY_NUM
#define IOTCS_MBED_CONNECT_RETRY_NUM 5
#endif

static iotcs_result ssl_connect(mbedtls_ssl_context *ssl, TCPSocketConnection *socket, int32_t timeout_ms, char* server_host, unsigned short server_port) {
    int ret;
    uint32_t flags;
    int retry = 0;
    /*
     * On ARM mbed server_host will be replaced with the IP address of the server
     * mbedtls_ssl_set_hostname requires an address name (i.e. iotserver), otherwise the certificate check will fail    
     */
    do {
        if (retry) {
            LOG_INFO("Retrying (%d attempts left)", IOTCS_MBED_CONNECT_RETRY_NUM - retry + 1);
            // shutdown the socket properly and give it sometime before retrying
            ssl_disconnect(ssl, socket);
            iotcs_port_sleep_millis(1000); /* TODO: magic number */
        }
        retry++;
        LOG_INFO("Connecting to %s:%d...", server_host, server_port);
        if ((ret = socket->connect(server_host, server_port)) < 0) {
            LOG_ERR("Socket connect failed (%d)", ret);
        } else {
            LOG_INFOS("Socket connected");
            /* Use blocking mode: YES 'false' argument turns on blocking mode */
            socket->set_blocking(false, IOTCS_MBED_SOCKET_CONNECT_TIMEOUT_MS);
            mbedtls_ssl_set_bio(ssl, socket, mbedtls_net_send, mbedtls_net_recv, NULL);
            while((ret = mbedtls_ssl_handshake(ssl)) != 0)
            {
                if(ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE)
                {
                    LOG_ERR("mbedtls_ssl_handshake failed 0x%x", -ret);
                    break;
                    //GOTO_ERR(1);
                }
            }
            socket->set_blocking(false, timeout_ms);
        }
    } while ((ret != 0) && (retry <= IOTCS_MBED_CONNECT_RETRY_NUM));
    GOTO_ERR_MSG(ret != 0, "Too many retries, bailing out...");

    if ((flags = mbedtls_ssl_get_verify_result(ssl)) != 0) {
        char vrfy_buf[512]; /* TODO: FIXME I'M TOO BIG!!! */

        mbedtls_x509_crt_verify_info(vrfy_buf, sizeof ( vrfy_buf), "  ! ", flags);

        LOG_ERRS(vrfy_buf);
        GOTO_ERR(1);
    }

// IOTCS_DISABLE_VERIFIER_CERT variable sets in makefile to disable certificate verification
#ifndef IOTCS_DISABLE_VERIFIER_CERT
    GOTO_ERR_MSG(hostname_verifier_check(ssl, IOTCSP_SERVER_HOST) != IOTCS_RESULT_OK, "Host name verifier failed.");
#endif
    return IOTCS_RESULT_OK;
error:
    ssl_disconnect(ssl, socket);
    return IOTCS_RESULT_FAIL;
}

static iotcs_result ssl_write(mbedtls_ssl_context *ssl, char* request, size_t length) {
    int ret = 0;
    if (length > 0) {

        while ((ret = mbedtls_ssl_write(ssl, reinterpret_cast<const unsigned char*> (request), length)) <= 0) {
            LOG_ERR("mbedtls_ssl_write failed (%d)", ret);
            return IOTCS_RESULT_FAIL;
        }
    }
    return IOTCS_RESULT_OK;
}

static iotcs_result ssl_read(mbedtls_ssl_context *ssl, char* buffer, int len, int *bytes_read) {
    int read;
    int tot_read;

    tot_read = 0;

    do {
        read = mbedtls_ssl_read(ssl, reinterpret_cast<unsigned char*> (buffer), len);
        if (read > 0) {
            len -= read;
            buffer += read;
            tot_read += read;
        }
    } while (read > 0);

    *bytes_read = tot_read;
    if (read == 0 || read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
        return IOTCS_RESULT_OK;
    } else if (read == MBEDTLS_ERR_SSL_WANT_READ || read == MBEDTLS_ERR_SSL_WANT_WRITE) {
        return IOTCS_RESULT_SSL_WANT_READ;
    } else {
        LOG_ERR("mbedtls_ssl_read failed (%d)", read);
        return IOTCS_RESULT_FAIL;
    }
}


iotcs_result iotcs_port_ssl_connect() {
    LOG_DBGS("ssl_connect");
    return ssl_connect(&ssl, &socket, IOTCS_MBED_SOCKET_READ_TIMEOUT_MS, g_server_host, g_server_port);
}

iotcs_result iotcs_port_ssl_disconnect(void) {
    LOG_DBGS("ssl_disconnect");
    return ssl_disconnect(&ssl, &socket);
}

iotcs_result iotcs_port_ssl_write(char* request, size_t length) {
    return ssl_write(&ssl, request, length);
}

iotcs_result iotcs_port_ssl_read(char* buffer, int len, int *bytes_read) {
    return ssl_read(&ssl, buffer, len, bytes_read);
}

#ifdef IOTCS_LONG_POLLING

iotcs_result iotcs_port_ssl_connect_lp(int32_t timeout_ms) {
    LOG_DBGS("ssl_connect lp");
    return ssl_connect(&lp_ssl, &lp_socket, timeout_ms, g_server_host, g_server_port);
}

iotcs_result iotcs_port_ssl_disconnect_lp(void) {
    LOG_DBGS("ssl_disconnect lp");
    return ssl_disconnect(&lp_ssl, &lp_socket);
}

iotcs_result iotcs_port_ssl_write_lp(char* request, size_t length) {
    return ssl_write(&lp_ssl, request, length);
}

iotcs_result iotcs_port_ssl_read_lp(char* buffer, int len, int *bytes_read) {
    return ssl_read(&lp_ssl, buffer, len, bytes_read);
}
#endif

static int iotcs_mbed_mqtt_read(Network* n, unsigned char* buffer, int len, int timeout_ms);
static int iotcs_mbed_mqtt_write(Network* n, unsigned char* buffer, int len, int timeout_ms);
static Network g_network = {iotcs_mbed_mqtt_read, iotcs_mbed_mqtt_write};

Network* iotcs_port_mqtt_network_connect(char* addr, int port, int ssl_support) {
    (void) ssl_support;
    (void) addr;
    (void) port;
    return (iotcs_port_ssl_connect() == IOTCS_RESULT_OK) ? &g_network : NULL;
}

void iotcs_port_mqtt_network_disconnect() {
    iotcs_port_ssl_disconnect();
}

static int iotcs_mbed_mqtt_read(Network* n, unsigned char* buffer, int len, int timeout_ms) {
    int bytes_read = 0;
    int remain = len;
    int ret;
    (void) n;
    socket.set_blocking(false, timeout_ms);

    if ((ret = mbedtls_ssl_read(&ssl, reinterpret_cast<unsigned char*> (buffer), remain)) < 0) {
        if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
            LOG_ERR("mbedtls_ssl_read failed with code %d", ret);
            LOG_INFO("mqtt_read %d ERROR", len);
            return -1;
        } else {
            return 0;
        }
    }
    remain -= ret;
    buffer += ret;
    bytes_read += ret;

    LOG_DBG("mqtt_read %d ok", len);
    return bytes_read;
}

static int iotcs_mbed_mqtt_write(Network* n, unsigned char* buffer, int len, int timeout_ms) {
    (void) n;
    int ret = 0;
    socket.set_blocking(false, timeout_ms);

    if ((ret = mbedtls_ssl_write(&ssl, reinterpret_cast<const unsigned char*> (buffer), len)) < 0) {
        if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
            ret = 0;
        } else {
            LOG_ERR("mbedtls_ssl_write failed with code %d", ret);
            LOG_INFO("mqtt_write %d ERROR", len);
            ret = -1;
        }
    }

    LOG_INFO("mqtt_write %d ok", len);
    return ret;
}

#ifdef IOTCS_STORAGE_SUPPORT
iotcs_result iotcs_port_storage_ssl_init(const char* addr, unsigned short port, iotcs_bool is_ssl) {
    int ret;
    if (!addr) {
        return IOTCS_RESULT_INVALID_ARGUMENT;
    }

    storage_server_port = port;
    util_free(storage_server_host);
    storage_server_host = util_safe_strcpy(addr);
    is_ssl_storage = is_ssl;

    if (is_ssl_storage) {
        mbedtls_ssl_init(&scs_ssl);
        if ((ret = mbedtls_ssl_setup(&scs_ssl, &conf)) != 0) {
            if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED) {
                LOG_CRITS("mbedtls_ssl_setup failed with out of memory");
            } else {
                LOG_ERR("mbedtls_ssl_setup failed with code %d.", ret);
            }
            return IOTCS_RESULT_FAIL;
        }

        if ((ret = mbedtls_ssl_set_hostname(&scs_ssl, addr)) != 0) {
            LOG_ERR("mbedtls_ssl_set_hostname failed with code %d.", ret);
            return IOTCS_RESULT_FAIL;
        }
    }
    return IOTCS_RESULT_OK;
}

void iotcs_port_storage_ssl_finalize(void) {
    iotcs_port_storage_ssl_disconnect();
    if (is_ssl_storage) {
        mbedtls_ssl_free(&scs_ssl);
    }
    util_free(storage_server_host);
}

iotcs_result iotcs_port_storage_ssl_connect(void) {
    LOG_DBGS("connect scs");
    if (is_ssl_storage) {
        return ssl_connect(&scs_ssl, &scs_socket, IOTCS_MBED_SOCKET_READ_TIMEOUT_MS, storage_server_host, storage_server_port);
    } else {
        int ret;
        int retry = 0;
        do {
            if (retry) {
                LOG_INFO("Retrying (%d attempts left)", IOTCS_MBED_CONNECT_RETRY_NUM - retry + 1);
                // shutdown the socket properly and give it sometime before retrying
                iotcs_port_sleep_millis(IOTCS_MBED_SOCKET_CONNECT_TIMEOUT_MS / 5); /* TODO: magic number */
            }
            retry++;
            LOG_INFO("Connecting to %s:%d...", storage_server_host, storage_server_port);
            if ((ret = scs_socket.connect(storage_server_host, storage_server_port)) < 0) {
                LOG_ERR("Socket connect failed (%d)", ret);
            } else {
                LOG_INFOS("Socket connected");
                /* Use blocking mode: YES 'false' argument turns on blocking mode */
                scs_socket.set_blocking(false, IOTCS_MBED_SOCKET_CONNECT_TIMEOUT_MS);
            }
        } while ((ret != 0) && (retry <= IOTCS_MBED_CONNECT_RETRY_NUM));
        return (ret == 0) ? IOTCS_RESULT_OK : IOTCS_RESULT_FAIL;
    }
    return IOTCS_RESULT_OK;
}

iotcs_result iotcs_port_storage_ssl_disconnect(void) {
    LOG_DBGS("disconnect scs");
    if (is_ssl_storage) {
        return ssl_disconnect(&scs_ssl, &scs_socket);
    } else {
        /* 'false' arg means no shutdown. We don't expect any data from server here */
        scs_socket.close(false);

        return IOTCS_RESULT_OK;
    }
    return IOTCS_RESULT_OK;
}

iotcs_result iotcs_port_storage_ssl_read(char* buffer, int len, int *bytes_read) {
    LOG_DBGS("read scs");
    if (is_ssl_storage) {
        return ssl_read(&scs_ssl, buffer, len, bytes_read);
    } else {
        int rv = scs_socket.receive_all((char *) buffer, len);
        if (rv > 0) {
            *bytes_read = rv;
        } else {
            LOG_ERR("Socket read failed. 0x%x.", -rv);
            return IOTCS_RESULT_FAIL;
        }
        return IOTCS_RESULT_OK;
    }
    return IOTCS_RESULT_OK;
}

iotcs_result iotcs_port_storage_ssl_write(char* request, size_t length) {
    LOG_DBGS("write scs");
    if (is_ssl_storage) {
        return ssl_write(&scs_ssl, request, length);
    } else {
        int rv = scs_socket.send_all((char *) request, length);
        if (rv < 0) {
            LOG_ERR("Socket write failed. 0x%x.", -rv);
            return IOTCS_RESULT_FAIL;
        }
        return IOTCS_RESULT_OK;
    }
    return IOTCS_RESULT_OK;
}
#endif
