/*
 * Decompiled with CFR 0.152.
 */
package com.oracle.svm.enterprise.hosted.ml.phases;

import ai.onnxruntime.OrtException;
import com.oracle.graal.compiler.enterprise.ml.features.c;
import com.oracle.graal.compiler.enterprise.ml.features.provider.b;
import com.oracle.graal.compiler.enterprise.ml.phases.InferGraphProfilesBasePhase;
import com.oracle.svm.core.util.VMError;
import com.oracle.svm.enterprise.hosted.ml.features.provider.a;
import com.oracle.svm.enterprise.hosted.profiling.phases.ProfilingInstrumentationPhase;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.iterators.FilteredNodeIterable;
import org.graalvm.compiler.nodes.AbstractBeginNode;
import org.graalvm.compiler.nodes.ControlSplitNode;
import org.graalvm.compiler.nodes.IfNode;
import org.graalvm.compiler.nodes.ProfileData;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.spi.CoreProviders;
import org.graalvm.nativeimage.ImageSingletons;

public class InferControlSplitProfilesPhase
extends InferGraphProfilesBasePhase<Node, c> {
    private static final float GC = 1.4E-44f;
    private static final float GD = 1.0f;
    private final boolean GE;

    public InferControlSplitProfilesPhase(com.oracle.graal.compiler.enterprise.ml.features.provider.c<Node, c> c2, boolean bl2) {
        super(c2, (com.oracle.graal.compiler.enterprise.ml.models.a)ImageSingletons.lookup(com.oracle.svm.enterprise.hosted.ml.models.a.class));
        this.GE = bl2;
    }

    protected void run(StructuredGraph structuredGraph, CoreProviders coreProviders) {
        if (!this.cG.h(Long.toString(structuredGraph.graphId()))) {
            return;
        }
        b b2 = this.cG.g(Long.toString(structuredGraph.graphId()));
        Map<Node, Float> map = this.a((b<Node, c>)b2);
        FilteredNodeIterable<Node> filteredNodeIterable = InferControlSplitProfilesPhase.d(structuredGraph);
        for (Node node : filteredNodeIterable) {
            ProfileData.ProfileSource profileSource;
            assert (node instanceof ControlSplitNode) : "Can not infer profiles for the non control split node: " + String.valueOf(node);
            if (!map.containsKey(node) || !(profileSource = ((ControlSplitNode)node).getProfileData().getProfileSource()).isUnknown()) continue;
            InferControlSplitProfilesPhase.a((ControlSplitNode)node, map.get(node));
        }
    }

    public static FilteredNodeIterable<Node> d(StructuredGraph structuredGraph) {
        return ProfilingInstrumentationPhase.p(structuredGraph).filter(node -> !ProfilingInstrumentationPhase.b((ControlSplitNode)node));
    }

    private Map<Node, Float> a(b<Node, c> b2) {
        ArrayList<Node> arrayList = new ArrayList<Node>();
        ArrayList<c> arrayList2 = new ArrayList<c>();
        for (Object object : b2.s()) {
            if (!(object instanceof IfNode)) continue;
            arrayList.add((Node)object);
            Optional optional = b2.a(object);
            assert (optional.isPresent()) : "Can not infer conditional probabilities: missing feature for the node: " + String.valueOf(object);
            arrayList2.add((c)optional.get());
        }
        HashMap hashMap = new HashMap();
        if (arrayList2.size() > 0) {
            Object object;
            object = new float[0];
            try {
                object = this.cH.a(arrayList2);
            }
            catch (OrtException ortException) {
                a.a(ortException, "Invalid ONNX Profile Inference", false);
            }
            catch (Exception exception) {
                a.a(exception, "General Inference Exception", false);
            }
            assert (arrayList.size() == ((Node)object).length) : "Invalid model run - broken inference for: " + String.valueOf(arrayList);
            for (int i2 = 0; i2 < ((Node)object).length; ++i2) {
                if (this.GE) {
                    object[i2] = (Node)InferControlSplitProfilesPhase.a((float)object[i2]);
                }
                hashMap.put((Node)arrayList.get(i2), Float.valueOf((float)object[i2]));
            }
        }
        return hashMap;
    }

    private static float a(float f2) {
        if (f2 < 1.4E-44f) {
            return 1.4E-44f;
        }
        if (f2 > 1.0f) {
            return 1.0f;
        }
        return f2;
    }

    private static void a(ControlSplitNode controlSplitNode, Float f2) {
        assert (controlSplitNode instanceof IfNode) : "Can not update conditional probabilities: can infer only the if nodes.";
        AbstractBeginNode abstractBeginNode = ((IfNode)controlSplitNode).trueSuccessor();
        HashMap<Node, Double> hashMap = new HashMap<Node, Double>();
        hashMap.put((Node)abstractBeginNode, Double.valueOf(f2.floatValue()));
        List list = controlSplitNode.successors().snapshot();
        List<Node> list2 = list.stream().filter(Node::isAlive).collect(Collectors.toList());
        List<Node> list3 = InferControlSplitProfilesPhase.a(list2, hashMap);
        list3.forEach(node -> {
            ProfileData.ProfileSource profileSource = controlSplitNode.getProfileData().getProfileSource();
            if (profileSource.isInjected() || profileSource.isProfiled() || profileSource.isAdopted()) {
                VMError.shouldNotReachHere((String)("Fatal Error: can not override injected/profiled/adopted profiles (node = " + String.valueOf(controlSplitNode) + ")."));
            }
            controlSplitNode.setProbability((AbstractBeginNode)node, ProfileData.BranchProbabilityData.inferred((double)((Double)hashMap.get(node))));
        });
    }

    public static List<Node> a(List<Node> list, Map<Node, Double> map) {
        return list.stream().filter(map::containsKey).collect(Collectors.toList());
    }
}

