-- module will not return anything, only register formatters with the main assert engine
local assert = require('luassert.assert')
local match = require('luassert.match')
local util = require('luassert.util')

local isatty, colors do
  local ok, term = pcall(require, 'term')
  isatty = io.type(io.stdout) == 'file' and ok and term.isatty(io.stdout)
  if not isatty then
    local isWindows = package.config:sub(1,1) == '\\'
    if isWindows and os.getenv("ANSICON") then
      isatty = true
    end
  end

  colors = setmetatable({
    none = function(c) return c end
  },{ __index = function(self, key)
    return function(c)
      for token in key:gmatch("[^%.]+") do
        c = term.colors[token](c)
      end
      return c
    end
  end
  })
end

local function fmt_string(arg)
  if type(arg) == "string" then
    return string.format("(string) '%s'", arg)
  end
end

-- A version of tostring which formats numbers more precisely.
local function tostr(arg)
  if type(arg) ~= "number" then
    return tostring(arg)
  end

  if arg ~= arg then
    return "NaN"
  elseif arg == 1/0 then
    return "Inf"
  elseif arg == -1/0 then
    return "-Inf"
  end

  local str = string.format("%.20g", arg)

  if math.type and math.type(arg) == "float" and not str:find("[%.,]") then
    -- Number is a float but looks like an integer.
    -- Insert ".0" after first run of digits.
    str = str:gsub("%d+", "%0.0", 1)
  end

  return str
end

local function fmt_number(arg)
  if type(arg) == "number" then
    return string.format("(number) %s", tostr(arg))
  end
end

local function fmt_boolean(arg)
  if type(arg) == "boolean" then
    return string.format("(boolean) %s", tostring(arg))
  end
end

local function fmt_nil(arg)
  if type(arg) == "nil" then
    return "(nil)"
  end
end

local type_priorities = {
  number = 1,
  boolean = 2,
  string = 3,
  table = 4,
  ["function"] = 5,
  userdata = 6,
  thread = 7
}

local function is_in_array_part(key, length)
  return type(key) == "number" and 1 <= key and key <= length and math.floor(key) == key
end

local function get_sorted_keys(t)
  local keys = {}
  local nkeys = 0

  for key in pairs(t) do
    nkeys = nkeys + 1
    keys[nkeys] = key
  end

  local length = #t

  local function key_comparator(key1, key2)
    local type1, type2 = type(key1), type(key2)
    local priority1 = is_in_array_part(key1, length) and 0 or type_priorities[type1] or 8
    local priority2 = is_in_array_part(key2, length) and 0 or type_priorities[type2] or 8

    if priority1 == priority2 then
      if type1 == "string" or type1 == "number" then
        return key1 < key2
      elseif type1 == "boolean" then
        return key1  -- put true before false
      end
    else
      return priority1 < priority2
    end
  end

  table.sort(keys, key_comparator)
  return keys, nkeys
end

local function fmt_table(arg, fmtargs)
  if type(arg) ~= "table" then
    return
  end

  local tmax = assert:get_parameter("TableFormatLevel")
  local showrec = assert:get_parameter("TableFormatShowRecursion")
  local errchar = assert:get_parameter("TableErrorHighlightCharacter") or ""
  local errcolor = assert:get_parameter("TableErrorHighlightColor")
  local crumbs = fmtargs and fmtargs.crumbs or {}
  local cache = {}
  local type_desc

  if getmetatable(arg) == nil then
    type_desc = "(" .. tostring(arg) .. ") "
  elseif not pcall(setmetatable, arg, getmetatable(arg)) then
    -- cannot set same metatable, so it is protected, skip id
    type_desc = "(table) "
  else
    -- unprotected metatable, temporary remove the mt
    local mt = getmetatable(arg)
    setmetatable(arg, nil)
    type_desc = "(" .. tostring(arg) .. ") "
    setmetatable(arg, mt)
  end

  local function ft(t, l, with_crumbs)
    if showrec and cache[t] and cache[t] > 0 then
      return "{ ... recursive }"
    end

    if next(t) == nil then
      return "{ }"
    end

    if l > tmax and tmax >= 0 then
      return "{ ... more }"
    end

    local result = "{"
    local keys, nkeys = get_sorted_keys(t)

    cache[t] = (cache[t] or 0) + 1
    local crumb = crumbs[#crumbs - l + 1]

    for i = 1, nkeys do
      local k = keys[i]
      local v = t[k]
      local use_crumbs = with_crumbs and k == crumb

      if type(v) == "table" then
        v = ft(v, l + 1, use_crumbs)
      elseif type(v) == "string" then
        v = "'"..v.."'"
      end

      local ch = use_crumbs and errchar or ""
      local indent = string.rep(" ",l * 2 - ch:len())
      local mark = (ch:len() == 0 and "" or colors[errcolor](ch))
      result = result .. string.format("\n%s%s[%s] = %s", indent, mark, tostr(k), tostr(v))
    end

    cache[t] = cache[t] - 1

    return result .. " }"
  end

  return type_desc .. ft(arg, 1, true)
end

local function fmt_function(arg)
  if type(arg) == "function" then
    local debug_info = debug.getinfo(arg)
    return string.format("%s @ line %s in %s", tostring(arg), tostring(debug_info.linedefined), tostring(debug_info.source))
  end
end

local function fmt_userdata(arg)
  if type(arg) == "userdata" then
    return string.format("(userdata) '%s'", tostring(arg))
  end
end

local function fmt_thread(arg)
  if type(arg) == "thread" then
    return string.format("(thread) '%s'", tostring(arg))
  end
end

local function fmt_matcher(arg)
  if not match.is_matcher(arg) then
    return
  end
  local not_inverted = {
    [true] = "is.",
    [false] = "no.",
  }
  local args = {}
  for idx = 1, arg.arguments.n do
    table.insert(args, assert:format({ arg.arguments[idx], n = 1, })[1])
  end
  return string.format("(matcher) %s%s(%s)",
                       not_inverted[arg.mod],
                       tostring(arg.name),
                       table.concat(args, ", "))
end

local function fmt_arglist(arglist)
  if not util.is_arglist(arglist) then
    return
  end
  local formatted_vals = {}
  for idx = 1, arglist.n do
    table.insert(formatted_vals, assert:format({ arglist[idx], n = 1, })[1])
  end
  return "(values list) (" .. table.concat(formatted_vals, ", ") .. ")"
end

assert:add_formatter(fmt_string)
assert:add_formatter(fmt_number)
assert:add_formatter(fmt_boolean)
assert:add_formatter(fmt_nil)
assert:add_formatter(fmt_table)
assert:add_formatter(fmt_function)
assert:add_formatter(fmt_userdata)
assert:add_formatter(fmt_thread)
assert:add_formatter(fmt_matcher)
assert:add_formatter(fmt_arglist)
-- Set default table display depth for table formatter
assert:set_parameter("TableFormatLevel", 3)
assert:set_parameter("TableFormatShowRecursion", false)
assert:set_parameter("TableErrorHighlightCharacter", "*")
assert:set_parameter("TableErrorHighlightColor", isatty and "red" or "none")
