/***********************************************************
 * $Id$
 *
 * JSON CBOR Login Tools of the clazzes.org project
 * http://www.clazzes.org
 *
 * Created: 04.06.2020
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 ***********************************************************/

package org.clazzes.login.jbo.u2f.impl;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.security.KeyStore;
import java.security.cert.CertPathBuilder;
import java.security.cert.CertPathBuilderResult;
import java.security.cert.CertPathValidator;
import java.security.cert.CertStore;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CollectionCertStoreParameters;
import java.security.cert.PKIXBuilderParameters;
import java.security.cert.PKIXParameters;
import java.security.cert.TrustAnchor;
import java.security.cert.X509CertSelector;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.UUID;

import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;

import org.clazzes.login.jbo.bc.BCTools;
import org.clazzes.login.jbo.common.Algorithm;
import org.clazzes.login.jbo.common.ExtensionFormat;
import org.clazzes.login.jbo.common.IExtensionConsumer;
import org.clazzes.login.jbo.common.PubKeyInfo;
import org.clazzes.login.jbo.u2f.AttestationCertInfo;
import org.clazzes.login.jbo.u2f.AttestationCertValidator;
import org.clazzes.login.jbo.u2f.AttestationResult;
import org.clazzes.login.jbo.u2f.AttestationType;
import org.clazzes.login.jbo.u2f.DeviceInfo;
import org.clazzes.login.jbo.u2f.DeviceRegistry;
import org.clazzes.login.jbo.u2f.json.DeviceRegistryParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Certificate validation engine for attestation certificates..
 */
public class AttestationCertValidatorImpl implements AttestationCertValidator {

    private static final Logger log = LoggerFactory.getLogger(AttestationCertValidatorImpl.class);

    private static final DeviceRegistry yubicoRegistry;
    private static final CertPathBuilder certPathBuilder;
    private static final CertPathValidator certPathValidator;
    private static final TrustManagerFactory trustManagerFactory;
    private static final Set<TrustAnchor> jceTrustAnchors;

    private static final Map<String,ExtensionFormat> extensionFormats;

    static {
        extensionFormats = new HashMap<String,ExtensionFormat>();
        extensionFormats.put(AttestationCertInfo.idFidoGenCeAaguid,ExtensionFormat.OCTET_STRING);
        extensionFormats.put(AttestationCertInfo.idFidoU2FTransports,ExtensionFormat.BIT_STRING);
        extensionFormats.put(AttestationCertInfo.idYubicoTokenType,ExtensionFormat.RAW_DATA);

        try {
            yubicoRegistry = DeviceRegistryParser.parseJson(AttestationCertValidatorImpl.class.getResourceAsStream("yubico-metadata.json"));
        } catch (IOException e) {
            throw new RuntimeException("Yubico device registry could not be instantiated.",e);
        }

        try {
            certPathValidator = CertPathValidator.getInstance("PKIX");
            certPathBuilder = CertPathBuilder.getInstance("PKIX");
            trustManagerFactory = TrustManagerFactory.getInstance("PKIX");
            trustManagerFactory.init((KeyStore)null);

            jceTrustAnchors = new HashSet<TrustAnchor>();

            for (TrustManager tm : trustManagerFactory.getTrustManagers()) {

                if (tm instanceof X509TrustManager) {

                    X509TrustManager xtm = (X509TrustManager) tm;

                    for (X509Certificate ca : xtm.getAcceptedIssuers()) {
                        jceTrustAnchors.add(new TrustAnchor(ca,null));
                    }
                }
            }

        } catch (Exception e) {
            throw new RuntimeException("PKIX TrustManagerFactory could not be instantiated.",e);
        }
    }

    private Date validationDate;

    /**
     * @param algorithm An optional algorithm hint, mostly used for RSA keys to specify the padding used for signatures.
     * @param cert The certificate to parse.
     * @return A parsed certificate information.
     * @throws CertificateEncodingException
     */
    @Override
    public AttestationCertInfo parseCertificate(Algorithm algorithm, X509Certificate cert) throws CertificateEncodingException {

        final Map<String,Object> extensionValues = new HashMap<String,Object>();
        String keyId = cert.getSubjectX500Principal().getName();

        PubKeyInfo pubKeyInfo = BCTools.getPubKeyInfo(algorithm,keyId,cert,new IExtensionConsumer() {

            @Override
            public void consumeOctetString(String oid, byte[] data) {
                if (AttestationCertInfo.idFidoGenCeAaguid.equals(oid)) {

                    extensionValues.put(oid,data);
                }
                else if (AttestationCertInfo.idYubicoTokenType.equals(oid)) {

                    try {
                        extensionValues.put(oid,new String(data,"US-ASCII"));
                    } catch (UnsupportedEncodingException e) {
                        throw new RuntimeException("US-ASCII not supported.",e);
                    }
                }
            }

            @Override
            public void consumeBitString(String oid, byte[] data, int padBits) {
                if (AttestationCertInfo.idFidoU2FTransports.equals(oid)) {
                    extensionValues.put(oid,Integer.reverse((data[0] & 0xff) << 24));
                }
            }

            @Override
            public Iterator<String> getSupportedOids() {
                return extensionFormats.keySet().iterator();
            }

            @Override
            public ExtensionFormat getExtensionFormat(String oid) {

                return extensionFormats.get(oid);
            }
        });

        return new AttestationCertInfo(extensionValues,pubKeyInfo);
    }

    private final Map<UUID,DeviceRegistry> deviceRegistries;
    private final Map<UUID,Set<TrustAnchor>> trustAnchorsByRegistryUUID;

    public void addDeviceRegistry(DeviceRegistry deviceRegistry) {

        this.deviceRegistries.put(deviceRegistry.getIdentifier(),deviceRegistry);

        HashSet<TrustAnchor> trustAnchors = new HashSet<TrustAnchor>();

        for (X509Certificate cert : deviceRegistry.getTrustedCertificates()) {
            trustAnchors.add(new TrustAnchor(cert,null));
        }

        this.trustAnchorsByRegistryUUID.put(deviceRegistry.getIdentifier(),trustAnchors);
    }

    public AttestationCertValidatorImpl() {
        this(true);
    }

    public AttestationCertValidatorImpl(boolean loadDefaultRegistries) {

        this.deviceRegistries = new HashMap<UUID,DeviceRegistry>();
        this.trustAnchorsByRegistryUUID = new HashMap<UUID,Set<TrustAnchor>>();

        if (loadDefaultRegistries) {
            this.addDeviceRegistry(yubicoRegistry);
        }
    }

    @Override
    public AttestationResult validateCertificateChain(Algorithm algorithm, X509Certificate[] chain) throws Exception {

        if (chain == null || chain.length < 1) {
            throw new IllegalArgumentException("Attestation certificate chain is empty.");
        }

        AttestationCertInfo aci = this.parseCertificate(algorithm,chain[0]);

        DeviceInfo di = null;
        Set<TrustAnchor> trustAnchors = null;
        DeviceRegistry registry = null;

        for (Entry<UUID, DeviceRegistry> e : this.deviceRegistries.entrySet()) {

            di = e.getValue().selectDevice(aci);

            if (di != null) {
                trustAnchors = this.trustAnchorsByRegistryUUID.get(e.getKey());
                registry = e.getValue();
                break;
            }
        }

        if (di == null || trustAnchors == null) {
            throw new SecurityException("Device for attestation certificate ["+aci+"] could not be found.");
        }

        log.info("Found device [{}] for attestation certificate [{}] in registry [{}].",
                new Object[] {di.getDisplayName(),aci,registry.getVendorInfo().getName()});

        CollectionCertStoreParameters cp = new CollectionCertStoreParameters(Arrays.asList(chain));

        CertStore store = CertStore.getInstance("Collection",cp);

        X509CertSelector cs = new X509CertSelector();

        cs.setCertificate(chain[0]);

        PKIXBuilderParameters pbp = new PKIXBuilderParameters(trustAnchors,cs);

        pbp.addCertStore(store);
        pbp.setRevocationEnabled(false);

        CertPathBuilderResult certPath = certPathBuilder.build(pbp);

        PKIXParameters params = new PKIXParameters(trustAnchors);

        params.addCertStore(store);
        params.setTargetCertConstraints(cs);
        params.setRevocationEnabled(false);

        certPathValidator.validate(certPath.getCertPath(),params);

        log.info("Successfully validated attestion certificate chain for device [{}].",
                di.getDisplayName());

        return new AttestationResult(aci,registry.getIdentifier(),di,registry.getVendorInfo(),AttestationType.CERTIFICATE);
    }

    @Override
    public void validateServerCertificateChain(X509Certificate[] chain, String fqdn) throws Exception {

        CollectionCertStoreParameters cp = new CollectionCertStoreParameters(Arrays.asList(chain));

        CertStore store = CertStore.getInstance("Collection",cp);

        X509CertSelector cs = new X509CertSelector();

        // dNSName acc. to https://tools.ietf.org/html/rfc5280#section-4.2.1.6
        cs.addSubjectAlternativeName(2,fqdn);

        PKIXBuilderParameters pbp = new PKIXBuilderParameters(jceTrustAnchors,cs);

        pbp.addCertStore(store);
        pbp.setRevocationEnabled(false);
        pbp.setDate(this.validationDate);

        CertPathBuilderResult certPath = certPathBuilder.build(pbp);

        PKIXParameters params = new PKIXParameters(jceTrustAnchors);

        params.addCertStore(store);
        params.setTargetCertConstraints(cs);
        params.setRevocationEnabled(false);
        params.setDate(this.validationDate);

        certPathValidator.validate(certPath.getCertPath(),params);
    }

    /**
     * @return The time for which the validity of the certification path
     * should be determined. If {@code null}, the current time is used.
     */
    public Date getValidationDate() {
        return this.validationDate;
    }

    /**
     * Sets the time for which the validity of the certification path
     * should be determined. If {@code null}, the current time is used.
     * <p>
     * Note that the {@code Date} supplied here is copied to protect
     * against subsequent modifications.
     *
     * @param validationDate the {@code Date}, or {@code null} for the
     * current time
     * @see #getValidationDate()
     */

    public void setValidationDate(Date validationDate) {
        this.validationDate = validationDate;
    }

}
