diff --git a/modules/_private_utils.lua b/modules/_private_utils.lua new file mode 100644 index 0000000..de92786 --- /dev/null +++ b/modules/_private_utils.lua @@ -0,0 +1,11 @@ +-- Functions exported by utils.lua but needed by vec2 or vec3 (which utils.lua requires) + +local private = {} +local floor = math.floor + +function private.round(value, precision) + if precision then return utils.round(value / precision) * precision end + return value >= 0 and floor(value+0.5) or ceil(value-0.5) +end + +return private diff --git a/modules/bound2.lua b/modules/bound2.lua index 8c049a4..319f8ed 100644 --- a/modules/bound2.lua +++ b/modules/bound2.lua @@ -152,6 +152,14 @@ function bound2.contains(a, v) and a.max.x >= v.x and a.max.y >= v.y end +-- Round all components of all vectors to nearest int (or other precision). +-- @tparam vec3 a bound to round. +-- @tparam precision Digits after the decimal (round number if unspecified) +-- @treturn vec3 Rounded bound +function bound2.round(a, precision) + return bound2.new(a.min:round(precision), a.max.round(precision)) +end + --- Return a formatted string. -- @tparam bound2 a bound to be turned into a string -- @treturn string formatted diff --git a/modules/bound3.lua b/modules/bound3.lua index 397640c..7f02ed2 100644 --- a/modules/bound3.lua +++ b/modules/bound3.lua @@ -152,6 +152,14 @@ function bound3.contains(a, v) and a.max.x >= v.x and a.max.y >= v.y and a.max.z >= v.z end +-- Round all components of all vectors to nearest int (or other precision). +-- @tparam vec3 a bound to round. +-- @tparam precision Digits after the decimal (round number if unspecified) +-- @treturn vec3 Rounded bound +function bound3.round(a, precision) + return bound3.new(a.min:round(precision), a.max:round(precision)) +end + --- Return a formatted string. -- @tparam bound3 a bound to be turned into a string -- @treturn string formatted diff --git a/modules/utils.lua b/modules/utils.lua index 933cbca..e6d0b8f 100644 --- a/modules/utils.lua +++ b/modules/utils.lua @@ -4,6 +4,7 @@ local modules = (...): gsub('%.[^%.]+$', '') .. "." local vec2 = require(modules .. "vec2") local vec3 = require(modules .. "vec3") +local private = require(modules .. "_private_utils") local abs = math.abs local ceil = math.ceil local floor = math.floor @@ -103,10 +104,7 @@ end -- @param value -- @param precision -- @return number -function utils.round(value, precision) - if precision then return utils.round(value / precision) * precision end - return value >= 0 and floor(value+0.5) or ceil(value-0.5) -end +utils.round = private.round --- Wrap `value` around if it exceeds `limit`. -- @param value diff --git a/modules/vec2.lua b/modules/vec2.lua index 01a0b97..8347b16 100644 --- a/modules/vec2.lua +++ b/modules/vec2.lua @@ -3,6 +3,7 @@ local modules = (...):gsub('%.[^%.]+$', '') .. "." local vec3 = require(modules .. "vec3") +local private = require(modules .. "_private_utils") local acos = math.acos local atan2 = math.atan2 local sqrt = math.sqrt @@ -321,6 +322,14 @@ function vec2.to_polar(a) return radius, theta end +-- Round all components to nearest int (or other precision). +-- @tparam vec2 a Vector to round. +-- @tparam precision Digits after the decimal (round numebr if unspecified) +-- @treturn vec2 Rounded vector +function vec2.round(a, precision) + return vec2.new(private.round(a.x, precision), private.round(a.y, precision)) +end + --- Return a formatted string. -- @tparam vec2 a Vector to be turned into a string -- @treturn string formatted diff --git a/modules/vec3.lua b/modules/vec3.lua index bf3d6bd..8e21caa 100644 --- a/modules/vec3.lua +++ b/modules/vec3.lua @@ -1,6 +1,8 @@ --- A 3 component vector. -- @module vec3 +local modules = (...):gsub('%.[^%.]+$', '') .. "." +local private = require(modules .. "_private_utils") local sqrt = math.sqrt local cos = math.cos local sin = math.sin @@ -254,6 +256,14 @@ function vec3.lerp(a, b, s) return a + (b - a) * s end +-- Round all components to nearest int (or other precision). +-- @tparam vec3 a Vector to round. +-- @tparam precision Digits after the decimal (round numebr if unspecified) +-- @treturn vec3 Rounded vector +function vec3.round(a, precision) + return vec3.new(private.round(a.x, precision), private.round(a.y, precision), private.round(a.z, precision)) +end + --- Unpack a vector into individual components. -- @tparam vec3 a Vector to unpack -- @treturn number x diff --git a/spec/bound2_spec.lua b/spec/bound2_spec.lua index a23322e..717ca1a 100644 --- a/spec/bound2_spec.lua +++ b/spec/bound2_spec.lua @@ -171,6 +171,15 @@ describe("bound2:", function() assert.is_not_true(a:contains(vec2(2,3))) end) + it("rounds a bound2", function() + local a = bound2(vec2(1.1,1.9), vec2(3.9,5.1)):round() + + assert.is.equal(1, a.min.x) + assert.is.equal(2, a.min.y) + assert.is.equal(4, a.max.x) + assert.is.equal(5, a.max.y) + end) + it("checks for bound2.zero", function() assert.is.equal(0, bound2.zero.min.x) assert.is.equal(0, bound2.zero.min.y) diff --git a/spec/bound3_spec.lua b/spec/bound3_spec.lua index e3ae232..110159c 100644 --- a/spec/bound3_spec.lua +++ b/spec/bound3_spec.lua @@ -208,6 +208,17 @@ describe("bound3:", function() assert.is_not_true(a:contains(vec3(2,3,7))) end) + it("rounds a bound3", function() + local a = bound3(vec3(1.1,1.9,3), vec3(3.9,5.1,6)):round() + + assert.is.equal(1, a.min.x) + assert.is.equal(2, a.min.y) + assert.is.equal(3, a.min.z) + assert.is.equal(4, a.max.x) + assert.is.equal(5, a.max.y) + assert.is.equal(6, a.max.z) + end) + it("checks for bound3.zero", function() assert.is.equal(0, bound3.zero.min.x) assert.is.equal(0, bound3.zero.min.y) diff --git a/spec/vec2_spec.lua b/spec/vec2_spec.lua index 8b826ad..29432bd 100644 --- a/spec/vec2_spec.lua +++ b/spec/vec2_spec.lua @@ -189,6 +189,12 @@ describe("vec2:", function() assert.is.equal("(+0.000,+0.000)", b) end) + it("rounds a 2-vector", function() + local a = vec2(1.1,1.9):round() + assert.is.equal(a.x, 1) + assert.is.equal(a.y, 2) + end) + -- Do this last, to insulate tests from accidental state contamination it("converts a vec3 to vec2 using the constructor", function() local vec3 = require "modules.vec3" diff --git a/spec/vec3_spec.lua b/spec/vec3_spec.lua index 79eca55..fc87135 100644 --- a/spec/vec3_spec.lua +++ b/spec/vec3_spec.lua @@ -202,4 +202,11 @@ describe("vec3:", function() local b = a:to_string() assert.is.equal("(+0.000,+0.000,+0.000)", b) end) + + it("rounds a 3-vector", function() + local a = vec3(1.1,1.9,3):round() + assert.is.equal(a.x, 1) + assert.is.equal(a.y, 2) + assert.is.equal(a.z, 3) + end) end)