/***********************************************************
 *
 * ASN.1 Tools of the clazzes.org project
 * https://www.clazzes.org
 *
 * Created: 24.07.2024
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 ***********************************************************/

package org.clazzes.login.asn1.tools;

import java.beans.IntrospectionException;
import java.beans.PropertyDescriptor;
import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.math.BigInteger;
import java.util.BitSet;
import java.util.SortedMap;
import java.util.TreeMap;

import org.clazzes.login.asn1.annotations.ASN1Field;
import org.clazzes.login.asn1.annotations.ASN1Sequence;
import org.clazzes.login.asn1.annotations.ASN1Type;
import org.clazzes.login.asn1.annotations.ASN1TypeCase;
import org.clazzes.login.asn1.annotations.ASN1TypeSwitch;
import org.clazzes.login.asn1.streams.ConfinedInputStream;
import org.ietf.jgss.GSSException;
import org.ietf.jgss.Oid;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class ASN1Helper {

    private static final Logger log = LoggerFactory.getLogger(ASN1Helper.class);


    public static void writeLength(OutputStream os, int len) throws IOException {

        if (len < 0) {
            throw new IllegalArgumentException("Negative length for ASN.1");
        }

        if (len < 128) {
            os.write(len);
        }
        else {
            byte[] tmp = new byte[5];
            tmp[1] = (byte)(len >>> 24);
            tmp[2] = (byte)(len >>> 16);
            tmp[3] = (byte)(len >>> 8);
            tmp[4] = (byte)(len);


            if (tmp[1] != 0) {
                tmp[0] = (byte)0x84;
                os.write(tmp);
            }
            else {
                if (tmp[2] != 0) {
                    tmp[1] = (byte)0x83;
                    os.write(tmp,1,4);
                }
                else {
                    if (tmp[3] != 0) {
                        tmp[2] = (byte)0x82;
                        os.write(tmp,2,3);
                    }
                    else {
                        tmp[3] = (byte)0x81;
                        os.write(tmp,3,2);
                    }
                }
            }
        }
    }

    public static int readLength(InputStream is) throws IOException {

        int b1 = is.read();

        if (b1 < 0) {
            throw new EOFException("EOF in first ASN.1 length byte.");
        }

        int ret;

        if (b1 < 128) {
            return b1;
        }
        else {

            int n = b1 & 0x7f;

            if (n < 1 || n > 4) {
                throw new ASN1Exception("Extendeg length byte ["+ b1 +"] is not in the range [0x81,0x84].");
            }

            ret = 0;

            while (--n >= 0) {

                int b = is.read();

                if (b < 0) {
                    throw new EOFException("EOF in followup ASN.1 length byte.");
                }

                ret = (ret << 8) | b;
            }
        }

        return ret;
    }

    protected static void readExpectedTag(InputStream is, ASN1Type type) throws IOException {

        int rtag = is.read();

        if (rtag != type.getTag()) {
            throw new ASN1Exception("Read tag ["+rtag+"] is not the expected tag ["+type+"].");
        }
    }

    protected static byte[] readBytes(InputStream is, ASN1Type type) throws IOException {

        readExpectedTag(is,type);

        int len = readLength(is);

        byte[] ret = new byte[len];

        is.read(ret);

        return ret;
    }

    protected static void writeBytes(OutputStream os, ASN1Type type, byte[] bytes) throws IOException {

        os.write(type.getTag());
        writeLength(os, bytes.length);
        os.write(bytes);
    }

    public static byte[] readOctetString(InputStream is) throws IOException {
        return readBytes(is,ASN1Type.OCTET_STRING);
    }

    public static void writeOctetString(OutputStream os, byte[] bytes) throws IOException {
        writeBytes(os,ASN1Type.OCTET_STRING,bytes);
    }

    public static byte[] readBitStringBytes(InputStream is) throws IOException {

        readExpectedTag(is,ASN1Type.BIT_STRING);

        int len = readLength(is);

        if (len <= 0) {
            throw new ASN1Exception("Zero length in BIT STRING.");
        }

        int unused = is.read();

        if (unused != 0) {
            log.warn("Ignoring unused bit count [{}] in BIT STRING.",unused);
        }

        byte[] ret = new byte[len-1];

        is.read(ret);

        return ret;
    }

    public static void writeBitStringBytes(OutputStream os, byte[] bytes) throws IOException {

        os.write(ASN1Type.BIT_STRING.getTag());
        writeLength(os,bytes.length+1);
        os.write(0);
        os.write(bytes);
    }

    public static BitSet readBitString(InputStream is) throws IOException {

        readExpectedTag(is,ASN1Type.BIT_STRING);

        int len = readLength(is);

        if (len <= 0) {
            throw new ASN1Exception("Zero length in BIT STRING.");
        }

        int unused = is.read();

        if (unused < 0 || unused > 7) {
            throw new ASN1Exception("Invalid unused bits ["+unused+"] in BIT STRING.");
        }

        byte[] bytes = new byte[len-1];

        is.read(bytes);

        int nbits = (len-1) * 8 - unused;

        BitSet ret = new BitSet(nbits);

        for (int i=0;i<nbits;++i) {

            ret.set(i,(bytes[i/8] & (1 << (7-(i&0x7)))) != 0);
        }

        return ret;
    }

    public static void writeBitString(OutputStream os, BitSet bs) throws IOException {

        os.write(ASN1Type.BIT_STRING.getTag());

        byte[] bytes = new byte[(bs.length()+7)/8];

        for (int i=0;i<bs.length();++i) {

            if (bs.get(i)) {
                bytes[i/8] |= (1 << (7-(i&0x7)));
            }
        }

        writeLength(os,bytes.length+1);
        os.write(bytes.length * 8 - bs.length());
        os.write(bytes);
    }

    public static BigInteger readInteger(InputStream is) throws IOException {

        byte[] bytes = readBytes(is,ASN1Type.INTEGER);

        return new BigInteger(bytes);
    }

    public static void writeInteger(OutputStream os, BigInteger i) throws IOException {

        byte[] bytes = i.toByteArray();
        writeBytes(os,ASN1Type.INTEGER,bytes);
    }

    public static Oid readOid(InputStream is) throws IOException {

        try {
            return new Oid(is);
        } catch (GSSException e) {
            throw new ASN1Exception("Invalid OBJECT IDENTIFIER",e);
        }
    }

    public static void writeOid(OutputStream os, Oid oid) throws IOException {

        try {
            os.write(oid.getDER());
        } catch (GSSException e) {
            throw new ASN1Exception("Invalid Oid instance given for OBJECT IDENTIFIER",e);
        }
    }

    public static Void readNull(InputStream is) throws IOException {

        readExpectedTag(is,ASN1Type.NULL);
        int c = is.read();

        if (c != 0) {
            throw new ASN1Exception("Invalid content byte ["+c+"] for NULL");
        }

        return null;
    }

    public static void writeNull(OutputStream os) throws IOException {

        os.write(ASN1Type.NULL.getTag());
        os.write(0);
    }


    protected static SortedMap<Integer,Field> getASN1Fields(Class<?> cls) throws ASN1Exception {

        if (cls.getAnnotation(ASN1Sequence.class) == null) {
            throw new ASN1Exception("Object of type ["+cls+"] is no ASN.1 sequence.");
        }

        SortedMap<Integer,Field> fields = new TreeMap<Integer,Field>();

        while (cls != null && cls != Object.class) {
            for (Field f : cls.getDeclaredFields()) {

                ASN1Field asn1Field = f.getAnnotation(ASN1Field.class);

                if (asn1Field != null) {
                    int order = asn1Field.order();
                    fields.put(order,f);
                }
            }

            cls = cls.getSuperclass();
        }

        return fields;
    }

    protected static Method getASN1TypeSwitch(Class<?> cls) throws ASN1Exception {

        if (cls.getAnnotation(ASN1Sequence.class) == null) {
            throw new ASN1Exception("Object of type ["+cls+"] is no ASN.1 sequence.");
        }

        while (cls != null && cls != Object.class) {

            for (Method m : cls.getDeclaredMethods()) {
                if (m.getAnnotation(ASN1TypeSwitch.class) != null) {
                    return m;
                }
            }

            cls = cls.getSuperclass();
        }

        return null;
    }

    public static Object readObject(InputStream is, ASN1Type type, Class<?> cls) throws IOException {

        switch (type) {
            case BIT_STRING:
                if (BitSet.class == cls) {
                    return readBitString(is);
                }
                else if (byte[].class == cls) {
                    return readBitStringBytes(is);
                }
                else {
                    throw new IllegalArgumentException("Unsupported type ["+cls+"] for ASN.1 BIT STRING.");
                }
            case INTEGER:
                if (BigInteger.class == cls) {
                    return readInteger(is);
                }
                else if (int.class == cls || Integer.class == cls) {
                    BigInteger bi = readInteger(is);
                    return bi.intValueExact();
                }
                else if (long.class == cls || Long.class == cls) {
                    BigInteger bi = readInteger(is);
                    return bi.longValueExact();
                }
                else {
                    throw new IllegalArgumentException("Unsupported type ["+cls+"] for ASN.1 INTEGER.");
                }
            case OBJECT_IDENTIFIER:
                if (Oid.class == cls) {
                    return readOid(is);
                }
                else {
                    throw new IllegalArgumentException("Unsupported type ["+cls+"] for ASN.1 OBJECT IDENTIFIER.");
                }
            case OCTET_STRING:
                if (byte[].class == cls) {
                    return readOctetString(is);
                }
                else {
                    throw new IllegalArgumentException("Unsupported type ["+cls+"] for ASN.1 OCTET STRING.");
                }
            case SEQUENCE:
                return readComplex(is,cls);
            case NULL:
                return readNull(is);
            default:
                throw new ASN1Exception("Reading objects of type ["+type+"] is not supported.");
        }
    }

    public static void readField(InputStream is, Object x, Field f) throws IOException {

        ASN1Field asn1Field = f.getAnnotation(ASN1Field.class);

        if (log.isDebugEnabled()) {
            log.info("Reading field [{}]",f);
        }

        try {
            PropertyDescriptor pd = new PropertyDescriptor(f.getName(),f.getDeclaringClass());

            Object value = readObject(is,asn1Field.type(),f.getType());

            pd.getWriteMethod().invoke(x,value);

        } catch (IntrospectionException |
                 IllegalAccessException |
                 IllegalArgumentException |
                 InvocationTargetException e) {
            throw new ASN1Exception("Reflection error writing field ["+f+"]",e);
        }
    }

    /**
     * Read a complex object annotated by {@link ASN1Sequence}.
     *
     * @param is The stream to read from.
     * @param cls A class with an {@link ASN1Sequence} annotation.
     * @return The parsed instance of the given class or one of its subclasses
     *         for polymorpohic structures.
     * @throws IOException
     */
    @SuppressWarnings("unchecked")
    public static <T> T readComplex(InputStream is, Class<T> cls) throws IOException {

        Method typeSwitch = getASN1TypeSwitch(cls);

        readExpectedTag(is,ASN1Type.SEQUENCE);

        int len = readLength(is);

        try (ConfinedInputStream cis = new ConfinedInputStream(is,len)) {

            Class<? extends T> subCls;

            if (typeSwitch == null) {
                subCls = cls;
            }
            else {
                ASN1TypeSwitch ts =
                    typeSwitch.getAnnotation(ASN1TypeSwitch.class);

                String designator;
                int intDesignator;

                if (typeSwitch.getReturnType() == Oid.class) {
                    designator = readOid(cis).toString();
                    intDesignator = Integer.MIN_VALUE;
                }
                else {
                    intDesignator = readInteger(cis).intValueExact();
                    designator = "";
                }

                subCls = null;

                for (ASN1TypeCase tc : ts.value()) {

                    if (designator.equals(tc.oid()) && intDesignator == tc.value()) {
                        subCls = (Class<? extends T>)tc.clazz();
                        break;
                    }
                }

                if (subCls == null) {
                    throw new ASN1Exception("No subclass for type switch value ["+
                                            designator+"] of class ["+cls+"]");
                }
            }

            T x;
            try {
                x = subCls.getDeclaredConstructor().newInstance();
            } catch (InstantiationException | IllegalAccessException | IllegalArgumentException
                    | InvocationTargetException | NoSuchMethodException | SecurityException e) {
                throw new ASN1Exception(
                            "Reflection error instantiating ["+subCls+"]",e);
            }

            SortedMap<Integer, Field> fields = getASN1Fields(subCls);

            for (Field f : fields.values()) {
                readField(cis,x,f);
            }

            return x;
        }

    }

    /**
     * Decode a complex object annotated by {@link ASN1Sequence}.
     * This method throws an exception, if not all of the provided input is read.
     *
     * @param encoded The BER encoded byte array to read from.
     * @param cls A class with an {@link ASN1Sequence} annotation.
     * @return The parsed instance of the given class or one of its subclasses
     *         for polymorpohic structures.
     * @throws IOException
     */
    public static <T> T decodeComplex(byte[] encoded, Class<T> cls) throws IOException {

        try (InputStream is = ConfinedInputStream.wrap(encoded)) {
            return ASN1Helper.readComplex(is,cls);
        }
    }

    public static void writeObject(OutputStream os, ASN1Type type, Object x) throws IOException {

        Class<?> cls = x == null ? null : x.getClass();

        switch (type) {
            case BIT_STRING:
                if (BitSet.class == cls) {
                    writeBitString(os,(BitSet)x);
                }
                else if (byte[].class == cls) {
                    writeBitStringBytes(os,(byte[])x);
                }
                else {
                    throw new IllegalArgumentException("Unsupported type ["+cls+"] for ASN.1 BIT STRING.");
                }
                break;
            case INTEGER:
                if (BigInteger.class == cls) {
                    writeInteger(os,(BigInteger)x);
                }
                else if (int.class == cls || Integer.class == cls) {
                    writeInteger(os,BigInteger.valueOf((Integer)x));
                }
                else if (long.class == cls || Long.class == cls) {
                    writeInteger(os,BigInteger.valueOf((Long)x));
                }
                else {
                    throw new IllegalArgumentException("Unsupported type ["+cls+"] for ASN.1 INTEGER.");
                }
                break;
            case OBJECT_IDENTIFIER:
                if (Oid.class == cls) {
                    writeOid(os,(Oid)x);
                }
                else {
                    throw new IllegalArgumentException("Unsupported type ["+cls+"] for ASN.1 OBJECT IDENTIFIER.");
                }
                break;
            case OCTET_STRING:
                if (byte[].class == cls) {
                    writeOctetString(os,(byte[])x);
                }
                else {
                    throw new IllegalArgumentException("Unsupported type ["+cls+"] for ASN.1 OCTET STRING.");
                }
                break;
            case SEQUENCE:
                writeComplex(os,x);
                break;
            case NULL:
                if (x == null) {
                    writeNull(os);
                }
                else {
                    throw new IllegalArgumentException("Unsupported type ["+cls+"] for ASN.1 NULL.");
                }
                break;
            default:
                throw new ASN1Exception("Writing objects of type ["+type+"] is not supported.");
        }
    }

    public static void writeField(OutputStream os, Object x, Field f) throws IOException {

        if (log.isDebugEnabled()) {
            log.info("Writing field [{}]",f);
        }

        ASN1Field asn1Field = f.getAnnotation(ASN1Field.class);

        try {
            PropertyDescriptor pd = new PropertyDescriptor(f.getName(),f.getDeclaringClass());

            Object v = pd.getReadMethod().invoke(x);

            writeObject(os,asn1Field.type(),v);

        } catch (IntrospectionException |
                 IllegalAccessException |
                 IllegalArgumentException |
                 InvocationTargetException e) {
            throw new ASN1Exception("Reflection error reading field ["+f+"]",e);
        }
    }

    /**
     * Write a complex object annotated by {@link ASN1Sequence}.
     *
     * @param os The stream to write to.
     * @param x The object to write, which must be an instance of a class with
     *         an {@link ASN1Sequence} annotation.
     * @throws IOException
     */
    public static void writeComplex(OutputStream os, Object x) throws IOException {

        SortedMap<Integer, Field> fields = getASN1Fields(x.getClass());

        Method typeSwitch = getASN1TypeSwitch(x.getClass());

        try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {

            if (typeSwitch != null) {
                try {

                    Object designator = typeSwitch.invoke(x);

                    if (typeSwitch.getReturnType() == Oid.class) {
                        writeOid(bos,(Oid)designator);
                    }
                    else {
                        writeInteger(bos,BigInteger.valueOf((Integer)typeSwitch.invoke(x)));
                    }

                }
                catch (IllegalAccessException |
                       IllegalArgumentException |
                       InvocationTargetException e) {
                    throw new ASN1Exception(
                       "Reflection error invoking type switch ["+typeSwitch+"]",e);
                }
            }

            for (Field f : fields.values()) {
                writeField(bos,x,f);
            }

            byte[] bytes = bos.toByteArray();

            writeBytes(os,ASN1Type.SEQUENCE,bytes);
        }
    }

    /**
     * Encode a complex object annotated by {@link ASN1Sequence}.
     *
     * @param x The object to write, which must be an instance of a class with
     *         an {@link ASN1Sequence} annotation.
     * @return the encoded form of the given complex object.
     * @throws IOException
     */
    public static byte[] encodeComplex(Object x) throws IOException {

        try (ByteArrayOutputStream bos = new ByteArrayOutputStream(1024)) {
            writeComplex(bos,x);
            return bos.toByteArray();
        }
    }

}