/***********************************************************
 * $Id$
 * 
 * LDAP single-sign-on for GWT applications of the clazzes.org project
 * http://www.clazzes.org
 *
 * Created: 11.08.2011
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * 
 ***********************************************************/

package org.clazzes.login.ldap;

import java.net.PasswordAuthentication;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.StringTokenizer;

import javax.naming.*;
import javax.naming.directory.*;
import javax.naming.ldap.InitialLdapContext;
import javax.naming.ldap.LdapContext;
import javax.naming.ldap.LdapName;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Helper functions for finding Active Directory Servers. 
 * 
 * @author wglas
 */
public abstract class AdsHelper {

    private static final Logger log = LoggerFactory.getLogger(AdsHelper.class);

    /**
     * This method queries the default DNS server for a given SRV record returns the
     * record with the minimal priority and the maximal weight, if multiple records with
     * the same priority exist.
     * 
     * @param dnsDomain The DNS domain to search.
     * @param srvName The name for the SRV record to search like <code>_ldap._tcp</code>
     * @return An ldap or ldaps URI of the resolved SRV record. If the SRV record did only
     *         contain a hostname, the DNS domain is appended to the returned URL as bind DN
     *         in the form returned by {@link #convertDomainToDN(String)}.
     * @throws NamingException Upon errors resolving the SRV record.
     */
    public static URI querySRV(String dnsDomain, String srvName) throws NamingException {

        //The DNS server to query
        String dnsSRVQueryType []   = {"SRV"};
        String dnsTXTQueryType []   = {"TXT"};

        Hashtable<String,String> dnsEnv = new Hashtable<String,String>();
        dnsEnv.put("java.naming.factory.initial","com.sun.jndi.dns.DnsContextFactory");

        DirContext dnsContext = new InitialDirContext(dnsEnv);

        try {

            if (log.isDebugEnabled())
                log.debug("Starting DNS query for [SRV "+srvName+"]...");

            Attributes dnsQueryResult =
                    dnsContext.getAttributes(srvName+"."+dnsDomain,dnsSRVQueryType);

            if (dnsQueryResult == null)
                throw new NamingException("SRV record ["+srvName+"] could not be found.");

            if (log.isDebugEnabled())
                log.debug("Parsing DNS query result for [SRV "+srvName+"].");

            long minPriority = Long.MAX_VALUE;
            long maxWeight = Long.MIN_VALUE;
            String found = null;
            int foundPort = -1;

            for (NamingEnumeration<? extends Attribute> dnsRR=dnsQueryResult.getAll();dnsRR.hasMoreElements();) {

                Attribute rr = dnsRR.next();

                for (Enumeration<?> vals = rr.getAll();vals.hasMoreElements();) {

                    String srvRecord = vals.nextElement().toString();

                    if (log.isDebugEnabled())
                        log.debug("Found record [SRV "+srvRecord+"].");

                    StringTokenizer st = new StringTokenizer(srvRecord);
                    long priority = Long.parseLong(st.nextToken());
                    long weight = Long.parseLong(st.nextToken());

                    if (priority < minPriority || (priority == minPriority && weight > maxWeight)) {

                        if (log.isDebugEnabled())
                            log.debug("Select record [SRV "+srvRecord+"].");

                        minPriority = priority;
                        maxWeight = weight;

                        foundPort = Integer.parseInt(st.nextToken());            			
                        found = st.nextToken();
                    }
                }
            }

            if (found == null)
                throw new NamingException("Query for SRV ["+srvName+"] returned an empty result set.");

            if (found.endsWith(".")) {
                found = found.substring(0,found.length()-1);
            }
            else {
                found = found + "." + dnsDomain;
            }

            log.info("Resolved SRV record ["+srvName+"] to ["+found+"].");

            if (log.isDebugEnabled())
                log.debug("Starting DNS query for [TXT "+found+"]...");

            dnsQueryResult =
                    dnsContext.getAttributes(found,dnsTXTQueryType);

            if (dnsQueryResult == null)
                throw new NamingException("TXT record ["+found+"] could not be found.");

            if (log.isDebugEnabled())
                log.debug("Parsing DNS query result for [TXT "+found+"].");

            URI ret = null;

            for (NamingEnumeration<? extends Attribute> dnsRR=dnsQueryResult.getAll();dnsRR.hasMoreElements();) {

                Attribute rr = dnsRR.next();

                for (Enumeration<?> vals = rr.getAll();vals.hasMoreElements();) {

                    String txtRecord = vals.nextElement().toString();

                    if (log.isDebugEnabled())
                        log.debug("Found record [TXT "+txtRecord+"].");

                    if (txtRecord.startsWith("service:ldap://") || txtRecord.startsWith("service:ldaps://")) {

                        try {
                            ret = new URI(txtRecord.substring(8));

                            log.info("Resolved TXT record ["+found+"] to ["+ret+"].");

                        } catch (URISyntaxException e) {
                            log.warn("Record [TXT "+txtRecord+"] does not contain a well-formatted ldap-URI.");
                        }
                    }
                }
            }

            if (ret == null) {

                try {
                    
                    String bindDN = convertDomainToDN(dnsDomain);
                    
                    ret = new URI("ldap",null,found,foundPort,bindDN,null,null);

                    if (log.isDebugEnabled())
                        log.debug("Record [TXT "+found+"] was not of type service:ldap(s), final URI from SRV record is ["+ret+"].");

                } catch (URISyntaxException e) {
                    throw new NamingException("Unable to build a valid URI from SRV record ["+found+"].");
                }       			
            }

            return ret;

        } finally {
            dnsContext.close();
        }
    }
    
    /**
     * Find the active directory server for a given DNS domain.
     * 
     * This method queries the default DNS server for a SRV record of the form
     * <code>_ldap._tcp.&lt;dnsDomain&gt;</code> and returns the retrieved record with the
     * minimal priority and the maximal weight, if multiple records with the same priority
     * exist.
     *
     * @param dnsDomain The DNS domain name. 
     * @return The most relevant ADS server for the given domain.
     * @throws NamingException
     */
    public static URI findAds(String dnsDomain) throws NamingException {

        return querySRV(dnsDomain,"_ldap._tcp");
    }

    /**
     * Find the active directory global context server for a given DNS domain.
     * 
     * This method queries the default DNS server for a SRV record of the form
     * <code>_ldap._tcp.gc._msdcs.&lt;dnsDomain&gt;</code> and returns the retrieved record with the
     * minimal priority and the maximal weight, if multiple records with the same priority
     * exist.
     *
     * @param dnsDomain The DNS domain name. 
     * @return The most relevant GC server for the given domain.
     * @throws NamingException
     */
    public static URI findGCServer(String dnsDomain) throws NamingException {

        return querySRV(dnsDomain,"_ldap._tcp.gc._msdcs");
    }

    /**
     * Convert a dot-separated dns domain to an LDAP DN.
     * 
     * The domain <code>your.domain.com</code> is e.g. converted to
     * the DN <code>/dc=your,dc=domain,dc=com</code>.
     * 
     * @param dnsDomain A fully qualified DNS domain with or without trailing dot.
     * @return An LDAP DN for the given domain.
     */
    public static String convertDomainToDN(String dnsDomain) {

        StringBuffer dc = new StringBuffer();

        int opos = 0,pos;

        while ((pos = dnsDomain.indexOf('.',opos)) >= opos) {

            dc.append(opos > 0 ? ',' : '/');
            dc.append("dc=");
            dc.append(dnsDomain.substring(opos,pos));
            opos = pos+1;
        }

        if (dnsDomain.length() > opos) {

            dc.append(opos > 0 ? ',' : '/');
            dc.append("dc=");
            dc.append(dnsDomain.substring(opos));
        }

        return dc.toString();
    }

    /**
     * Resolve a given server URI to a real LDAP URL.
     * 
     * An URI with an <code>ldap</code> scheme is returned unmodified.
     * 
     * The host part of an URI with an <code>ads</code> or <code>adss</code> scheme will be passed
     * to {@link #findAds(String)} and the returned host/port will be used for the
     * host/port part of the returned URI.
     * 
     * The host part of an URI with a <code>gc</code> or <code>gcs</code> scheme will be passed
     * to {@link #findGCServer(String)} and the returned host/port will be used for the
     * host/port part of the returned URI.
     * 
     * @param uriString An URI describing the LDAP server to be connected.
     * @return An ldap URL used for establishing an {@link LdapContext}.
     * @throws NamingException 
     * @throws URISyntaxException 
     */
    public static URI resolveServerURI(String uriString) throws NamingException, URISyntaxException {

        URI uri = new URI(uriString);

        return resolveServerURI(uri);
    }

    private static URI addSubContext(URI uri, String subContext) throws NamingException {
        
        if (subContext != null && subContext.startsWith("/")) {

            String subPath = (uri.getPath() == null || !uri.getPath().startsWith("/")) ?
                    subContext :
                    subContext + "," + uri.getPath().substring(1);
            
            try {
                if (log.isDebugEnabled()) {
                    log.debug("Resolved ADS subContext [{}] to absolute path [{}].",subContext,subPath);
                }

                return new URI(uri.getScheme(),uri.getHost(),subPath,uri.getQuery(),null);
            } catch (URISyntaxException e) {
                throw new NamingException("Cannot append subcontext ["+subContext+"] to URI ["+uri+"]: "+e.getMessage());
            }
        }
        
        return uri;
    }

    private static URI changeLDAPToLDAPS(URI uri) throws URISyntaxException {
        
        if ((uri != null) && "ldap".equals(uri.getScheme())) {
            return new URI("ldaps",uri.getHost(),uri.getPath(),uri.getQuery(),null);
        }
        
        return uri;
    }

    /**
     * Resolve a given server URI to a real LDAP URL.
     * 
     * An URI with an <code>ldap</code> scheme is returned unmodified.
     * 
     * The host part of an URI with an <code>ads</code> or <code>adss</code> scheme will be passed
     * to {@link #findAds(String)} and the returned host/port will be used for the
     * host/port part of the returned URI.
     * 
     * The host part of an URI with a <code>gc</code> or <code>gcs</code> scheme will be passed
     * to {@link #findGCServer(String)} and the returned host/port will be used for the
     * host/port part of the returned URI.
     * 
     * @param uri An URI describing the LDAP server to be connected.
     * @return An ldap URL used for establishing an {@link LdapContext}.
     * @throws NamingException 
     */
    public static URI resolveServerURI(URI uri) throws NamingException, URISyntaxException {

        URI ret;

        if ("ldap".equals(uri.getScheme()) || "ldaps".equals(uri.getScheme()))
            ret = uri;

        else if ("ads".equals(uri.getScheme())) {

            ret = addSubContext(findAds(uri.getHost()),uri.getPath());
        }
        else if ("adss".equals(uri.getScheme())) {

            ret = changeLDAPToLDAPS(addSubContext(findAds(uri.getHost()),uri.getPath()));
        }
        else if ("gc".equals(uri.getScheme())) {

            ret = addSubContext(findGCServer(uri.getHost()),uri.getPath());
        }
        else if ("gcs".equals(uri.getScheme())) {

            ret = changeLDAPToLDAPS(addSubContext(findGCServer(uri.getHost()),uri.getPath()));
        }
        else
            throw new NamingException("Unsupported URI scheme ["+uri.getScheme()+"] given.");

        return ret;
    }


    /**
     * @param url The server to connect to, which must have a <code>ldap</code> or
     *               <code>ldaps</code> scheme.
     * @param auth The DN of the user and a a password for simple authentication.
     *             If set to <code>null</code>,
     *             the connection is set up without authentication.
     * @param authMechanism The mechanism to use, if <code>auth</code> is not null. This is
     *             passed Parameter {@link Context#SECURITY_AUTHENTICATION} to
     *             the initial LDAP context.
     * @return A connected LDAP context.
     * @throws NamingException Upon errors.
     */
    public static LdapContext connectToADS(URI url,
            PasswordAuthentication auth,
            String authMechanism) throws NamingException {

        Hashtable<String,String> env = new Hashtable<String,String>();
        env.put(Context.INITIAL_CONTEXT_FACTORY, "com.sun.jndi.ldap.LdapCtxFactory");
        env.put(Context.PROVIDER_URL,url.toString());
        env.put(Context.REFERRAL,"follow");

        if (auth == null) {

            if (log.isDebugEnabled()) {
                log.debug("Binding anonymously to URL ["+url+"].");
            }

            env.put(Context.SECURITY_AUTHENTICATION,"none");
        } else {

            if (log.isDebugEnabled()) {
                log.debug("Binding as user ["+auth.getUserName()+"] to URL ["+url+"] with mechanism ["+authMechanism+"].");
            }

            env.put(Context.SECURITY_AUTHENTICATION,authMechanism);
            env.put(Context.SECURITY_PRINCIPAL,auth.getUserName());
            env.put(Context.SECURITY_CREDENTIALS,new String(auth.getPassword()));
        }

        return new InitialLdapContext(env, null);
    }


    /**
     * Creates an AdsPrincipal with the details filled as good as possible.  
     * @param domainConfig Domainconfig belongig to the search result. 
     * @param result       The SearchResult record, i.e. after calling result.next().
     * @return             A new AdsPrincipal, or null if any param was null.
     * @throws NamingException
     */
    public static AdsPrincipal createPrincipal(DomainConfig domainConfig, SearchResult result) throws NamingException {

        if (domainConfig == null || result == null) {
            return null;
        }

        String userDn = result.getName();

        Attributes userAttributes = result.getAttributes();
        String userName = userDn;
        String prettyName = userDn;
        String eMailAddress = null;
        if (userAttributes == null) {
            log.warn("Search result with DN ["+userDn+"] has no attributes");
        } else {
            Attribute userNameAttribute = userAttributes.get(domainConfig.getUserAttribute());
            if (userNameAttribute == null || userNameAttribute.get() == null) {
                log.warn("Search result with DN ["+userDn+"] without userName attribute "+domainConfig.getUserAttribute());
            } else {
                userName = userNameAttribute.get().toString();
                log.debug("Search result with DN ["+userDn+"] has userName attribute "+domainConfig.getUserAttribute()+" with value "+userName);
            }
            Attribute prettyNameAttribute = userAttributes.get(domainConfig.getPrettyNameAttribute());
            if (prettyNameAttribute == null || prettyNameAttribute.get() == null) {
                log.warn("Search result with DN ["+userDn+"] without prettyName attribute "+domainConfig.getPrettyNameAttribute());
            } else {
                prettyName = prettyNameAttribute.get().toString();
                log.debug("Search result with DN ["+userDn+"] has prettyName attribute "+domainConfig.getPrettyNameAttribute()+" with value "+prettyName);
            }
            Attribute mailAttribute = userAttributes.get(domainConfig.getEMailAddressAttribute());
            if (mailAttribute == null || mailAttribute.get() == null) {
                log.warn("Search result with DN ["+userDn+"] without eMailAddress attribute "+domainConfig.getEMailAddressAttribute());
            } else {
                eMailAddress = mailAttribute.get().toString();
                log.debug("Search result with DN ["+userDn+"] has eMailAddress attribute "+domainConfig.getEMailAddressAttribute()+" with value "+mailAttribute);
            }
        }

        String ma = domainConfig.getMobileAttribute();
        String tka = domainConfig.getTokenIdsAttribute();

        if (ma == null && tka == null) {

            return new AdsPrincipal(userName, domainConfig.getDomain(), prettyName, eMailAddress);
        }
        else {

            String mobileNumber = null;
            String[] tokenIds = null;

            Attribute mobileAttribute = userAttributes.get(ma);
            if (mobileAttribute == null || mobileAttribute.get() == null) {
                log.warn("Search result with DN ["+userDn+"] without mobile attribute "+ma);
            } else {
                mobileNumber = mobileAttribute.get().toString();
                log.debug("Search result with DN ["+userDn+"] has mobile attribute "+ma+" with value "+mobileAttribute);
            }

            Attribute tokenIdsAttribute = userAttributes.get(tka);

            if (tokenIdsAttribute == null || tokenIdsAttribute.get() == null) {
                log.warn("Search result with DN ["+userDn+"] without tokenIds attribute "+ma);
            } else {
                String tokenIds_s = tokenIdsAttribute.get().toString();
                log.debug("Search result with DN ["+userDn+"] has tokenIds attribute "+ma+" with value "+tokenIdsAttribute);

                tokenIds = tokenIds_s.trim().split("\\s+");
            }

            return new MFAAdsPrincipal(userName,domainConfig.getDomain(),prettyName,eMailAddress,mobileNumber,tokenIds);
        }
    }


    public static AdsGroup createGroup(DomainConfig domainConfig, SearchResult result) throws NamingException {

        if (domainConfig == null || result == null) {
            return null;
        }

        String groupDn = result.getName();

        Attributes groupAttributes = result.getAttributes();
        String groupName = groupDn;
        String prettyName = groupDn;
        if (groupAttributes == null) {
            log.warn("Search result with DN ["+groupDn+"] has no attributes");
        } else {
            Attribute groupNameAttribute = groupAttributes.get(domainConfig.getGroupAttribute());
            if (groupNameAttribute == null || groupNameAttribute.get() == null) {
                log.warn("Search result with DN ["+groupDn+"] without attribute "+domainConfig.getGroupAttribute());
            } else {
                groupName = groupNameAttribute.get().toString();
                log.debug("Search result with DN ["+groupDn+"] has groupName attribute "+domainConfig.getGroupAttribute()+" with value "+groupName);
            }
            Attribute prettyNameAttribute = groupAttributes.get(domainConfig.getPrettyNameAttribute());
            if (prettyNameAttribute == null || prettyNameAttribute.get() == null) {
                log.warn("Search result with DN ["+groupDn+"] without attribute "+domainConfig.getPrettyNameAttribute());
            } else {
                prettyName = prettyNameAttribute.get().toString();
                log.debug("Search result with DN ["+groupDn+"] has prettyName attribute "+domainConfig.getPrettyNameAttribute()+" with value "+prettyName);
            }
        }
        return new AdsGroup(groupName, domainConfig.getDomain(), prettyName);

    }

    /**
     * Gets the DN of a SearchResult using the baseDN of uri and domainConfig
     *
     * @param uri The LDAP server URI with its path being the basic baseDN.
     * @param additionalBasePath The subtree base dn you put as the first parameter into the search method.
     * @param gotResult The result for which you like to get the absolut DN.
     * @return The absolute DN with additional base DN if it was relative in the first place.
     * @throws InvalidNameException If either one of the two base DN were invalid or the resulting DN of gotResult.
     */
    public static String getAbsoluteDn(URI uri, String additionalBasePath, SearchResult gotResult) throws InvalidNameException {

        // it was absolute in the first place
        if (!gotResult.isRelative()) {
            return gotResult.getName();
        }

        // the return DN is relative, so make it absolute
        String baseDnUriPathString = uri.getPath() == null
                ? ""
                : (uri.getPath().startsWith("/")
                ? uri.getPath().substring(1)
                : uri.getPath());
        Name baseDnOFUriPath = new LdapName(baseDnUriPathString);
        Name baseDnSearchedSubTree = new LdapName(additionalBasePath);
        Name relativeDn = new LdapName(gotResult.getName());
        Name fullDn = baseDnOFUriPath
                .addAll(baseDnSearchedSubTree)
                .addAll(relativeDn);
        return fullDn.toString();
    }

    /**
     * Gets the relative DN to the baseDN of uri and additionalBasePath.
     *
     * @param uri The LDAP server URI with its path being the basic baseDN.
     * @param additionalBasePath The subtree base DN to be relative in.
     * @param absolutePath The absolute path to make relative
     * @return The relative DN (to uri.path + additionalBasePath).
     *         String.empty if absolute Path was either uri.path or uri.path + additionalBasePath.
     *         Null if it was not relative.
     * @throws InvalidNameException If either one of the two base DN or the absolutePath were invalid.
     */
    public static String getRelativeDn(URI uri, String additionalBasePath, String absolutePath) throws InvalidNameException {
        String baseDnUriPathString = uri.getPath() == null
                ? ""
                : (uri.getPath().startsWith("/")
                ? uri.getPath().substring(1)
                : uri.getPath());
        Name baseDnOfUriPath = new LdapName(baseDnUriPathString);
        Name baseDnSearchedSubTree = new LdapName(additionalBasePath);
        Name fullBaseDn = baseDnOfUriPath.addAll(baseDnSearchedSubTree);

        Name absoluteDn = new LdapName(absolutePath);

        // same is empty string
        if (absoluteDn.equals(fullBaseDn)) {
            return "";
        }

        // too small: cannot be -> null
        if (absoluteDn.size() <= fullBaseDn.size()) {
            return null;
        }

        if (!absoluteDn.startsWith(fullBaseDn)) {
            return null;
        }

        Name relativePath = absoluteDn.getSuffix(fullBaseDn.size());;
        return relativePath.toString();
    }

    public static int USER_ACCOUNT_DISABLED = 2;

    /**
     * Reads bits 2 of "userAccountControl" (AD) which means "account disabled".
     * https://ldapwiki.com/wiki/USER_ACCOUNT_DISABLED
     * https://docs.microsoft.com/en-us/windows/win32/api/iads/ne-iads-ads_user_flag_enum
     */
    public static boolean isDisabledAdUser(Attributes userAttributes) throws NamingException {
        Attribute userAccountControl = userAttributes.get("userAccountControl");
        if (userAccountControl == null || userAccountControl.get() == null) {
            log.trace("userAccountControl was null");
            return false;
        }
        String attributeValue = userAccountControl.get().toString();
        int intValue = Integer.parseInt(attributeValue);
        return (intValue & USER_ACCOUNT_DISABLED) > 0;
    }
}
