package org.clazzes.util.sql.dao;

import java.io.Serializable;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.clazzes.util.aop.DAOException;
import org.clazzes.util.sql.transactionprovider.ClosableTransaction;
import org.clazzes.util.sql.transactionprovider.TransactionException;
import org.clazzes.util.sql.transactionprovider.TransactionProvider;

@SuppressWarnings("unchecked")
public class MockIdDAO<T, I extends Serializable> implements IIdDAO<T> {
	private static final class TransactionValues<T, I extends Serializable> {
		public final long id;
		public final Map<I, T> values = new HashMap<>();
		public final Set<I> deletedValues = new HashSet<>();
		public TransactionValues(long id) {
			this.id = id;
		}
	}

	private final Deque<TransactionValues<T, I>> transactionValues = new ArrayDeque<>();
	private final Map<I, T> rootValues = new HashMap<>();
	private final Function<T, I> getIdFunction;
	private final BiConsumer<T, I> setIdFunction;
	private final Function<T, T> copy;
	private final Class<I> idClass;
	private Consumer<T> validator = null;
	private Supplier<DAOException> errorInjector = null;

	public MockIdDAO(Function<T, I> getIdFunction,
			BiConsumer<T, I> setIdFunction, Function<T, T> copy,
			Class<I> idClass) {
		this.getIdFunction = getIdFunction;
		this.setIdFunction = setIdFunction;
		this.copy = copy;
		this.idClass = idClass;
	}

	private T getStoredDto(T dto) {
		I id = this.getId(dto);
		if (id == null && this.idClass == Long.class || this.idClass == long.class) {
			this.setIdFunction.accept(dto, (I) (Long) MockIdGenerator.next());
		}
		T ret = this.copy.apply(dto);
		if (this.validator != null) {
			this.validator.accept(ret);
		}
		return ret;
	}

	private void checkErrorInjector() {
		if (this.errorInjector != null) {
			DAOException error = this.errorInjector.get();
			if (error != null) {
				throw error;
			}
		}
	}

	@Override
	public T save(T dto_) {
		this.checkErrorInjector();
		if (this.getId(dto_) != null && this.get(this.getId(dto_)) != null) {
			throw new DAOException("id must be unique");
		}
		T dto = this.getStoredDto(dto_);

		TransactionValues<T, I> transaction = this.transactionValues.peekLast();
		Map<I, T> map = transaction == null ? this.rootValues : transaction.values;
		Set<I> deletedValues = transaction == null ? null : transaction.deletedValues;
		map.put(this.getId(dto), dto);
		if (deletedValues != null) {
			deletedValues.remove(this.getId(dto));
		}
		return dto;
	}

	@Override
	public List<T> saveBatch(List<T> dtos) {
		this.checkErrorInjector();
		Set<I> ids = new HashSet<>();
		for (T dto: dtos) {
			I id = this.getId(dto);
			if (id != null) {
				if (ids.contains(id)) {
					throw new DAOException("id must be unique");
				} else {
					ids.add(id);
				}
			}
		}
		if (!this.getBatch(ids).isEmpty()) {
			throw new DAOException("id must be unique");
		}
		TransactionValues<T, I> transaction = this.transactionValues.peekLast();
		Map<I, T> map = transaction == null ? this.rootValues : transaction.values;
		Set<I> deletedValues = transaction == null ? null : transaction.deletedValues;
		for (T dto_: dtos) {
			T dto = this.getStoredDto(dto_);
			map.put(this.getId(dto), dto);
			if (deletedValues != null) {
				deletedValues.remove(this.getId(dto));
			}
		}
		return dtos;
	}

	private void commit(long id) {
		TransactionValues<T, I> commitedTransaction = this.transactionValues.pollLast();
		if (commitedTransaction == null || commitedTransaction.id != id) {
			throw new DAOException("mismatched commit/begin/rollback.");
		}
		TransactionValues<T, I> upperTransaction = this.transactionValues.peekLast();
		Map<I, T> map = upperTransaction == null ? this.rootValues : upperTransaction.values;
		Set<I> deletedValues = upperTransaction == null ? null : upperTransaction.deletedValues;

		for (I deletedId: commitedTransaction.deletedValues) {
			map.remove(deletedId);
			if (deletedValues != null) {
				deletedValues.add(deletedId);
			}
		}

		for (T dto: commitedTransaction.values.values()) {
			map.put(this.getId(dto), dto);
		}
	}

	private void rollback(long id) {
		TransactionValues<T, I> rolledbackTransaction = this.transactionValues.pollLast();
		if (rolledbackTransaction == null || rolledbackTransaction.id != id) {
			throw new DAOException("mismatched commit/begin/rollback.");
		}
	}

	@Override
	public int update(T dto) {
		this.checkErrorInjector();
		if (this.getId(dto) == null || this.get(this.getId(dto)) == null) {
			throw new DAOException("update on non-existent dto.");
		}

		TransactionValues<T, I> transaction = this.transactionValues.peekLast();
		Map<I, T> map = transaction == null ? this.rootValues : transaction.values;

		map.put(this.getId(dto), dto);

		return 1;
	}

	@Override
	public int[] updateBatch(Collection<T> dtos) {
		this.checkErrorInjector();
		Set<I> ids = new HashSet<>();
		for (T dto: dtos) {
			I id = this.getId(dto);
			if (id == null) {
				throw new DAOException("update must specifiy id.");
			} else if (ids.contains(id)) {
				throw new DAOException("id must be unique");
			} else {
				ids.add(id);
			}
		}
		if (this.getBatch(ids).size() != ids.size()) {
			throw new DAOException("update on non-existent dto.");
		}

		TransactionValues<T, I> transaction = this.transactionValues.peekLast();
		Map<I, T> map = transaction == null ? this.rootValues : transaction.values;

		for (T dto: dtos) {
			map.put(this.getId(dto), dto);
		}

		return dtos.stream()
			.mapToInt(a -> 1)
			.toArray();
	}

	private static interface TriFunction<A, B, C, R> {
		public R apply(A a, B b, C c);
	}

	private <T1, T2, T3, T4> T2 reduceValues(TriFunction<Set<I>, Function<T4, T2>, T3, T1> processDeletedValues, TriFunction<Map<I, T>, Function<T3, T1>, T4, T2> processMap, T4 init) {
		Iterator<TransactionValues<T, I>> it = this.transactionValues.iterator();
		return reduceValuesImpl(processDeletedValues, processMap, it, init);
	}

	private <T4, T3, T1, T2> T2 reduceValuesImpl(
			TriFunction<Set<I>, Function<T4, T2>, T3, T1> processDeletedValues,
			TriFunction<Map<I, T>, Function<T3, T1>, T4, T2> processMap,
			Iterator<TransactionValues<T, I>> it, T4 v1) {
		if (it.hasNext()) {
			TransactionValues<T, I> transactionValues = it.next();
			return processMap.apply(transactionValues.values, v2 -> processDeletedValues.apply(transactionValues.deletedValues, v3 -> reduceValuesImpl(processDeletedValues, processMap, it, v3), v2), v1);
		} else {
			return processMap.apply(this.rootValues, null, v1);
		}
	}

	@Override
	public List<T> getAll() {
		this.checkErrorInjector();
		Set<I> seenIds = new HashSet<>();
		List<T> ret = new ArrayList<>();

		this.reduceValues((deletedValues, rec, v) -> {
				seenIds.addAll(deletedValues);
				return rec == null ? null : rec.apply(null);
			}, (map, rec, v) -> {
				for (T dto: map.values()) {
					if (!seenIds.contains(this.getId(dto))) {
						seenIds.add(this.getId(dto));
						ret.add(dto);
					}
				}
				return rec == null ? null : rec.apply(null);
			}, null);

		return ret;
	}


	@Override
	public T get(Serializable id_) {
		this.checkErrorInjector();
		return this.<T, T, I, I>reduceValues((deletedValues, rec, id) -> {
				if (deletedValues.contains(id)) {
					return null;
				} else {
					return rec == null ? null : rec.apply(id);
				}
			}, (map, rec, id) -> {
				if (map.containsKey(id)) {
					return map.get(id);
				} else {
					return rec == null ? null : rec.apply(id);
				}
			}, (I) id_);
	}

	@Override
	public List<T> getBatch(Serializable... ids) {
		return this.getBatch(Arrays.asList(ids));
	}

	@Override
	public List<T> getBatch(Collection<? extends Serializable> ids_) {
		this.checkErrorInjector();
		return this.<List<T>, List<T>, Collection<I>, Collection<I>>reduceValues((deletedValues, rec, ids) -> {
				if (ids.isEmpty()) {
					return new ArrayList<>();
				}

				return rec == null ? new ArrayList<>() : rec.apply(ids.stream()
																   .filter(id -> !deletedValues.contains(id))
																   .collect(Collectors.toList()));
			}, (map, rec, ids) -> {
				if (ids.isEmpty()) {
					return new ArrayList<>();
				}

				List<T> ret = rec == null ? new ArrayList<>() : rec.apply(ids.stream()
																		  .filter(id -> !map.containsKey(id))
																		  .collect(Collectors.toList()));

				ids.stream()
					.filter(id -> map.containsKey(id))
					.map(id -> map.get(id))
					.forEach(ret::add);

				return ret;
			}, (Collection<I>) ids_);
	}

	@Override
	public boolean delete(Serializable id) {
		this.checkErrorInjector();
		TransactionValues<T, I> transaction = this.transactionValues.peekLast();
		Map<I, T> map = transaction == null ? this.rootValues : transaction.values;
		Set<I> deletedValues = transaction == null ? null : transaction.deletedValues;
		map.remove(id);
		if (deletedValues != null) {
			deletedValues.add((I) id);
		}
		return true;
	}

	@Override
	public int[] deleteBatch(Collection<? extends Serializable> ids) {
		this.checkErrorInjector();
		TransactionValues<T, I> transaction = this.transactionValues.peekLast();
		Map<I, T> map = transaction == null ? this.rootValues : transaction.values;
		Set<I> deletedValues = transaction == null ? null : transaction.deletedValues;
		for (I id: (Collection<I>) ids) {
			map.remove(id);
			if (deletedValues != null) {
				deletedValues.add((I) id);
			}
		}
		return ids.stream().mapToInt(a -> 1).toArray();
	}

	@Override
	public I getId(T dto) {
		return this.getIdFunction.apply(dto);
	}

	@Override
	public Class<I> getIdClass() {
		return this.idClass;
	}

	private ClosableTransaction transaction(Runnable closer) {
		long id = MockIdGenerator.next();

		this.transactionValues.addLast(new TransactionValues<>(id));

		return new ClosableTransaction() {
			boolean willCommit = false;

			@Override
			public void willCommit() {
				this.willCommit = true;
			}

			@Override
			public void willRollback() {
				this.willCommit = false;
			}

			@Override
			public void close() throws TransactionException {
				if (this.willCommit) {
					commit(id);
				} else {
					rollback(id);
				}
				if (closer != null) {
					closer.run();
				}
			}
		};
	}

	private boolean isInTransaction = false;

	public TransactionProvider transactionProvider() {
		return new TransactionProvider() {

			@Override
			public ClosableTransaction getTransaction() {
				if (MockIdDAO.this.isInTransaction) {
					throw new TransactionException("nested transaction");
				} else {
					MockIdDAO.this.isInTransaction = true;
					return transaction(() -> MockIdDAO.this.isInTransaction = false);
				}
			}

			@Override
			public ClosableTransaction getTransaction(int isolationLevel) {
				return this.getTransaction();
			}

			@Override
			public ClosableTransaction getTransaction(
					String useThreadLocalKey) {
				throw new TransactionException("useThreadLocalKey should not be used.");
			}

			@Override
			public ClosableTransaction getTransaction(String useThreadLocalKey,
					int isolationLevel) {
				throw new TransactionException("useThreadLocalKey should not be used.");
			}

			@Override
			public boolean isActive() {
				return true;
			}
		};
	}

	public static TransactionProvider transactionProvider(MockIdDAO<?, ?>... daos) {
		List<TransactionProvider> providers = Arrays.stream(daos)
			.map(a -> a.transactionProvider())
			.collect(Collectors.toList());

		return new TransactionProvider() {

            @Override
            public ClosableTransaction getTransaction() {
				List<ClosableTransaction> transactions = providers.stream()
					.map(a -> a.getTransaction())
					.collect(Collectors.toList());

				return new ClosableTransaction() {

                    @Override
                    public void willCommit() {
						transactions.forEach(ClosableTransaction::willCommit);
                    }

                    @Override
                    public void willRollback() {
						transactions.forEach(ClosableTransaction::willRollback);
                    }

                    @Override
                    public void close() throws TransactionException {
						transactions.forEach(ClosableTransaction::close);
                    }
				};
            }

            @Override
            public ClosableTransaction getTransaction(int isolationLevel) {
				return getTransaction();
            }

            @Override
            public ClosableTransaction getTransaction(
                    String useThreadLocalKey) {
				throw new TransactionException("useThreadLocalKey should not be used.");
            }

            @Override
            public ClosableTransaction getTransaction(String useThreadLocalKey,
                    int isolationLevel) {
				throw new TransactionException("useThreadLocalKey should not be used.");
            }

            @Override
            public boolean isActive() {
				return true;
            }
		};

	}

	public ClosableTransaction getTransaction() {
		return this.transaction(null);
	}

	public List<T> getAllNoIdValues() {
		return this.getAll()
			.stream()
			.map(dto -> {
					T ret = MockIdDAO.this.copy.apply(dto);
					MockIdDAO.this.setIdFunction.accept(ret, null);
					return ret;
				})
			.collect(Collectors.toList());
	}

	public T withErrorInjection(Supplier<DAOException> error, Supplier<T> fun) {
		Supplier<DAOException> prev = this.errorInjector;
		this.errorInjector = error;
		try {
			return fun.get();
		} finally {
			this.errorInjector = prev;
		}
	}

	public <R> R withValidator(Consumer<T> validator, Supplier<R> fun) {
		Consumer<T> prev = this.validator;
		this.validator = validator;
		try {
			return fun.get();
		} finally {
			this.validator = prev;
		}

	}
}
