package org.clazzes.svc.runner.sshd;

import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

public class RefCountedInputStream extends InputStream {
    @SuppressWarnings("unused")
    private static final org.slf4j.Logger log = org.slf4j.LoggerFactory.getLogger(RefCountedInputStream.class);

    private static final class SharedData {
        private final InputStream underlying;
        private final AtomicInteger refs = new AtomicInteger(1);

        public SharedData(InputStream underlying) {
            this.underlying = underlying;
        }
    }

    private final SharedData data;
    private final AtomicBoolean closed = new AtomicBoolean(false);

    public RefCountedInputStream(RefCountedInputStream other) {
        this((InputStream) other);
    }

    public RefCountedInputStream(InputStream other) {
        if (other instanceof RefCountedInputStream refCounted) {
            this.data = refCounted.data;
            this.data.refs.incrementAndGet();
        } else {
            this.data = new SharedData(other);
        }
    }

    @Override
    public int read(byte[] b, int off, int len) throws IOException {
        if (closed.get()) {
            throw new IOException("Stream closed");
        }

        return this.data.underlying.read(b, off, len);
    }

    @Override
    public int read() throws IOException {
        var buffer = new byte[1];
        var read = this.read(buffer, 0, 1);
        return read <= 0 ? -1 : buffer[0];
    }

    @Override
    public void close() throws IOException {
        if (closed.getAndSet(true)) {
            return;
        }

        var endRefs = this.data.refs.decrementAndGet();
        if (endRefs <= 0) {
            log.trace("Real close in ref counted input stream for {}", this.data.underlying);
            this.data.underlying.close();
        }
    }

    @Override
    public String toString() {
        return "RefCountedInputStream [refCount="+this.data.refs.get()+", underlying="+this.data.underlying+"]";
    }

}
