aboutsummaryrefslogtreecommitdiffstats
path: root/swarm/bmt/bmt_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'swarm/bmt/bmt_test.go')
-rw-r--r--swarm/bmt/bmt_test.go122
1 files changed, 98 insertions, 24 deletions
diff --git a/swarm/bmt/bmt_test.go b/swarm/bmt/bmt_test.go
index e074d90e7..ae40eadab 100644
--- a/swarm/bmt/bmt_test.go
+++ b/swarm/bmt/bmt_test.go
@@ -34,12 +34,12 @@ import (
// the actual data length generated (could be longer than max datalength of the BMT)
const BufferSize = 4128
+var counts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128}
+
+// calculates the Keccak256 SHA3 hash of the data
func sha3hash(data ...[]byte) []byte {
h := sha3.NewKeccak256()
- for _, v := range data {
- h.Write(v)
- }
- return h.Sum(nil)
+ return doHash(h, nil, data...)
}
// TestRefHasher tests that the RefHasher computes the expected BMT hash for
@@ -129,31 +129,48 @@ func TestRefHasher(t *testing.T) {
}
}
-func TestHasherCorrectness(t *testing.T) {
- err := testHasher(testBaseHasher)
- if err != nil {
- t.Fatal(err)
+// tests if hasher responds with correct hash
+func TestHasherEmptyData(t *testing.T) {
+ hasher := sha3.NewKeccak256
+ var data []byte
+ for _, count := range counts {
+ t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
+ pool := NewTreePool(hasher, count, PoolSize)
+ defer pool.Drain(0)
+ bmt := New(pool)
+ rbmt := NewRefHasher(hasher, count)
+ refHash := rbmt.Hash(data)
+ expHash := Hash(bmt, nil, data)
+ if !bytes.Equal(expHash, refHash) {
+ t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
+ }
+ })
}
}
-func testHasher(f func(BaseHasherFunc, []byte, int, int) error) error {
+func TestHasherCorrectness(t *testing.T) {
data := newData(BufferSize)
hasher := sha3.NewKeccak256
size := hasher().Size()
- counts := []int{1, 2, 3, 4, 5, 8, 16, 32, 64, 128}
var err error
for _, count := range counts {
- max := count * size
- incr := 1
- for n := 1; n <= max; n += incr {
- err = f(hasher, data, n, count)
- if err != nil {
- return err
+ t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) {
+ max := count * size
+ incr := 1
+ capacity := 1
+ pool := NewTreePool(hasher, count, capacity)
+ defer pool.Drain(0)
+ for n := 0; n <= max; n += incr {
+ incr = 1 + rand.Intn(5)
+ bmt := New(pool)
+ err = testHasherCorrectness(bmt, hasher, data, n, count)
+ if err != nil {
+ t.Fatal(err)
+ }
}
- }
+ })
}
- return nil
}
// Tests that the BMT hasher can be synchronously reused with poolsizes 1 and PoolSize
@@ -215,12 +232,69 @@ LOOP:
}
}
-// helper function that creates a tree pool
-func testBaseHasher(hasher BaseHasherFunc, d []byte, n, count int) error {
- pool := NewTreePool(hasher, count, 1)
- defer pool.Drain(0)
- bmt := New(pool)
- return testHasherCorrectness(bmt, hasher, d, n, count)
+// Tests BMT Hasher io.Writer interface is working correctly
+// even multiple short random write buffers
+func TestBMTHasherWriterBuffers(t *testing.T) {
+ hasher := sha3.NewKeccak256
+
+ for _, count := range counts {
+ t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
+ errc := make(chan error)
+ pool := NewTreePool(hasher, count, PoolSize)
+ defer pool.Drain(0)
+ n := count * 32
+ bmt := New(pool)
+ data := newData(n)
+ rbmt := NewRefHasher(hasher, count)
+ refHash := rbmt.Hash(data)
+ expHash := Hash(bmt, nil, data)
+ if !bytes.Equal(expHash, refHash) {
+ t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
+ }
+ attempts := 10
+ f := func() error {
+ bmt := New(pool)
+ bmt.Reset()
+ var buflen int
+ for offset := 0; offset < n; offset += buflen {
+ buflen = rand.Intn(n-offset) + 1
+ read, err := bmt.Write(data[offset : offset+buflen])
+ if err != nil {
+ return err
+ }
+ if read != buflen {
+ return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read)
+ }
+ }
+ hash := bmt.Sum(nil)
+ if !bytes.Equal(hash, expHash) {
+ return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash)
+ }
+ return nil
+ }
+
+ for j := 0; j < attempts; j++ {
+ go func() {
+ errc <- f()
+ }()
+ }
+ timeout := time.NewTimer(2 * time.Second)
+ for {
+ select {
+ case err := <-errc:
+ if err != nil {
+ t.Fatal(err)
+ }
+ attempts--
+ if attempts == 0 {
+ return
+ }
+ case <-timeout.C:
+ t.Fatalf("timeout")
+ }
+ }
+ })
+ }
}
// helper function that compares reference and optimised implementations on