/*
 * Copyright (c) 2016, Oracle and/or its affiliates.  All rights reserved.
 *
 * This software is dual-licensed to you under the MIT License (MIT) and
 * the Universal Permissive License (UPL).  See the LICENSE file in the root
 * directory for license terms.  You may choose either license, or both.
 */

package com.oracle.iot.sample.quickstart;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.json.Json;

import java.security.*;
import java.security.spec.InvalidKeySpecException;
import java.io.FileNotFoundException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.util.Base64;

public class ClientAssertionHelper {
	
	public static String data = "\"grant_type=client_credentials&client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer&"
			+ "client_assertion=ASSERTION&scope=SCOPE\"";
	public static StringBuilder curlCommand = new StringBuilder("curl -s -X POST -H 'Accept:application/json' -H 'Content-Type: application/x-www-form-urlencoded' "
			+ "--data "); 
	private static String ACTIVATION_SCOPE = "oracle/iot/activation";
	private static String MESSAGE_SCOPE = "";
	private static String ACIVATION_TOKEN_LOCATION = "../json/activation_token.json";
	private static String MESSAGE_TOKEN_LOCATION = "../json/message_token.json";
	
	/**
	 * Generate client assertion to compose cURL command to get activation token from server
	 * @param activationSecret shared secret that client provides
	 * @param id when getting activation token, id means activation id, when getting message token, id means endpoint id
	 * @param URL iot server url
	 */
	public static void clientAssertionForActivation(String activationSecret, String id, String URL) {
		Key key = new SecretKeySpec(activationSecret.getBytes(), "HmacSHA256");
		try {
			String clientAssertion = getClientAssertion(id, "HS256", key);
			String curl = printAssertionAndCurl(clientAssertion, ACTIVATION_SCOPE, URL, ACIVATION_TOKEN_LOCATION);
			writeCurlToFile(curl, "activation_token_curl.sh");
		} catch (InvalidKeyException | NoSuchAlgorithmException | SignatureException | InvalidKeySpecException e) {
			e.printStackTrace();
		}
	}
	
	/**
	 * Generate client assertion to compose cURL command to get message token from server
	 * @param privateKeyFileName
	 * @param id
	 * @param URL
	 */
	public static void clientAssertionForMessaging(String privateKeyFileName, String id, String URL) {
		try {
			PrivateKey key = RSAKeyHelper.getPrivateKeyFromFile(privateKeyFileName);
			String clientAssertion = getClientAssertion(id, "RS256", key);
			String curl = printAssertionAndCurl(clientAssertion, MESSAGE_SCOPE, URL, MESSAGE_TOKEN_LOCATION);
			writeCurlToFile(curl, "message_token_curl.sh");
		} catch (Exception e) {
			e.printStackTrace();
		}
	}
	
	// print client assertion and cURL command
	private static String printAssertionAndCurl(String clientAssertion, String scope, String URL, String location) {
		String realData = data.replaceFirst("ASSERTION", clientAssertion)
				  			  .replaceFirst("SCOPE", scope);
		String curl = curlCommand.append(realData)
								 .append(" \"" + URL + "\"")
								 .append(" | tee " + location).toString();
		System.out.println("Client Assertion:\n" + clientAssertion + "\n");
		System.out.println("cURL command:\n" + curl + "\n");
		return curl;
	}
	
	private static String getClientAssertion(String id, String algorithm, Key key) throws InvalidKeyException, NoSuchAlgorithmException, SignatureException, InvalidKeySpecException  {   
        String header = getJWTHeader(algorithm);
        String claimSet = getJWTClaimSet(id);

        StringBuilder inputToSign = new StringBuilder();
        inputToSign.append(Base64Encode(toUTF8(header)));
        inputToSign.append(".");
        inputToSign.append(Base64Encode(toUTF8(claimSet)));

        byte[] signedBytes = sign(toUTF8(inputToSign.toString()), algorithm, key);
        String signature = Base64Encode(signedBytes);

        inputToSign.append(".");
        inputToSign.append(signature);
        return inputToSign.toString();
    }
	
	public static String Base64Encode(byte[] input) {
        String s = Base64.getEncoder().encodeToString(input);
        s = s.split("=")[0];  // as per RFC 4648 (see http://en.wikipedia.org/wiki/Base64#Variants_summary_table)
        s = s.replace('+', '-');
        s = s.replace('/', '_');
        return s;
    }
	
	public static byte[] toUTF8(String input) {
		return input.getBytes(StandardCharsets.UTF_8);
	}
	
	private static byte[] sign(byte[] material, String algorithm, Key key) throws NoSuchAlgorithmException,
     				InvalidKeyException, SignatureException, InvalidKeySpecException {
		 if ( algorithm.equals("RS256") ) {
		     Signature sig = Signature.getInstance("SHA256withRSA");
		     sig.initSign((PrivateKey) key);
		     sig.update(material);
		     return sig.sign();
		 } else {
		     Mac mac = Mac.getInstance("HmacSHA256");
		     mac.init(key);
		     mac.update(material);
		     return mac.doFinal();
		 }
	}
	
	private static String getJWTHeader(String algorithm) {
		return Json.createObjectBuilder()
				.add("typ", "JWT")
				.add("alg", algorithm)
				.build().toString();
	}
	
	private static String getJWTClaimSet(String id) {
		int expiryTime = (int) (System.currentTimeMillis()/1000 + 15 * 60);
		return Json.createObjectBuilder()
				.add("iss", id)
				.add("aud", "oracle/iot/oauth2/token")
				.add("exp", expiryTime)
				.build().toString();
	}
	
	private static void writeCurlToFile(String curl, String fileName) {
		try (PrintWriter out = new PrintWriter(fileName)) {
			out.println(curl);
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		}
		
	}
	
}
