package org.renjin.pipeliner.fusion;

import java.lang.reflect.Method;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import org.renjin.eval.EvalException;
import org.renjin.pipeliner.fusion.kernel.CompiledKernel;
import org.renjin.pipeliner.fusion.kernel.LoopKernel;
import org.renjin.pipeliner.fusion.node.BinaryVectorOpNode;
import org.renjin.pipeliner.fusion.node.DistanceMatrixNode;
import org.renjin.pipeliner.fusion.node.DoubleArrayNode;
import org.renjin.pipeliner.fusion.node.IntArrayNode;
import org.renjin.pipeliner.fusion.node.IntBufferNode;
import org.renjin.pipeliner.fusion.node.IntSeqNode;
import org.renjin.pipeliner.fusion.node.LoopNode;
import org.renjin.pipeliner.fusion.node.RepeatingNode;
import org.renjin.pipeliner.fusion.node.TransposeNode;
import org.renjin.pipeliner.fusion.node.UnaryVectorOpNode;
import org.renjin.pipeliner.fusion.node.VirtualVectorNode;
import org.renjin.pipeliner.node.DeferredNode;
import org.renjin.pipeliner.node.FunctionNode;
import org.renjin.pipeliner.node.NodeShape;
import org.renjin.primitives.sequence.IntSequence;
import org.renjin.primitives.vector.MemoizedComputation;
import org.renjin.repackaged.asm.Type;
import org.renjin.sexp.DoubleArrayVector;
import org.renjin.sexp.IntArrayVector;
import org.renjin.sexp.IntBufferVector;
import org.renjin.sexp.LogicalArrayVector;
import org.renjin.sexp.Vector;

/* loaded from: input_file:org/renjin/pipeliner/fusion/FusedNode.class */
public class FusedNode extends DeferredNode implements Runnable {
    private LoopKernel kernel;
    private LoopNode[] kernelOperands;
    private MemoizedComputation memoizedComputation;
    private DoubleArrayVector resultVector;
    private Future<CompiledKernel> compiledKernel;

    public FusedNode(FunctionNode functionNode) {
        this.kernel = LoopKernels.INSTANCE.get(functionNode);
        this.kernelOperands = new LoopNode[functionNode.getOperands().size()];
        this.memoizedComputation = (MemoizedComputation) functionNode.getVector();
        for (int i = 0; i < this.kernelOperands.length; i++) {
            this.kernelOperands[i] = addLoopNode(functionNode.getOperand(i));
        }
    }

    private LoopNode addLoopNode(DeferredNode deferredNode) {
        Method findMethod;
        Method findMethod2;
        if (deferredNode instanceof FusedNode) {
            int addInput = addInput(deferredNode);
            deferredNode.addOutput(this);
            return new DoubleArrayNode(addInput, Type.getType(DoubleArrayVector.class));
        }
        if (deferredNode instanceof FunctionNode) {
            FunctionNode functionNode = (FunctionNode) deferredNode;
            String computationName = functionNode.getComputationName();
            if (computationName.equals("dist")) {
                return new DistanceMatrixNode(addLoopNode(functionNode.getOperand(0)));
            }
            if (computationName.equals("rep")) {
                return new RepeatingNode(addLoopNode(deferredNode.getOperand(0)), addLoopNode(deferredNode.getOperand(1)));
            }
            if (computationName.equals("t")) {
                return new TransposeNode(addLoopNode(deferredNode.getOperand(0)), addLoopNode(deferredNode.getOperand(1)));
            }
            int size = deferredNode.getOperands().size();
            if (size == 1 && (findMethod2 = UnaryVectorOpNode.findMethod(deferredNode.getVector())) != null) {
                return new UnaryVectorOpNode(computationName, findMethod2, addLoopNode(deferredNode.getOperand(0)));
            }
            if (size == 2 && (findMethod = BinaryVectorOpNode.findMethod(deferredNode.getVector())) != null) {
                return new BinaryVectorOpNode(computationName, findMethod, addLoopNode(deferredNode.getOperand(0)), addLoopNode(deferredNode.getOperand(1)));
            }
        }
        return addLoopInput(deferredNode);
    }

    private LoopNode addLoopInput(DeferredNode deferredNode) {
        int addInput = addInput(deferredNode);
        deferredNode.addOutput(this);
        if (deferredNode.getVector() instanceof IntBufferVector) {
            return new IntBufferNode(addInput);
        }
        if (deferredNode.getVector() instanceof IntSequence) {
            return new IntSeqNode(addInput);
        }
        if (deferredNode.getVector() instanceof DoubleArrayVector) {
            return new DoubleArrayNode(addInput, deferredNode.getResultVectorType());
        }
        if (!(deferredNode.getVector() instanceof IntArrayVector) && !(deferredNode.getVector() instanceof LogicalArrayVector)) {
            return new VirtualVectorNode(addInput, deferredNode.getVector());
        }
        return new IntArrayNode(addInput, deferredNode.getResultVectorType());
    }

    @Override // org.renjin.pipeliner.node.DeferredNode
    public String getDebugLabel() {
        return this.kernel.debugLabel(this.kernelOperands);
    }

    @Override // org.renjin.pipeliner.node.DeferredNode
    public NodeShape getShape() {
        return NodeShape.ELLIPSE;
    }

    @Override // org.renjin.pipeliner.node.DeferredNode
    public Type getResultVectorType() {
        return Type.getType(DoubleArrayVector.class);
    }

    public void startCompilation(LoopKernelCache loopKernelCache) {
        this.compiledKernel = loopKernelCache.get(this.kernel, this.kernelOperands);
    }

    @Override // java.lang.Runnable
    public void run() {
        try {
            CompiledKernel compiledKernel = this.compiledKernel.get();
            Vector[] vectorArr = new Vector[getOperands().size()];
            for (int i = 0; i < vectorArr.length; i++) {
                vectorArr[i] = getOperand(i).getVector();
            }
            this.resultVector = DoubleArrayVector.unsafe(compiledKernel.compute(vectorArr), this.memoizedComputation.getAttributes());
            this.memoizedComputation.setResult(this.resultVector);
        } catch (InterruptedException | ExecutionException e) {
            throw new EvalException("Exception compiling kernel", e);
        }
    }

    @Override // org.renjin.pipeliner.node.DeferredNode
    public DoubleArrayVector getVector() {
        if (this.resultVector == null) {
            throw new IllegalStateException("Not computed yet.");
        }
        return this.resultVector;
    }
}
