Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions core/felt/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,29 @@ func (h *TransactionHash) Unmarshal(e []byte) {
func (h *TransactionHash) SetBytesCanonical(data []byte) error {
return (*Hash)(h).SetBytesCanonical(data)
}

type StateRootHash Hash

func (h *StateRootHash) String() string {
return (*Hash)(h).String()
}

func (h *StateRootHash) UnmarshalJSON(data []byte) error {
return (*Hash)(h).UnmarshalJSON(data)
}

func (h *StateRootHash) MarshalJSON() ([]byte, error) {
return (*Hash)(h).MarshalJSON()
}

func (h *StateRootHash) Marshal() []byte {
return (*Hash)(h).Marshal()
}

func (h *StateRootHash) Unmarshal(e []byte) {
(*Hash)(h).Unmarshal(e)
}

func (h *StateRootHash) SetBytesCanonical(data []byte) error {
return (*Hash)(h).SetBytesCanonical(data)
}
16 changes: 13 additions & 3 deletions core/state/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,29 @@ func NewStateDB(disk db.KeyValueStore, triedb database.TrieDB) *StateDB {

// Opens a class trie for the given state root
func (s *StateDB) ClassTrie(stateComm *felt.Felt) (*trie2.Trie, error) {
return trie2.New(trieutils.NewClassTrieID(*stateComm), ClassTrieHeight, crypto.Poseidon, s.triedb)
return trie2.New(
trieutils.NewClassTrieID(felt.StateRootHash(*stateComm)),
ClassTrieHeight,
crypto.Poseidon,
s.triedb,
)
}

// Opens a contract trie for the given state root
func (s *StateDB) ContractTrie(stateComm *felt.Felt) (*trie2.Trie, error) {
return trie2.New(trieutils.NewContractTrieID(*stateComm), ContractTrieHeight, crypto.Pedersen, s.triedb)
return trie2.New(
trieutils.NewContractTrieID(felt.StateRootHash(*stateComm)),
ContractTrieHeight,
crypto.Pedersen,
s.triedb,
)
}

// Opens a contract storage trie for the given state root and contract address
func (s *StateDB) ContractStorageTrie(stateComm, owner *felt.Felt) (*trie2.Trie, error) {
return trie2.New(
trieutils.NewContractStorageTrieID(
*stateComm,
felt.StateRootHash(*stateComm),
felt.Address(*owner),
),
ContractStorageTrieHeight,
Expand Down
8 changes: 7 additions & 1 deletion core/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,13 @@ func (s *State) flush(
p := pool.New().WithMaxGoroutines(runtime.GOMAXPROCS(0)).WithErrors()

p.Go(func() error {
return s.db.triedb.Update(&update.curComm, &update.prevComm, blockNum, update.classNodes, update.contractNodes)
return s.db.triedb.Update(
(*felt.StateRootHash)(&update.curComm),
(*felt.StateRootHash)(&update.prevComm),
blockNum,
update.classNodes,
update.contractNodes,
)
})

batch := s.db.disk.NewBatch()
Expand Down
37 changes: 21 additions & 16 deletions core/trie2/databasetest.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func newTestNodeReader(id trieutils.TrieID, nodes []*trienode.MergeNodeSet, db d
func (n *testNodeReader) Node(
owner *felt.Address,
path *trieutils.Path,
hash *felt.Felt,
hash *felt.Hash,
isLeaf bool,
) ([]byte, error) {
for _, nodes := range n.nodes {
Expand All @@ -42,7 +42,8 @@ func (n *testNodeReader) Node(
continue
}
if _, ok := node.(*trienode.DeletedNode); ok {
return nil, &MissingNodeError{owner: *owner, path: *path, hash: node.Hash()}
hash := node.Hash()
return nil, &MissingNodeError{owner: *owner, path: *path, hash: felt.Hash(hash)}
}
return node.Blob(), nil
}
Expand All @@ -54,7 +55,7 @@ func readNode(
id trieutils.TrieID,
scheme dbScheme,
path *trieutils.Path,
hash *felt.Felt,
hash *felt.Hash,
isLeaf bool,
) ([]byte, error) {
owner := id.Owner()
Expand All @@ -68,20 +69,21 @@ func readNode(
}

type TestNodeDatabase struct {
disk db.KeyValueStore
root felt.Felt
scheme dbScheme
nodes map[felt.Felt]*trienode.MergeNodeSet
rootLinks map[felt.Felt]felt.Felt // map[child_root]parent_root - keep track of the parent root for each child root
disk db.KeyValueStore
root felt.StateRootHash
scheme dbScheme
nodes map[felt.StateRootHash]*trienode.MergeNodeSet
// map[child_root]parent_root - keep track of the parent root for each child root
rootLinks map[felt.StateRootHash]felt.StateRootHash
}

func NewTestNodeDatabase(disk db.KeyValueStore, scheme dbScheme) TestNodeDatabase {
return TestNodeDatabase{
disk: disk,
root: felt.Zero,
root: felt.StateRootHash{},
scheme: scheme,
nodes: make(map[felt.Felt]*trienode.MergeNodeSet),
rootLinks: make(map[felt.Felt]felt.Felt),
nodes: make(map[felt.StateRootHash]*trienode.MergeNodeSet),
rootLinks: make(map[felt.StateRootHash]felt.StateRootHash),
}
}

Expand All @@ -90,8 +92,8 @@ func (d *TestNodeDatabase) Update(root, parent *felt.Felt, nodes *trienode.Merge
return nil
}

rootVal := *root
parentVal := *parent
rootVal := felt.StateRootHash(*root)
parentVal := felt.StateRootHash(*parent)

if _, ok := d.nodes[rootVal]; ok { // already exists
return nil
Expand All @@ -109,10 +111,13 @@ func (d *TestNodeDatabase) NodeReader(id trieutils.TrieID) (database.NodeReader,
return newTestNodeReader(id, nodes, d.disk, d.scheme), nil
}

func (d *TestNodeDatabase) dirties(root *felt.Felt, newerFirst bool) ([]*trienode.MergeNodeSet, []felt.Felt) {
func (d *TestNodeDatabase) dirties(
root *felt.StateRootHash,
newerFirst bool,
) ([]*trienode.MergeNodeSet, []felt.StateRootHash) {
var (
pending []*trienode.MergeNodeSet
roots []felt.Felt
roots []felt.StateRootHash
)

rootVal := *root
Expand All @@ -128,7 +133,7 @@ func (d *TestNodeDatabase) dirties(root *felt.Felt, newerFirst bool) ([]*trienod
roots = append(roots, rootVal)
} else {
pending = append([]*trienode.MergeNodeSet{nodes}, pending...)
roots = append([]felt.Felt{rootVal}, roots...)
roots = append([]felt.StateRootHash{rootVal}, roots...)
}

rootVal = d.rootLinks[rootVal]
Expand Down
2 changes: 1 addition & 1 deletion core/trie2/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type MissingNodeError struct {
tt trieutils.TrieType
owner felt.Address
path trieutils.Path
hash felt.Felt
hash felt.Hash
err error
}

Expand Down
4 changes: 2 additions & 2 deletions core/trie2/node_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func newNodeReader(id trieutils.TrieID, nodeDB database.NodeDatabase) (nodeReade
return nodeReader{id: id, reader: reader}, nil
}

func (r *nodeReader) node(path trieutils.Path, hash *felt.Felt, isLeaf bool) ([]byte, error) {
func (r *nodeReader) node(path trieutils.Path, hash *felt.Hash, isLeaf bool) ([]byte, error) {
if r.reader == nil {
return nil, &MissingNodeError{tt: r.id.Type(), owner: r.id.Owner(), path: path, hash: *hash}
}
Expand All @@ -28,5 +28,5 @@ func (r *nodeReader) node(path trieutils.Path, hash *felt.Felt, isLeaf bool) ([]
}

func NewEmptyNodeReader() nodeReader {
return nodeReader{id: trieutils.NewEmptyTrieID(felt.Zero), reader: nil}
return nodeReader{id: trieutils.NewEmptyTrieID(felt.StateRootHash{}), reader: nil}
}
36 changes: 28 additions & 8 deletions core/trie2/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func New(
}

stateComm := id.StateComm()
if stateComm.IsZero() {
if felt.IsZero(&stateComm) {
return tr, nil
}

Expand Down Expand Up @@ -107,7 +107,7 @@ func NewFromRootHash(
}

stateComm := id.StateComm()
if stateComm.IsZero() {
if felt.IsZero(&stateComm) {
return tr, nil
}

Expand Down Expand Up @@ -533,7 +533,7 @@ func (t *Trie) resolveNode(hn *trienode.HashNode, path Path) (trienode.Node, err
hash = felt.Felt(*hn)
}

blob, err := t.nodeReader.node(path, &hash, path.Len() == t.height)
blob, err := t.nodeReader.node(path, (*felt.Hash)(&hash), path.Len() == t.height)
if err != nil {
return nil, err
}
Expand All @@ -544,7 +544,7 @@ func (t *Trie) resolveNode(hn *trienode.HashNode, path Path) (trienode.Node, err
// Resolves the node at the given path from the database
func (t *Trie) resolveNodeWithHash(path *Path, hash *felt.Felt) (trienode.Node, error) {
isLeaf := path.Len() == t.height
blob, err := t.nodeReader.node(*path, hash, isLeaf)
blob, err := t.nodeReader.node(*path, (*felt.Hash)(hash), isLeaf)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -580,23 +580,43 @@ func (t *Trie) String() string {
}

func NewEmptyPedersen() (*Trie, error) {
return New(trieutils.NewEmptyTrieID(felt.Zero), contractClassTrieHeight, crypto.Pedersen, triedb.NewEmptyNodeDatabase())
return New(
trieutils.NewEmptyTrieID(felt.StateRootHash{}),
contractClassTrieHeight,
crypto.Pedersen,
triedb.NewEmptyNodeDatabase(),
)
}

func NewEmptyPoseidon() (*Trie, error) {
return New(trieutils.NewEmptyTrieID(felt.Zero), contractClassTrieHeight, crypto.Poseidon, triedb.NewEmptyNodeDatabase())
return New(
trieutils.NewEmptyTrieID(felt.StateRootHash{}),
contractClassTrieHeight,
crypto.Poseidon,
triedb.NewEmptyNodeDatabase(),
)
}

func RunOnTempTriePedersen(height uint8, do func(*Trie) error) error {
trie, err := New(trieutils.NewEmptyTrieID(felt.Zero), height, crypto.Pedersen, triedb.NewEmptyNodeDatabase())
trie, err := New(
trieutils.NewEmptyTrieID(felt.StateRootHash{}),
height,
crypto.Pedersen,
triedb.NewEmptyNodeDatabase(),
)
if err != nil {
return err
}
return do(trie)
}

func RunOnTempTriePoseidon(height uint8, do func(*Trie) error) error {
trie, err := New(trieutils.NewEmptyTrieID(felt.Zero), height, crypto.Poseidon, triedb.NewEmptyNodeDatabase())
trie, err := New(
trieutils.NewEmptyTrieID(felt.StateRootHash{}),
height,
crypto.Poseidon,
triedb.NewEmptyNodeDatabase(),
)
if err != nil {
return err
}
Expand Down
14 changes: 12 additions & 2 deletions core/trie2/trie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,12 @@ func runRandTest(rt randTest) error {
db := memory.New()
curRoot := felt.Zero
trieDB := NewTestNodeDatabase(db, scheme)
tr, err := New(trieutils.NewContractTrieID(curRoot), contractClassTrieHeight, crypto.Pedersen, &trieDB)
tr, err := New(
trieutils.NewContractTrieID(felt.StateRootHash(curRoot)),
contractClassTrieHeight,
crypto.Pedersen,
&trieDB,
)
if err != nil {
return err
}
Expand Down Expand Up @@ -345,7 +350,12 @@ func runRandTest(rt randTest) error {
}
}

newtr, err := New(trieutils.NewContractTrieID(root), contractClassTrieHeight, crypto.Pedersen, &trieDB)
newtr, err := New(
trieutils.NewContractTrieID(felt.StateRootHash(root)),
contractClassTrieHeight,
crypto.Pedersen,
&trieDB,
)
if err != nil {
rt[i].err = fmt.Errorf("new trie failed: %w", err)
}
Expand Down
6 changes: 3 additions & 3 deletions core/trie2/triedb/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const (

// Represents a reader for trie nodes
type NodeReader interface {
Node(owner *felt.Address, path *trieutils.Path, hash *felt.Felt, isLeaf bool) ([]byte, error)
Node(owner *felt.Address, path *trieutils.Path, hash *felt.Hash, isLeaf bool) ([]byte, error)
}

// Represents a database that produces a node reader for a given trie id
Expand All @@ -37,10 +37,10 @@ type TrieDB interface {
NodeIterator
io.Closer

Commit(stateComm *felt.Felt) error
Commit(stateComm *felt.StateRootHash) error
Update(
root,
parent *felt.Felt,
parent *felt.StateRootHash,
blockNum uint64,
mergeClassNodes,
mergeContractNodes *trienode.MergeNodeSet,
Expand Down
2 changes: 1 addition & 1 deletion core/trie2/triedb/empty.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type EmptyNodeReader struct{}
func (EmptyNodeReader) Node(
owner *felt.Address,
path *trieutils.Path,
hash *felt.Felt,
hash *felt.Hash,
isLeaf bool,
) ([]byte, error) {
return nil, nil
Expand Down
4 changes: 2 additions & 2 deletions core/trie2/triedb/hashdb/clean_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ func newCleanCache(size uint64) cleanCache {
}
}

func (c *cleanCache) getNode(path *trieutils.Path, hash *felt.Felt) []byte {
func (c *cleanCache) getNode(path *trieutils.Path, hash *felt.Hash) []byte {
key := nodeKey(path, hash)
value := c.cache.Get(nil, key)

return value
}

func (c *cleanCache) putNode(path *trieutils.Path, hash *felt.Felt, value []byte) {
func (c *cleanCache) putNode(path *trieutils.Path, hash *felt.Hash, value []byte) {
key := nodeKey(path, hash)
c.cache.Set(key, value)
}
Loading
Loading