/*
 * Decompiled with CFR 0.152.
 */
package cadyts.utilities.math.metropolishastings;

import cadyts.utilities.math.metropolishastings.MHProposal;
import cadyts.utilities.math.metropolishastings.MHStateProcessor;
import cadyts.utilities.math.metropolishastings.MHTransition;
import cadyts.utilities.math.metropolishastings.MHWeight;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class MHAlgorithm<S> {
    private final MHProposal<S> proposal;
    private final MHWeight<S> weight;
    private final Random rnd;
    private S initialState = null;
    private List<MHStateProcessor<S>> stateProcessors = new ArrayList<MHStateProcessor<S>>();
    private int msgInterval = 1;
    private long lastCompTime_ms = 0L;

    public MHAlgorithm(MHProposal<S> proposal, MHWeight<S> weight, Random rnd) {
        if (proposal == null) {
            throw new IllegalArgumentException("proposal is null");
        }
        if (weight == null) {
            throw new IllegalArgumentException("weight is null");
        }
        if (rnd == null) {
            throw new IllegalArgumentException("rnd is null");
        }
        this.proposal = proposal;
        this.weight = weight;
        this.rnd = rnd;
    }

    public void setInitialState(S initialState) {
        this.initialState = initialState;
    }

    public S getInitialState() {
        return this.initialState;
    }

    public void setMsgInterval(int msgInterval) {
        if (msgInterval < 1) {
            throw new IllegalArgumentException("message interval < 1");
        }
        this.msgInterval = msgInterval;
    }

    public int getMsgInterval() {
        return this.msgInterval;
    }

    public void addStateProcessor(MHStateProcessor<S> stateProcessor) {
        if (stateProcessor == null) {
            throw new IllegalArgumentException("state processor is null");
        }
        this.stateProcessors.add(stateProcessor);
    }

    public long getLastCompTime_ms() {
        return this.lastCompTime_ms;
    }

    public void run(int iterations) {
        this.lastCompTime_ms = 0L;
        for (MHStateProcessor<S> processor : this.stateProcessors) {
            processor.start();
        }
        long tick_ms = System.currentTimeMillis();
        S currentState = this.initialState != null ? this.initialState : this.proposal.newInitialState();
        double currentLogWeight = this.weight.logWeight(currentState);
        this.lastCompTime_ms += System.currentTimeMillis() - tick_ms;
        for (MHStateProcessor<S> processor : this.stateProcessors) {
            processor.processState(currentState);
        }
        int i = 1;
        while (i < iterations) {
            if (i % this.msgInterval == 0) {
                System.out.println("MH iteration " + i);
                System.out.println("  state  = " + currentState);
                System.out.println("  weight = " + Math.exp(currentLogWeight));
            }
            tick_ms = System.currentTimeMillis();
            MHTransition<S> proposalTransition = this.proposal.newTransition(currentState);
            S proposalState = proposalTransition.getNewState();
            double proposalLogWeight = this.weight.logWeight(proposalState);
            double logAlpha = proposalLogWeight - currentLogWeight + (proposalTransition.getBwdLogProb() - proposalTransition.getFwdLogProb());
            if (Math.log(this.rnd.nextDouble()) < logAlpha) {
                currentState = proposalState;
                currentLogWeight = proposalLogWeight;
            }
            this.lastCompTime_ms += System.currentTimeMillis() - tick_ms;
            for (MHStateProcessor<S> processor : this.stateProcessors) {
                processor.processState(currentState);
            }
            ++i;
        }
        for (MHStateProcessor<S> processor : this.stateProcessors) {
            processor.end();
        }
    }
}

