-- UnitTests provides unit testing for other Lua scripts. For details see [[Wikipedia:Lua#Unit_testing]].
-- For user documentation see talk page.
local UnitTests = {}
local insert, concat = table.insert, table.concat;
local gsub, text_nowiki = string.gsub, mw.text.nowiki;

function first_difference(s1, s2)
	local max = math.min(#s1, #s2)
	for i = 1, max do
		if s1:sub(i,i) ~= s2:sub(i,i) then return i end
	end
	return max+1;
end

function nowiki(text)
		return gsub(text_nowiki(text), '[\n ]', {['\n']='<br/>',[' ']='&#32;'});
end

function UnitTests:equals(name, got, want, options)
	local name = nowiki(name);
	local nowiki = options and options.nowiki and nowiki or function(text) return text; end
	
	self.results[#self.results+1] = {
		equals = want == got,
		name = name,
		got = nowiki(got),
		want = nowiki(want)
	};
	
	return self;
end

function UnitTests:preprocess_equals(got, want, options)
	return self:equals(got, self.frame:preprocess(got), want, options);
end

function UnitTests:preprocess_equals_many(prefix, suffix, cases, options)
	local frame, got_case = self.frame;
	for _, case in ipairs(cases) do
		got_case = concat{prefix, case[1], suffix};
		self:equals(got_case, frame:preprocess(got_case), case[2], options);
	end
	return self;
end

function UnitTests:preprocess_equals_preprocess(got, want, options)
	local frame = self.frame;
	return self:equals(got, frame:preprocess(got), frame:preprocess(want), options);
end

function UnitTests:preprocess_equals_preprocess_many(prefix1, suffix1, prefix2, suffix2, cases, options)
	local frame, got_case, want_case = self.frame;
	for _, case in ipairs(cases) do
		got_case = concat{prefix1, case[1], suffix1};
		want_case = concat{prefix2, case[2] or case[1], suffix2};
		self:equals(got_case, frame:preprocess(got_case), frame:preprocess(want_case), options);
	end
	return self;
end

local function deep_compare(t1, t2, ignore_mt)
	local ty1, ty2 = type(t1), type(t2)
	if ty1 ~= ty2 then return false end
	if ty1 ~= 'table' then return t1 == t2 end
	local mt = getmetatable(t1) or getmetatable(t2);
	if not ignore_mt and mt and mt.__eq then return t1 == t2 end
	for k, v1 in pairs(t1) do
		local v2 = t2[k]
		if v2 == nil or not deep_compare(v1, v2, ignore_mt) then return false end
	end
	for k, v2 in pairs(t2) do
		local v1 = t1[k]
		if v1 == nil or not deep_compare(v1, v2, ignore_mt) then return false end
	end
	return true
end
 
function serialize(v)
	if type(v) == 'table' then
		local mt = getmetatable(v);
		if mt and mt.__tostring then
			return tostring(v);
		end
		local result = {};
		for key, val in pairs(v) do
			if type(key) == 'number' then
				insert(result, serialize(val));
			else
				insert(result, concat{'[', serialize(key), '] = ', serialize(val)})
			end
		end
		return concat{'{', concat(result, ','), '}'};
	elseif type(v) == 'string' then
		return string.format("%q", string.gsub(v, '\n', '\\n'));
	else
		return tostring(v);
	end
end

function UnitTests:equals_deep(name, got, want, options)
	local name, differs = nowiki(name), nil;
	local nowiki = options and options.nowiki and nowiki or function(text) return text; end
	
	self.results[#self.results+1] = {
		equals = deep_compare(got, want),
		name = name,
		got = nowiki(serialize(got)),
		want = nowiki(serialize(want))
	};
	
	return self;
end

function UnitTests:run(frame)
	local header = '{| class="wikitable unittests"\n! !! Text !! Expected !! Actual';
	local tick, cross = frame:preprocess('{{tick}}'), frame:preprocess('{{cross}}');
	local failures, differs_at = 0, frame.args.differs_at;
	local results, differs = {};
	
	self.frame = frame;
	
	if differs_at then header = header .. ' !! Differs at'; end
	for i,test in ipairs(self.tests) do
		results[#results+1] = concat{ header, "\n|+ '''", test, "''':" };
		
		self.results = {};
		self[test](self, frame);
		
		for r,case in ipairs(self.results) do
			if case.equals then
				results[#results+1] = concat{ '|- class="test-pass"\n| ', tick };
				differs = nil;
			else
				results[#results+1] = concat{ '|- class="test-fail"\n| ', cross };
				differs = differs_at and first_difference(case.want, case.got);
				failures = failures + 1;
			end
			results[#results+1] = concat({ '', case.name, case.want, case.got, case.differs }, '\n|');
		end
		
		results[#results+1] = "|}\n";
	end
	
	insert(results, 1, string.format(
		'<span style="color:%s; font-weight:bold;">%s</span>\n',
		failures == 0 and "#008000" or "#800000",
		failures == 0 and 'All tests passed.' or string.format('%d tests failed.', failures)
	));
	results = concat(results, '\n');
	self.results = nil;
	return results;
end

function UnitTests:__newindex(k, v)
	if type(v) == 'function' then
		table.insert(self.tests, k);
	end
	rawset(self, k, v);
end

UnitTests.__index = UnitTests;

function UnitTests:new()
	local o = { tests = {} };
	function o.run_tests(frame) return o:run(frame) end;
	return setmetatable(o, self);
end

local p = UnitTests:new();
return p