/*
 * Decompiled with CFR 0.152.
 */
package org.graalvm.compiler.lir.processor;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.processing.RoundEnvironment;
import javax.lang.model.element.AnnotationMirror;
import javax.lang.model.element.Element;
import javax.lang.model.element.ElementKind;
import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.Name;
import javax.lang.model.element.PackageElement;
import javax.lang.model.element.TypeElement;
import javax.lang.model.element.VariableElement;
import javax.lang.model.type.TypeKind;
import javax.lang.model.type.TypeMirror;
import javax.tools.Diagnostic;
import javax.tools.JavaFileObject;
import org.graalvm.compiler.processor.AbstractProcessor;

public class IntrinsicStubProcessor
extends AbstractProcessor {
    private static final String NODE_INTRINSIC_CLASS_NAME = "org.graalvm.compiler.graph.Node.NodeIntrinsic";
    private static final String GENERATE_STUB_CLASS_NAME = "org.graalvm.compiler.lir.GenerateStub";
    private static final String GENERATE_STUBS_CLASS_NAME = "org.graalvm.compiler.lir.GenerateStubs";
    private static final String GENERATED_STUBS_HOLDER_CLASS_NAME = "org.graalvm.compiler.lir.GeneratedStubsHolder";
    private static final String CONSTANT_NODE_PARAMETER_CLASS_NAME = "org.graalvm.compiler.graph.Node.ConstantNodeParameter";
    private TypeElement nodeIntrinsic;
    private TypeElement generatedStubsHolder;
    private TypeElement generateStub;
    private TypeElement generateStubs;
    private TypeMirror constantNodeParameter;

    @Override
    public Set<String> getSupportedAnnotationTypes() {
        return Set.of(GENERATE_STUB_CLASS_NAME, GENERATE_STUBS_CLASS_NAME, GENERATED_STUBS_HOLDER_CLASS_NAME);
    }

    @Override
    protected boolean doProcess(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) {
        if (!roundEnv.processingOver()) {
            this.nodeIntrinsic = this.getTypeElement(NODE_INTRINSIC_CLASS_NAME);
            this.generatedStubsHolder = this.getTypeElement(GENERATED_STUBS_HOLDER_CLASS_NAME);
            this.generateStub = this.getTypeElement(GENERATE_STUB_CLASS_NAME);
            this.generateStubs = this.getTypeElement(GENERATE_STUBS_CLASS_NAME);
            this.constantNodeParameter = this.getType(CONSTANT_NODE_PARAMETER_CLASS_NAME);
            for (Element element : roundEnv.getElementsAnnotatedWith(this.generatedStubsHolder)) {
                AnnotationMirror generatedStubsHolderAnnotation = this.getAnnotation(element, this.generatedStubsHolder.asType());
                TargetVM targetVM = TargetVM.valueOf(IntrinsicStubProcessor.getAnnotationValue(generatedStubsHolderAnnotation, "targetVM", String.class));
                ArrayList<GenerateStubClass> classes = new ArrayList<GenerateStubClass>();
                for (TypeMirror sourceType : IntrinsicStubProcessor.getAnnotationValueList(generatedStubsHolderAnnotation, "sources", TypeMirror.class)) {
                    TypeElement source = this.asTypeElement(sourceType);
                    ArrayList<GenerateStub> stubs = new ArrayList<GenerateStub>();
                    HashMap<MinimumFeaturesGetter, MinimumFeaturesGetter> minimumFeatureGetters = new HashMap<MinimumFeaturesGetter, MinimumFeaturesGetter>();
                    for (Element element2 : source.getEnclosedElements()) {
                        AnnotationMirror generateStubsAnnotation;
                        AnnotationMirror generateStubAnnotation = this.getAnnotation(element2, this.generateStub.asType());
                        if (generateStubAnnotation != null) {
                            this.extractStubs(targetVM, source, stubs, minimumFeatureGetters, (ExecutableElement)element2, generateStubAnnotation, List.of(generateStubAnnotation));
                        }
                        if ((generateStubsAnnotation = this.getAnnotation(element2, this.generateStubs.asType())) == null) continue;
                        List<AnnotationMirror> values = IntrinsicStubProcessor.getAnnotationValueList(generateStubsAnnotation, "value", AnnotationMirror.class);
                        this.extractStubs(targetVM, source, stubs, minimumFeatureGetters, (ExecutableElement)element2, generateStubsAnnotation, values);
                    }
                    classes.add(new GenerateStubClass(source, stubs, minimumFeatureGetters.keySet()));
                }
                this.createStubs(this, targetVM, (TypeElement)element, classes);
            }
        }
        return false;
    }

    private void extractStubs(TargetVM targetVM, TypeElement source, ArrayList<GenerateStub> stubs, HashMap<MinimumFeaturesGetter, MinimumFeaturesGetter> minimumFeatureGetters, ExecutableElement method, AnnotationMirror annotation, List<AnnotationMirror> generateStubAnnotations) {
        if (this.getAnnotation(method, this.nodeIntrinsic.asType()) == null) {
            String msg = String.format("methods annotated with %s must also be annotated with %s", annotation, this.nodeIntrinsic);
            this.env().getMessager().printMessage(Diagnostic.Kind.ERROR, msg, method, annotation);
        }
        RuntimeCheckedFlagsMethod rtc = this.findRuntimeCheckedFlagsVariant(this, source, method, annotation);
        for (AnnotationMirror generateStubAnnotationValue : generateStubAnnotations) {
            MinimumFeaturesGetter minimumFeaturesGetter = IntrinsicStubProcessor.extractMinimumFeaturesGetter(targetVM, minimumFeatureGetters, generateStubAnnotationValue);
            stubs.add(new GenerateStub(generateStubAnnotationValue, method, rtc, minimumFeaturesGetter));
        }
    }

    private static MinimumFeaturesGetter extractMinimumFeaturesGetter(TargetVM targetVM, HashMap<MinimumFeaturesGetter, MinimumFeaturesGetter> minimumFeatureGetters, AnnotationMirror genStub) {
        if (targetVM == TargetVM.substrate) {
            String amd64Getter = IntrinsicStubProcessor.getAnnotationValue(genStub, "minimumCPUFeaturesAMD64", String.class);
            String aarch64Getter = IntrinsicStubProcessor.getAnnotationValue(genStub, "minimumCPUFeaturesAARCH64", String.class);
            if (!amd64Getter.isEmpty() || !aarch64Getter.isEmpty()) {
                MinimumFeaturesGetter minimumFeaturesGetter = new MinimumFeaturesGetter(amd64Getter, aarch64Getter);
                MinimumFeaturesGetter existing = minimumFeatureGetters.putIfAbsent(minimumFeaturesGetter, minimumFeaturesGetter);
                return existing == null ? minimumFeaturesGetter : existing;
            }
        }
        return null;
    }

    private RuntimeCheckedFlagsMethod findRuntimeCheckedFlagsVariant(AbstractProcessor processor, TypeElement clazz, ExecutableElement method, AnnotationMirror annotation) {
        for (Element element : clazz.getEnclosedElements()) {
            ExecutableElement cur;
            RuntimeCheckedFlagsMethod runtimeCheckedFlagsVariant;
            if (element.getKind() != ElementKind.METHOD || processor.getAnnotation(element, this.nodeIntrinsic.asType()) == null || (runtimeCheckedFlagsVariant = IntrinsicStubProcessor.checkRuntimeCheckedFlagsVariant(method, cur = (ExecutableElement)element)) == null) continue;
            return runtimeCheckedFlagsVariant;
        }
        processor.env().getMessager().printMessage(Diagnostic.Kind.ERROR, String.valueOf(method) + ": Could not find runtime checked flags variant. For every method annotated with @GenerateStub, a second @NodeIntrinsic method with the same signature + an additional @ConstantNodeParameter EnumSet<CPUFeature> parameter for runtime checked CPU flags is required.", method, annotation);
        return null;
    }

    private static RuntimeCheckedFlagsMethod checkRuntimeCheckedFlagsVariant(ExecutableElement method, ExecutableElement cur) {
        if (!cur.getReturnType().equals(method.getReturnType()) || cur.getParameters().size() != method.getParameters().size() + 1) {
            return null;
        }
        int iCur = 0;
        int iDiff = -1;
        for (VariableElement variableElement : method.getParameters()) {
            VariableElement pCur = cur.getParameters().get(iCur++);
            if (variableElement.asType().equals(pCur.asType())) continue;
            if (iDiff < 0 && variableElement.asType().equals(cur.getParameters().get(iCur).asType()) && pCur.asType().toString().startsWith("java.util.EnumSet")) {
                iDiff = iCur - 1;
                ++iCur;
                continue;
            }
            return null;
        }
        assert (iCur == cur.getParameters().size() - 1);
        if (iDiff < 0) {
            iDiff = iCur;
            if (!cur.getParameters().get(iDiff).asType().toString().startsWith("java.util.EnumSet")) {
                return null;
            }
        }
        return new RuntimeCheckedFlagsMethod(cur, iDiff);
    }

    private void createStubs(AbstractProcessor processor, TargetVM targetVM, TypeElement holder, ArrayList<GenerateStubClass> classes) {
        PackageElement pkg = (PackageElement)holder.getEnclosingElement();
        String genClassName = String.valueOf(holder.getSimpleName()) + "Gen";
        String pkgQualifiedName = pkg.getQualifiedName().toString();
        String qualifiedGenClassName = pkgQualifiedName + "." + genClassName;
        HashSet<String> uniqueNames = new HashSet<String>();
        try {
            JavaFileObject factory = processor.env().getFiler().createSourceFile(qualifiedGenClassName, holder);
            try (PrintWriter out = new PrintWriter(factory.openWriter());){
                out.printf("// CheckStyle: stop header check\n", new Object[0]);
                out.printf("// CheckStyle: stop line length check\n", new Object[0]);
                out.printf("// GENERATED CONTENT - DO NOT EDIT\n", new Object[0]);
                out.printf("// GENERATOR: %s\n", this.getClass().getName());
                out.printf("package %s;\n", pkgQualifiedName);
                out.printf("\n", new Object[0]);
                HashSet<String> imports = new HashSet<String>();
                switch (targetVM) {
                    case hotspot: {
                        imports.addAll(List.of("org.graalvm.compiler.api.replacements.Snippet", "org.graalvm.compiler.hotspot.HotSpotForeignCallLinkage", "org.graalvm.compiler.hotspot.meta.HotSpotProviders", "org.graalvm.compiler.options.OptionValues", "org.graalvm.compiler.hotspot.stubs.SnippetStub"));
                        break;
                    }
                    case substrate: {
                        imports.addAll(List.of("com.oracle.svm.core.SubstrateTargetDescription", "com.oracle.svm.core.Uninterruptible", "com.oracle.svm.core.snippets.SubstrateForeignCallTarget", "com.oracle.svm.core.cpufeature.Stubs", "com.oracle.svm.graal.RuntimeCPUFeatureRegion", "org.graalvm.compiler.api.replacements.Fold", "org.graalvm.compiler.debug.GraalError", "org.graalvm.nativeimage.ImageSingletons", "java.util.EnumSet", "jdk.vm.ci.code.Architecture"));
                    }
                }
                for (GenerateStubClass genClass : classes) {
                    imports.add(genClass.clazz.toString());
                    for (GenerateStub generateStub : genClass.stubs) {
                        for (VariableElement variableElement : generateStub.method.getParameters()) {
                            if (this.getAnnotation(variableElement, this.constantNodeParameter) == null || variableElement.asType().getKind().isPrimitive()) continue;
                            imports.add(variableElement.asType().toString());
                        }
                    }
                }
                for (String i : imports) {
                    int lastDot = i.lastIndexOf(46);
                    if (pkgQualifiedName.length() == lastDot && i.startsWith(pkgQualifiedName)) continue;
                    out.printf("import %s;\n", i);
                }
                out.printf("\n", new Object[0]);
                out.printf("public class %s", genClassName);
                if (targetVM == TargetVM.hotspot) {
                    out.printf(" extends SnippetStub ", new Object[0]);
                }
                out.printf("{\n", new Object[0]);
                if (targetVM == TargetVM.hotspot) {
                    out.printf("\n", new Object[0]);
                    out.printf("    public %s(OptionValues options, HotSpotProviders providers, HotSpotForeignCallLinkage linkage) {\n", genClassName);
                    out.printf("        super(linkage.getDescriptor().getName(), options, providers, linkage);\n", new Object[0]);
                    out.printf("    }\n", new Object[0]);
                }
                out.printf("\n", new Object[0]);
                for (GenerateStubClass genClass : classes) {
                    if (targetVM == TargetVM.substrate) {
                        int n = 0;
                        for (MinimumFeaturesGetter featuresGetter : genClass.featureGetters) {
                            featuresGetter.setName(String.format("%s_getMinimumFeatures%s", genClass.clazz.getSimpleName(), n++ > 0 ? "_" + n : ""));
                            out.printf("    @Fold\n", new Object[0]);
                            out.printf("    public static EnumSet<?> %s() {\n", featuresGetter.getName());
                            out.printf("        Architecture arch = ImageSingletons.lookup(SubstrateTargetDescription.class).arch;\n", new Object[0]);
                            out.printf("        if (arch instanceof jdk.vm.ci.amd64.AMD64) {\n", new Object[0]);
                            if (featuresGetter.amd64Getter.isEmpty()) {
                                out.printf("            throw GraalError.shouldNotReachHere(\"not implemented\");\n", new Object[0]);
                            } else {
                                out.printf("            return %s.%s();\n", genClass.clazz.getSimpleName(), featuresGetter.amd64Getter);
                            }
                            out.printf("        }\n", new Object[0]);
                            out.printf("        if (arch instanceof jdk.vm.ci.aarch64.AArch64) {\n", new Object[0]);
                            if (featuresGetter.aarch64Getter.isEmpty()) {
                                out.printf("            throw GraalError.shouldNotReachHere(\"not implemented\");\n", new Object[0]);
                            } else {
                                out.printf("            return %s.%s();\n", genClass.clazz.getSimpleName(), featuresGetter.aarch64Getter);
                            }
                            out.printf("        }\n", new Object[0]);
                            out.printf("        throw GraalError.shouldNotReachHere();\n", new Object[0]);
                            out.printf("    }\n", new Object[0]);
                            out.printf("\n", new Object[0]);
                        }
                    }
                    for (GenerateStub generateStub : genClass.stubs) {
                        String name = IntrinsicStubProcessor.getAnnotationValue(generateStub.annotation, "name", String.class);
                        if (name.isEmpty()) {
                            name = generateStub.method.getSimpleName().toString();
                        }
                        if (!uniqueNames.add(name)) {
                            processor.env().getMessager().printMessage(Diagnostic.Kind.ERROR, "duplicate stub name: " + name, generateStub.method, generateStub.annotation);
                        }
                        List<String> list = IntrinsicStubProcessor.getAnnotationValueList(generateStub.annotation, "parameters", String.class);
                        Name className = genClass.clazz.getSimpleName();
                        switch (targetVM) {
                            case hotspot: {
                                this.generateStub(targetVM, out, className, name, list, generateStub.method, null, 0);
                                break;
                            }
                            case substrate: {
                                if (generateStub.minimumFeaturesGetter == null) {
                                    this.generateStub(targetVM, out, className, name, list, generateStub.method, null, 0);
                                } else {
                                    this.generateStub(targetVM, out, className, name, list, generateStub.runtimeCheckedFlagsMethod.method, generateStub.minimumFeaturesGetter.getName() + "()", generateStub.runtimeCheckedFlagsMethod.runtimeCheckedFlagsParameterIndex);
                                }
                                this.generateStub(targetVM, out, className, name + "RTC", list, generateStub.runtimeCheckedFlagsMethod.method, String.format("Stubs.getRuntimeCheckedCPUFeatures(%s.class)", className), generateStub.runtimeCheckedFlagsMethod.runtimeCheckedFlagsParameterIndex);
                            }
                        }
                    }
                }
                out.printf("}\n", new Object[0]);
            }
        }
        catch (IOException e) {
            processor.env().getMessager().printMessage(Diagnostic.Kind.ERROR, e.getMessage());
        }
    }

    private void generateStub(TargetVM targetVM, PrintWriter out, Name className, String methodName, List<String> params, ExecutableElement m, String runtimeCheckedFeatures, int runtimeCheckedFeaturesParameterIndex) {
        switch (targetVM) {
            case hotspot: {
                out.printf("    @Snippet\n", new Object[0]);
                break;
            }
            case substrate: {
                out.printf("    @Uninterruptible(reason = \"Must not do a safepoint check.\")\n", new Object[0]);
                out.printf("    @SubstrateForeignCallTarget(stubCallingConvention = false, fullyUninterruptible = true)\n", new Object[0]);
            }
        }
        out.printf("    private static %s %s(", m.getReturnType(), methodName);
        out.printf(m.getParameters().stream().filter(p -> this.getAnnotation((Element)p, this.constantNodeParameter) == null).map(p -> String.valueOf(p.asType()) + " " + String.valueOf(p.getSimpleName())).collect(Collectors.joining(", ")), new Object[0]);
        out.printf(") {\n", new Object[0]);
        if (runtimeCheckedFeatures != null) {
            out.printf("        RuntimeCPUFeatureRegion region = RuntimeCPUFeatureRegion.enterSet(%s);\n", runtimeCheckedFeatures);
            out.printf("        try {\n    ", new Object[0]);
        }
        out.printf("        %s%s.%s(", m.getReturnType().getKind() == TypeKind.VOID ? "" : "return ", className, m.getSimpleName());
        int iConst = 0;
        List<? extends VariableElement> parameters = m.getParameters();
        for (int i = 0; i < parameters.size(); ++i) {
            VariableElement p2 = parameters.get(i);
            if (i > 0) {
                out.printf(", ", new Object[0]);
            }
            if (this.getAnnotation(p2, this.constantNodeParameter) == null) {
                out.printf(p2.getSimpleName().toString(), new Object[0]);
                continue;
            }
            if (runtimeCheckedFeatures != null && i == runtimeCheckedFeaturesParameterIndex) {
                out.printf(runtimeCheckedFeatures, new Object[0]);
                continue;
            }
            if (!p2.asType().getKind().isPrimitive()) {
                out.printf("%s.", this.asTypeElement(p2.asType()).getSimpleName());
            }
            out.printf(params.get(iConst++), new Object[0]);
        }
        out.printf(");\n", new Object[0]);
        if (runtimeCheckedFeatures != null) {
            out.printf("        } finally {\n", new Object[0]);
            out.printf("            region.leave();\n", new Object[0]);
            out.printf("        }\n", new Object[0]);
        }
        out.printf("    }\n", new Object[0]);
        out.printf("\n", new Object[0]);
    }

    static enum TargetVM {
        hotspot,
        substrate;

    }

    private static final class GenerateStubClass {
        private final TypeElement clazz;
        private final ArrayList<GenerateStub> stubs;
        private final Set<MinimumFeaturesGetter> featureGetters;

        private GenerateStubClass(TypeElement clazz, ArrayList<GenerateStub> stubs, Set<MinimumFeaturesGetter> featureGetters) {
            this.clazz = clazz;
            this.stubs = stubs;
            this.featureGetters = featureGetters;
        }
    }

    private static final class RuntimeCheckedFlagsMethod {
        private final ExecutableElement method;
        private final int runtimeCheckedFlagsParameterIndex;

        private RuntimeCheckedFlagsMethod(ExecutableElement method, int runtimeCheckedFlagsParameterIndex) {
            this.method = method;
            this.runtimeCheckedFlagsParameterIndex = runtimeCheckedFlagsParameterIndex;
        }
    }

    private static final class MinimumFeaturesGetter {
        private final String amd64Getter;
        private final String aarch64Getter;
        private String name;

        private MinimumFeaturesGetter(String amd64Getter, String aarch64Getter) {
            this.amd64Getter = amd64Getter;
            this.aarch64Getter = aarch64Getter;
        }

        public String getName() {
            return this.name;
        }

        public void setName(String name) {
            this.name = name;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            MinimumFeaturesGetter that = (MinimumFeaturesGetter)o;
            return this.amd64Getter.equals(that.amd64Getter) && this.aarch64Getter.equals(that.aarch64Getter);
        }

        public int hashCode() {
            int result = this.amd64Getter.hashCode();
            result = 31 * result + this.aarch64Getter.hashCode();
            return result;
        }
    }

    private static final class GenerateStub {
        private final AnnotationMirror annotation;
        private final ExecutableElement method;
        private final RuntimeCheckedFlagsMethod runtimeCheckedFlagsMethod;
        private final MinimumFeaturesGetter minimumFeaturesGetter;

        private GenerateStub(AnnotationMirror annotation, ExecutableElement method, RuntimeCheckedFlagsMethod runtimeCheckedFlagsMethod, MinimumFeaturesGetter minimumFeaturesGetter) {
            this.annotation = annotation;
            this.method = method;
            this.runtimeCheckedFlagsMethod = runtimeCheckedFlagsMethod;
            this.minimumFeaturesGetter = minimumFeaturesGetter;
        }
    }
}

