/***********************************************************
 * $Id$
 *
 * OAuth Login Services of the clazzes.org project
 * http://www.clazzes.org
 *
 * Created: 31.05.2017
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 ***********************************************************/

package org.clazzes.login.jbo.jwt;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.security.Signature;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import org.clazzes.login.jbo.common.Algorithm;
import org.clazzes.login.jbo.common.Helpers;
import org.clazzes.login.jbo.json.JsonHelper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.fasterxml.jackson.core.JacksonException;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.module.SimpleModule;

/**
 * <p>
 * Parse <a href="https://tools.ietf.org/html/rfc7519">RFC 7519</a> compatible
 * JSON Web Tokens.
 * </p>
 */
public class JWTokenParser {

    private static final Logger log = LoggerFactory.getLogger(JWTokenParser.class);

    /**
     * Parse an object and only accept string arrays and no arrays of complex types
     * as can be found in JWT Tokens.
     * @param p A JSON parser.
     * @param ctxt A context used for logging.
     * @return A map of the found keys with valuies. Values of type <code>List</code>
     *         are always instances of <code>List&lt;String&gt;</code>.
     * @throws IOException
     */
    public static Map<String,Object> parseObject(JsonParser p, String ctxt) throws IOException {

        Map<String,Object> kvs = new HashMap<String,Object>();

        String key;

        while ((key = JsonHelper.nextName(p,ctxt)) != null) {

            Object value = parsePrimitive(p,key);

            if (value != null)  {
                kvs.put(key,value);
            }
        }

        JsonHelper.endObject(p,ctxt);

        return kvs;
    }

    private static Object parsePrimitive(JsonParser reader, String name) throws IOException {

        String ctxt = "JWToken.additionalClaims."+name;

        switch (reader.nextToken()) {

        case VALUE_NUMBER_INT:
            return reader.getLongValue();

        case VALUE_NUMBER_FLOAT:
            return reader.getDoubleValue();

        case VALUE_STRING:
            return reader.getText();

        case VALUE_TRUE:
        case VALUE_FALSE:
            return reader.getBooleanValue();

        case START_ARRAY:
            {
                List<String> values = new ArrayList<String>();

                String a;

                while ((a=reader.nextTextValue()) != null) {
                    values.add(a);
                }

                JsonHelper.endArray(reader,ctxt);
                return values;
            }

        case START_OBJECT:
            return parseObject(reader,ctxt);

        default:
            log.warn("Ignoring token claim [{}], which is not of type string, number, string-array, map-of-primitives or boolean",name);
            reader.skipChildren();
            return null;
        }
    }

    private static final class JWTTokenClaimsDeserializer extends JsonDeserializer<JWTokenClaims> {

        @Override
        public JWTokenClaims deserialize(JsonParser p, DeserializationContext dctxt)
                throws IOException, JacksonException {

            String ctxt = "JWTTokenClaims";

            String jwtId = null;
            String issuer = null;
            String subject = null;
            List<String> audience = null;
            Long issuedAt = null;
            Long notBefore = null;
            Long expiration = null;
            Map<String,Object> additionalClaims = new TreeMap<String,Object>();

            JsonHelper.beginObject(p,ctxt);

            String name;

            while ((name = JsonHelper.nextName(p,ctxt)) != null) {

                if ("aud".equals(name)) {

                    JsonToken token = p.nextToken();

                    if (token == JsonToken.START_ARRAY) {

                        audience = new ArrayList<String>();
                        String a;

                        while ((a=p.nextTextValue()) != null) {
                            audience.add(a);
                        }

                        JsonHelper.endArray(p,ctxt);
                    }
                    else if (token == JsonToken.VALUE_STRING) {
                        audience = Collections.singletonList(p.getText());
                    }
                    else {
                        throw new IllegalArgumentException("JWT Token claims 'aud' has is neither a string nor an array.");
                    }
                }
                else if ("iss".equals(name)) {
                    issuer = JsonHelper.nextString(p,ctxt);
                }
                else if ("sub".equals(name)) {
                    subject = JsonHelper.nextString(p,ctxt);
                }
                else if ("jti".equals(name)) {
                    jwtId = JsonHelper.nextString(p,ctxt);
                }
                else if ("iat".equals(name)) {
                    issuedAt = JsonHelper.nextLong(p,ctxt) * 1000L;
                }
                else if ("nbf".equals(name)) {
                    notBefore = JsonHelper.nextLong(p,ctxt) * 1000L;
                }
                else if ("exp".equals(name)) {
                    expiration = JsonHelper.nextLong(p,ctxt) * 1000L;
                }
                else {
                    Object v = parsePrimitive(p,name);
                    additionalClaims.put(name,v);
                }
            }

            return new JWTokenClaims(jwtId,issuer,subject,audience,issuedAt,notBefore,expiration,additionalClaims);
        }
    }


    private static void writeIfNotNull(JsonGenerator writer, String key, String value) throws IOException {

        if (value != null) {

            writer.writeFieldName(key);
            writer.writeString(value);
        }
    }

    private static void writeIfNotNull(JsonGenerator writer, String key, List<String> value) throws IOException {

        if (value != null) {

            writer.writeFieldName(key);

            if (value.size() == 1) {

                writer.writeString(value.get(0));
            }
            else {
                writer.writeStartArray();
                for (String v : value) {
                    writer.writeString(v);
                }
                writer.writeEndArray();
            }
        }
    }

    private static void writeIfNotNull(JsonGenerator writer, String key, Long value) throws IOException {

        if (value != null) {

            writer.writeFieldName(key);
            writer.writeNumber(value.longValue()/1000L);
        }
    }


    private static final class JWTTokenClaimsSerializer extends JsonSerializer<JWTokenClaims> {

        @Override
        public Class<JWTokenClaims> handledType() {
            return JWTokenClaims.class;
        }

        @Override
        public void serialize(JWTokenClaims claims, JsonGenerator gen, SerializerProvider serializers)
                throws IOException {

            gen.writeStartObject();

            writeIfNotNull(gen,"aud",claims.getAudience());
            writeIfNotNull(gen,"iss",claims.getIssuer());
            writeIfNotNull(gen,"sub",claims.getSubject());
            writeIfNotNull(gen,"jti",claims.getJwtId());

            writeIfNotNull(gen,"iat",claims.getIssuedAt());
            writeIfNotNull(gen,"nbf",claims.getNotBefore());
            writeIfNotNull(gen,"exp",claims.getExpiration());

            if (claims.getAdditionalClaims() != null) {

                for (Map.Entry<String,Object> e : claims.getAdditionalClaims().entrySet()) {

                    gen.writeFieldName(e.getKey());
                    gen.writePOJO(e.getValue());
                }
            }

            gen.writeEndObject();
        }
    };

    /**
     * @param token The JWT ID Token as transferred in a OAuth authentication response.
     * @return The parsed token, of which the signature may be validated.
     * @throws UnsupportedEncodingException
     */
    public static final JWToken parseJWToken(String token) throws IOException {

        byte[][] parts = Helpers.splitToken(token,3);

        if (parts.length != 3) {
            throw new IllegalArgumentException("JWT Token is not composed of three dot-separated parts.");
        }

        Algorithm algorithm = null;
        String keyId = null;
        byte[] signaturePayload = token.substring(0,token.lastIndexOf('.')).getBytes("US-ASCII");
        byte[] signature = parts[2];

        JsonFactory jf = new JsonFactory();

        try (JsonParser reader = jf.createParser(new ByteArrayInputStream(parts[0]))) {

            String ctxt = "JWTTokenHeader";

            JsonHelper.beginObject(reader,ctxt);

            String name;

            while ((name = JsonHelper.nextName(reader,ctxt)) != null) {

                if ("kid".equals(name)) {
                    keyId = JsonHelper.nextString(reader,ctxt);
                }
                else if ("alg".equals(name)) {

                    String rawAlgorithm = JsonHelper.nextString(reader,ctxt);

                    algorithm = Algorithm.valueOf(rawAlgorithm);
                }
                else if ("typ".equals(name)) {
                    String typ = JsonHelper.nextString(reader,ctxt);

                    if (!"JWT".equals(typ)) {
                        throw new IllegalArgumentException("JWT Token has invalid type ["+typ+"].");
                    }
                }
                else {
                    log.warn("Invalid attribute [{}] in JSON Web Key.",name);
                    reader.nextToken();
                    reader.skipChildren();
                }
            }

        } catch (Exception e) {
            throw new IllegalArgumentException("Error parsing JWT Token header.",e);
        }

        ObjectMapper mapper = new ObjectMapper();
        SimpleModule cm = new SimpleModule();
        cm.addDeserializer(JWTokenClaims.class,new JWTTokenClaimsDeserializer());
        mapper.registerModule(cm);

        JWTokenClaims claims = mapper.readValue(parts[1],JWTokenClaims.class);

        return new JWToken(algorithm,keyId,claims,signaturePayload,signature);
    }

    /**
     * Calculate the signature payload to be passed to
     * {@link JWToken#JWToken(Algorithm, String, JWTokenClaims, byte[], byte[])}
     * before actually signing the token.
     *
     * @param algorithm The JCE algorithm name suitable for
     *                  {@link Signature#getInstance(String)}.
     * @param keyId The ID of the key, which is used for signing.
     * @param claims The claims to be included in the token.
     * @return The ASCII representation of the base64-encoded token header plus
     *         a dot plus the base64-encoded JSON serialized version of the claim set.
     * @throws IOException
     */
    public static byte[] formatSignaturePayload(Algorithm algorithm, String keyId, JWTokenClaims claims) throws IOException {

        ByteArrayOutputStream bos = new ByteArrayOutputStream();

        JsonFactory jf = new JsonFactory();

        try (JsonGenerator writer = jf.createGenerator(bos)) {

            writer.writeStartObject();

            writer.writeFieldName("typ");
            writer.writeString("JWT");

            writer.writeFieldName("alg");
            writer.writeString(algorithm.toString());

            writer.writeFieldName("kid");
            writer.writeString(keyId);

            writer.writeEndObject();
        }

        String part1 = Helpers.formatBase64(bos.toByteArray());

        bos.reset();

        ObjectMapper mapper = new ObjectMapper();
        SimpleModule sm = new SimpleModule();

        sm.addSerializer(new JWTTokenClaimsSerializer());
        mapper.registerModule(sm);

        mapper.writeValue(bos,claims);

        String part2 = Helpers.formatBase64(bos.toByteArray());

        byte[] ret = new byte[part1.length() + 1 + part2.length()];

        System.arraycopy(part1.getBytes("US-ASCII"),0,ret,0,part1.length());
        ret[part1.length()] = '.';
        System.arraycopy(part2.getBytes("US-ASCII"),0,ret,part1.length()+1,part2.length());

        return ret;
    }

    /**
     * @param token The token with the signature payload and the signature set.
     * @return The serialized ID token.
     * @throws UnsupportedEncodingException
     */
    public static String formatSignedToken(JWToken token) throws UnsupportedEncodingException {

        return new String(token.getSignaturePayload(),"US-ASCII") + "." + Helpers.formatBase64(token.getSignature());
    }
}
