/*
 * Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
 *
 * This software is dual-licensed to you under the MIT License (MIT) and
 * the Universal Permissive License (UPL). See the LICENSE file in the root
 * directory for license terms. You may choose either license, or both.
 */

#include <stdio.h>
#include <string.h>
#include <openssl/ssl.h>
#include <openssl/x509_vfy.h>
#include <openssl/x509v3.h>
#include "iotcs_hostname_verifier.h"

#include "log/log.h"
#define IOTCSP_MODULE_LOG_CHANNEL LOG_CHANNEL_PORT_SSL
#include "log/log_template.h"

char* verifier_get_common_name(X509* cert) {
    X509_NAME_ENTRY* entry = NULL;
    ASN1_STRING* data = NULL;
    char* entry_string = NULL;
    X509_NAME* name = NULL;

    if (!cert) {
        return NULL;
    }

    name = X509_get_subject_name(cert);
    if (!name) {
        LOG_ERRS("verifier_get_common_name: X509_get_subject_name method returns NULL.");
        return NULL;
    }

    int lastpos = X509_NAME_get_index_by_NID(name, NID_commonName, -1);

    if (lastpos == -1) {
        LOG_ERRS("verifier_get_common_name: X509_NAME_get_index_by_NID method returns -1.");
        return NULL;
    }

    entry = X509_NAME_get_entry(name, lastpos);

    if (!entry) {
        LOG_ERRS("verifier_get_common_name: X509_NAME_get_entry method returns NULL.");
        return NULL;
    }

    data = X509_NAME_ENTRY_get_data(entry);

    if (!data) {
        LOG_ERRS("verifier_get_common_name: X509_NAME_ENTRY_get_data method returns NULL.");
        return NULL;
    }

    entry_string = (char*) ASN1_STRING_data(data);

    if (!entry_string) {
        LOG_ERRS("verifier_get_common_name: ASN1_STRING_data method returns NULL.");
        return NULL;
    }

    return entry_string;
}

static iotcs_result hostname_verifier_check_common_name(X509* server_cert, const char* host_addr) {
    STACK_OF(GENERAL_NAME)* alt_names = NULL;
    int alt_names_nb = -1;

    if (!host_addr || !server_cert) {
        LOG_ERRS("verifier_check_common_name: some of the input parameters are NULL.");
        return IOTCS_RESULT_INVALID_ARGUMENT;
    }

    //check alternative names   
    alt_names = X509_get_ext_d2i((X509 *) server_cert, NID_subject_alt_name, NULL, NULL);

    if (alt_names) {
        alt_names_nb = sk_GENERAL_NAME_num(alt_names);

        if (alt_names_nb > 0) {
            int i;
            // Check each name within the extension
            for (i = 0; i < alt_names_nb; i++) {
                const GENERAL_NAME *current_name = sk_GENERAL_NAME_value(alt_names, i);

                if (!current_name) {
                    LOG_ERRS("verifier_check_common_name: sk_GENERAL_NAME_value method returns NULL.");
                    return IOTCS_RESULT_FAIL;
                }

                if (current_name->type == GEN_DNS) {
                    char* dns_name = (char*) ASN1_STRING_data(current_name->d.dNSName);

                    if (!dns_name) {
                        LOG_ERRS("verifier_check_common_name: ASN1_STRING_data method returns NULL.");
                        return IOTCS_RESULT_FAIL;
                    }

                    if ((size_t) ASN1_STRING_length(current_name->d.dNSName) == strlen(dns_name)) {
                        if (strcmp(dns_name, host_addr) == 0) {
                            return IOTCS_RESULT_OK;
                        }
                    }
                }
            }
        }
    }

    //check server certificate CN
    if (server_cert) {
        char* server_entry_string = NULL;

        server_entry_string = verifier_get_common_name(server_cert);

        if (!server_entry_string) {
            LOG_ERRS("verifier_check_common_name: verifier_get_common_name method returns NULL.");
            return IOTCS_RESULT_FAIL;
        }

        /* Check if it is wildcard certificate (we support only leading wildcard) in a form "*.XYZ" */
        if (server_entry_string[0] == '*' && server_entry_string[1] == '.') {
            int cn_len = strlen(server_entry_string);
            int host_addr_len = strlen(host_addr);
            /* E.g.: CN "*.abc" should match ".abc" and "XYZ.abc".
             * In host name we should compare only trailing (cn_len - 1) chars the remaining
             * leading chars could be random and must be ignored */
            int ignored_offset = host_addr_len - (cn_len - 1);
            if (ignored_offset < 0) {
                ignored_offset = 0;
            }

            if (strcmp(server_entry_string + 1/* skip '*' symbol */, host_addr + ignored_offset) == 0) {
                return IOTCS_RESULT_OK;
            } else {
                return IOTCS_RESULT_FAIL;
            }
        }

        /* Not a wildcard cert - do exact match */
        if (strcmp(server_entry_string, host_addr) == 0) {
            return IOTCS_RESULT_OK;
        }
    }

    return IOTCS_RESULT_FAIL;
}

iotcs_result iotcs_hostname_verifier_check(SSL* ssl, const char* host_addr) {
    X509* server_cert = NULL;
    iotcs_result result;

    if (!host_addr || !ssl) {
        LOG_ERRS("verifier_check: some of the input parameters are NULL.");
        return IOTCS_RESULT_INVALID_ARGUMENT;
    }

    server_cert = SSL_get_peer_certificate(ssl);
    if (!server_cert) {
        LOG_ERRS("verifier_check: SSL_get_peer_certificate() returns NULL.");
        return IOTCS_RESULT_OUT_OF_MEMORY;
    }

    //check host_name
    result = hostname_verifier_check_common_name(server_cert, host_addr);
    X509_free(server_cert);

    return result;
}
