/*
 * Decompiled with CFR 0.152.
 */
package com.oracle.svm.enterprise.hosted.ml.phases;

import com.oracle.graal.compiler.enterprise.ml.features.c;
import com.oracle.graal.compiler.enterprise.ml.models.a;
import com.oracle.graal.compiler.enterprise.ml.phases.InferGraphProfilesBasePhase;
import com.oracle.svm.core.util.VMError;
import com.oracle.svm.enterprise.hosted.ml.features.provider.ProfileInferenceFeature;
import com.oracle.svm.enterprise.hosted.ml.features.provider.b;
import com.oracle.svm.enterprise.hosted.ml.inferencer.Inferencer;
import com.oracle.svm.enterprise.hosted.profiling.utilities.a;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import jdk.graal.compiler.graph.Node;
import jdk.graal.compiler.graph.iterators.FilteredNodeIterable;
import jdk.graal.compiler.nodes.AbstractBeginNode;
import jdk.graal.compiler.nodes.ControlSplitNode;
import jdk.graal.compiler.nodes.IfNode;
import jdk.graal.compiler.nodes.ProfileData;
import jdk.graal.compiler.nodes.StructuredGraph;
import jdk.graal.compiler.nodes.spi.CoreProviders;
import org.graalvm.nativeimage.ImageSingletons;

public class InferControlSplitProfilesPhase
extends InferGraphProfilesBasePhase<Node, c> {
    private static final float vB = 1.4E-44f;
    private static final float vC = 1.0f;
    private final boolean vD;

    public InferControlSplitProfilesPhase(com.oracle.graal.compiler.enterprise.ml.features.provider.c<Node, c> c2, boolean bl2) {
        super(c2, (com.oracle.graal.compiler.enterprise.ml.models.a)ImageSingletons.lookup(com.oracle.svm.enterprise.hosted.ml.models.a.class));
        this.vD = bl2;
    }

    protected void run(StructuredGraph structuredGraph, CoreProviders coreProviders) {
        String string = b.i(structuredGraph);
        if (!this.cy.h(string)) {
            return;
        }
        com.oracle.graal.compiler.enterprise.ml.features.provider.b b2 = this.cy.g(string);
        Map<Node, Float> map = this.a((com.oracle.graal.compiler.enterprise.ml.features.provider.b<Node, c>)b2, structuredGraph);
        FilteredNodeIterable<Node> filteredNodeIterable = InferControlSplitProfilesPhase.j(structuredGraph);
        for (Node node : filteredNodeIterable) {
            assert (node instanceof ControlSplitNode) : "Can not infer profiles for the non control split node: " + String.valueOf(node);
            if (!map.containsKey(node)) continue;
            ProfileData.ProfileSource profileSource = ((ControlSplitNode)node).getProfileData().getProfileSource();
            if (profileSource.isUnknown()) {
                InferControlSplitProfilesPhase.a((ControlSplitNode)node, map.get(node));
            }
            if (!profileSource.isProfiled() || !ProfileInferenceFeature.shouldLogMLProfileInference()) continue;
            InferControlSplitProfilesPhase.a(structuredGraph, node, (com.oracle.graal.compiler.enterprise.ml.features.provider.b<Node, ? extends c>)b2, map);
            InferControlSplitProfilesPhase.a((ControlSplitNode)node, map.get(node));
        }
    }

    private Map<Node, Float> a(com.oracle.graal.compiler.enterprise.ml.features.provider.b<Node, c> b2, StructuredGraph structuredGraph) {
        a.a a2 = null;
        try {
            a2 = this.cz.a(b2, structuredGraph.getDebug());
        }
        catch (Inferencer.InferenceException inferenceException) {
            ProfileInferenceFeature.handleMLProfileInferenceThrowable(inferenceException, "Invalid ONNX Profile Inference", false);
        }
        catch (Exception exception) {
            ProfileInferenceFeature.handleMLProfileInferenceThrowable(exception, "General Inference Exception", false);
        }
        HashMap<Node, Float> hashMap = new HashMap<Node, Float>();
        if (a2 == null) {
            return hashMap;
        }
        for (int i2 = 0; i2 < a2.v().size(); ++i2) {
            if (this.vD) {
                a2.w()[i2] = InferControlSplitProfilesPhase.a(a2.w()[i2]);
            }
            hashMap.put((Node)a2.v().get(i2), Float.valueOf(a2.w()[i2]));
        }
        return hashMap;
    }

    protected static void a(StructuredGraph structuredGraph, Node node, com.oracle.graal.compiler.enterprise.ml.features.provider.b<Node, ? extends c> b2, Map<Node, Float> map) {
        Optional optional = b2.a((Object)node);
        assert (optional.isPresent()) : "Features must be present as we have the ML predicted profile for the node: " + String.valueOf(node);
        long l2 = ((com.oracle.svm.enterprise.hosted.ml.features.b)optional.get()).getLabelCount();
        com.oracle.svm.enterprise.hosted.ml.logging.a.a(structuredGraph, (IfNode)node, ((ProfileData.BranchProbabilityData)((ControlSplitNode)node).getProfileData()).getDesignatedSuccessorProbability(), map.get(node).floatValue(), l2);
    }

    public static FilteredNodeIterable<Node> j(StructuredGraph structuredGraph) {
        return a.r(structuredGraph).filter(node -> a.a((ControlSplitNode)node));
    }

    protected static float a(float f2) {
        if (f2 < 1.4E-44f) {
            return 1.4E-44f;
        }
        if (f2 > 1.0f) {
            return 1.0f;
        }
        return f2;
    }

    protected static void a(ControlSplitNode controlSplitNode, Float f2) {
        Map<Node, Double> map = InferControlSplitProfilesPhase.a(controlSplitNode, (double)f2.floatValue());
        List list = controlSplitNode.successors().snapshot();
        List<Node> list2 = list.stream().filter(Node::isAlive).collect(Collectors.toList());
        List<Node> list3 = InferControlSplitProfilesPhase.successorsMatchingProfiles(list2, map);
        list3.forEach(node -> {
            ProfileData.ProfileSource profileSource = controlSplitNode.getProfileData().getProfileSource();
            if ((profileSource.isInjected() || profileSource.isProfiled() || profileSource.isAdopted()) && !ProfileInferenceFeature.shouldLogMLProfileInference()) {
                VMError.shouldNotReachHere((String)("Fatal Error: can not override injected/profiled/adopted profiles (node = " + String.valueOf(controlSplitNode) + ")."));
            }
            controlSplitNode.setProbability((AbstractBeginNode)node, ProfileData.BranchProbabilityData.inferred((double)((Double)map.get(node))));
        });
    }

    private static Map<Node, Double> a(ControlSplitNode controlSplitNode, double d2) {
        assert (controlSplitNode instanceof IfNode) : "Can not update conditional probabilities for node: " + String.valueOf(controlSplitNode) + ": can infer only the if nodes.";
        AbstractBeginNode abstractBeginNode = ((IfNode)controlSplitNode).trueSuccessor();
        HashMap<Node, Double> hashMap = new HashMap<Node, Double>();
        hashMap.put((Node)abstractBeginNode, d2);
        return hashMap;
    }

    private static List<Node> successorsMatchingProfiles(List<Node> list, Map<Node, Double> map) {
        return list.stream().filter(map::containsKey).collect(Collectors.toList());
    }
}

