local M = {}

local ls = require("luasnip")
local extras = require("luasnip.extras")
local conds = require("luasnip.extras.expand_conditions")
local t = ls.text_node
local i = ls.insert_node
local d = ls.dynamic_node
local c = ls.choice_node
local f = ls.function_node
local sn = ls.snippet_node
local fmt = require("luasnip.extras.fmt").fmt
local n = extras.nonempty

local autosnippets = {}

-- {{{ Helpers
local function unlines(...)
  return table.concat({ ... }, "\n")
end

local function flatten(arr)
  local result = {}
  for _, subarray in ipairs(arr) do
    if type(subarray) == "table" then
      for _, value in ipairs(subarray) do
        table.insert(result, value)
      end
    else
      table.insert(result, subarray)
    end
  end
  return result
end

local function trig(name)
  return { name = name, trig = name }
end
-- }}}
-- {{{ Math mode detection
-- Taken from: https://github.com/iurimateus/luasnip-latex-snippets.nvim/blob/main/lua/luasnip-latex-snippets/util/ts_utils.lua
local MATH_NODES = {
  displayed_equation = true,
  inline_formula = true,
  math_environment = true,
}

local TEXT_NODES = {
  text_mode = true,
  label_definition = true,
  label_reference = true,
}

local function get_node_at_cursor()
  local pos = vim.api.nvim_win_get_cursor(0)
  -- Subtract one to account for 1-based row indexing in nvim_win_get_cursor
  local row, col = pos[1] - 1, pos[2]

  local parser = vim.treesitter.get_parser(0, "latex")
  if not parser then
    return
  end

  local root_tree = parser:parse({ row, col, row, col })[1]
  local root = root_tree and root_tree:root()
  if not root then
    return
  end

  return root:named_descendant_for_range(row, col, row, col)
end

local function in_text(check_parent)
  local node = get_node_at_cursor()
  while node do
    if node:type() == "text_mode" then
      if check_parent then
        -- For \text{}
        local parent = node:parent()
        if parent and MATH_NODES[parent:type()] then
          return false
        end
      end

      return true
    elseif MATH_NODES[node:type()] then
      return false
    end
    node = node:parent()
  end
  return true
end

local function in_mathzone()
  local node = get_node_at_cursor()
  while node do
    if TEXT_NODES[node:type()] then
      return false
    elseif MATH_NODES[node:type()] then
      return true
    end
    node = node:parent()
  end
  return false
end

local function not_math()
  return in_text(true)
end
-- }}}
-- {{{ Start of line & non-math
local beginTextCondition = function(...)
  return conds.line_begin(...) and not_math()
end

local parseBeginText = ls.extend_decorator.apply(ls.parser.parse_snippet, {
  condition = beginTextCondition,
}) --[[@as function]]

local snipBeginText = ls.extend_decorator.apply(ls.snippet, {
  condition = beginTextCondition,
}) --[[@as function]]

local function env(name)
  return "\\begin{" .. name .. "}\n\t$0\n\\end{" .. name .. "}"
end

local function optional_square_bracket_arg(index, default)
  return sn(index, {
    n(1, "[", ""),
    i(1, default),
    n(1, "]", ""),
  })
end

local function theorem_env(name, prefix)
  return snipBeginText(trig(name), {
    t("\\begin{" .. name .. "}"),
    optional_square_bracket_arg(1),
    n(2, "\\label{" .. prefix .. ":", ""),
    i(2),
    n(2, "}", ""),
    t({ "", "\t" }),
    i(0),
    t({ "", "\\end{" .. name .. "}" }),
  })
end

vim.list_extend(autosnippets, {
  parseBeginText(
    { trig = "begin", name = "Begin / end environment" },
    env("$1")
  ),
  -- {{{ Chapters / sections
  parseBeginText(trig("chapter"), "\\chapter{$1}\n$0"),
  parseBeginText(trig("section"), "\\section{$1}\n$0"),
  parseBeginText(trig("subsection"), "\\subsection{$1}\n$0"),
  parseBeginText(trig("subsubsection"), "\\subsubsection{$1}\n$0"),
  -- }}}
  -- {{{ Lists
  parseBeginText({ trig = "item", name = "List item" }, "\\item"),
  parseBeginText({ trig = "olist", name = "Ordered list" }, env("enumerate")),
  parseBeginText({ trig = "ulist", name = "Unordered list" }, env("itemize")),
  -- }}}
  -- {{{ Theorem envs
  theorem_env("theorem", "thm"),
  theorem_env("lemma", "lem"),
  theorem_env("exercise", "exe"),
  theorem_env("definition", "def"),
  theorem_env("corollary", "cor"),
  theorem_env("example", "exa"),
  snipBeginText(trig("proof"), {
    t("\\begin{proof}"),
    optional_square_bracket_arg(1),
    t({ "", "\t" }),
    i(0),
    t({ "", "\\end{proof}" }),
  }),
  -- }}}
  -- {{{ Special structures
  parseBeginText(
    { trig = "ciff", name = "If and only if cases" },
    unlines(
      "\\begin{enumerate}",
      "\t\\item[$\\implies$]$1",
      "\t\\item[$\\impliedby$]$2",
      "\\end{enumerate}",
      "$0"
    )
  ),
  -- }}}
})
-- }}}
-- {{{ Non-math
local parseText = ls.extend_decorator.apply(ls.parser.parse_snippet, {
  condition = not_math,
}) --[[@as function]]

local snipText = ls.extend_decorator.apply(ls.snippet, {
  condition = not_math,
}) --[[@as function]]

local function ref(name, prefix)
  return {
    parseText(
      { trig = "r" .. name, name = name .. " reference" },
      "\\ref{" .. prefix .. ":$1}$0"
    ),
    parseText(
      { trig = "pr" .. name, name = name .. " reference" },
      "(\\ref{" .. prefix .. ":$1})$0"
    ),
  }
end

vim.list_extend(
  autosnippets,
  flatten({
    -- {{{ References
    ref("theorem", "thm"),
    ref("lemma", "lem"),
    ref("exercise", "exe"),
    ref("definition", "def"),
    ref("corollary", "cor"),
    ref("example", "exa"),
    {
      parseText({ trig = "ref", name = "reference" }, "\\ref{$1}$0"),
      parseText({ trig = "pref", name = "reference" }, "(\\ref{$1})$0"),
    },
    -- }}}
    {
      -- {{{ Misc
      parseText(trig("quote"), "``$1''$0"),
      parseText(trig("forcecr"), "{\\ \\\\\\\\}"),
      -- }}}
      -- {{{ Let ...
      snipText(
        { trig = "([Ll]et)", trigEngine = "pattern", name = "definition" },
        {
          f(function(_, snip)
            return snip.captures[1]
          end),
          t(" "),
          sn(1, fmt("${} ≔ {}$", { i(1), i(2) })),
        }
      ),
      -- }}}
      -- {{{ Display / inline math
      parseText({ trig = "dm", name = "display math" }, env("align*")),
      parseText({ trig = "im", name = "inline math" }, "\\$$1\\$$0"),
      -- }}}
    },
  })
)
-- }}}

function M.setup()
  ls.add_snippets("tex", autosnippets, {
    type = "autosnippets",
    default_priority = 0,
  })
end

return M