/*
 * Decompiled with CFR 0.152.
 */
package com.oracle.svm.hosted.webimage.wasm.ast.visitors;

import com.oracle.svm.hosted.webimage.wasm.ast.Export;
import com.oracle.svm.hosted.webimage.wasm.ast.Function;
import com.oracle.svm.hosted.webimage.wasm.ast.Global;
import com.oracle.svm.hosted.webimage.wasm.ast.ImportDescriptor;
import com.oracle.svm.hosted.webimage.wasm.ast.Instruction;
import com.oracle.svm.hosted.webimage.wasm.ast.Limit;
import com.oracle.svm.hosted.webimage.wasm.ast.Memory;
import com.oracle.svm.hosted.webimage.wasm.ast.StartFunction;
import com.oracle.svm.hosted.webimage.wasm.ast.Table;
import com.oracle.svm.hosted.webimage.wasm.ast.Tag;
import com.oracle.svm.hosted.webimage.wasm.ast.TypeUse;
import com.oracle.svm.hosted.webimage.wasm.ast.WasmModule;
import com.oracle.svm.hosted.webimage.wasm.ast.id.WasmId;
import com.oracle.svm.hosted.webimage.wasm.ast.visitors.WasmVisitor;
import com.oracle.svm.webimage.wasm.types.WasmLMUtil;
import com.oracle.svm.webimage.wasm.types.WasmPrimitiveType;
import com.oracle.svm.webimage.wasm.types.WasmValType;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import jdk.graal.compiler.debug.GraalError;

public class WasmValidator
extends WasmVisitor {
    private final Context ctxt = new ContextImpl();
    private final ArrayDeque<WasmValType> vals = new ArrayDeque();
    private final Deque<CtrlFrame> ctrls = new ArrayDeque<CtrlFrame>();

    private RuntimeException error(String msg) {
        Object errorMsg = "Error";
        Function currentFunction = this.ctxt.getCurrentFunc();
        Instruction currentInstruction = this.ctxt.getCurrentInstruction();
        if (currentFunction != null) {
            errorMsg = (String)errorMsg + " in function: " + String.valueOf(this.ctxt.getCurrentFunc().getId());
        }
        if (currentInstruction != null) {
            errorMsg = (String)errorMsg + " at " + String.valueOf(currentInstruction);
        }
        if (msg != null) {
            errorMsg = (String)errorMsg + ": " + msg;
        }
        throw GraalError.shouldNotReachHere((String)errorMsg);
    }

    private void errorIf(boolean condition, String msg) {
        if (condition) {
            throw this.error(msg);
        }
    }

    private static boolean assertIdsEqual(WasmId first, WasmId second) {
        return Objects.equals(first, second);
    }

    private CtrlFrame topFrame() {
        return Objects.requireNonNull(this.ctrls.peek());
    }

    private void pushVal(WasmValType t) {
        this.vals.push(t);
    }

    private WasmValType popVal() {
        CtrlFrame top = this.topFrame();
        if (this.vals.isEmpty() && top.unreachable) {
            return null;
        }
        this.errorIf(this.vals.isEmpty(), "Expected some value on stack");
        return this.vals.pop();
    }

    private RuntimeException typeMismatch(WasmValType[] expected) {
        return this.typeMismatch(expected, (Deque<WasmValType>)this.vals.clone());
    }

    private RuntimeException typeMismatch(WasmValType[] expected, Deque<WasmValType> actual) {
        List<WasmValType> actualTypes = Arrays.asList((WasmValType[])actual.toArray(WasmValType[]::new));
        Collections.reverse(actualTypes);
        throw this.error("type mismatch, expected " + Arrays.toString(expected) + " but got " + String.valueOf(actualTypes));
    }

    private void popVals(WasmValType ... expected) {
        int numTypes = expected.length;
        if (this.vals.size() < numTypes) {
            throw this.typeMismatch(expected);
        }
        Object oldStack = this.vals.clone();
        for (int i = numTypes - 1; i >= 0; --i) {
            WasmValType expectedType = expected[i];
            WasmValType actualType = this.popVal();
            if (actualType == null || expectedType == null || Objects.equals(expectedType, actualType)) continue;
            throw this.typeMismatch(expected, (Deque<WasmValType>)oldStack);
        }
    }

    private void assertStackEmpty() {
        if (!this.vals.isEmpty()) {
            throw this.typeMismatch(new WasmValType[0]);
        }
    }

    private void assertLocalExists(WasmId.Local local) {
        this.errorIf(!this.ctxt.hasLocal(local), "Local does not exist: " + String.valueOf(local));
    }

    private Global getAndAssertGlobalExists(WasmId.Global global) {
        Global g = this.ctxt.getGlobal(global);
        this.errorIf(g == null, "global does not exist: " + String.valueOf(global));
        return g;
    }

    private void pushCtrl(WasmId.Label label) {
        this.assertStackEmpty();
        if (label != null) {
            this.assertIdUniqueName(label, this.ctrls.stream().map(frame -> frame.label).filter(Objects::nonNull).collect(Collectors.toList()));
        }
        this.ctrls.push(new CtrlFrame(label));
    }

    private void popCtrl(WasmId.Label expectedLabel) {
        this.errorIf(this.ctrls.isEmpty(), "Expected a control frame");
        CtrlFrame frame = this.topFrame();
        this.assertStackEmpty();
        this.errorIf(frame.label != expectedLabel, "Expected control frame " + String.valueOf(expectedLabel) + " but got " + String.valueOf(frame.label));
        this.ctrls.pop();
    }

    private void unreachable() {
        CtrlFrame top = this.topFrame();
        this.vals.clear();
        top.unreachable = true;
    }

    private void applyBinary(WasmValType left, WasmValType right, WasmValType result) {
        this.applyTypeUse(TypeUse.forBinary(result, left, right));
    }

    private void applyUnary(WasmValType input, WasmValType result) {
        this.applyTypeUse(TypeUse.forUnary(result, input));
    }

    private void applyTypeUse(TypeUse typeUse) {
        this.popVals((WasmValType[])typeUse.params.toArray(WasmValType[]::new));
        typeUse.results.forEach(this::pushVal);
    }

    private void assertLabelExists(WasmId.Label label) {
        this.errorIf(this.ctrls.stream().noneMatch(frame -> WasmValidator.assertIdsEqual(label, frame.label)), "Label " + String.valueOf(label) + " does not exist.");
    }

    private <I extends WasmId> void assertIdUniqueName(I id, Collection<I> others) {
        this.errorIf(!id.isResolved(), "Found unresolved id: " + String.valueOf(id));
        this.errorIf(others.stream().map(WasmId::getName).anyMatch(name -> name.equals(id.getName())), "Id does not have unique name: " + String.valueOf(id));
    }

    private void checkLimit(Limit limit, long maxValue) {
        this.errorIf((long)limit.getMin() > maxValue, "Limit min larger than " + maxValue + ": " + String.valueOf(limit));
        if (limit.hasMax()) {
            this.errorIf((long)limit.getMin() > maxValue, "Limit max larger than " + maxValue + ": " + String.valueOf(limit));
            this.errorIf(limit.getMin() > limit.getMax(), "Limit min larger than max: " + String.valueOf(limit));
        }
    }

    @Override
    public void visitModule(WasmModule m) {
        m.getFunctions().forEach(func -> this.ctxt.addFunc(func.getId(), func.getSignature()));
        m.getImports().values().forEach(importDecl -> {
            ImportDescriptor patt0$temp = importDecl.getDescriptor();
            if (patt0$temp instanceof ImportDescriptor.Function) {
                ImportDescriptor.Function funcImport = (ImportDescriptor.Function)patt0$temp;
                this.ctxt.addFunc(importDecl.getFunctionId(), funcImport.typeUse);
            }
        });
        super.visitModule(m);
        this.ctxt.clearFuncs();
    }

    @Override
    public void visitMemory(Memory m) {
        super.visitMemory(m);
        this.ctxt.addMemory(m.id);
        int maxSize = 65536;
        this.checkLimit(m.limit, maxSize);
    }

    @Override
    public void visitTag(Tag tag) {
        this.ctxt.addTag(tag.id);
        super.visitTag(tag);
    }

    @Override
    public void visitGlobal(Global global) {
        this.ctxt.addGlobal(global);
        this.pushCtrl(null);
        super.visitGlobal(global);
        this.popVals(global.getType());
        this.popCtrl(null);
    }

    @Override
    public void visitExport(Export e) {
        super.visitExport(e);
        this.ctxt.addExport(e);
    }

    @Override
    public void visitStartFunction(StartFunction startFunction) {
        WasmId.Func function = startFunction.function;
        TypeUse typeUse = this.ctxt.getFunc(function);
        this.errorIf(typeUse == null, "Start function not found: " + String.valueOf(function));
        this.errorIf(!typeUse.params.isEmpty() || !typeUse.results.isEmpty(), "Start function must not take arguments or produce results: " + String.valueOf(function) + ", " + String.valueOf(typeUse));
    }

    @Override
    public void visitFunction(Function f) {
        f.getParams().forEach(this.ctxt::addLocal);
        f.getLocals().forEach(this.ctxt::addLocal);
        this.errorIf(f.getResults().size() > 1, "Function has multiple return values: " + String.valueOf(f));
        this.ctxt.setReturnType(f.getResults().isEmpty() ? null : f.getResults().get(0));
        this.ctxt.setCurrentFunc(f);
        this.pushCtrl(null);
        super.visitFunction(f);
        this.popCtrl(null);
        this.ctxt.setCurrentFunc(null);
        this.ctxt.setReturnType(null);
        this.ctxt.clearLocals();
    }

    @Override
    public void visitTable(Table t) {
        this.ctxt.addTable(t.id, t);
        this.checkLimit(t.limit, 0xFFFFFFFFL);
        this.errorIf(!t.elementType.isRef(), "Table has non reference element type: " + String.valueOf(t));
    }

    @Override
    public void visitInstruction(Instruction inst) {
        Instruction parentInst = this.ctxt.getCurrentInstruction();
        this.ctxt.setCurrentInstruction(inst);
        super.visitInstruction(inst);
        this.ctxt.setCurrentInstruction(parentInst);
    }

    @Override
    public void visitBlock(Instruction.Block block) {
        this.pushCtrl(block.getLabel());
        super.visitBlock(block);
        this.popCtrl(block.getLabel());
    }

    @Override
    public void visitLoop(Instruction.Loop loop) {
        this.pushCtrl(loop.getLabel());
        super.visitLoop(loop);
        this.popCtrl(loop.getLabel());
    }

    @Override
    public void visitIf(Instruction.If ifBlock) {
        this.visitInstruction(ifBlock.condition);
        this.popVals(WasmPrimitiveType.i32);
        this.pushCtrl(ifBlock.getLabel());
        this.visitInstructions(ifBlock.thenInstructions);
        if (ifBlock.hasElse()) {
            this.popCtrl(ifBlock.getLabel());
            this.pushCtrl(ifBlock.getLabel());
            this.visitInstructions(ifBlock.elseInstructions);
        }
        this.popCtrl(ifBlock.getLabel());
    }

    @Override
    public void visitTry(Instruction.Try tryBlock) {
        this.pushCtrl(tryBlock.getLabel());
        this.visitInstructions(tryBlock.instructions);
        for (Instruction.Try.Catch catchBlock : tryBlock.catchBlocks) {
            this.errorIf(!this.ctxt.hasTag(catchBlock.tag), "No matching tag for catch: " + String.valueOf(catchBlock));
            this.pushCtrl(null);
            catchBlock.tag.typeUse.params.forEach(this::pushVal);
            this.visitInstructions(catchBlock.instructions);
            this.popCtrl(null);
        }
        this.popCtrl(tryBlock.getLabel());
    }

    @Override
    public void visitUnreachable(Instruction.Unreachable unreachable) {
        this.unreachable();
        super.visitUnreachable(unreachable);
    }

    @Override
    public void visitDrop(Instruction.Drop inst) {
        super.visitDrop(inst);
        this.popVal();
    }

    @Override
    public void visitBinary(Instruction.Binary inst) {
        super.visitBinary(inst);
        this.applyBinary(inst.op.leftInputType, inst.op.rightInputType, inst.op.outputType);
    }

    @Override
    public void visitBreak(Instruction.Break inst) {
        super.visitBreak(inst);
        WasmId.Label targetLabel = inst.getTarget();
        this.assertLabelExists(targetLabel);
        if (inst.condition == null) {
            this.unreachable();
        } else {
            this.popVals(WasmPrimitiveType.i32);
        }
    }

    @Override
    public void visitBreakTable(Instruction.BreakTable inst) {
        super.visitBreakTable(inst);
        this.popVals(WasmPrimitiveType.i32);
        WasmId.Label defaultLabel = inst.getDefaultTarget();
        this.assertLabelExists(defaultLabel);
        for (int i = 0; i < inst.numTargets(); ++i) {
            this.assertLabelExists(inst.getTarget(i));
        }
        this.unreachable();
    }

    @Override
    public void visitConst(Instruction.Const constValue) {
        super.visitConst(constValue);
        this.pushVal(constValue.literal.type);
    }

    @Override
    public void visitRelocation(Instruction.Relocation relocation) {
        this.errorIf(!relocation.wasProcessed(), "Unhandled relocation: " + String.valueOf(relocation));
        super.visitRelocation(relocation);
    }

    @Override
    public void visitLocalGet(Instruction.LocalGet localGet) {
        super.visitLocalGet(localGet);
        this.assertLocalExists(localGet.getLocal());
        this.pushVal(localGet.getType());
    }

    @Override
    public void visitLocalSet(Instruction.LocalSet localSet) {
        super.visitLocalSet(localSet);
        this.assertLocalExists(localSet.getLocal());
        this.popVals(localSet.getType());
    }

    @Override
    public void visitLocalTee(Instruction.LocalTee localTee) {
        super.visitLocalTee(localTee);
        this.assertLocalExists(localTee.getLocal());
        this.applyUnary(localTee.getType(), localTee.getType());
    }

    @Override
    public void visitGlobalGet(Instruction.GlobalGet globalGet) {
        super.visitGlobalGet(globalGet);
        this.getAndAssertGlobalExists(globalGet.getGlobal());
        this.pushVal(globalGet.getType());
    }

    @Override
    public void visitGlobalSet(Instruction.GlobalSet globalSet) {
        super.visitGlobalSet(globalSet);
        Global g = this.getAndAssertGlobalExists(globalSet.getGlobal());
        this.errorIf(!g.mutable, "Found global.set for immutable global " + String.valueOf(g));
        this.popVals(globalSet.getType());
    }

    @Override
    public void visitReturn(Instruction.Return ret) {
        super.visitReturn(ret);
        if (!ret.isVoid()) {
            this.popVals(this.ctxt.getReturnType());
        }
    }

    @Override
    public void visitCall(Instruction.Call inst) {
        super.visitCall(inst);
        TypeUse typeUse = this.ctxt.getFunc(inst.getTarget());
        this.errorIf(typeUse == null, "Function call target does not exist: " + String.valueOf(inst.getTarget()));
        this.applyTypeUse(typeUse);
    }

    @Override
    public void visitCallIndirect(Instruction.CallIndirect inst) {
        super.visitCallIndirect(inst);
        Table table = this.ctxt.getTable(inst.table);
        this.errorIf(table == null, "No matching table for call_indirect: " + String.valueOf(inst));
        this.errorIf(!table.elementType.isFuncRef(), "Target table is not a function table: " + String.valueOf(table));
        this.popVals(WasmPrimitiveType.i32);
        this.applyTypeUse(inst.signature);
    }

    @Override
    public void visitThrow(Instruction.Throw inst) {
        super.visitThrow(inst);
        this.errorIf(!this.ctxt.hasTag(inst.tag), "No matching tag for throw: " + String.valueOf(inst));
        this.applyTypeUse(inst.tag.typeUse);
    }

    @Override
    public void visitUnary(Instruction.Unary inst) {
        super.visitUnary(inst);
        this.errorIf(inst.op == Instruction.Unary.Op.Nop, "Unary Nop left in module");
        this.applyUnary(inst.op.inputType, inst.op.outputType);
    }

    @Override
    public void visitSelect(Instruction.Select inst) {
        super.visitSelect(inst);
        this.popVals(WasmPrimitiveType.i32);
        WasmValType t1 = this.popVal();
        WasmValType t2 = this.popVal();
        this.errorIf(t1 != t2 && t1 != null && t2 != null, "Select operand types do not match: " + String.valueOf(t1) + ", " + String.valueOf(t2));
        this.pushVal(t1 == null ? t2 : t1);
    }

    @Override
    public void visitLoad(Instruction.Load inst) {
        this.visitInstruction(inst.getOffset());
        this.popVals(WasmPrimitiveType.i32);
        this.visitInstruction(inst.baseAddress);
        this.applyUnary(WasmPrimitiveType.i32, inst.stackType);
    }

    @Override
    public void visitStore(Instruction.Store inst) {
        this.visitInstruction(inst.getOffset());
        this.popVals(WasmPrimitiveType.i32);
        this.visitInstruction(inst.baseAddress);
        this.visitInstruction(inst.value);
        this.popVals(WasmPrimitiveType.i32, inst.stackType);
    }

    @Override
    public void visitMemoryGrow(Instruction.MemoryGrow inst) {
        super.visitMemoryGrow(inst);
        this.applyUnary(WasmLMUtil.POINTER_TYPE, WasmLMUtil.POINTER_TYPE);
    }

    @Override
    public void visitMemoryFill(Instruction.MemoryFill inst) {
        super.visitMemoryFill(inst);
        this.popVals(WasmPrimitiveType.i32, WasmPrimitiveType.i32, WasmPrimitiveType.i32);
    }

    @Override
    public void visitMemoryCopy(Instruction.MemoryCopy inst) {
        super.visitMemoryCopy(inst);
        this.popVals(WasmPrimitiveType.i32, WasmPrimitiveType.i32, WasmPrimitiveType.i32);
    }

    @Override
    public void visitMemorySize(Instruction.MemorySize inst) {
        super.visitMemorySize(inst);
        this.pushVal(WasmLMUtil.POINTER_TYPE);
    }

    class ContextImpl
    implements Context {
        private Instruction inst = null;
        private Function func = null;
        private WasmValType returnType = null;
        private final Set<WasmId.Local> locals = new HashSet<WasmId.Local>();
        private final Map<WasmId.Func, TypeUse> funcs = new HashMap<WasmId.Func, TypeUse>();
        private final Map<WasmId.Global, Global> globals = new HashMap<WasmId.Global, Global>();
        private final Map<WasmId.Table, Table> tables = new HashMap<WasmId.Table, Table>();
        private final Set<WasmId.Memory> memories = new HashSet<WasmId.Memory>();
        private final Set<WasmId.Tag> tags = new HashSet<WasmId.Tag>();
        private final Set<String> exportNames = new HashSet<String>();

        ContextImpl() {
        }

        @Override
        public Instruction getCurrentInstruction() {
            return this.inst;
        }

        @Override
        public void setCurrentInstruction(Instruction inst) {
            this.inst = inst;
        }

        @Override
        public Function getCurrentFunc() {
            return this.func;
        }

        @Override
        public void setCurrentFunc(Function func) {
            this.func = func;
        }

        @Override
        public WasmValType getReturnType() {
            return this.returnType;
        }

        @Override
        public void setReturnType(WasmValType returnType) {
            this.returnType = returnType;
        }

        @Override
        public void addLocal(WasmId.Local id) {
            this.addId(id, this.locals);
        }

        @Override
        public boolean hasLocal(WasmId.Local id) {
            return this.locals.contains(id);
        }

        @Override
        public void clearLocals() {
            this.locals.clear();
        }

        @Override
        public void addFunc(WasmId.Func id, TypeUse typeUse) {
            this.addIdMapping(id, typeUse, this.funcs);
        }

        @Override
        public TypeUse getFunc(WasmId.Func id) {
            return this.funcs.get(id);
        }

        @Override
        public void clearFuncs() {
            this.funcs.clear();
        }

        @Override
        public void addGlobal(Global global) {
            this.addIdMapping(global.getId(), global, this.globals);
        }

        @Override
        public Global getGlobal(WasmId.Global id) {
            return this.globals.get(id);
        }

        @Override
        public void addTable(WasmId.Table id, Table table) {
            this.addIdMapping(id, table, this.tables);
        }

        @Override
        public Table getTable(WasmId.Table id) {
            return this.tables.get(id);
        }

        @Override
        public void addMemory(WasmId.Memory id) {
            this.addId(id, this.memories);
        }

        @Override
        public boolean hasMemory(WasmId.Memory id) {
            return this.memories.contains(id);
        }

        @Override
        public void addTag(WasmId.Tag id) {
            this.addId(id, this.tags);
            WasmValidator.this.errorIf(!id.typeUse.results.isEmpty(), "Tag has return value: " + String.valueOf(id));
        }

        @Override
        public boolean hasTag(WasmId.Tag id) {
            return this.tags.contains(id);
        }

        @Override
        public void addExport(Export export) {
            WasmValidator.this.errorIf(this.exportNames.contains(export.name), "Duplicate export: " + String.valueOf(export));
            this.exportNames.add(export.name);
            switch (export.type) {
                case FUNC: {
                    WasmValidator.this.errorIf(this.getFunc(export.getFuncId()) == null, "No matching function for export: " + String.valueOf(export));
                    break;
                }
                case MEM: {
                    WasmValidator.this.errorIf(!this.hasMemory(export.getMemoryId()), "No matching memory for export: " + String.valueOf(export));
                    break;
                }
                case TAG: {
                    WasmValidator.this.errorIf(!this.hasTag(export.getTagId()), "No matching tag for export: " + String.valueOf(export));
                    break;
                }
                default: {
                    throw WasmValidator.this.error("Unsupported export " + String.valueOf(export));
                }
            }
        }

        private <I extends WasmId, T> void addIdMapping(I id, T data, Map<I, T> map) {
            WasmValidator.this.assertIdUniqueName(id, map.keySet());
            T old = map.putIfAbsent(id, data);
            WasmValidator.this.errorIf(old != null, "Duplicate id: " + String.valueOf(id));
        }

        private <I extends WasmId> void addId(I id, Collection<I> others) {
            WasmValidator.this.assertIdUniqueName(id, others);
            WasmValidator.this.errorIf(others.contains(id), "Duplicate id: " + String.valueOf(id));
            others.add(id);
        }
    }

    static interface Context {
        public Instruction getCurrentInstruction();

        public void setCurrentInstruction(Instruction var1);

        public Function getCurrentFunc();

        public void setCurrentFunc(Function var1);

        public WasmValType getReturnType();

        public void setReturnType(WasmValType var1);

        public void addLocal(WasmId.Local var1);

        public boolean hasLocal(WasmId.Local var1);

        public void clearLocals();

        public void addFunc(WasmId.Func var1, TypeUse var2);

        public TypeUse getFunc(WasmId.Func var1);

        public void clearFuncs();

        public void addGlobal(Global var1);

        public Global getGlobal(WasmId.Global var1);

        public void addTable(WasmId.Table var1, Table var2);

        public Table getTable(WasmId.Table var1);

        public void addMemory(WasmId.Memory var1);

        public boolean hasMemory(WasmId.Memory var1);

        public void addTag(WasmId.Tag var1);

        public boolean hasTag(WasmId.Tag var1);

        public void addExport(Export var1);
    }

    static class CtrlFrame {
        final WasmId.Label label;
        boolean unreachable = false;

        CtrlFrame(WasmId.Label label) {
            this.label = label;
        }
    }
}

