Skip to content

Commit 86b981d

Browse files
authored
Din implementation based on gorgonia (#12)
* Fix frontend/.husky/pre-commit x perm * Add gorgonia.org/gorgonia * Finished 30% Din model * Changing din train * Add SampleInfo * Fix OuterProd * Fix att0 size * Maybe Din.Train works * Din forward backprop work now, still some bugs * Grad sames work but maybe wrong direction * Fix att layer name and test data * Fix final act func and dropout, works now but not very stable * Din and simple MLP training test * Fix din loss * Fix cost function! * Add Graph() interface * Batch predict to be tested * Model Marshal and new from json, predict done! * Refactor common part of model * Add UserBehavior interface * More test data for din
1 parent 5d7c1ec commit 86b981d

File tree

9 files changed

+1093
-52
lines changed

9 files changed

+1093
-52
lines changed

frontend/.husky/commit-msg

Lines changed: 0 additions & 4 deletions
This file was deleted.

frontend/.husky/pre-commit

100644100755
File mode changed.

go.mod

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module github.com/auxten/edgeRec
33
go 1.18
44

55
require (
6-
github.com/chewxy/math32 v1.0.4
6+
github.com/chewxy/math32 v1.0.8
77
github.com/gin-gonic/gin v1.8.1
88
github.com/go-sql-driver/mysql v1.6.0
99
github.com/karlseguin/ccache/v2 v2.0.8
@@ -13,7 +13,6 @@ require (
1313
github.com/pa-m/sklearn v0.0.0-20200711083454-beb861ee48b1
1414
github.com/peterh/liner v1.2.0
1515
github.com/pkg/errors v0.9.1
16-
github.com/sashabaranov/go-fastapi v0.0.0-20211231174335-50b05b1379e1
1716
github.com/sirupsen/logrus v1.2.0
1817
github.com/smartystreets/goconvey v1.7.2
1918
github.com/spf13/cobra v1.1.1
@@ -23,35 +22,37 @@ require (
2322
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
2423
gonum.org/v1/gonum v0.11.0
2524
gonum.org/v1/plot v0.10.1
25+
gopkg.in/cheggaaa/pb.v1 v1.0.27
26+
gorgonia.org/gorgonia v0.9.17
27+
gorgonia.org/tensor v0.9.24
2628
)
2729

2830
require (
2931
git.sr.ht/~sbinet/gg v0.3.1 // indirect
30-
github.com/PuerkitoBio/purell v1.1.1 // indirect
31-
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect
3232
github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b // indirect
33+
github.com/apache/arrow/go/arrow v0.0.0-20210105145422-88aaea5262db // indirect
34+
github.com/awalterschulze/gographviz v0.0.0-20190221210632-1e9ccb565bca // indirect
35+
github.com/chewxy/hm v1.0.0 // indirect
3336
github.com/davecgh/go-spew v1.1.1 // indirect
3437
github.com/gin-contrib/sse v0.1.0 // indirect
3538
github.com/go-fonts/liberation v0.2.0 // indirect
3639
github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81 // indirect
37-
github.com/go-openapi/jsonpointer v0.19.5 // indirect
38-
github.com/go-openapi/jsonreference v0.19.6 // indirect
39-
github.com/go-openapi/spec v0.20.4 // indirect
40-
github.com/go-openapi/swag v0.19.15 // indirect
4140
github.com/go-pdf/fpdf v0.6.0 // indirect
4241
github.com/go-playground/locales v0.14.0 // indirect
4342
github.com/go-playground/universal-translator v0.18.0 // indirect
4443
github.com/go-playground/validator/v10 v10.10.0 // indirect
4544
github.com/goccy/go-json v0.9.7 // indirect
45+
github.com/gogo/protobuf v1.3.2 // indirect
4646
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
47+
github.com/golang/protobuf v1.5.0 // indirect
48+
github.com/google/flatbuffers v1.12.0 // indirect
4749
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect
4850
github.com/inconshreveable/mousetrap v1.0.0 // indirect
49-
github.com/josharian/intern v1.0.0 // indirect
5051
github.com/json-iterator/go v1.1.12 // indirect
5152
github.com/jtolds/gls v4.20.0+incompatible // indirect
5253
github.com/konsorten/go-windows-terminal-sequences v1.0.1 // indirect
54+
github.com/leesper/go_rng v0.0.0-20171009123644-5344a9259b21 // indirect
5355
github.com/leodido/go-urn v1.2.1 // indirect
54-
github.com/mailru/easyjson v0.7.6 // indirect
5556
github.com/mattn/go-isatty v0.0.14 // indirect
5657
github.com/mattn/go-runewidth v0.0.7 // indirect
5758
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
@@ -62,14 +63,21 @@ require (
6263
github.com/smartystreets/assertions v1.2.0 // indirect
6364
github.com/spf13/pflag v1.0.5 // indirect
6465
github.com/ugorji/go/codec v1.2.7 // indirect
66+
github.com/xtgo/set v1.0.0 // indirect
67+
go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 // indirect
6568
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect
6669
golang.org/x/image v0.0.0-20220302094943-723b81ca9867 // indirect
6770
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f // indirect
6871
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 // indirect
6972
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 // indirect
7073
golang.org/x/text v0.3.7 // indirect
7174
golang.org/x/tools v0.1.9 // indirect
75+
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
7276
google.golang.org/protobuf v1.28.0 // indirect
7377
gopkg.in/yaml.v2 v2.4.0 // indirect
7478
gopkg.in/yaml.v3 v3.0.1 // indirect
79+
gorgonia.org/cu v0.9.3 // indirect
80+
gorgonia.org/dawson v1.2.0 // indirect
81+
gorgonia.org/vecf32 v0.9.0 // indirect
82+
gorgonia.org/vecf64 v0.9.0 // indirect
7583
)

go.sum

Lines changed: 130 additions & 28 deletions
Large diffs are not rendered by default.

model/din/din.go

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
package din
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
_ "net/http/pprof"
7+
8+
"github.com/pkg/errors"
9+
log "github.com/sirupsen/logrus"
10+
G "gorgonia.org/gorgonia"
11+
"gorgonia.org/tensor"
12+
)
13+
14+
type DinNet struct {
15+
uProfileDim, uBehaviorSize, uBehaviorDim int
16+
iFeatureDim int
17+
cFeatureDim int
18+
19+
g *G.ExprGraph
20+
21+
vm G.VM
22+
23+
//input nodes
24+
xUserProfile, xUbMatrix, xItemFeature, xCtxFeature *G.Node
25+
26+
mlp0, mlp1, mlp2 *G.Node // weights of MLP layers
27+
d0, d1 float64 // dropout probabilities
28+
att0, att1 []*G.Node // weights of Attention layers
29+
30+
out *G.Node
31+
}
32+
33+
type dinModel struct {
34+
UProfileDim int `json:"uProfileDim"`
35+
UBehaviorSize int `json:"uBehaviorSize"`
36+
UBehaviorDim int `json:"uBehaviorDim"`
37+
IFeatureDim int `json:"iFeatureDim"`
38+
CFeatureDim int `json:"cFeatureDim"`
39+
Mlp0 []float64 `json:"mlp0"`
40+
Mlp1 []float64 `json:"mlp1"`
41+
Mlp2 []float64 `json:"mlp2"`
42+
Att0 [][]float64 `json:"att0"`
43+
Att1 [][]float64 `json:"att1"`
44+
}
45+
46+
func (din *DinNet) Vm() G.VM {
47+
return din.vm
48+
}
49+
50+
func (din *DinNet) SetVM(vm G.VM) {
51+
din.vm = vm
52+
}
53+
54+
func (din *DinNet) Marshal() (data []byte, err error) {
55+
modelData := dinModel{
56+
UProfileDim: din.uProfileDim,
57+
UBehaviorSize: din.uBehaviorSize,
58+
UBehaviorDim: din.uBehaviorDim,
59+
IFeatureDim: din.iFeatureDim,
60+
CFeatureDim: din.cFeatureDim,
61+
Mlp0: din.mlp0.Value().Data().([]float64),
62+
Mlp1: din.mlp1.Value().Data().([]float64),
63+
Mlp2: din.mlp2.Value().Data().([]float64),
64+
}
65+
modelData.Att0 = make([][]float64, din.uBehaviorSize)
66+
modelData.Att1 = make([][]float64, din.uBehaviorSize)
67+
for i := 0; i < din.uBehaviorSize; i++ {
68+
modelData.Att0[i] = din.att0[i].Value().Data().([]float64)
69+
modelData.Att1[i] = din.att1[i].Value().Data().([]float64)
70+
}
71+
//marshal to json
72+
data, err = json.Marshal(modelData)
73+
74+
return
75+
}
76+
77+
func NewDinNetFromJson(data []byte) (din *DinNet, err error) {
78+
var m dinModel
79+
if err = json.Unmarshal(data, &m); err != nil {
80+
return
81+
}
82+
var (
83+
g = G.NewGraph()
84+
uProfileDim = m.UProfileDim
85+
uBehaviorSize = m.UBehaviorSize
86+
uBehaviorDim = m.UBehaviorDim
87+
iFeatureDim = m.IFeatureDim
88+
cFeatureDim = m.CFeatureDim
89+
att0_0 = uBehaviorDim + iFeatureDim + uBehaviorSize*uBehaviorDim*iFeatureDim
90+
)
91+
92+
// attention layer
93+
att0 := make([]*G.Node, m.UBehaviorSize)
94+
att1 := make([]*G.Node, m.UBehaviorSize)
95+
for i := 0; i < m.UBehaviorSize; i++ {
96+
att0[i] = G.NewMatrix(
97+
g,
98+
dt,
99+
G.WithShape(att0_0, att0_1),
100+
G.WithValue(tensor.New(tensor.WithShape(att0_0, att0_1), tensor.WithBacking(m.Att0[i]))),
101+
G.WithName(fmt.Sprintf("att0-%d", i)),
102+
)
103+
att1[i] = G.NewMatrix(
104+
g,
105+
dt,
106+
G.WithShape(att0_1, 1),
107+
G.WithValue(tensor.New(tensor.WithShape(att0_1, 1), tensor.WithBacking(m.Att1[i]))),
108+
G.WithName(fmt.Sprintf("att1-%d", i)),
109+
)
110+
}
111+
mlp0 := G.NewMatrix(g, dt,
112+
G.WithShape(uProfileDim+uBehaviorDim+iFeatureDim+cFeatureDim, mlp0_1),
113+
G.WithName("mlp0"),
114+
G.WithValue(tensor.New(
115+
tensor.WithShape(uProfileDim+uBehaviorDim+iFeatureDim+cFeatureDim, mlp0_1),
116+
tensor.WithBacking(m.Mlp0)),
117+
),
118+
)
119+
120+
mlp1 := G.NewMatrix(g, dt,
121+
G.WithShape(mlp0_1, mlp1_2),
122+
G.WithName("mlp1"),
123+
G.WithValue(tensor.New(tensor.WithShape(mlp0_1, mlp1_2), tensor.WithBacking(m.Mlp1))),
124+
)
125+
126+
mlp2 := G.NewMatrix(g, dt,
127+
G.WithShape(mlp1_2, 1),
128+
G.WithName("mlp2"),
129+
G.WithValue(tensor.New(tensor.WithShape(mlp1_2, 1), tensor.WithBacking(m.Mlp2))),
130+
)
131+
132+
din = &DinNet{
133+
uProfileDim: m.UProfileDim,
134+
uBehaviorSize: m.UBehaviorSize,
135+
uBehaviorDim: m.UBehaviorDim,
136+
iFeatureDim: m.IFeatureDim,
137+
cFeatureDim: m.CFeatureDim,
138+
g: g,
139+
att0: att0,
140+
att1: att1,
141+
mlp0: mlp0,
142+
mlp1: mlp1,
143+
mlp2: mlp2,
144+
}
145+
return
146+
}
147+
148+
func (din *DinNet) Graph() *G.ExprGraph {
149+
return din.g
150+
}
151+
152+
func (din *DinNet) Out() *G.Node {
153+
return din.out
154+
}
155+
156+
func (din *DinNet) In() G.Nodes {
157+
return G.Nodes{din.xUserProfile, din.xUbMatrix, din.xItemFeature, din.xCtxFeature}
158+
}
159+
160+
func (din *DinNet) learnable() G.Nodes {
161+
ret := make(G.Nodes, 3, 3+2*din.uBehaviorSize)
162+
ret[0] = din.mlp0
163+
ret[1] = din.mlp1
164+
ret[2] = din.mlp2
165+
ret = append(ret, din.att0...)
166+
ret = append(ret, din.att1...)
167+
return ret
168+
}
169+
170+
func NewDinNet(
171+
uProfileDim, uBehaviorSize, uBehaviorDim int,
172+
iFeatureDim int,
173+
cFeatureDim int,
174+
) *DinNet {
175+
if uBehaviorDim != iFeatureDim {
176+
log.Fatalf("uBehaviorDim %d != iFeatureDim %d", uBehaviorDim, iFeatureDim)
177+
}
178+
g := G.NewGraph()
179+
// attention layer
180+
att0 := make([]*G.Node, uBehaviorSize)
181+
att1 := make([]*G.Node, uBehaviorSize)
182+
for i := 0; i < uBehaviorSize; i++ {
183+
att0[i] = G.NewTensor(g, dt, 2, G.WithShape(uBehaviorDim+iFeatureDim+uBehaviorSize*uBehaviorDim*iFeatureDim, att0_1), G.WithName(fmt.Sprintf("att0-%d", i)), G.WithInit(G.Gaussian(0, 1)))
184+
att1[i] = G.NewTensor(g, dt, 2, G.WithShape(att0_1, 1), G.WithName(fmt.Sprintf("att1-%d", i)), G.WithInit(G.Gaussian(0, 1)))
185+
}
186+
187+
// user behaviors are represented as a sequence of item embeddings. Before
188+
// being fed into the MLP, we need to flatten the sequence into a single with
189+
// sum pooling with Attention as the weights which is the key point of DIN model.
190+
mlp0 := G.NewMatrix(g, dt, G.WithShape(uProfileDim+uBehaviorDim+iFeatureDim+cFeatureDim, mlp0_1), G.WithName("mlp0"), G.WithInit(G.Gaussian(0, 1)))
191+
192+
mlp1 := G.NewMatrix(g, dt, G.WithShape(mlp0_1, mlp1_2), G.WithName("mlp1"), G.WithInit(G.Gaussian(0, 1)))
193+
194+
mlp2 := G.NewMatrix(g, dt, G.WithShape(mlp1_2, 1), G.WithName("mlp2"), G.WithInit(G.Gaussian(0, 1)))
195+
196+
return &DinNet{
197+
uProfileDim: uProfileDim,
198+
uBehaviorSize: uBehaviorSize,
199+
uBehaviorDim: uBehaviorDim,
200+
iFeatureDim: iFeatureDim,
201+
cFeatureDim: cFeatureDim,
202+
203+
g: g,
204+
att0: att0,
205+
att1: att1,
206+
207+
d0: 0.01,
208+
d1: 0.01,
209+
210+
mlp0: mlp0,
211+
mlp1: mlp1,
212+
mlp2: mlp2,
213+
}
214+
}
215+
216+
//Fwd performs the forward pass
217+
// xUserProfile: [batchSize, userProfileDim]
218+
// xUserBehaviors: [batchSize, uBehaviorSize, uBehaviorDim]
219+
// xItemFeature: [batchSize, iFeatureDim]
220+
// xContextFeature: [batchSize, cFeatureDim]
221+
func (din *DinNet) Fwd(xUserProfile, xUbMatrix, xItemFeature, xCtxFeature *G.Node, batchSize, uBehaviorSize, uBehaviorDim int) (err error) {
222+
iFeatureDim := xItemFeature.Shape()[1]
223+
if uBehaviorDim != iFeatureDim {
224+
return errors.Errorf("uBehaviorDim %d != iFeatureDim %d", uBehaviorDim, iFeatureDim)
225+
}
226+
xUserBehaviors := G.Must(G.Reshape(xUbMatrix, tensor.Shape{batchSize, uBehaviorSize, uBehaviorDim}))
227+
228+
// outProduct should computed batch by batch!!!!
229+
outProdVecs := make([]*G.Node, batchSize)
230+
for i := 0; i < batchSize; i++ {
231+
// ubVec.Shape() = [uBehaviorSize * uBehaviorDim]
232+
ubVec := G.Must(G.Slice(xUbMatrix, G.S(i)))
233+
// item.Shape() = [iFeatureDim]
234+
itemVec := G.Must(G.Slice(xItemFeature, G.S(i)))
235+
// outProd.Shape() = [uBehaviorSize * uBehaviorDim, iFeatureDim]
236+
outProd := G.Must(G.OuterProd(ubVec, itemVec))
237+
outProdVecs[i] = G.Must(G.Reshape(outProd, tensor.Shape{uBehaviorSize * uBehaviorDim * iFeatureDim}))
238+
}
239+
//outProductsVec.Shape() = [batchSize * uBehaviorSize * uBehaviorDim * iFeatureDim]
240+
outProductsVec := G.Must(G.Concat(0, outProdVecs...))
241+
outProducts := G.Must(G.Reshape(outProductsVec, tensor.Shape{batchSize, uBehaviorSize * uBehaviorDim * iFeatureDim}))
242+
243+
actOuts := G.NewTensor(din.Graph(), dt, 2, G.WithShape(batchSize, uBehaviorDim), G.WithName("actOuts"), G.WithInit(G.Zeroes()))
244+
for i := 0; i < uBehaviorSize; i++ {
245+
// xUserBehaviors[:, i, :], ub.shape: [batchSize, uBehaviorDim]
246+
ub := G.Must(G.Slice(xUserBehaviors, []tensor.Slice{nil, G.S(i)}...))
247+
// Concat all xUserBehaviors[i], outProducts, xItemFeature
248+
// actConcat.Shape() = [batchSize, uBehaviorDim+iFeatureDim+uBehaviorSize*uBehaviorDim*iFeatureDim]
249+
actConcat := G.Must(G.Concat(1, ub, outProducts, xItemFeature))
250+
actOut := G.Must(G.BroadcastHadamardProd(
251+
ub,
252+
G.Must(G.Rectify(
253+
G.Must(G.Mul(
254+
G.Must(G.Mul(actConcat, din.att0[i])),
255+
din.att1[i],
256+
)))), // [batchSize, 1]
257+
nil, []byte{1},
258+
)) // [batchSize, uBehaviorDim]
259+
260+
// Sum pooling
261+
actOuts = G.Must(G.Add(actOuts, actOut))
262+
}
263+
264+
// Concat all xUserProfile, actOuts, xItemFeature, xCtxFeature
265+
concat := G.Must(G.Concat(1, xUserProfile, actOuts, xItemFeature, xCtxFeature))
266+
267+
// MLP
268+
269+
// mlp0.Shape: [userProfileDim+userBehaviorDim+itemFeatureDim+contextFeatureDim, 200]
270+
// out.Shape: [batchSize, 200]
271+
mlp0Out := G.Must(G.LeakyRelu(G.Must(G.Mul(concat, din.mlp0)), 0.1))
272+
mlp0Out = G.Must(G.Dropout(mlp0Out, din.d0))
273+
// mlp1.Shape: [200, 80]
274+
// out.Shape: [batchSize, 80]
275+
mlp1Out := G.Must(G.LeakyRelu(G.Must(G.Mul(mlp0Out, din.mlp1)), 0.1))
276+
mlp1Out = G.Must(G.Dropout(mlp1Out, din.d1))
277+
// mlp2.Shape: [80, 1]
278+
// out.Shape: [batchSize, 1]
279+
mlp2Out := G.Must(G.Sigmoid(G.Must(G.Mul(mlp1Out, din.mlp2))))
280+
281+
din.out = mlp2Out
282+
din.xUserProfile = xUserProfile
283+
din.xItemFeature = xItemFeature
284+
din.xCtxFeature = xCtxFeature
285+
din.xUbMatrix = xUbMatrix
286+
return
287+
}

0 commit comments

Comments
 (0)