local fmt = string.format local type = type local ipairs = ipairs local get_mt = getmetatable local set_mt = setmetatable -- Schema local schema = {} local function is_raw_table(val) return type(val) == 'table' and not get_mt(val) end --- 过滤数组 local function ifilter(t, filter) local new = {} local filtered_num = 0 for i, v in ipairs(t) do if filter(v) then new[i - filtered_num] = v else filtered_num = filtered_num + 1 end end return new, filtered_num end ---@param ty string 类型 ---@param a string? 冠词 ---@return function local function type_checker(ty, a) local fmt_str = "%s (type: %s) isn't " if a then fmt_str = fmt_str..a..' ' end fmt_str = fmt_str..ty return function(self, testee) if type(testee) == ty then return true end return false, fmt(fmt_str, testee, type(testee)) end end --- `val`是`x`或`{x}`时返回`{x}`,返回的`{x}`与前一个`{x}`是同一个对象 local function ensure_wrapped(val) return is_raw_table(val) and val or {val} end local mts = set_mt({}, {__mode = 'k'}) ---@type {[metatable]: true} ---@param name string ---@param super_mt metatable? ---@param without_override boolean? ---@return metatable local function reg_mt(name, super_mt, without_override) local index = super_mt and super_mt.__index or {} if not without_override then index = set_mt({}, {__index = index}) end local mt = { __name = name, __index = index, } mts[mt] = true return mt end ---@param v any ---@return string | nil local function get_scm_type(v) local mt = get_mt(v) if not mts[mt] then return nil end return mt.__name end ---@param constraints table ---@return table | nil local function get_validators_from_constraints(constraints) local t = ifilter( ensure_wrapped(constraints.validators or constraints.validator), function(v) return assert(type(v) == 'function', 'validator需要是函数或元素为函数的表') end ) return t[1] and t or nil end ---@param validators function[]? ---@param val any ---@return boolean, string? local function validate_all(validators, val) if not validators then return true end for _, validator in ipairs(validators) do local valid, msg = validator(val) if not valid then return false, msg end end return true end local Any_mt = reg_mt('Any', nil) schema.Any = set_mt({ test = function() return true end, }, Any_mt) function Any_mt:__call(constraints) return set_mt({ super = self, validators = get_validators_from_constraints(constraints), }, get_mt(self)) end function Any_mt.__index:test(testee) if self.super then local valid, msg = self.super:test(testee) if not valid then return false, msg end end return validate_all(self.validators, testee) end function Any_mt.__index:assert(testee) local valid, msg = self:test(testee) if valid then return true end error(msg, 2) end local Nil_mt = reg_mt('Nil', Any_mt, true) schema.Nil = set_mt({ test = type_checker('nil'), }, Nil_mt) local Boolean_mt = reg_mt('Boolean', Any_mt, true) schema.Boolean = set_mt({ test = type_checker('boolean', 'a'), }, Boolean_mt) local Number_mt = reg_mt('Number', Any_mt) schema.Number = set_mt({ test = type_checker('number', 'a'), }, Number_mt) function Number_mt:__call(constraints) return set_mt({ super = self, int = constraints.int, lt = constraints.lt, gt = constraints.gt, le = constraints.le or constraints.max, ge = constraints.ge or constraints.min, ne = constraints.ne, validators = get_validators_from_constraints(constraints), }, Number_mt) end function Number_mt.__index:test(testee) if self.super then local valid, msg = self.super:test(testee) if not valid then return false, msg end end if self.int and math.fmod(testee, 1) ~= 0 then return false, fmt("%s isn't an integer", testee) end if self.lt and testee >= self.lt then return false, fmt("%s isn't < %s", testee, self.lt) end if self.gt and testee <= self.gt then return false, fmt("%s isn't > %s", testee, self.gt) end if self.le and testee > self.le then return false, fmt("%s isn't <= %s", testee, self.le) end if self.ge and testee < self.ge then return false, fmt("%s isn't >= %s", testee, self.ge) end if self.ne and testee == self.ne then return false, fmt('testee equals %s', self.ne) end return validate_all(self.validators, testee) end local String_mt = reg_mt('String', Any_mt) schema.String = set_mt({ test = type_checker('string', 'a') }, String_mt) function String_mt:__call(constraints) return set_mt({ super = self, max_len = constraints.max_len, min_len = constraints.min_len, pattern = constraints.pattern, validators = get_validators_from_constraints(constraints), }, String_mt) end function String_mt.__index:test(testee) if self.super then local valid, msg = self.super:test(testee) if not valid then return false, msg end end if self.max_len and #testee > self.max_len then return false, fmt("the length of %q (%d) exceeds %s", testee, #testee, self.max_len) end if self.min_len and #testee < self.min_len then return false, fmt("the length of %q (%d) is under %s", testee, #testee, self.min_len) end if self.pattern and not testee:match(self.pattern) then return false, fmt("%q doesn't match the pattern %q", testee, self.pattern) end return validate_all(self.validators, testee) end local Function_mt = reg_mt('Function', Any_mt) schema.Function = set_mt({ test = type_checker('function', 'a'), }, Function_mt) Function_mt.__call = Any_mt.__call function Function_mt.__index:test(testee) if self.super then local valid, msg = self.super:test(testee) if not valid then return false, msg end end return validate_all(self.validators, testee) end local Table_mt = reg_mt('Table', Any_mt) schema.Table = set_mt({ test = type_checker('table', 'a') }, Table_mt) function Table_mt:__call(constraints) local specific = {} local generic = {} for k, v in pairs(constraints) do local scm_type = get_scm_type(k) if scm_type then if scm_type == 'Literal' then specific[k.val] = v else generic[k] = v end elseif k ~= 'validators' and k ~= 'validator' then specific[k] = v end end return set_mt({ super = self, specific = specific, generic = generic, validators = get_validators_from_constraints(constraints) }, Table_mt) end function Table_mt.__index:test(testee) if self.super then local valid, msg = self.super:test(testee) if not valid then return false, msg end end for key_scm, val_scm in pairs(self.generic) do for testee_key, testee_val in pairs(testee) do if key_scm:test(testee_key) then local valid, msg = val_scm:test(testee_val) if not valid then return false, msg end end end end for key, val_scm in pairs(self.specific) do local valid, msg = val_scm:test(testee[key]) if not valid then return false, testee[key] == nil and fmt('`%s` misses field `%s`', testee, key) or msg end end return validate_all(self.validators, testee) end local Union_mt = reg_mt('Union', Any_mt) function schema.Union(...) local union = {} for i = 1, select('#', ...) do local sub_scm = select(i, ...) if sub_scm == nil then union[schema.Nil] = true elseif get_scm_type(sub_scm) == 'Union' then for scm_in_union in next, sub_scm do union[scm_in_union] = true end else union[sub_scm] = true end end return set_mt(union, Union_mt) end function Union_mt.__index:test(testee) for allowed_val in next, self do if get_scm_type(allowed_val) then if allowed_val:test(testee) then return true end elseif testee == allowed_val then return true end end return false, fmt('testee `%s` fails to match each value in the union: %s', testee, self) end for mt in next, mts do mt.__bor = schema.Union mt.__div = schema.Union end schema.Truthy = schema.Any{validator=function(v) return v end} schema.Falsy = schema.Any{validator=function(v) return not v end} return schema