Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Add generic fallback to all scalar functions" #86

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

devmotion
Copy link
Member

@devmotion devmotion commented Jan 22, 2025

I propose to revert #71 which added generally unsafe fallback definitions. There's no guarantee in general that they do not error since argument values are not checked and no type constraints are applied.

Copy link

codecov bot commented Jan 22, 2025

Codecov Report

Attention: Patch coverage is 50.00000% with 1 line in your changes missing coverage. Please review.

Project coverage is 92.19%. Comparing base (291ca27) to head (edfd502).

Files with missing lines Patch % Lines
src/NaNMath.jl 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master      #86      +/-   ##
==========================================
- Coverage   92.66%   92.19%   -0.47%     
==========================================
  Files           1        1              
  Lines         150      141       -9     
==========================================
- Hits          139      130       -9     
  Misses         11       11              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ChrisRackauckas
Copy link
Member

This has merge conflicts and doesn't make sense. There are checks in the functions for the domains.

@ChrisRackauckas ChrisRackauckas deleted the dw/revert_fallback branch January 22, 2025 21:36
@devmotion
Copy link
Member Author

There are checks in the functions for the domains.

The point of NaNMath for me (and others I guess) is that sqrt, pow etc. don't error. That's not guaranteed if NaNMath just falls back to Base.sqrt or ^. The arguments have to be checked in NaNMath, either based on the type or the value, to ensure that whatever is called does not error.

The pow changes in #71 also predated #83, so it's not even clear that there's any immediate benefit from the changes of pow in #71.

@devmotion devmotion restored the dw/revert_fallback branch January 22, 2025 21:56
@devmotion devmotion reopened this Jan 22, 2025
@devmotion
Copy link
Member Author

This has merge conflicts

Only after #85 was merged. When the PR was opened it had no conflicts.

@oameye
Copy link

oameye commented Jan 22, 2025

The point of NaNMath for me (and others I guess) is that sqrt, pow etc. don't error.

But before they also errored with MethodError: NanMath.sqrt(::ComplexF64). So nothing really changed, except that it now will error with ArgumentError.

@devmotion
Copy link
Member Author

That's besides the point - of course, if something is not supported by e.g. NaNMath.sqrt then calling it it with such an argument will error. If Complex numbers are safe, they must be supported explicitly as in #83.

But falling back to some function in all cases completely breaks the safety promises of NaNMath. The main point is to not throw ArgumentErrors or DomainErrors.

@ChrisRackauckas
Copy link
Member

I can see the philosophical argument, but I can't see a practical application other than real-time controls that wouldn't just want to have the fallback. What about a middle ground: a Preference "strict=false" that can be turned to true that removes the fallbacks?

@MilesCranmer
Copy link

MilesCranmer commented Jan 22, 2025

I agree with @ChrisRackauckas, the fallbacks are good. NaNs are an instance of floats, so it makes sense to only define NaNs for those and then pass thru for everything else. NaNMath provides an interface if a package does want to add special handling for their type, but the default is to propagate the DomainError. It just feels like what I'd expect for any interface on an abstract number.

Maybe another interface is to just define a nan method so its easier to extend:

function log(x)
    if applicable(zero, x) & applicable(<, x, zero(x)) && x < zero(x)
        return nan(x)
    else
        return Base.log(x)
    end
end

nan(::Float64) = NaN
# etc

Then a custom type only needs to ensure that zero, <, and nan are defined. And by default it would just be AbstractFloat.

@mlubin
Copy link
Collaborator

mlubin commented Jan 22, 2025

I can't see a practical application other than real-time controls that wouldn't just want to have the fallback.

If I understand your point, I guess that same argument could be applied for NaNMath as a whole?

To throw in my two cents as the original NaNMath author, I agree with @devmotion's point. Providing fallbacks that could throw domain errors breaks what I thought was the API contract. (If you'd like to change the contract, that's potentially reasonable but should be more explicit.)

@MilesCranmer
Copy link

Or maybe an interface like

valid_domain(::typeof(log), _) = true
valid_domain(::typeof(log), x::AbstractFloat) = x < zero(x)

@devmotion
Copy link
Member Author

I want to use NaNMath in my MTK-based ODE models to ensure that function evaluation never crashes but only returns NaN if there's a numerical issue. This safety aspect is the main point of using NaNMath IMO - otherwise I could directly use the Base methods and hope for the best. The main API guarantee of NaNMath.sqrt, NaNMath.pow etc. is that they do not error if values are outside of the supported domain.

Thus IMO NaNMath functions must only be extended explicitly and in such a way that this API guarantee is satisfied. This means that

  1. support for Base types such as FloatXX and Complex etc. must be done explicitly in NaNMath
  2. support for non-standard number types in other packages must be added by extending NaNMath functions (nowadays most likely in a Pkg extension)

It's just not feasible to do 2. in NaNMath because the maintenance burden on NaNMath would be too high; it seems much easier to deal with this (there are only a handful of functions anyway!) in the respective packages and thereby distribute the maintenance burden (maintainers of these packages should also be the ones most familiar with their number types).

I also don't expect that there are too many packages that would have to do 2. since the number of relevant custom number types is not too large in my experience.

@MilesCranmer
Copy link

I'd personally vote to change the contract.

Because of the speed issues noted in #63 I ended up implementing an in-house version of NaNMath from scratch (actually started before I noticed this library existed):
https://github.com/MilesCranmer/SymbolicRegression.jl/blob/9fabc303dd33c624739e759d960374c7f85e56f7/src/Operators.jl#L36-L116. Without knowledge of the API rules here, this organically grew a generic fallback function because of some issues due to compatibility with LoopVectorization and other user libraries. LoopVectorization skips DomainErrors automatically anyways, and it was a huge pain to write all the different signatures/ambiguities.

As mentioned, DomainErrors will still get hit when necessary, should such a user type even generate them. I don't think it's a pragmatic goal to pursue one error over another because both tracebacks give you a clear error with information about what method needs to be added.

@devmotion
Copy link
Member Author

I'll vote against a change for the reasons I outlined above. If a different behaviour is desired in SymbolicRegressions, then probably NaNMath is not the right tool for its use case and a different package is needed.

I want the safety guarantees of NaNMath and will happily accept if the number of supported types is restricted to a safe set. I'll also accept if it's slower than unsafe alternatives.

In my use case, otherwise errors might be thrown deep inside of the ODE solver, even within subsequent time steps, and there's no good way for me to handle these. I mean there's surely a reason for why MTK by default creates ODE functions with NaNMath.

@MilesCranmer
Copy link

MilesCranmer commented Jan 22, 2025

I wouldn't say it's necessarily safer in either setting. It's just hitting a different error. I do not think a MethodError deep in a call stack is any safer or more helpful than a DomainError. If some end user is passing a non-Real type then I think a DomainError (or some crazy HyperbolicTopologyError) would even be more helpful than a MethodError error from a package which is only meant to modify fallback behavior for floats, which might otherwise be confusing.

@ChrisRackauckas
Copy link
Member

I think the instances that have come up are ForwardDiff, BigFloat, Tracker, and ReverseDiff. For the autodiff ones, the issue is that not supporting the autodiff stuff entirely makes using NaNMath more generally pretty difficult. And on the other hand, one of the big reasons for NaNMath is for a solver to "try a step", and then pull back / change dt / d alpha, whatever the step size is in an ODE, Newton method, gradient decent, etc. In that context, the gradient calculation is unlikely to be the part calculating the NaNs, it's the "try at a new point" phase that is most likely the thing that has now stepped out of bounds. For that context, we then have that NaNMath not having full support on these AD types as causing an error where the algorithm would have otherwise worked, which is why we've had 3 or 4 PRs / issues over the years to add these fallbacks to the library. People keep running into it, seeing that there is a trivial fix, and wondering why we won't just add it.

But on the other side, I do agree that if you do that, you no longer can guarantee that NaNMath will never error. Once one fallback can error, there is no 100% guarantee. And that's not great either.

But if you go with a Preference system, then at least it's possible for both choices to be made.

@devmotion
Copy link
Member Author

ForwardDiff, DiffRules, ReverseDiff, and Tracker have all supported NaNMath for many years, and also defined derivatives for NaNMath functions. So I don't think AD is necessarily a problem with NaNMath. The only problem I ran into and which was recently fixed both in ForwardDiff and Symbolics is that the derivative of NaNMath.pow did use Base.log instead of NaNMath.log: JuliaDiff/ForwardDiff.jl#717 JuliaSymbolics/Symbolics.jl#1400 So the problem was not missing but unsafe code!

But if you go with a Preference system, then at least it's possible for both choices to be made.

I'm not generally against it but given the experience with ForwardDiff and the original/current API contract of NaNMath the default setting should be the NaN-safe one without fallbacks. I'm a bit worried that it will be difficult for users to discover the preference setting - e.g., my impression is that NaN-safe mode in ForwardDiff is not known very widely even though it is explained in the official docs. Another consequence of a preference setting might be that other packages would be pressured into adding support for NaNMath anyway (to support the setting without fallbacks), which might make the unsafe setting less attractive and useful in general.

@longemen3000
Copy link

longemen3000 commented Jan 23, 2025

Please me correct me if I'm wrong, but if I remember correctly, the use of preferences is meant to be used by end users, a package developer can offer preferences and according to those preferences, but not change preferences in other packages. So, for package developers, changing the preferences in ForwardDiff (or NaNMath, if they are enabled) is a problem.

In particular, i would like opt-in support, but more in the likes of ForwardDiff.can_dual than a preference system. Maybe a two-step verification system?:

#signalling support for NaNMath, that is, 
#if there is a function f(x::T1)::T2 with a domain for x, 
#then NaNMath.f(x::T1)::T2 will return f(x) if x is in the domain of f, and T2(NaN) otherwise. 

can_nan(::Type{Float64}) = true
can_nan(::typeof(Base.log), ::Type{Float64}) = true #if we only support a limited set of functions

#we need to convert to a valid input type
nan_promote(x::Float64) = x
nan_promote(x::Int) = float(x)
#TODO: reasonable defaults for Number, error on anything else
can_nan(x::Tuple) = mapreduce(can_nan, &, x)
can_nan(x) = can_nan(typeof(nan_promote(x)))
can_nan(x::Type) = false
#TODO: maybe define another function for more specific type selection? Also, is this granularity ok? 
can_nan(f, x) = can_nan(x)
log(x::T) where T = nan_log(x, Val{can_nan(Base.log, T)}())

#maybe this function should be inside a module? 
function nan_log(x::T, ::Val{true}) where T
  x < 0 ? T(NaN) : Base.log(x)
end

function nan_log(x::T, ::Val{false})
  throw(ArgumentError("""
NaNMath.log explicitly does not support $T. If this type 
can support calculating NaN, define NaNMath.can_nan(::typeof(log), x::Type{$T}
"""
end

In this way, NaNMath acts as just another composable layer, instead of another set of rules that need to be defined for packages and intersection of packages.

@MilesCranmer
Copy link

MilesCranmer commented Jan 23, 2025

Okay, how about the following idea? I like this approach a lot.

Basically, you could have a submodule Generic which defines a method with a generic fallback. The top-level module imports that and calls it when available.

NaNMath.log          # no fallback
NaNMath.Generic.log  # has fallback

These are separate functions. The Generic submodule would then simply be the following:

module NaNMath

#= existing code =#

module Generic
    import ..log as _log
    
    log(x) = applicable(_log, x) ? _log(x) : Base.log(x)
    
    #= other functions =#
end
end

Thus, should a user define a custom NaNMath extension, it will also be accessible to the generic method.

  • The top-level module would continue to define the same API contract as before.
  • The generic method will continue to be compatible with everything that Base.log is compatible with.
  • applicable will get inlined by the compiler, so this won't affect performance.

Wdyt?

Could also be inverted and have a Strict submodule. Whatever makes more sense.

@devmotion
Copy link
Member Author

Please let's just merge this PR and move this discussion to an issue. Breaking the core API guarantees and putting this change in a non-breaking release is basically the worst scenario, regardless of which API design you'd prefer.

@devmotion
Copy link
Member Author

In particular, i would like opt-in support

Well, then it seems you'd want the API without generic fallback 😄

I should also emphasize that even without generic fallback by no means every number type T has to extend NaNMath.sqrt etc. By default, NaNMath already promotes arguments to floating point numbers, so in general you only have to define methods for float(T). Then you'll also naturally have NaNs supported by the involved types. If you actually want and if the computations are safe (ie guaranteed to never error, e.g. as ^(::Number, ::Integer)), then you can avoid the float promotions and work with number types that do not (necessarily) support NaNs.

Wdyt?

You can just define such a method in your package(s) if you want this generality. No need to put it in NaNMath and make the API messy/confusing for users and downstream packages (eg which version should MTK use?).

@ChrisRackauckas
Copy link
Member

Let's simplify this discussion a little bit: can you give one example of using the current master/release version where you cause the sqrt to throw a DomainError instead of a NaN? Instead of a hypothetical break, show the example of how bad of a break you think has occured on the sqrt function.

@devmotion
Copy link
Member Author

It broke the API guarantees for any special number type that had not opt in to the NaNMath API. For instance,

julia> using NaNMath, Unitful

julia> NaNMath.sqrt(u"-1s^2")
ERROR: DomainError with -1.0:
sqrt was called with a negative real argument but will only return a complex result if called with a complex argument. Try sqrt(Complex(x)).
Stacktrace:
 [1] throw_complex_domainerror(f::Symbol, x::Float64)
   @ Base.Math ./math.jl:33
 [2] sqrt
   @ ./math.jl:608 [inlined]
 [3] sqrt
   @ ./math.jl:1531 [inlined]
 [4] sqrt
   @ ~/.julia/packages/Unitful/nwwOk/src/quantities.jl:205 [inlined]
 [5] sqrt(x::Quantity{Int64, 𝐓², Unitful.FreeUnits{(s²,), 𝐓², nothing}})
   @ NaNMath ~/.julia/packages/NaNMath/h9tir/src/NaNMath.jl:19
 [6] top-level scope
   @ REPL[19]:1
julia> using NaNMath, DynamicQuantities

julia> NaNMath.sqrt(u"-1s^2")
ERROR: DomainError with -1.0:
sqrt was called with a negative real argument but will only return a complex result if called with a complex argument. Try sqrt(Complex(x)).
Stacktrace:
 [1] throw_complex_domainerror(f::Symbol, x::Float64)
   @ Base.Math ./math.jl:33
 [2] sqrt
   @ ./math.jl:608 [inlined]
 [3] sqrt
   @ ~/.julia/packages/DynamicQuantities/OrITh/src/math.jl:146 [inlined]
 [4] sqrt(x::Quantity{Float64, Dimensions{FixedRational{Int32, 25200}}})
   @ NaNMath ~/.julia/packages/NaNMath/rftLo/src/NaNMath.jl:19
 [5] top-level scope
   @ REPL[4]:1

-1 is a bit extreme but in a numerical solver it's easy to hit slightly negative values. Without the NaNMath API guarantees, why should I use NaNMath at all?

But I'm digressing. My main point right now is just: This change completely broke existing API guarantees and hence should not have been released in a non-breaking release. It should be reverted.

@ChrisRackauckas
Copy link
Member

How about:

using NaNMath, Unitful

NaNMath.sqrt(x::Real) = sqrt(float(x))
NaNMath.sqrt(x::Complex) = Base.sqrt(float(x))
NaNMath.sqrt(x::T) where {T<:Number} = x < zero(x) ? T(NaN) : Base.sqrt(float(x))
julia> NaNMath.sqrt(u"-1.0s^2")
NaN

?

Maybe the thing that NaNMath actually needs is just an interface nangen(x::Number) where nangan(x::Float64) = NaN but then for example this allows for NaN on BigFloat etc. to be defined. Then the generic fallback can be:

function NaNMath.sqrt(x::T) where {T<:Number} = 
  hasnan(T) || error("NaNMath not supported for this type as no NaN for this type is possible")
  x < zero(x) ? nangen(x) : Base.sqrt(float(x))
end

and so then you need to define hasnan and nangen for a new type to work? Would that be sufficiently safe for your concerns?

@devmotion
Copy link
Member Author

devmotion commented Jan 23, 2025

That's similar to how I would imagine it.

Maybe the thing that NaNMath actually needs is just an interface

I don't think that's needed. I would use

NaNMath.sqrt(x::Complex) = Base.sqrt(float(x))
NaNMath.sqrt(x::Number) = x < zero(x) ? typeof(Base.sqrt(-float(x)))(NaN) : Base.sqrt(float(x))

should always be fine:

julia> using Unitful

julia> NaNMath.sqrt(-1)
NaN

julia> NaNMath.sqrt(-1.0)
NaN

julia> NaNMath.sqrt(big"-1")
NaN

julia> typeof(NaNMath.sqrt(big"-1"))
BigFloat

julia> NaNMath.sqrt(u"-1s^2")
NaN s

Edit: I edited the initial version such that unitful numbers are supported as well. Note that the units should be s, not s^2.


Or maybe slightly simpler:

NaNMath.sqrt(x::Complex) = Base.sqrt(float(x))
NaNMath.sqrt(x::Number) = x < zero(x) ? Base.sqrt(typeof(float(x))(NaN)) : Base.sqrt(float(x))

@ChrisRackauckas
Copy link
Member

I would be fine with this form. It's close to what's in there but handles the non AbstractFloat stuff better. Could you change to this form? If we all agree this is sufficiently safe then I think we get the best of both worlds?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants