/***********************************************************
 * $Id$
 *
 * JSON CBOR Login Tools of the clazzes.org project
 * http://www.clazzes.org
 *
 * Created: 21.05.2020
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 ***********************************************************/

package org.clazzes.login.jbo.bc;

import java.math.BigInteger;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
import java.security.spec.ECPoint;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.bouncycastle.asn1.ASN1BitString;
import org.bouncycastle.asn1.ASN1Encodable;
import org.bouncycastle.asn1.ASN1ObjectIdentifier;
import org.bouncycastle.asn1.ASN1OctetString;
import org.bouncycastle.asn1.nist.NISTNamedCurves;
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers;
import org.bouncycastle.asn1.x509.Certificate;
import org.bouncycastle.asn1.x509.Extension;
import org.bouncycastle.asn1.x509.KeyUsage;
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
import org.bouncycastle.asn1.x509.TBSCertificate;
import org.bouncycastle.asn1.x9.X9ECParameters;
import org.bouncycastle.asn1.x9.X9ObjectIdentifiers;
import org.bouncycastle.math.ec.ECCurve;
import org.bouncycastle.util.BigIntegers;
import org.clazzes.login.jbo.common.Algorithm;
import org.clazzes.login.jbo.common.CurveType;
import org.clazzes.login.jbo.common.IExtensionConsumer;
import org.clazzes.login.jbo.common.KeyOperation;
import org.clazzes.login.jbo.common.KeyType;
import org.clazzes.login.jbo.common.PubKeyInfo;

/**
 * Bouncycastle tools for JBO Login Tools.
 */
public abstract class BCTools {

    private static final KeyOperation defaultKeyOperations[] = new KeyOperation[] {KeyOperation.verify};

    // copied over from bouncycastle, because the package
    // org.bouncycastle.asn1.edec.EdECObjectIdentifiers has been moved to
    // org.bouncycastle.internal.asn1.edec.EdECObjectIdentifiers with bc-1.78
    // without any notice.
    private static final ASN1ObjectIdentifier id_edwards_curve_algs      = new ASN1ObjectIdentifier("1.3.101");
    private static final ASN1ObjectIdentifier id_Ed25519 = id_edwards_curve_algs.branch("112").intern();
    private static final ASN1ObjectIdentifier id_Ed448 = id_edwards_curve_algs.branch("113").intern();


    /**
     * Decompress a compressed elliptic curve point
     *
     * @param ct The curve type.
     * @param x The x coordinate of the point. If <code>null</code> is passed in,
     *         {@link ECPoint#POINT_INFINITY} is returned.
     * @param ytilde If <code>true</code>, the y value is odd, otherwise it is even.
     *               This parameter is ignored, if x is <code>null</code>.
     * @return The decompressed point.
     */
    public static final ECPoint decompressPoint(CurveType ct, BigInteger x, boolean ytilde) {

        if (x == null) {
            return ECPoint.POINT_INFINITY;
        }

        X9ECParameters p = NISTNamedCurves.getByOID(new ASN1ObjectIdentifier(ct.getOid()));
        ECCurve curve = p.getCurve();

        // we have to encode the big integer and the y tilde,
        // because ECCurve.decompressPoint() is protected.
        int expectedLength = (curve.getFieldSize() + 7) / 8;

        byte[] encoded = new byte[expectedLength + 1];

        encoded[0] = (byte)(ytilde ? 0x03 : 0x02);

        System.arraycopy(BigIntegers.asUnsignedByteArray(expectedLength, x),0,encoded,1,expectedLength);

        org.bouncycastle.math.ec.ECPoint bcPoint = p.getCurve().decodePoint(encoded);

        return new ECPoint(bcPoint.getAffineXCoord().toBigInteger(),bcPoint.getAffineYCoord().toBigInteger());
    }

    /**
     * Decompress a compressed elliptic curve point
     *
     * @param ct The curve type.
     * @param point The EC point to encode
     * @return The encoded point in uncompresed format..
     */
    public static final byte[] encodePoint(CurveType ct, ECPoint point) {

        if (point == null || point.equals(ECPoint.POINT_INFINITY)) {
            return new byte[1];
        }

        X9ECParameters p = NISTNamedCurves.getByOID(new ASN1ObjectIdentifier(ct.getOid()));
        ECCurve curve = p.getCurve();

        // we have to encode the big integer and the y tilde,
        // because ECCurve.decompressPoint() is protected.
        int expectedLength = (curve.getFieldSize() + 7) / 8;

        byte[] encoded = new byte[2*expectedLength + 1];

        encoded[0] = (byte)(0x04);

        System.arraycopy(BigIntegers.asUnsignedByteArray(expectedLength,point.getAffineX()),0,encoded,1,expectedLength);
        System.arraycopy(BigIntegers.asUnsignedByteArray(expectedLength,point.getAffineY()),0,encoded,expectedLength+1,expectedLength);

        return encoded;
    }

    /**
     * Extract all public key informations from an X509 Certificate.
     *
     * @param keyId The key ID
     * @param algorithm An optional algorithm hint, mostly used for RSA keys to specify the padding used for signatures.
     * @param certificate An X509 certificate to parse.
     * @return A public key info with correct curve, algorithm and key type informations.
     * @throws CertificateEncodingException If the given given certificate could not be encoded.
     */
    public static PubKeyInfo getPubKeyInfo(Algorithm algorithm, String keyId, X509Certificate certificate) throws CertificateEncodingException {

        return getPubKeyInfo(algorithm,keyId,certificate,null);
    }
    /**
     * Extract all public key informations from an X509 Certificate.
     *
     * @param keyId The key ID
     * @param algorithm An optional algorithm hint, mostly used for RSA keys to specify the padding used for signatures.
     * @param certificate An X509 certificate to parse.
     * @param extensionConsumer An optional consumer for extension values.
     * @return A public key info with correct curve, algorithm and key type informations.
     * @throws CertificateEncodingException If the given given certificate could not be encoded.
     */
    public static PubKeyInfo getPubKeyInfo(Algorithm algorithm, String keyId, X509Certificate certificate, IExtensionConsumer extensionConsumer) throws CertificateEncodingException {

        Certificate cert = Certificate.getInstance(certificate.getEncoded());

        TBSCertificate tbs = cert.getTBSCertificate();

        KeyUsage ku = KeyUsage.fromExtensions(tbs.getExtensions());

        SubjectPublicKeyInfo spki = tbs.getSubjectPublicKeyInfo();

        ASN1ObjectIdentifier algorithmId = spki.getAlgorithm().getAlgorithm();

        KeyType keyType;
        CurveType curveType;

        if (X9ObjectIdentifiers.id_ecPublicKey.equals(algorithmId)) {

            ASN1Encodable params = spki.getAlgorithm().getParameters();

            if (!(params instanceof ASN1ObjectIdentifier)) {
                throw new IllegalArgumentException("EC public key without a curve parameter sepcified.");
            }

            keyType = KeyType.EC2;
            curveType = CurveType.getByOid(params.toString());
            if (algorithm == null) {
                algorithm = curveType.getAlgorithm();
            }
            else if (algorithm != curveType.getAlgorithm()) {
                throw new IllegalArgumentException("Algorithm ["+algorithm+"] is not compatible with EC curve ["+curveType+"].");
            }
        }
        else if (PKCSObjectIdentifiers.rsaEncryption.equals(algorithmId)) {

            keyType = KeyType.RSA;
            curveType = null;

            if (algorithm == null) {
                algorithm = Algorithm.RS256;
            }
            else if (algorithm.getKeyType() != KeyType.RSA) {
                throw new IllegalArgumentException("Algorithm ["+algorithm+"] is not compatible with RSA public key.");
            }
        }
        else if (id_Ed25519.equals(algorithmId)) {

            keyType = KeyType.OKP;
            curveType = CurveType.Ed25519;
            if (algorithm == null) {
                algorithm = Algorithm.EdDSA;
            }
            else if (algorithm != Algorithm.EdDSA) {
                throw new IllegalArgumentException("Algorithm ["+algorithm+"] is not compatible with EdDSA curve ["+curveType+"].");
            }
        }
        else if (id_Ed448.equals(algorithmId)) {

            keyType = KeyType.OKP;
            curveType = CurveType.Ed448;
            if (algorithm == null) {
                algorithm = Algorithm.EdDSA;
            }
            else if (algorithm != Algorithm.EdDSA) {
                throw new IllegalArgumentException("Algorithm ["+algorithm+"] is not compatible with EdDSA curve ["+curveType+"].");
            }
        }
        else {
            throw new IllegalArgumentException("Unknown certificate public key algorithm ["+algorithmId+"].");
        }

        KeyOperation[] keyOperations;

        if (ku == null) {
            keyOperations = defaultKeyOperations;
        }
        else {

            List<KeyOperation> usages = new ArrayList<KeyOperation>();

            if (ku.hasUsages(KeyUsage.digitalSignature) ||
                ku.hasUsages(KeyUsage.nonRepudiation)) {
                usages.add(KeyOperation.verify);
            }
            if (ku.hasUsages(KeyUsage.dataEncipherment)) {
                usages.add(KeyOperation.encrypt);
            }
            if (ku.hasUsages(KeyUsage.keyEncipherment)) {
                usages.add(KeyOperation.wrapKey);
            }

            keyOperations = usages.toArray(new KeyOperation[usages.size()]);
        }

        if (extensionConsumer != null) {

            for (Iterator<String> it = extensionConsumer.getSupportedOids();it.hasNext();) {

                String oid = it.next();

                Extension ext = tbs.getExtensions().getExtension(new ASN1ObjectIdentifier(oid));

                if (ext == null) {
                    continue;
                }

                switch (extensionConsumer.getExtensionFormat(oid)) {

                case BIT_STRING:
                    ASN1BitString bs = (ASN1BitString)ext.getParsedValue();
                    extensionConsumer.consumeBitString(oid,bs.getBytes(),bs.getPadBits());
                    break;
                case OCTET_STRING:
                    ASN1OctetString os = (ASN1OctetString)ext.getParsedValue();
                    extensionConsumer.consumeOctetString(oid,os.getOctets());
                    break;
                case RAW_DATA:
                    extensionConsumer.consumeOctetString(oid,ext.getExtnValue().getOctets());
                    break;
                }
            }
        }

        return new PubKeyInfo(keyId,keyType,curveType,algorithm,keyOperations,certificate.getPublicKey());
    }
}
