diff --git a/doc.go b/doc.go index 8e8d1f04..30780327 100644 --- a/doc.go +++ b/doc.go @@ -34,6 +34,7 @@ var ( errSerializeHashedNode = errors.New("trying to serialize a hashed internal node") errInsertIntoOtherStem = errors.New("insert splits a stem where it should not happen") errUnknownNodeType = errors.New("unknown node type detected") + errNilNodeType = errors.New("nil node type detected") errMissingNodeInStateless = errors.New("trying to access a node that is missing from the stateless view") errIsPOAStub = errors.New("trying to read/write a proof of absence leaf node") ) diff --git a/tree.go b/tree.go index 44d61c85..0c17dbd2 100644 --- a/tree.go +++ b/tree.go @@ -329,9 +329,14 @@ func (n *InternalNode) Children() []VerkleNode { // SetChild *replaces* the child at the given index with the given node. func (n *InternalNode) SetChild(i int, c VerkleNode) error { + if c == nil { + return errNilNodeType + } + if i >= NodeWidth { return errors.New("child index higher than node width") } + n.children[i] = c return nil } @@ -608,8 +613,14 @@ func (n *InternalNode) Delete(key []byte, resolver NodeResolverFn) (bool, error) // signal that this node should be deleted // as well. for _, c := range n.children { - if _, ok := c.(Empty); !ok { + switch c.(type) { + case *InternalNode, *LeafNode: break + case Empty, HashedNode: + case UnknownNode: + panic(errUnknownNodeType) + default: + panic(errNilNodeType) } } @@ -639,14 +650,20 @@ func (n *InternalNode) Flush(flush NodeFlushFn) { n.Commit() for i, child := range n.children { - if c, ok := child.(*InternalNode); ok { + switch c := child.(type) { + case *InternalNode: c.Commit() c.Flush(flushAndCapturePath) n.children[i] = HashedNode{} - } else if c, ok := child.(*LeafNode); ok { + case *LeafNode: c.Commit() - flushAndCapturePath(c.stem[:n.depth+1], n.children[i]) + flushAndCapturePath(c.stem[:n.depth+1], c) n.children[i] = HashedNode{} + case Empty, HashedNode: + case UnknownNode: + panic(errUnknownNodeType) + default: + panic(errNilNodeType) } } flush(path, n) @@ -875,33 +892,32 @@ func (n *InternalNode) GetProofItems(keys keylist, resolver NodeResolverFn) (*Pr var fiPtrs [NodeWidth]*Fr var points [NodeWidth]*Point for i, child := range n.children { + var c VerkleNode + fiPtrs[i] = &fi[i] - if child != nil { - var c VerkleNode - if _, ok := child.(HashedNode); ok { - childpath := make([]byte, n.depth+1) - copy(childpath[:n.depth+1], keys[0][:n.depth]) - childpath[n.depth] = byte(i) - if resolver == nil { - return nil, nil, nil, fmt.Errorf("no resolver for path %x", childpath) - } - serialized, err := resolver(childpath) - if err != nil { - return nil, nil, nil, fmt.Errorf("error resolving for path %x: %w", childpath, err) - } - c, err = ParseNode(serialized, n.depth+1) - if err != nil { - return nil, nil, nil, err - } - n.children[i] = c - } else { - c = child + switch child := child.(type) { + case HashedNode: + childpath := make([]byte, n.depth+1) + copy(childpath[:n.depth+1], keys[0][:n.depth]) + childpath[n.depth] = byte(i) + if resolver == nil { + return nil, nil, nil, fmt.Errorf("no resolver for path %x", childpath) } - points[i] = c.Commitment() - } else { - // TODO: add a test case to cover this scenario. - points[i] = new(Point) + serialized, err := resolver(childpath) + if err != nil { + return nil, nil, nil, fmt.Errorf("error resolving for path %x: %w", childpath, err) + } + c, err = ParseNode(serialized, n.depth+1) + if err != nil { + return nil, nil, nil, err + } + n.children[i] = c + case *InternalNode, *LeafNode, Empty, UnknownNode: + c = child + default: + panic(errNilNodeType) } + points[i] = c.Commitment() } if err := banderwagon.BatchMapToScalarField(fiPtrs[:], points[:]); err != nil { return nil, nil, nil, fmt.Errorf("batch mapping to scalar fields: %s", err) @@ -972,8 +988,14 @@ func (n *InternalNode) Serialize() ([]byte, error) { // Write the . bitlist := ret[internalBitlistOffset:internalCommitmentOffset] for i, c := range n.children { - if _, ok := c.(Empty); !ok { + switch c.(type) { + case *InternalNode, *LeafNode: setBit(bitlist, i) + case Empty, HashedNode: + case UnknownNode: + panic(errUnknownNodeType) + default: + panic(errNilNodeType) } } @@ -995,6 +1017,9 @@ func (n *InternalNode) Copy() VerkleNode { } for i, child := range n.children { + if child == nil { + panic(errNilNodeType) + } ret.children[i] = child.Copy() } @@ -1024,7 +1049,7 @@ func (n *InternalNode) toDot(parent, path string) string { for i, child := range n.children { if child == nil { - continue + panic(errNilNodeType) } ret = fmt.Sprintf("%s%s", ret, child.toDot(me, fmt.Sprintf("%s%02x", path, i))) } @@ -1719,6 +1744,11 @@ func (n *InternalNode) collectNonHashedNodes(list []VerkleNode, paths [][]byte, copy(childpath, path) childpath[len(path)] = byte(i) list, paths = childNode.collectNonHashedNodes(list, paths, childpath) + case Empty, HashedNode: + case UnknownNode: + panic(errUnknownNodeType) + default: + panic(errNilNodeType) } } return list, paths @@ -1729,8 +1759,14 @@ func (n *InternalNode) serializeInternalWithUncompressedCommitment(pointsIdx map serialized := make([]byte, nodeTypeSize+bitlistSize+banderwagon.UncompressedSize) bitlist := serialized[internalBitlistOffset:internalCommitmentOffset] for i, c := range n.children { - if _, ok := c.(Empty); !ok { + switch c.(type) { + case *InternalNode, *LeafNode: setBit(bitlist, i) + case Empty, HashedNode: + case UnknownNode: + panic(errUnknownNodeType) + default: + panic(errNilNodeType) } } serialized[nodeTypeOffset] = internalRLPType