Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multiple dispatch for compact layers #14

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions src/compact.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import Flux: _big_show

"""
@compact(forward::Function; name=nothing, parameters...)
@compact(forward::Function[, layer_type]; name=nothing, parameters...)

Creates a layer by specifying some `parameters`, in the form of keywords,
and (usually as a `do` block) a function for the forward pass.
You may think of `@compact` as a specialized `let` block creating local variables
that are trainable in Flux.
Declared variable names may be used within the body of the `forward` function.
that are trainable in Flux. Declared variable names may be used within the
body of the `forward` function.

Here is a linear model:

Expand All @@ -28,7 +28,7 @@ end
d(ones(5, 10)) # 7×10 Matrix as output.
```

Finally, here is a simple MLP:
Here is a simple MLP:

```
using Flux
Expand Down Expand Up @@ -78,11 +78,33 @@ println(model) # "Linear(3 => 1)"

This can be useful when using `@compact` to hierarchically construct
complex models to be used inside a `Chain`.

You can also specify a symbol to identify the type of layer, which is
useful for dispatching on types of layers (since the function block
will generate a new type each time it is evaluated):

```
model = @compact(MyLayer, w=rand(3)) do x
sum(w .* x)
end

f(::CompactLayer{:MyLayer}) = 1
f(::CompactLayer{:Default}) = 0

println(f(model)) # 1
```
"""
macro compact(fex, kwexs...)
# check input
Meta.isexpr(fex, :(->)) || error("expects a do block")
isempty(kwexs) && error("expects keyword arguments")
# Check if first kwexes is just a Symbol:
(layer_symbol, kwexs) = if first(kwexs) isa Symbol
(first(kwexs), Base.tail(kwexs))
else
(:Default, kwexs)
end
layer_symbol = QuoteNode(layer_symbol)
all(ex -> Meta.isexpr(ex, (:kw,:(=))), kwexs) || error("expects only keyword argumens")

# check if user has named layer:
Expand Down Expand Up @@ -112,7 +134,7 @@ macro compact(fex, kwexs...)
return esc(quote
let
$(assigns...)
$CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(vars...))
$CompactLayer(Val($layer_symbol), $fex, $name, ($layer, $input, $block), $setup; $(vars...))
end
end)
end
Expand All @@ -128,17 +150,21 @@ function addprefix!(ex::Expr, self, vars)
end
addprefix!(not_ex, self, vars) = nothing

struct CompactLayer{F,NT1<:NamedTuple,NT2<:NamedTuple}
struct CompactLayer{S,F,NT1<:NamedTuple,NT2<:NamedTuple}
symbol::Val{S}
fun::F
name::Union{String,Nothing}
strings::NTuple{3,String}
setup_strings::NT1
variables::NT2
end
CompactLayer(f::Function, name::Union{String,Nothing}, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, name, str, setup_str, NamedTuple(kw))
function CompactLayer(symb::Val, f::Function, name::Union{String,Nothing}, str::Tuple, setup_str::NamedTuple; kw...)
return CompactLayer(symb, f, name, str, setup_str, NamedTuple(kw))
end
(m::CompactLayer)(x...) = m.fun(m.variables, x...)
CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro")
Flux.@functor CompactLayer
layer_symbol(::CompactLayer{S}) where {S} = S

Flux._show_children(m::CompactLayer) = m.variables

Expand Down Expand Up @@ -167,6 +193,9 @@ function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
layer, input, block = obj.strings
pre, post = ("(", ")")
println(io, " "^indent, isnothing(name) ? "" : "$name = ", layer, pre)
if layer_symbol(obj) != :Default
println(io, " "^(indent+2), string(layer_symbol(obj)), ",")
end
for k in keys(obj.variables)
v = obj.variables[k]
if Flux._show_leaflike(v)
Expand Down
95 changes: 94 additions & 1 deletion test/compact.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Fluxperimental: @compact
import Fluxperimental: @compact, CompactLayer

# Strip both strings of spaces, and then test:
function similar_strings(s1, s2)
Expand Down Expand Up @@ -182,3 +182,96 @@ end
@test similar_strings(get_model_string(model), expected_string)

end

@testset "Dispatch using symbols" begin
model1 = @compact(W=randn(32)) do x
W .* x
end
model2 = @compact(MyCustomLayer, W=randn(32)) do x
W .* x
end
@eval my_custom_function(::CompactLayer{:Default}) = :Default
@eval my_custom_function(::CompactLayer{:MyCustomLayer}) = :MyCustomLayer

@test model1 isa CompactLayer{:Default}
@test model2 isa CompactLayer{:MyCustomLayer}
@test my_custom_function(model1) == :Default
@test my_custom_function(model2) == :MyCustomLayer

expected_string2 = """@compact(
MyCustomLayer,
W = randn(32), # 32 parameters
) do x
W .* x
end"""
@test similar_strings(get_model_string(model2), expected_string2)

@testset "Nested symbolic layers" begin
num_features = 1
num_out = 2
d_attn = 4
d_value = 12
num_heads = 3

model3 = @compact(
SelfAttention,
out = Dense(num_heads * d_value => num_out),
heads = [
@compact(
Head,
K = Dense(num_features => d_attn),
V = Dense(num_features => d_value),
Q = Dense(num_features => d_attn)
) do x
k, v, q = K(x), V(x), Q(x)
x = sum(k .* q; dims=1) ./ sqrt(d_attn)
softmax(x; dims=2) .* v
end for _ in 1:num_heads
]
) do x
out(vcat([h(x) for h in heads]...))
end
@test model3 isa CompactLayer{:SelfAttention}
@test all(t -> isa(t, CompactLayer{:Head}), model3.variables.heads)

expected_string3 = """@compact(
SelfAttention,
out = Dense(36 => 2), # 74 parameters
heads = Array(
@compact(
Head,
K = Dense(1 => 4), # 8 parameters
V = Dense(1 => 12), # 24 parameters
Q = Dense(1 => 4), # 8 parameters
) do x
(k, v, q) = (K(x), V(x), Q(x))
x = sum(k .* q; dims = 1) ./ sqrt(d_attn)
softmax(x; dims = 2) .* v
end,
@compact(
Head,
K = Dense(1 => 4), # 8 parameters
V = Dense(1 => 12), # 24 parameters
Q = Dense(1 => 4), # 8 parameters
) do x
(k, v, q) = (K(x), V(x), Q(x))
x = sum(k .* q; dims = 1) ./ sqrt(d_attn)
softmax(x; dims = 2) .* v
end,
@compact(
Head,
K = Dense(1 => 4), # 8 parameters
V = Dense(1 => 12), # 24 parameters
Q = Dense(1 => 4), # 8 parameters
) do x
(k, v, q) = (K(x), V(x), Q(x))
x = sum(k .* q; dims = 1) ./ sqrt(d_attn)
softmax(x; dims = 2) .* v
end,
),
) do x
out(vcat([h(x) for h = heads]...))
end # Total: 20 arrays, 194 parameters, 3.515 KiB."""
@test similar_strings(get_model_string(model3), expected_string3)
end
end