diff --git a/src/core/utils.jl b/src/core/utils.jl
index 25d1b1d69f1b21cf1a2035d459f64f2272435518..067c966a49ad86bb4aea24298478f693e0974133 100644
--- a/src/core/utils.jl
+++ b/src/core/utils.jl
@@ -164,14 +164,10 @@ end
 A utility function to make sure that a number is within a given set of bounds.
 Returns `max`/`min` if `x` is greater/less than this.
 """
-function bounds(x::Number; max::Number=Inf, min::Number=0)
-    if unit(x) != NoUnits
-        max = max*unit(x)
-        min = min*unit(x)
-    end
-    x > max ? max :
-        x < min ? min :
-        x
+function bounds(x::T; max::T=typemax(T), min::T=zero(T)) where {T<:Union{Real,Unitful.AbstractQuantity}}
+    x > max && return max
+    x < min && return min
+    return x
 end
 
 """
diff --git a/test/io_tests.jl b/test/io_tests.jl
index deebec3c249b7df52368432deb209001ec02b525..17b18ebde9078c8312a531ddc7221e5b94a93ffb 100644
--- a/test/io_tests.jl
+++ b/test/io_tests.jl
@@ -122,7 +122,19 @@ end
     @test length(AnnualDate(heute):birthday) == 14
     # bounds
     @test Ps.bounds(3) == 3
+    @test Ps.bounds(3.0) == 3.0
     @test Ps.bounds(-3) == 0
+    @test Ps.bounds(-3.0) == 0.0
     @test Ps.bounds(20, max=10) == 10
+    @test Ps.bounds(20.0, max=10.0) == 10.0
     @test Ps.bounds(-3, min=-10) == -3
+    @test Ps.bounds(-3.0, min=-10.0) == -3.0
+    @test Ps.bounds(3u"m") == 3u"m"
+    @test Ps.bounds(3.0u"m") == 3.0u"m"
+    @test Ps.bounds(-3u"m") == 0u"m"
+    @test Ps.bounds(-3.0u"m") == 0.0u"m"
+    @test Ps.bounds(20u"m", max=10u"m") == 10u"m"
+    @test Ps.bounds(20.0u"m", max=10.0u"m") == 10.0u"m"
+    @test Ps.bounds(-3u"m", min=-10u"m") == -3u"m"
+    @test Ps.bounds(-3.0u"m", min=-10.0u"m") == -3.0u"m"
 end