package org.clazzes.util.lang;

import java.util.AbstractMap;
import java.util.AbstractSet;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

public class MappedMap<K,KM,V,VM> extends AbstractMap<KM,VM> {
    private final Map<K,V> underlying;
    private final Function<V,VM> valueMapper;
    private final Function<K,KM> keyMapper;
    private final Function<Object,Object> reverseKeyMapper;

    private MappedMap(Map<K, V> underlying, Function<V, VM> valueMapper,
            Function<K, KM> keyMapper,
            Function<Object, Object> reverseKeyMapper) {
        this.underlying = underlying;
        this.valueMapper = valueMapper;
        this.keyMapper = keyMapper;
        this.reverseKeyMapper = reverseKeyMapper;
    }

    @Override
    public boolean containsKey(Object arg0) {
        return this.underlying.containsKey(this.reverseKeyMapper.apply(arg0));
    }

    @Override
    public boolean containsValue(Object arg0) {
        return this.values().contains(arg0);
    }

    @Override
    public Set<Map.Entry<KM, VM>> entrySet() {

        return new AbstractSet<Map.Entry<KM,VM>> () {
            Set<Entry<K,V>> underlyingSet = MappedMap.this.underlying.entrySet();
            Function<Entry<K,V>,Entry<KM,VM>> entryMapper = entry -> new AbstractMap.SimpleImmutableEntry<>(MappedMap.this.keyMapper.apply(entry.getKey()),MappedMap.this.valueMapper.apply(entry.getValue()));

            @Override
            public boolean contains(Object arg0) {
                if (!(arg0 instanceof Entry<?,?>)) {
                    return false;
                }
                Entry<?,?> entry = ((Entry<?,?>) arg0);
                Object mappedKey = MappedMap.this.reverseKeyMapper.apply(entry.getKey());
                if (!MappedMap.this.underlying.containsKey(mappedKey)) {
                    return false;
                }
                VM value = MappedMap.this.valueMapper.apply(MappedMap.this.underlying.get(mappedKey));
                return Objects.equals(value,entry.getValue());
            }

            @Override
            public Iterator<Entry<KM, VM>> iterator() {
                return new MappedIterator<>(this.underlyingSet.iterator(),this.entryMapper);
            }

            @Override
            public int size() {
                return this.underlyingSet.size();
            }
        };
    }

    @Override
    public VM get(Object arg0) {
        Object mappedKey = this.reverseKeyMapper.apply(arg0);

        if (!this.underlying.containsKey(mappedKey)) {
            return null;
        }

        return this.valueMapper.apply(this.underlying.get(mappedKey));
    }

    @Override
    public Set<KM> keySet() {
        return MappedSet.newInstance(this.underlying.keySet(),this.keyMapper,this.reverseKeyMapper);
    }

    @Override
    public int size() {
        return this.underlying.size();
    }

    @Override
    public Collection<VM> values() {
        return MappedCollection.newInstance(this.underlying.values(), this.valueMapper);
    }

    public static <K,V,VM> MappedMap<K,K,V,VM> newInstance(Map<K,V> underlying, Function<V,VM> valueMapper) {
        return new MappedMap<K,K,V,VM>(underlying, valueMapper, Function.identity(), Function.identity());
    };


    public static <K,KM,V,VM> MappedMap<K,KM,V,VM> newInstance(Map<K,V> underlying, Function<V,VM> valueMapper,
            Function<K, KM> keyMapper,
            Function<Object, Object> reverseKeyMapper) {
        return new MappedMap<K,KM,V,VM>(underlying, valueMapper, keyMapper, reverseKeyMapper);
    };
}
