/***********************************************************
 * $Id$
 *
 * OAuth Login Services of the clazzes.org project
 * http://www.clazzes.org
 *
 * Created: 28.05.2017
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 ***********************************************************/

package org.clazzes.login.jbo.jwt;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.math.BigInteger;
import java.security.AlgorithmParameters;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.ECGenParameterSpec;
import java.security.spec.ECParameterSpec;
import java.security.spec.ECPoint;
import java.security.spec.ECPublicKeySpec;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.RSAPublicKeySpec;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;

import org.clazzes.login.jbo.common.Algorithm;
import org.clazzes.login.jbo.common.CurveType;
import org.clazzes.login.jbo.common.Helpers;
import org.clazzes.login.jbo.common.KeyOperation;
import org.clazzes.login.jbo.common.KeyType;
import org.clazzes.login.jbo.common.PubKeyInfo;
import org.clazzes.login.jbo.json.JsonHelper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.fasterxml.jackson.core.JacksonException;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.module.SimpleModule;

/**
 * A parser for {@link JWKPubKey} objects.
 */
public class JWKeyParser {

    private static final Logger log = LoggerFactory.getLogger(JWKeyParser.class);

    private static class JWKPubKeySerializer extends JsonSerializer<JWKPubKey> {

        @Override
        public Class<JWKPubKey> handledType() {

            return JWKPubKey.class;
        }

        @Override
        public void serialize(JWKPubKey pubKey, JsonGenerator gen, SerializerProvider serializers) throws IOException {

            gen.writeStartObject();

            PubKeyInfo keyInfo = pubKey.getPubKeyInfo();

            gen.writeFieldName("kid");
            gen.writeString(keyInfo.getKeyId());

            gen.writeFieldName("kty");
            gen.writeString(keyInfo.getKeyType().getJwkType());

            boolean mayVerify = false;
            boolean mayEncrypt = false;

            if (keyInfo.getKeyOperations() != null) {

                gen.writeFieldName("key_ops");

                gen.writeStartArray();

                for (KeyOperation ko : keyInfo.getKeyOperations()) {
                    gen.writeString(ko.toString());

                    if (ko == KeyOperation.verify) {
                        mayVerify = true;
                    }
                    if (ko == KeyOperation.encrypt || ko == KeyOperation.wrapKey) {
                        mayEncrypt = true;
                    }
                }

                gen.writeEndArray();
            }

            gen.writeFieldName("use");
            gen.writeString(mayVerify ? "sig" : (mayEncrypt ? "enc" : "sig"));

            gen.writeFieldName("alg");
            gen.writeString(keyInfo.getAlgorithm().toString());

            if (keyInfo.getKeyType() == KeyType.RSA) {
                RSAPublicKey rsaPubKey = (RSAPublicKey)keyInfo.getPublicKey();

                gen.writeFieldName("n");
                gen.writeString(Helpers.formatPositiveBigInt(rsaPubKey.getModulus()));

                gen.writeFieldName("e");
                gen.writeString(Helpers.formatPositiveBigInt(rsaPubKey.getPublicExponent()));
            }
            else if (keyInfo.getKeyType() == KeyType.EC2) {

                ECPublicKey ecPubKey = (ECPublicKey)keyInfo.getPublicKey();

                gen.writeFieldName("crv");
                gen.writeString(keyInfo.getCurve().getJwkType());

                gen.writeFieldName("x");
                gen.writeString(Helpers.formatPositiveBigInt(ecPubKey.getW().getAffineX()));

                gen.writeFieldName("y");
                gen.writeString(Helpers.formatPositiveBigInt(ecPubKey.getW().getAffineY()));
            }
            else {

                throw new IllegalArgumentException("Invalid JSON Web Key type ["+keyInfo.getKeyType()+"].");
            }

            if (pubKey.getCertificateUrl() != null) {
                gen.writeFieldName("x5u");
                gen.writeString(pubKey.getCertificateUrl());
            }

            if (pubKey.getCertificateThumbprint() != null) {
                gen.writeFieldName("x5t");
                gen.writeString(Helpers.formatBase64(pubKey.getCertificateThumbprint()));
            }

            if (pubKey.getCertificateThumbprintSha256() != null) {
                gen.writeFieldName("x5t#S256");
                gen.writeString(Helpers.formatBase64(pubKey.getCertificateThumbprintSha256()));
            }

            if (pubKey.getCertificateChain() != null) {

                gen.writeFieldName("x5c");

                gen.writeStartArray();

                for (Certificate cert : pubKey.getCertificateChain()) {
                    try {
                        // as per https://tools.ietf.org/html/rfc7517#section-4.7 certificates are
                        // base64-encoded, but not in the URL-safe variant.
                        gen.writeString(Base64.getEncoder().encodeToString(cert.getEncoded()));
                    } catch (CertificateEncodingException e) {
                        throw new IllegalArgumentException("JWK public key with unencodable certificate in chain.",e);
                    }
                }

                gen.writeEndArray();
            }

            gen.writeEndObject();
        }
    }

    private static class JWKPubKeyDeserializer extends JsonDeserializer<JWKPubKey> {

        @Override
        public Class<?> handledType() {
            return JWKPubKey.class;
        }

        @Override
        public JWKPubKey deserialize(JsonParser p, DeserializationContext dctxt) throws IOException, JacksonException {

            String ctxt = "JWKPubKey";

            JsonHelper.beginObject(p,ctxt);

            String keyId = null;
            KeyType keyType = null;
            CurveType curve = null;
            Algorithm algorithm = null;
            String usage = null;
            KeyOperation keyOperations[] = null;
            String certificateUrl = null;
            Certificate[] certificateChain = null;
            byte[] certificateThumbprint = null;
            byte[] certificateThumbprintSha256 = null;

            // ECC parameters
            BigInteger x = null;
            BigInteger y = null;

            // RSA modulus/exponent.
            BigInteger modulus = null;
            BigInteger exponent = null;

            String name;

            while ((name = JsonHelper.nextName(p,ctxt)) != null) {

                if ("kid".equals(name)) {
                    keyId = JsonHelper.nextString(p,ctxt);
                }
                else if ("kty".equals(name)) {
                    keyType = KeyType.getByJwkType(JsonHelper.nextString(p,ctxt));
                }
                else if ("alg".equals(name)) {
                    String alg = JsonHelper.nextString(p,ctxt);

                    if ("RSA-OAEP".equals(alg)) {
                        // keycloak issues keys with algorithm RSA-OAEP, which we interpret as
                        // RSAES-OAEP-SHA-256 for the moment.
                        algorithm = Algorithm.RSAES_OAEP_SHA_256;
                    }
                    else {
                        algorithm = Algorithm.valueOf(alg);
                    }
                }
                else if ("use".equals(name)) {
                    usage = JsonHelper.nextString(p,ctxt);
                }
                else if ("crv".equals(name)) {
                    curve = CurveType.getByJwkType(JsonHelper.nextString(p,ctxt));
                }
                else if ("x5u".equals(name)) {
                    certificateUrl = JsonHelper.nextString(p,ctxt);
                }
                else if ("key_ops".equals(name)) {

                    ArrayList<KeyOperation> key_ops = new ArrayList<KeyOperation>();
                    JsonHelper.beginArray(p,ctxt);

                    String ko;

                    while ((ko=p.nextTextValue()) != null) {
                        key_ops.add(KeyOperation.valueOf(ko));
                    }

                    JsonHelper.endArray(p,ctxt);

                    keyOperations = key_ops.toArray(new KeyOperation[key_ops.size()]);
                }
                else if ("x5c".equals(name)) {

                    ArrayList<Certificate> certs = new ArrayList<Certificate>();

                    CertificateFactory cf;
                    try {
                        cf = CertificateFactory.getInstance("X.509");

                        JsonHelper.beginArray(p,ctxt);

                        String crt;

                        while ((crt=p.nextTextValue()) != null) {

                            byte[] encoded = Helpers.parseBase64(crt);

                            Certificate cert = cf.generateCertificate(new ByteArrayInputStream(encoded));
                            certs.add(cert);
                        }

                        JsonHelper.endArray(p,ctxt);

                    } catch (CertificateException e) {
                        throw new IllegalArgumentException("Unable to parse x5c certificate chain",e);
                    }

                    certificateChain = certs.toArray(new Certificate[certs.size()]);
                }
                else if ("x5t".equals(name)) {

                    byte[] sha1Thumb = Helpers.parseBase64(JsonHelper.nextString(p,ctxt));

                    if (sha1Thumb.length != 20) {
                        throw new IllegalArgumentException("x5t SHA1 certificate thumbprint has length ["+sha1Thumb.length+"], which is not 20.");
                    }

                    certificateThumbprint = sha1Thumb;
                }
                else if ("x5t#S256".equals(name)) {

                    byte[] sha256Thumb = Helpers.parseBase64(JsonHelper.nextString(p,ctxt));

                    if (sha256Thumb.length != 32) {
                        throw new IllegalArgumentException("x5t#S256 SHA256 certificate thumbprint has length ["+sha256Thumb.length+"], which is not 32.");
                    }

                    certificateThumbprintSha256 = sha256Thumb;
                }
                else if ("e".equals(name)) {
                    exponent = Helpers.parsePositiveBigInt(JsonHelper.nextString(p,ctxt));
                }
                else if ("n".equals(name)) {
                    modulus = Helpers.parsePositiveBigInt(JsonHelper.nextString(p,ctxt));
                }
                else if ("x".equals(name)) {
                    x = Helpers.parsePositiveBigInt(JsonHelper.nextString(p,ctxt));
                }
                else if ("y".equals(name)) {
                    y = Helpers.parsePositiveBigInt(JsonHelper.nextString(p,ctxt));
                }
                else {
                    log.warn("Invalid attribute [{}] in JSON Web Key.",name);
                    p.nextToken();
                    p.skipChildren();
                }
            }

            JsonHelper.endObject(p,ctxt);

            PublicKey pubKey;

            if (keyType == KeyType.RSA) {

                if (exponent == null || modulus == null) {
                    throw new IllegalArgumentException("Public RSA JSON Web Key with missing modulus or exponent.");
                }

                KeyFactory kf;
                try {
                    kf = KeyFactory.getInstance("RSA");

                    RSAPublicKeySpec rpks = new RSAPublicKeySpec(modulus,exponent);

                    pubKey = kf.generatePublic(rpks);

                } catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
                    throw new IllegalArgumentException("Public RSA JSON Web Key could not be instantiated.",e);
                }

                if (algorithm == null) {
                    algorithm = Algorithm.RS256;
                }
            }
            else if (keyType == KeyType.EC2) {

                if (x == null || y == null) {
                    throw new IllegalArgumentException("EC JSON Web Key with missing x or y base.");
                }

                try {
                    AlgorithmParameters ap = AlgorithmParameters.getInstance("EC");
                    ap.init(new ECGenParameterSpec(curve.getJceType()));

                    ECParameterSpec ecps = ap.getParameterSpec(ECParameterSpec.class);

                    ECPoint w = new ECPoint(x,y);

                    ECPublicKeySpec ecpks = new ECPublicKeySpec(w,ecps);

                    KeyFactory kf = KeyFactory.getInstance("EC");

                    pubKey = kf.generatePublic(ecpks);

                } catch (Exception e) {
                    throw new IllegalArgumentException("Public EC JSON Web Key could not be instantiated.",e);
                }

                if (algorithm == null) {
                    algorithm = curve.getAlgorithm();
                }
            }
            else {

                throw new IllegalArgumentException("Invalid JSON Web Key type ["+keyType+"].");
            }

            if (certificateChain != null && certificateChain.length == 0) {
                throw new IllegalArgumentException("Public JSON Web Key has an empty X.509 certificate chain.");
            }

            if (certificateThumbprint == null) {

                if (certificateChain != null) {
                    certificateThumbprint = Helpers.getSha1Fingerprint(certificateChain[0]);
                }
            }
            else {

                if (certificateChain == null && certificateUrl == null) {
                    throw new IllegalArgumentException("Public JSON Web Key has an empty X.509 certificate chain but a given thumbprint.");
                }

                if (certificateChain != null) {
                    byte[] check = Helpers.getSha1Fingerprint(certificateChain[0]);

                    if (!Arrays.equals(certificateThumbprint,check)) {
                        throw new IllegalArgumentException("Public JSON Web Key has a X.509 certificate and a differing x5t thumbprint.");
                    }
                }
            }

            if (certificateThumbprintSha256 == null) {

                if (certificateChain != null) {
                    certificateThumbprintSha256 = Helpers.getSha256Fingerprint(certificateChain[0]);
                }
            }
            else {

                if (certificateChain == null && certificateUrl == null) {
                    throw new IllegalArgumentException("Public JSON Web Key has an empty X.509 certificate chain but a given SHA-256 thumbprint.");
                }

                if (certificateChain != null) {
                    byte[] check = Helpers.getSha256Fingerprint(certificateChain[0]);

                    if (!Arrays.equals(certificateThumbprintSha256,check)) {
                        throw new IllegalArgumentException("Public JSON Web Key has a X.509 certificate and a differing x5t#S256 thumbprint.");
                    }
                }
            }

            if (keyOperations == null) {

                if ("enc".equals(usage)) {
                    keyOperations = new KeyOperation[] { KeyOperation.encrypt, KeyOperation.wrapKey };
                }
                else if (usage == null || "sig".equals(usage)) {
                    keyOperations = new KeyOperation[] { KeyOperation.verify };
                }
                else {
                    throw new IllegalArgumentException("JSON Web Key has invalid key usage ["+usage+"]");
                }

            }
            else if (usage != null) {

                log.warn("JSON Web Key has key usage ["+usage+"] will be ignored because of the presence of key_ops.");
            }

            if (certificateChain != null) {

                PublicKey check = certificateChain[0].getPublicKey();

                if (!Helpers.pubKeysAreEqual(pubKey,check)) {
                    throw new IllegalArgumentException("Public JSON Web Key has a X.509 certificate with a non-matching public key.");
                }
            }

            PubKeyInfo keyInfo = new PubKeyInfo(keyId,keyType,curve,algorithm,keyOperations,pubKey);

            return new JWKPubKey(keyInfo,certificateUrl,
                    certificateChain,certificateThumbprint,certificateThumbprintSha256);
        }
    }
    private static class JWKPubKeyListHolder implements Serializable {

        private static final long serialVersionUID = -8420183466318633281L;

        private List<JWKPubKey> keys;

        public List<JWKPubKey> getKeys() {
            return this.keys;
        }

        public void setKeys(List<JWKPubKey> keys) {
            this.keys = keys;
        }
    };

    /**
     * Parse JSON Web public key from the given input stream.
     *
     * @param is An input stream containing a single JSON Web Key object.
     * @return The parsed JWT public key.
     * @throws IOException Upon parse errors.
     */
    public static JWKPubKey parsePubKey(InputStream is) throws IOException {

        ObjectMapper objectMapper = new ObjectMapper();

        SimpleModule module = new SimpleModule();
        module.addDeserializer(JWKPubKey.class,new JWKPubKeyDeserializer());

        objectMapper.registerModule(module);

        return objectMapper.readValue(is,JWKPubKey.class);
    }

    /**
     * Format a JSON Web public key and write it to the given output stream.
     *
     * @param os An output stream to write a single JSON Web Key object to.
     * @param pubKey A JWT public key to write to the output stream.
     * @throws IOException
     */
    public static void formatPubKey(OutputStream os, JWKPubKey pubKey) throws IOException {

        ObjectMapper objectMapper = new ObjectMapper();

        SimpleModule module = new SimpleModule();
        module.addSerializer(JWKPubKey.class,new JWKPubKeySerializer());

        objectMapper.registerModule(module);

        objectMapper.writeValue(os,pubKey);
    }

    /**
     * Parse JSON Web public key from the given input stream.
     *
     * @param is An input stream containing a single JSON object containing a list
     *           of Web Key objects under key <code>keys</code>.
     * @return The parsed list of JWT public keys.
     * @throws IOException
     */
    public static List<JWKPubKey> parsePubKeyList(InputStream is) throws IOException {

        ObjectMapper objectMapper = new ObjectMapper();

        SimpleModule module = new SimpleModule();
        module.addDeserializer(JWKPubKey.class,new JWKPubKeyDeserializer());

        objectMapper.registerModule(module);

        JWKPubKeyListHolder holder = objectMapper.readValue(is,JWKPubKeyListHolder.class);

        return holder.getKeys();
    }

    /**
     * Format a list of JSON Web public key and write it to the given output stream.
     *
     * @param os An output stream to write a single JSON Web Key object to.
     * @param pubKeyList A list of JWT public key to write to the output stream.
     * @throws IOException
     */
    public static void formatPubKeyList(OutputStream os, List<JWKPubKey> pubKeyList) throws IOException {

        ObjectMapper objectMapper = new ObjectMapper();


        SimpleModule module = new SimpleModule();
        module.addSerializer(JWKPubKey.class,new JWKPubKeySerializer());

        objectMapper.registerModule(module);

        JWKPubKeyListHolder holder = new JWKPubKeyListHolder();
        holder.setKeys(pubKeyList);

        objectMapper.writeValue(os,holder);
    }
}
