"""Fix up various things after deserialization."""

from __future__ import annotations

from typing import Any, Final

from mypy.lookup import lookup_fully_qualified
from mypy.nodes import (
    Block,
    ClassDef,
    Decorator,
    FuncDef,
    MypyFile,
    OverloadedFuncDef,
    ParamSpecExpr,
    SymbolTable,
    SymbolTableNode,
    TypeAlias,
    TypeInfo,
    TypeVarExpr,
    TypeVarTupleExpr,
    Var,
)
from mypy.types import (
    NOT_READY,
    AnyType,
    CallableType,
    Instance,
    LiteralType,
    Overloaded,
    Parameters,
    ParamSpecType,
    ProperType,
    TupleType,
    TypeAliasType,
    TypedDictType,
    TypeOfAny,
    TypeType,
    TypeVarTupleType,
    TypeVarType,
    TypeVisitor,
    UnboundType,
    UnionType,
    UnpackType,
)
from mypy.visitor import NodeVisitor


class NodeFixer(NodeVisitor[None]):
    current_info: TypeInfo | None = None

    def __init__(self, modules: dict[str, MypyFile], allow_missing: bool) -> None:
        self.modules = modules
        # N.B: we do an allow_missing fixup when fixing up a fine-grained
        # incremental cache load (since there may be cross-refs into deleted
        # modules)
        self.allow_missing = allow_missing
        self.type_fixer = TypeFixer(self.modules, allow_missing)

    # NOTE: This method isn't (yet) part of the NodeVisitor API.
    def visit_type_info(self, info: TypeInfo) -> None:
        save_info = self.current_info
        try:
            self.current_info = info
            if info.defn:
                info.defn.accept(self)
            if info.names:
                self.visit_symbol_table(info.names)
            if info.bases:
                for base in info.bases:
                    base.accept(self.type_fixer)
            if info._promote:
                for p in info._promote:
                    p.accept(self.type_fixer)
            if info.tuple_type:
                info.tuple_type.accept(self.type_fixer)
                info.update_tuple_type(info.tuple_type)
                if info.special_alias:
                    info.special_alias.alias_tvars = list(info.defn.type_vars)
                    for i, t in enumerate(info.defn.type_vars):
                        if isinstance(t, TypeVarTupleType):
                            info.special_alias.tvar_tuple_index = i
            if info.typeddict_type:
                info.typeddict_type.accept(self.type_fixer)
                info.update_typeddict_type(info.typeddict_type)
                if info.special_alias:
                    info.special_alias.alias_tvars = list(info.defn.type_vars)
                    for i, t in enumerate(info.defn.type_vars):
                        if isinstance(t, TypeVarTupleType):
                            info.special_alias.tvar_tuple_index = i
            if info.declared_metaclass:
                info.declared_metaclass.accept(self.type_fixer)
            if info.metaclass_type:
                info.metaclass_type.accept(self.type_fixer)
            if info.self_type:
                info.self_type.accept(self.type_fixer)
            if info.alt_promote:
                info.alt_promote.accept(self.type_fixer)
                instance = Instance(info, [])
                # Hack: We may also need to add a backwards promotion (from int to native int),
                # since it might not be serialized.
                if instance not in info.alt_promote.type._promote:
                    info.alt_promote.type._promote.append(instance)
            if info._mro_refs:
                info.mro = [
                    lookup_fully_qualified_typeinfo(
                        self.modules, name, allow_missing=self.allow_missing
                    )
                    for name in info._mro_refs
                ]
                info._mro_refs = None
        finally:
            self.current_info = save_info

    # NOTE: This method *definitely* isn't part of the NodeVisitor API.
    def visit_symbol_table(self, symtab: SymbolTable) -> None:
        for key in symtab:
            value = symtab[key]
            cross_ref = value.cross_ref
            # Fix up module cross-reference eagerly because it is very cheap.
            if cross_ref is not None:
                if cross_ref in self.modules:
                    value.cross_ref = None
                    value.unfixed = False
                    value._node = self.modules[cross_ref]
                # TODO: this should not be needed, looks like a daemon bug.
                elif self.allow_missing:
                    self.resolve_cross_ref(value)
            # Look at private attribute to avoid triggering fixup eagerly.
            elif isinstance(value._node, TypeInfo):
                self.visit_type_info(value._node)
            else:
                value.stored_info = self.current_info

    def resolve_cross_ref(self, value: SymbolTableNode) -> None:
        """Replace cross-reference with an actual referred node."""
        assert value.cross_ref is not None
        cross_ref = value.cross_ref
        value.cross_ref = None
        value.unfixed = False
        stnode = lookup_fully_qualified(
            cross_ref, self.modules, raise_on_missing=not self.allow_missing
        )
        if stnode is not None:
            if stnode is value:
                # The node seems to refer to itself, which can mean that
                # the target is a deleted submodule of the current module,
                # and thus lookup falls back to the symbol table of the parent
                # package. Here's how this may happen:
                #
                #   pkg/__init__.py:
                #     from pkg import sub
                #
                # Now if pkg.sub is deleted, the pkg.sub symbol table entry
                # appears to refer to itself. Replace the entry with a
                # placeholder to avoid a crash. We can't delete the entry,
                # as it would stop dependency propagation.
                short_name = cross_ref.rsplit(".", maxsplit=1)[-1]
                value._node = Var(short_name + "@deleted")
            else:
                assert stnode.node is not None, cross_ref
                value._node = stnode.node
        elif not self.allow_missing:
            assert False, f"Could not find cross-ref {cross_ref}"
        else:
            # We have a missing crossref in allow missing mode, need to put something
            value._node = missing_info(self.modules)

    def visit_func_def(self, func: FuncDef) -> None:
        if func.type is not None:
            func.type.accept(self.type_fixer)
            if isinstance(func.type, CallableType):
                func.type.definition = func

    def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
        if o.type:
            o.type.accept(self.type_fixer)
        for item in o.items:
            item.accept(self)
        if o.impl:
            o.impl.accept(self)
        if isinstance(o.type, Overloaded):
            # For error messages we link the original definition for each item.
            for typ, item in zip(o.type.items, o.items):
                typ.definition = item

    def visit_decorator(self, d: Decorator) -> None:
        if d.func:
            d.func.accept(self)
        if d.var:
            d.var.accept(self)
        typ = d.var.type
        if isinstance(typ, ProperType) and isinstance(typ, CallableType):
            typ.definition = d.func

    def visit_class_def(self, c: ClassDef) -> None:
        for v in c.type_vars:
            v.accept(self.type_fixer)

    def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
        for value in tv.values:
            value.accept(self.type_fixer)
        tv.upper_bound.accept(self.type_fixer)
        tv.default.accept(self.type_fixer)

    def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:
        p.upper_bound.accept(self.type_fixer)
        p.default.accept(self.type_fixer)

    def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
        tv.upper_bound.accept(self.type_fixer)
        tv.tuple_fallback.accept(self.type_fixer)
        tv.default.accept(self.type_fixer)

    def visit_var(self, v: Var) -> None:
        if v.type is not None:
            v.type.accept(self.type_fixer)
        if v.setter_type is not None:
            v.setter_type.accept(self.type_fixer)

    def visit_type_alias(self, a: TypeAlias) -> None:
        a.target.accept(self.type_fixer)
        for v in a.alias_tvars:
            v.accept(self.type_fixer)


class TypeFixer(TypeVisitor[None]):
    def __init__(self, modules: dict[str, MypyFile], allow_missing: bool) -> None:
        self.modules = modules
        self.allow_missing = allow_missing

    def visit_instance(self, inst: Instance) -> None:
        type_ref = inst.type_ref
        if type_ref is None:
            return  # We've already been here.
        inst.type_ref = None
        inst.type = lookup_fully_qualified_typeinfo(
            self.modules, type_ref, allow_missing=self.allow_missing
        )
        # TODO: Is this needed or redundant?
        # Also fix up the bases, just in case.
        for base in inst.type.bases:
            if base.type is NOT_READY:
                base.accept(self)
        for a in inst.args:
            a.accept(self)
        if inst.last_known_value is not None:
            inst.last_known_value.accept(self)
        if inst.extra_attrs:
            for v in inst.extra_attrs.attrs.values():
                v.accept(self)

    def visit_type_alias_type(self, t: TypeAliasType) -> None:
        type_ref = t.type_ref
        if type_ref is None:
            return  # We've already been here.
        t.type_ref = None
        t.alias = lookup_fully_qualified_alias(
            self.modules, type_ref, allow_missing=self.allow_missing
        )
        for a in t.args:
            a.accept(self)

    def visit_any(self, o: Any) -> None:
        pass  # Nothing to descend into.

    def visit_callable_type(self, ct: CallableType) -> None:
        if ct.fallback:
            ct.fallback.accept(self)
        for argt in ct.arg_types:
            # argt may be None, e.g. for __self in NamedTuple constructors.
            if argt is not None:
                argt.accept(self)
        if ct.ret_type is not None:
            ct.ret_type.accept(self)
        for v in ct.variables:
            v.accept(self)
        if ct.type_guard is not None:
            ct.type_guard.accept(self)
        if ct.type_is is not None:
            ct.type_is.accept(self)

    def visit_overloaded(self, t: Overloaded) -> None:
        for ct in t.items:
            ct.accept(self)

    def visit_erased_type(self, o: Any) -> None:
        # This type should exist only temporarily during type inference
        raise RuntimeError("Shouldn't get here", o)

    def visit_deleted_type(self, o: Any) -> None:
        pass  # Nothing to descend into.

    def visit_none_type(self, o: Any) -> None:
        pass  # Nothing to descend into.

    def visit_uninhabited_type(self, o: Any) -> None:
        pass  # Nothing to descend into.

    def visit_partial_type(self, o: Any) -> None:
        raise RuntimeError("Shouldn't get here", o)

    def visit_tuple_type(self, tt: TupleType) -> None:
        if tt.items:
            for it in tt.items:
                it.accept(self)
        if tt.partial_fallback is not None:
            tt.partial_fallback.accept(self)

    def visit_typeddict_type(self, tdt: TypedDictType) -> None:
        if tdt.items:
            for it in tdt.items.values():
                it.accept(self)
        if tdt.fallback is not None:
            if tdt.fallback.type_ref is not None:
                if (
                    lookup_fully_qualified(
                        tdt.fallback.type_ref,
                        self.modules,
                        raise_on_missing=not self.allow_missing,
                    )
                    is None
                ):
                    # We reject fake TypeInfos for TypedDict fallbacks because
                    # the latter are used in type checking and must be valid.
                    tdt.fallback.type_ref = "typing._TypedDict"
            tdt.fallback.accept(self)

    def visit_literal_type(self, lt: LiteralType) -> None:
        lt.fallback.accept(self)

    def visit_type_var(self, tvt: TypeVarType) -> None:
        if tvt.values:
            for vt in tvt.values:
                vt.accept(self)
        tvt.upper_bound.accept(self)
        tvt.default.accept(self)

    def visit_param_spec(self, p: ParamSpecType) -> None:
        p.upper_bound.accept(self)
        p.default.accept(self)
        p.prefix.accept(self)

    def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
        t.tuple_fallback.accept(self)
        t.upper_bound.accept(self)
        t.default.accept(self)

    def visit_unpack_type(self, u: UnpackType) -> None:
        u.type.accept(self)

    def visit_parameters(self, p: Parameters) -> None:
        for argt in p.arg_types:
            if argt is not None:
                argt.accept(self)
        for var in p.variables:
            var.accept(self)

    def visit_unbound_type(self, o: UnboundType) -> None:
        for a in o.args:
            a.accept(self)

    def visit_union_type(self, ut: UnionType) -> None:
        if ut.items:
            for it in ut.items:
                it.accept(self)

    def visit_type_type(self, t: TypeType) -> None:
        t.item.accept(self)


def lookup_fully_qualified_typeinfo(
    modules: dict[str, MypyFile], name: str, *, allow_missing: bool
) -> TypeInfo:
    stnode = lookup_fully_qualified(name, modules, raise_on_missing=not allow_missing)
    node = stnode.node if stnode else None
    if isinstance(node, TypeInfo):
        return node
    else:
        # Looks like a missing TypeInfo during an initial daemon load, put something there
        assert (
            allow_missing
        ), "Should never get here in normal mode, got {}:{} instead of TypeInfo".format(
            type(node).__name__, node.fullname if node else ""
        )
        return missing_info(modules)


def lookup_fully_qualified_alias(
    modules: dict[str, MypyFile], name: str, *, allow_missing: bool
) -> TypeAlias:
    stnode = lookup_fully_qualified(name, modules, raise_on_missing=not allow_missing)
    node = stnode.node if stnode else None
    if isinstance(node, TypeAlias):
        return node
    elif isinstance(node, TypeInfo):
        if node.special_alias:
            # Already fixed up.
            return node.special_alias
        if node.tuple_type:
            alias = TypeAlias.from_tuple_type(node)
        elif node.typeddict_type:
            alias = TypeAlias.from_typeddict_type(node)
        else:
            assert allow_missing
            return missing_alias()
        node.special_alias = alias
        return alias
    else:
        # Looks like a missing TypeAlias during an initial daemon load, put something there
        assert (
            allow_missing
        ), "Should never get here in normal mode, got {}:{} instead of TypeAlias".format(
            type(node).__name__, node.fullname if node else ""
        )
        return missing_alias()


_SUGGESTION: Final = "<missing {}: *should* have gone away during fine-grained update>"


def missing_info(modules: dict[str, MypyFile]) -> TypeInfo:
    suggestion = _SUGGESTION.format("info")
    dummy_def = ClassDef(suggestion, Block([]))
    dummy_def.fullname = suggestion

    info = TypeInfo(SymbolTable(), dummy_def, "<missing>")
    obj_type = lookup_fully_qualified_typeinfo(modules, "builtins.object", allow_missing=False)
    info.bases = [Instance(obj_type, [])]
    info.mro = [info, obj_type]
    return info


def missing_alias() -> TypeAlias:
    suggestion = _SUGGESTION.format("alias")
    return TypeAlias(AnyType(TypeOfAny.special_form), suggestion, "<missing>", line=-1, column=-1)
