diff --git a/src/core/utils.jl b/src/core/utils.jl index 25d1b1d69f1b21cf1a2035d459f64f2272435518..067c966a49ad86bb4aea24298478f693e0974133 100644 --- a/src/core/utils.jl +++ b/src/core/utils.jl @@ -164,14 +164,10 @@ end A utility function to make sure that a number is within a given set of bounds. Returns `max`/`min` if `x` is greater/less than this. """ -function bounds(x::Number; max::Number=Inf, min::Number=0) - if unit(x) != NoUnits - max = max*unit(x) - min = min*unit(x) - end - x > max ? max : - x < min ? min : - x +function bounds(x::T; max::T=typemax(T), min::T=zero(T)) where {T<:Union{Real,Unitful.AbstractQuantity}} + x > max && return max + x < min && return min + return x end """ diff --git a/test/io_tests.jl b/test/io_tests.jl index deebec3c249b7df52368432deb209001ec02b525..17b18ebde9078c8312a531ddc7221e5b94a93ffb 100644 --- a/test/io_tests.jl +++ b/test/io_tests.jl @@ -122,7 +122,19 @@ end @test length(AnnualDate(heute):birthday) == 14 # bounds @test Ps.bounds(3) == 3 + @test Ps.bounds(3.0) == 3.0 @test Ps.bounds(-3) == 0 + @test Ps.bounds(-3.0) == 0.0 @test Ps.bounds(20, max=10) == 10 + @test Ps.bounds(20.0, max=10.0) == 10.0 @test Ps.bounds(-3, min=-10) == -3 + @test Ps.bounds(-3.0, min=-10.0) == -3.0 + @test Ps.bounds(3u"m") == 3u"m" + @test Ps.bounds(3.0u"m") == 3.0u"m" + @test Ps.bounds(-3u"m") == 0u"m" + @test Ps.bounds(-3.0u"m") == 0.0u"m" + @test Ps.bounds(20u"m", max=10u"m") == 10u"m" + @test Ps.bounds(20.0u"m", max=10.0u"m") == 10.0u"m" + @test Ps.bounds(-3u"m", min=-10u"m") == -3u"m" + @test Ps.bounds(-3.0u"m", min=-10.0u"m") == -3.0u"m" end