diff options
Diffstat (limited to 'swarm/bmt/bmt_test.go')
-rw-r--r-- | swarm/bmt/bmt_test.go | 122 |
1 files changed, 98 insertions, 24 deletions
diff --git a/swarm/bmt/bmt_test.go b/swarm/bmt/bmt_test.go index e074d90e7..ae40eadab 100644 --- a/swarm/bmt/bmt_test.go +++ b/swarm/bmt/bmt_test.go @@ -34,12 +34,12 @@ import ( // the actual data length generated (could be longer than max datalength of the BMT) const BufferSize = 4128 +var counts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128} + +// calculates the Keccak256 SHA3 hash of the data func sha3hash(data ...[]byte) []byte { h := sha3.NewKeccak256() - for _, v := range data { - h.Write(v) - } - return h.Sum(nil) + return doHash(h, nil, data...) } // TestRefHasher tests that the RefHasher computes the expected BMT hash for @@ -129,31 +129,48 @@ func TestRefHasher(t *testing.T) { } } -func TestHasherCorrectness(t *testing.T) { - err := testHasher(testBaseHasher) - if err != nil { - t.Fatal(err) +// tests if hasher responds with correct hash +func TestHasherEmptyData(t *testing.T) { + hasher := sha3.NewKeccak256 + var data []byte + for _, count := range counts { + t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) { + pool := NewTreePool(hasher, count, PoolSize) + defer pool.Drain(0) + bmt := New(pool) + rbmt := NewRefHasher(hasher, count) + refHash := rbmt.Hash(data) + expHash := Hash(bmt, nil, data) + if !bytes.Equal(expHash, refHash) { + t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash) + } + }) } } -func testHasher(f func(BaseHasherFunc, []byte, int, int) error) error { +func TestHasherCorrectness(t *testing.T) { data := newData(BufferSize) hasher := sha3.NewKeccak256 size := hasher().Size() - counts := []int{1, 2, 3, 4, 5, 8, 16, 32, 64, 128} var err error for _, count := range counts { - max := count * size - incr := 1 - for n := 1; n <= max; n += incr { - err = f(hasher, data, n, count) - if err != nil { - return err + t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) { + max := count * size + incr := 1 + capacity := 1 + pool := NewTreePool(hasher, count, capacity) + defer pool.Drain(0) + for n := 0; n <= max; n += incr { + incr = 1 + rand.Intn(5) + bmt := New(pool) + err = testHasherCorrectness(bmt, hasher, data, n, count) + if err != nil { + t.Fatal(err) + } } - } + }) } - return nil } // Tests that the BMT hasher can be synchronously reused with poolsizes 1 and PoolSize @@ -215,12 +232,69 @@ LOOP: } } -// helper function that creates a tree pool -func testBaseHasher(hasher BaseHasherFunc, d []byte, n, count int) error { - pool := NewTreePool(hasher, count, 1) - defer pool.Drain(0) - bmt := New(pool) - return testHasherCorrectness(bmt, hasher, d, n, count) +// Tests BMT Hasher io.Writer interface is working correctly +// even multiple short random write buffers +func TestBMTHasherWriterBuffers(t *testing.T) { + hasher := sha3.NewKeccak256 + + for _, count := range counts { + t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) { + errc := make(chan error) + pool := NewTreePool(hasher, count, PoolSize) + defer pool.Drain(0) + n := count * 32 + bmt := New(pool) + data := newData(n) + rbmt := NewRefHasher(hasher, count) + refHash := rbmt.Hash(data) + expHash := Hash(bmt, nil, data) + if !bytes.Equal(expHash, refHash) { + t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash) + } + attempts := 10 + f := func() error { + bmt := New(pool) + bmt.Reset() + var buflen int + for offset := 0; offset < n; offset += buflen { + buflen = rand.Intn(n-offset) + 1 + read, err := bmt.Write(data[offset : offset+buflen]) + if err != nil { + return err + } + if read != buflen { + return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read) + } + } + hash := bmt.Sum(nil) + if !bytes.Equal(hash, expHash) { + return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash) + } + return nil + } + + for j := 0; j < attempts; j++ { + go func() { + errc <- f() + }() + } + timeout := time.NewTimer(2 * time.Second) + for { + select { + case err := <-errc: + if err != nil { + t.Fatal(err) + } + attempts-- + if attempts == 0 { + return + } + case <-timeout.C: + t.Fatalf("timeout") + } + } + }) + } } // helper function that compares reference and optimised implementations on |