Skip to content

Commit 1033f63

Browse files
authored
Merge pull request #12 from calebx/main
build Delete() for delete a node and all its descendant
2 parents 43484f0 + 9e32fb7 commit 1033f63

File tree

2 files changed

+104
-53
lines changed

2 files changed

+104
-53
lines changed

nested_set.go

Lines changed: 84 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
package nestedset
22

33
import (
4-
"context"
54
"database/sql"
65
"fmt"
76
"reflect"
87
"strings"
8+
"sync"
99

1010
"gorm.io/gorm"
11-
"gorm.io/gorm/clause"
11+
"gorm.io/gorm/schema"
1212
)
1313

1414
// MoveDirection means where the node is going to be located
@@ -25,7 +25,7 @@ const (
2525
MoveDirectionInner MoveDirection = 0
2626
)
2727

28-
type nodeItem struct {
28+
type nestedItem struct {
2929
ID int64
3030
ParentID sql.NullInt64
3131
Depth int
@@ -36,31 +36,25 @@ type nodeItem struct {
3636
DbNames map[string]string
3737
}
3838

39-
// parseNode parse a gorm structure into an internal source structure
40-
// for bring in all required data attribute like scope, left, righ etc.
41-
func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nodeItem, err error) {
42-
tx = db
43-
stmt := &gorm.Statement{
44-
DB: tx,
45-
ConnPool: tx.ConnPool,
46-
Context: context.Background(),
47-
Clauses: map[string]clause.Clause{},
48-
}
49-
50-
err = stmt.Parse(source)
39+
// parseNode parse a gorm struct into an internal nested item struct
40+
// bring in all required data attribute like scope, left, righ etc.
41+
func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nestedItem, err error) {
42+
scm, err := schema.Parse(source, &sync.Map{}, schema.NamingStrategy{})
5143
if err != nil {
5244
err = fmt.Errorf("Invalid source, must be a valid Gorm Model instance, %v", source)
5345
return
5446
}
5547

56-
item = nodeItem{TableName: stmt.Table, DbNames: map[string]string{}}
48+
tx = db.Table(scm.Table)
49+
50+
item = nestedItem{TableName: scm.Table, DbNames: map[string]string{}}
5751
sourceValue := reflect.Indirect(reflect.ValueOf(source))
5852
sourceType := sourceValue.Type()
5953
for i := 0; i < sourceType.NumField(); i++ {
6054
t := sourceType.Field(i)
6155
v := sourceValue.Field(i)
6256

63-
schemaField := stmt.Schema.LookUpField(t.Name)
57+
schemaField := scm.LookUpField(t.Name)
6458
dbName := schemaField.DBName
6559

6660
switch t.Tag.Get("nestedset") {
@@ -98,27 +92,26 @@ func parseNode(db *gorm.DB, source interface{}) (tx *gorm.DB, item nodeItem, err
9892
return
9993
}
10094

101-
// Create a new node by parent with Gorm original Create()
102-
// ```nestedset.Create(db, &Category{...}, nil)```` will create a new category in root level
95+
// Create a new node within its parent by Gorm original Create() method
96+
// ```nestedset.Create(db, &Category{...}, nil)``` will create a new category in root level
10397
// ```nestedset.Create(db, &Category{...}, &parent)``` will create a new category under parent node as its last child
10498
func Create(db *gorm.DB, source, parent interface{}) error {
105-
return db.Transaction(func(db *gorm.DB) (err error) {
106-
tx, target, err := parseNode(db, source)
107-
if err != nil {
108-
return err
109-
}
99+
tx, target, err := parseNode(db, source)
100+
if err != nil {
101+
return err
102+
}
110103

111-
// for totally blank table / scope default init root would be [1 - 2]
112-
setDepth, setToLft, setToRgt := 0, 1, 2
113-
tableName, dbNames := target.TableName, target.DbNames
104+
// for totally blank table / scope default init root would be [1 - 2]
105+
setToDepth, setToLft, setToRgt := 0, 1, 2
106+
dbNames := target.DbNames
114107

115-
// put node into root level when parent is nil
108+
return tx.Transaction(func(tx *gorm.DB) (err error) {
109+
// create node in root level when parent is nil
116110
if parent == nil {
117111
lastNode := make(map[string]interface{})
118-
orderSQL := formatSQL(":rgt desc", target)
119-
rst := tx.Model(source).Select(dbNames["rgt"]).Order(orderSQL).First(&lastNode)
112+
rst := tx.Select(dbNames["rgt"]).Order(formatSQL(":rgt DESC", target)).Take(&lastNode)
120113
if rst.Error == nil {
121-
setToLft = lastNode[dbNames["rgt"]].(int) + 1
114+
setToLft = int(lastNode[dbNames["rgt"]].(int64) + 1)
122115
setToRgt = setToLft + 1
123116
}
124117
} else {
@@ -129,36 +122,31 @@ func Create(db *gorm.DB, source, parent interface{}) error {
129122

130123
setToLft = targetParent.Rgt
131124
setToRgt = targetParent.Rgt + 1
132-
setDepth = targetParent.Depth + 1
125+
setToDepth = targetParent.Depth + 1
133126

134127
// UPDATE tree SET rgt = rgt + 2 WHERE rgt >= new_lft;
135-
err = tx.Table(tableName).
136-
Where(formatSQL(":rgt >= ?", target), setToLft).
137-
UpdateColumn(dbNames["rgt"], gorm.Expr(formatSQL(":rgt + 2", target))).
138-
Error
128+
err = tx.Where(formatSQL(":rgt >= ?", target), setToLft).
129+
UpdateColumn(dbNames["rgt"], gorm.Expr(formatSQL(":rgt + 2", target))).Error
139130
if err != nil {
140131
return err
141132
}
142133

143134
// UPDATE tree SET lft = lft + 2 WHERE lft > new_lft;
144-
err = tx.Table(tableName).
145-
Where(formatSQL(":lft > ?", target), setToLft).
146-
UpdateColumn(dbNames["lft"], gorm.Expr(formatSQL(":lft + 2", target))).
147-
Error
135+
err = tx.Where(formatSQL(":lft > ?", target), setToLft).
136+
UpdateColumn(dbNames["lft"], gorm.Expr(formatSQL(":lft + 2", target))).Error
148137
if err != nil {
149138
return err
150139
}
151140

152-
// UPDATE tree SET children_count = children_count + 1 WHERE is = parent.id;
153-
err = db.Model(parent).Update(
154-
dbNames["children_count"], gorm.Expr(formatSQL(":children_count + 1", target)),
155-
).Error
141+
// UPDATE tree SET children_count = children_count + 1 WHERE id = parent.id;
142+
err = tx.Model(parent).Update(
143+
dbNames["children_count"], gorm.Expr(formatSQL(":children_count + 1", target))).Error
156144
if err != nil {
157145
return err
158146
}
159147
}
160148

161-
// Set Lft, Rgt, Depth dynamically by refect
149+
// Set Lft, Rgt, Depth dynamically
162150
v := reflect.Indirect(reflect.ValueOf(source))
163151
t := v.Type()
164152
for i := 0; i < t.NumField(); i++ {
@@ -174,7 +162,7 @@ func Create(db *gorm.DB, source, parent interface{}) error {
174162
break
175163
case "depth":
176164
f := v.FieldByName(f.Name)
177-
f.SetInt(int64(setDepth))
165+
f.SetInt(int64(setToDepth))
178166
break
179167
}
180168
}
@@ -184,6 +172,51 @@ func Create(db *gorm.DB, source, parent interface{}) error {
184172
})
185173
}
186174

175+
// Delete a node from scoped list and its all descendent
176+
// ```nestedset.Delete(db, &Category{...})```
177+
func Delete(db *gorm.DB, source interface{}) error {
178+
tx, target, err := parseNode(db, source)
179+
if err != nil {
180+
return err
181+
}
182+
183+
// Batch Delete Method in GORM requires an instance of current source type without ID
184+
// to avoid GORM style Delete interface, we hacked here by set source ID to 0
185+
dbNames := target.DbNames
186+
v := reflect.Indirect(reflect.ValueOf(source))
187+
t := v.Type()
188+
for i := 0; i < t.NumField(); i++ {
189+
f := t.Field(i)
190+
if f.Tag.Get("nestedset") == "id" {
191+
f := v.FieldByName(f.Name)
192+
f.SetInt(0)
193+
break
194+
}
195+
}
196+
197+
return tx.Transaction(func(tx *gorm.DB) (err error) {
198+
err = tx.Where(formatSQL(":lft >= ? AND :rgt <= ?", target), target.Lft, target.Rgt).
199+
Delete(source).Error
200+
if err != nil {
201+
return err
202+
}
203+
204+
// UPDATE tree SET rgt = rgt - width WHERE rgt > target_rgt;
205+
// UPDATE tree SET lft = lft - width WHERE lft > target_rgt;
206+
width := target.Rgt - target.Lft + 1
207+
for _, d := range []string{"rgt", "lft"} {
208+
err = tx.Where(formatSQL(":"+d+" > ?", target), target.Rgt).
209+
Update(dbNames[d], gorm.Expr(formatSQL(":"+d+" - ?", target), width)).
210+
Error
211+
if err != nil {
212+
return err
213+
}
214+
}
215+
216+
return nil
217+
})
218+
}
219+
187220
// MoveTo move node to a position which is related a target node
188221
// ```nestedset.MoveTo(db, &node, &to, nestedset.MoveDirectionInner)``` will move [&node] to [&to] node's child_list as its first child
189222
func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
@@ -197,8 +230,6 @@ func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
197230
return err
198231
}
199232

200-
tx = db.Table(targetNode.TableName)
201-
202233
var right, depthChange int
203234
var newParentID sql.NullInt64
204235
if direction == MoveDirectionLeft || direction == MoveDirectionRight {
@@ -218,7 +249,7 @@ func MoveTo(db *gorm.DB, node, to interface{}, direction MoveDirection) error {
218249
return moveToRightOfPosition(tx, targetNode, right, depthChange, newParentID)
219250
}
220251

221-
func moveToRightOfPosition(tx *gorm.DB, targetNode nodeItem, position, depthChange int, newParentID sql.NullInt64) error {
252+
func moveToRightOfPosition(tx *gorm.DB, targetNode nestedItem, position, depthChange int, newParentID sql.NullInt64) error {
222253
return tx.Transaction(func(tx *gorm.DB) (err error) {
223254
oldParentID := targetNode.ParentID
224255
targetRight := targetNode.Rgt
@@ -261,7 +292,7 @@ func moveToRightOfPosition(tx *gorm.DB, targetNode nodeItem, position, depthChan
261292
})
262293
}
263294

264-
func syncChildrenCount(tx *gorm.DB, targetNode nodeItem, oldParentID, newParentID sql.NullInt64) (err error) {
295+
func syncChildrenCount(tx *gorm.DB, targetNode nestedItem, oldParentID, newParentID sql.NullInt64) (err error) {
265296
var oldParentCount, newParentCount int64
266297

267298
if oldParentID.Valid {
@@ -289,7 +320,7 @@ func syncChildrenCount(tx *gorm.DB, targetNode nodeItem, oldParentID, newParentI
289320
return nil
290321
}
291322

292-
func moveTarget(tx *gorm.DB, targetNode nodeItem, targetID int64, targetIds []int64, step, depthChange int, newParentID sql.NullInt64) (err error) {
323+
func moveTarget(tx *gorm.DB, targetNode nestedItem, targetID int64, targetIds []int64, step, depthChange int, newParentID sql.NullInt64) (err error) {
293324
dbNames := targetNode.DbNames
294325

295326
if len(targetIds) > 0 {
@@ -307,7 +338,7 @@ func moveTarget(tx *gorm.DB, targetNode nodeItem, targetID int64, targetIds []in
307338
return tx.Where(formatSQL(":id = ?", targetNode), targetID).Update(dbNames["parent_id"], newParentID).Error
308339
}
309340

310-
func moveAffected(tx *gorm.DB, targetNode nodeItem, gte, lte, step int) (err error) {
341+
func moveAffected(tx *gorm.DB, targetNode nestedItem, gte, lte, step int) (err error) {
311342
dbNames := targetNode.DbNames
312343

313344
return tx.Where(formatSQL("(:lft BETWEEN ? AND ?) OR (:rgt BETWEEN ? AND ?)", targetNode), gte, lte, gte, lte).
@@ -317,7 +348,7 @@ func moveAffected(tx *gorm.DB, targetNode nodeItem, gte, lte, step int) (err err
317348
}).Error
318349
}
319350

320-
func formatSQL(placeHolderSQL string, node nodeItem) (out string) {
351+
func formatSQL(placeHolderSQL string, node nestedItem) (out string) {
321352
out = placeHolderSQL
322353

323354
out = strings.ReplaceAll(out, ":table_name", node.TableName)

nested_set_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,26 @@ func TestCreateSource(t *testing.T) {
130130
assert.Equal(t, c2.ChildrenCount, 1)
131131
}
132132

133+
func TestDeleteSource(t *testing.T) {
134+
initData()
135+
136+
c1 := Category{Title: "c1s"}
137+
Create(db, &c1, nil)
138+
139+
cp := Category{Title: "cp"}
140+
Create(db, &cp, c1)
141+
142+
c2 := Category{Title: "c2s"}
143+
Create(db, &c2, nil)
144+
145+
db.First(&c1)
146+
Delete(db, &c1)
147+
148+
db.Model(&c2).First(&c2)
149+
assert.Equal(t, c2.Lft, 1)
150+
assert.Equal(t, c2.Rgt, 2)
151+
}
152+
133153
func TestMoveToRight(t *testing.T) {
134154
// case 1
135155
initData()

0 commit comments

Comments
 (0)