11package nestedset
22
33import (
4- "context"
54 "database/sql"
65 "fmt"
76 "reflect"
87 "strings"
8+ "sync"
99
1010 "gorm.io/gorm"
11- "gorm.io/gorm/clause "
11+ "gorm.io/gorm/schema "
1212)
1313
1414// MoveDirection means where the node is going to be located
@@ -25,7 +25,7 @@ const (
2525 MoveDirectionInner MoveDirection = 0
2626)
2727
28- type nodeItem struct {
28+ type nestedItem struct {
2929 ID int64
3030 ParentID sql.NullInt64
3131 Depth int
@@ -36,31 +36,25 @@ type nodeItem struct {
3636 DbNames map [string ]string
3737}
3838
39- // parseNode parse a gorm structure into an internal source structure
40- // for bring in all required data attribute like scope, left, righ etc.
41- func parseNode (db * gorm.DB , source interface {}) (tx * gorm.DB , item nodeItem , err error ) {
42- tx = db
43- stmt := & gorm.Statement {
44- DB : tx ,
45- ConnPool : tx .ConnPool ,
46- Context : context .Background (),
47- Clauses : map [string ]clause.Clause {},
48- }
49-
50- err = stmt .Parse (source )
39+ // parseNode parse a gorm struct into an internal nested item struct
40+ // bring in all required data attribute like scope, left, righ etc.
41+ func parseNode (db * gorm.DB , source interface {}) (tx * gorm.DB , item nestedItem , err error ) {
42+ scm , err := schema .Parse (source , & sync.Map {}, schema.NamingStrategy {})
5143 if err != nil {
5244 err = fmt .Errorf ("Invalid source, must be a valid Gorm Model instance, %v" , source )
5345 return
5446 }
5547
56- item = nodeItem {TableName : stmt .Table , DbNames : map [string ]string {}}
48+ tx = db .Table (scm .Table )
49+
50+ item = nestedItem {TableName : scm .Table , DbNames : map [string ]string {}}
5751 sourceValue := reflect .Indirect (reflect .ValueOf (source ))
5852 sourceType := sourceValue .Type ()
5953 for i := 0 ; i < sourceType .NumField (); i ++ {
6054 t := sourceType .Field (i )
6155 v := sourceValue .Field (i )
6256
63- schemaField := stmt . Schema .LookUpField (t .Name )
57+ schemaField := scm .LookUpField (t .Name )
6458 dbName := schemaField .DBName
6559
6660 switch t .Tag .Get ("nestedset" ) {
@@ -98,27 +92,26 @@ func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nodeItem, err
9892 return
9993}
10094
101- // Create a new node by parent with Gorm original Create()
102- // ```nestedset.Create(db, &Category{...}, nil)```` will create a new category in root level
95+ // Create a new node within its parent by Gorm original Create() method
96+ // ```nestedset.Create(db, &Category{...}, nil)``` will create a new category in root level
10397// ```nestedset.Create(db, &Category{...}, &parent)``` will create a new category under parent node as its last child
10498func Create (db * gorm.DB , source , parent interface {}) error {
105- return db .Transaction (func (db * gorm.DB ) (err error ) {
106- tx , target , err := parseNode (db , source )
107- if err != nil {
108- return err
109- }
99+ tx , target , err := parseNode (db , source )
100+ if err != nil {
101+ return err
102+ }
110103
111- // for totally blank table / scope default init root would be [1 - 2]
112- setDepth , setToLft , setToRgt := 0 , 1 , 2
113- tableName , dbNames := target . TableName , target .DbNames
104+ // for totally blank table / scope default init root would be [1 - 2]
105+ setToDepth , setToLft , setToRgt := 0 , 1 , 2
106+ dbNames := target .DbNames
114107
115- // put node into root level when parent is nil
108+ return tx .Transaction (func (tx * gorm.DB ) (err error ) {
109+ // create node in root level when parent is nil
116110 if parent == nil {
117111 lastNode := make (map [string ]interface {})
118- orderSQL := formatSQL (":rgt desc" , target )
119- rst := tx .Model (source ).Select (dbNames ["rgt" ]).Order (orderSQL ).First (& lastNode )
112+ rst := tx .Select (dbNames ["rgt" ]).Order (formatSQL (":rgt DESC" , target )).Take (& lastNode )
120113 if rst .Error == nil {
121- setToLft = lastNode [dbNames ["rgt" ]].(int ) + 1
114+ setToLft = int ( lastNode [dbNames ["rgt" ]].(int64 ) + 1 )
122115 setToRgt = setToLft + 1
123116 }
124117 } else {
@@ -129,36 +122,31 @@ func Create(db *gorm.DB, source, parent interface{}) error {
129122
130123 setToLft = targetParent .Rgt
131124 setToRgt = targetParent .Rgt + 1
132- setDepth = targetParent .Depth + 1
125+ setToDepth = targetParent .Depth + 1
133126
134127 // UPDATE tree SET rgt = rgt + 2 WHERE rgt >= new_lft;
135- err = tx .Table (tableName ).
136- Where (formatSQL (":rgt >= ?" , target ), setToLft ).
137- UpdateColumn (dbNames ["rgt" ], gorm .Expr (formatSQL (":rgt + 2" , target ))).
138- Error
128+ err = tx .Where (formatSQL (":rgt >= ?" , target ), setToLft ).
129+ UpdateColumn (dbNames ["rgt" ], gorm .Expr (formatSQL (":rgt + 2" , target ))).Error
139130 if err != nil {
140131 return err
141132 }
142133
143134 // UPDATE tree SET lft = lft + 2 WHERE lft > new_lft;
144- err = tx .Table (tableName ).
145- Where (formatSQL (":lft > ?" , target ), setToLft ).
146- UpdateColumn (dbNames ["lft" ], gorm .Expr (formatSQL (":lft + 2" , target ))).
147- Error
135+ err = tx .Where (formatSQL (":lft > ?" , target ), setToLft ).
136+ UpdateColumn (dbNames ["lft" ], gorm .Expr (formatSQL (":lft + 2" , target ))).Error
148137 if err != nil {
149138 return err
150139 }
151140
152- // UPDATE tree SET children_count = children_count + 1 WHERE is = parent.id;
153- err = db .Model (parent ).Update (
154- dbNames ["children_count" ], gorm .Expr (formatSQL (":children_count + 1" , target )),
155- ).Error
141+ // UPDATE tree SET children_count = children_count + 1 WHERE id = parent.id;
142+ err = tx .Model (parent ).Update (
143+ dbNames ["children_count" ], gorm .Expr (formatSQL (":children_count + 1" , target ))).Error
156144 if err != nil {
157145 return err
158146 }
159147 }
160148
161- // Set Lft, Rgt, Depth dynamically by refect
149+ // Set Lft, Rgt, Depth dynamically
162150 v := reflect .Indirect (reflect .ValueOf (source ))
163151 t := v .Type ()
164152 for i := 0 ; i < t .NumField (); i ++ {
@@ -174,7 +162,7 @@ func Create(db *gorm.DB, source, parent interface{}) error {
174162 break
175163 case "depth" :
176164 f := v .FieldByName (f .Name )
177- f .SetInt (int64 (setDepth ))
165+ f .SetInt (int64 (setToDepth ))
178166 break
179167 }
180168 }
@@ -184,6 +172,51 @@ func Create(db *gorm.DB, source, parent interface{}) error {
184172 })
185173}
186174
175+ // Delete a node from scoped list and its all descendent
176+ // ```nestedset.Delete(db, &Category{...})```
177+ func Delete (db * gorm.DB , source interface {}) error {
178+ tx , target , err := parseNode (db , source )
179+ if err != nil {
180+ return err
181+ }
182+
183+ // Batch Delete Method in GORM requires an instance of current source type without ID
184+ // to avoid GORM style Delete interface, we hacked here by set source ID to 0
185+ dbNames := target .DbNames
186+ v := reflect .Indirect (reflect .ValueOf (source ))
187+ t := v .Type ()
188+ for i := 0 ; i < t .NumField (); i ++ {
189+ f := t .Field (i )
190+ if f .Tag .Get ("nestedset" ) == "id" {
191+ f := v .FieldByName (f .Name )
192+ f .SetInt (0 )
193+ break
194+ }
195+ }
196+
197+ return tx .Transaction (func (tx * gorm.DB ) (err error ) {
198+ err = tx .Where (formatSQL (":lft >= ? AND :rgt <= ?" , target ), target .Lft , target .Rgt ).
199+ Delete (source ).Error
200+ if err != nil {
201+ return err
202+ }
203+
204+ // UPDATE tree SET rgt = rgt - width WHERE rgt > target_rgt;
205+ // UPDATE tree SET lft = lft - width WHERE lft > target_rgt;
206+ width := target .Rgt - target .Lft + 1
207+ for _ , d := range []string {"rgt" , "lft" } {
208+ err = tx .Where (formatSQL (":" + d + " > ?" , target ), target .Rgt ).
209+ Update (dbNames [d ], gorm .Expr (formatSQL (":" + d + " - ?" , target ), width )).
210+ Error
211+ if err != nil {
212+ return err
213+ }
214+ }
215+
216+ return nil
217+ })
218+ }
219+
187220// MoveTo move node to a position which is related a target node
188221// ```nestedset.MoveTo(db, &node, &to, nestedset.MoveDirectionInner)``` will move [&node] to [&to] node's child_list as its first child
189222func MoveTo (db * gorm.DB , node , to interface {}, direction MoveDirection ) error {
@@ -197,8 +230,6 @@ func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
197230 return err
198231 }
199232
200- tx = db .Table (targetNode .TableName )
201-
202233 var right , depthChange int
203234 var newParentID sql.NullInt64
204235 if direction == MoveDirectionLeft || direction == MoveDirectionRight {
@@ -218,7 +249,7 @@ func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
218249 return moveToRightOfPosition (tx , targetNode , right , depthChange , newParentID )
219250}
220251
221- func moveToRightOfPosition (tx * gorm.DB , targetNode nodeItem , position , depthChange int , newParentID sql.NullInt64 ) error {
252+ func moveToRightOfPosition (tx * gorm.DB , targetNode nestedItem , position , depthChange int , newParentID sql.NullInt64 ) error {
222253 return tx .Transaction (func (tx * gorm.DB ) (err error ) {
223254 oldParentID := targetNode .ParentID
224255 targetRight := targetNode .Rgt
@@ -261,7 +292,7 @@ func moveToRightOfPosition(tx *gorm.DB, targetNode nodeItem, position, depthChan
261292 })
262293}
263294
264- func syncChildrenCount (tx * gorm.DB , targetNode nodeItem , oldParentID , newParentID sql.NullInt64 ) (err error ) {
295+ func syncChildrenCount (tx * gorm.DB , targetNode nestedItem , oldParentID , newParentID sql.NullInt64 ) (err error ) {
265296 var oldParentCount , newParentCount int64
266297
267298 if oldParentID .Valid {
@@ -289,7 +320,7 @@ func syncChildrenCount(tx *gorm.DB, targetNode nodeItem, oldParentID, newParentI
289320 return nil
290321}
291322
292- func moveTarget (tx * gorm.DB , targetNode nodeItem , targetID int64 , targetIds []int64 , step , depthChange int , newParentID sql.NullInt64 ) (err error ) {
323+ func moveTarget (tx * gorm.DB , targetNode nestedItem , targetID int64 , targetIds []int64 , step , depthChange int , newParentID sql.NullInt64 ) (err error ) {
293324 dbNames := targetNode .DbNames
294325
295326 if len (targetIds ) > 0 {
@@ -307,7 +338,7 @@ func moveTarget(tx *gorm.DB, targetNode nodeItem, targetID int64, targetIds []in
307338 return tx .Where (formatSQL (":id = ?" , targetNode ), targetID ).Update (dbNames ["parent_id" ], newParentID ).Error
308339}
309340
310- func moveAffected (tx * gorm.DB , targetNode nodeItem , gte , lte , step int ) (err error ) {
341+ func moveAffected (tx * gorm.DB , targetNode nestedItem , gte , lte , step int ) (err error ) {
311342 dbNames := targetNode .DbNames
312343
313344 return tx .Where (formatSQL ("(:lft BETWEEN ? AND ?) OR (:rgt BETWEEN ? AND ?)" , targetNode ), gte , lte , gte , lte ).
@@ -317,7 +348,7 @@ func moveAffected(tx *gorm.DB, targetNode nodeItem, gte, lte, step int) (err err
317348 }).Error
318349}
319350
320- func formatSQL (placeHolderSQL string , node nodeItem ) (out string ) {
351+ func formatSQL (placeHolderSQL string , node nestedItem ) (out string ) {
321352 out = placeHolderSQL
322353
323354 out = strings .ReplaceAll (out , ":table_name" , node .TableName )
0 commit comments