aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--core/blockchain.go12
-rw-r--r--core/headerchain.go37
-rw-r--r--eth/handler.go37
-rw-r--r--les/handler.go35
-rw-r--r--light/lightchain.go12
5 files changed, 106 insertions, 27 deletions
diff --git a/core/blockchain.go b/core/blockchain.go
index ea26fa034..34832252a 100644
--- a/core/blockchain.go
+++ b/core/blockchain.go
@@ -1524,6 +1524,18 @@ func (bc *BlockChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []com
return bc.hc.GetBlockHashesFromHash(hash, max)
}
+// GetAncestor retrieves the Nth ancestor of a given block. It assumes that either the given block or
+// a close ancestor of it is canonical. maxNonCanonical points to a downwards counter limiting the
+// number of blocks to be individually checked before we reach the canonical chain.
+//
+// Note: ancestor == 0 returns the same block, 1 returns its parent and so on.
+func (bc *BlockChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) {
+ bc.chainmu.Lock()
+ defer bc.chainmu.Unlock()
+
+ return bc.hc.GetAncestor(hash, number, ancestor, maxNonCanonical)
+}
+
// GetHeaderByNumber retrieves a block header from the database by number,
// caching it (associated with its hash) if found.
func (bc *BlockChain) GetHeaderByNumber(number uint64) *types.Header {
diff --git a/core/headerchain.go b/core/headerchain.go
index 2ac0cccc7..6e759ed1c 100644
--- a/core/headerchain.go
+++ b/core/headerchain.go
@@ -307,6 +307,43 @@ func (hc *HeaderChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []co
return chain
}
+// GetAncestor retrieves the Nth ancestor of a given block. It assumes that either the given block or
+// a close ancestor of it is canonical. maxNonCanonical points to a downwards counter limiting the
+// number of blocks to be individually checked before we reach the canonical chain.
+//
+// Note: ancestor == 0 returns the same block, 1 returns its parent and so on.
+func (hc *HeaderChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) {
+ if ancestor > number {
+ return common.Hash{}, 0
+ }
+ if ancestor == 1 {
+ // in this case it is cheaper to just read the header
+ if header := hc.GetHeader(hash, number); header != nil {
+ return header.ParentHash, number - 1
+ } else {
+ return common.Hash{}, 0
+ }
+ }
+ for ancestor != 0 {
+ if rawdb.ReadCanonicalHash(hc.chainDb, number) == hash {
+ number -= ancestor
+ return rawdb.ReadCanonicalHash(hc.chainDb, number), number
+ }
+ if *maxNonCanonical == 0 {
+ return common.Hash{}, 0
+ }
+ *maxNonCanonical--
+ ancestor--
+ header := hc.GetHeader(hash, number)
+ if header == nil {
+ return common.Hash{}, 0
+ }
+ hash = header.ParentHash
+ number--
+ }
+ return hash, number
+}
+
// GetTd retrieves a block's total difficulty in the canonical chain from the
// database by hash and number, caching it if found.
func (hc *HeaderChain) GetTd(hash common.Hash, number uint64) *big.Int {
diff --git a/eth/handler.go b/eth/handler.go
index 918d71088..a46e7f13c 100644
--- a/eth/handler.go
+++ b/eth/handler.go
@@ -340,6 +340,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
return errResp(ErrDecode, "%v: %v", msg, err)
}
hashMode := query.Origin.Hash != (common.Hash{})
+ first := true
+ maxNonCanonical := uint64(100)
// Gather headers until the fetch or network limits is reached
var (
@@ -351,31 +353,36 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
// Retrieve the next header satisfying the query
var origin *types.Header
if hashMode {
- origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash)
+ if first {
+ first = false
+ origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash)
+ if origin != nil {
+ query.Origin.Number = origin.Number.Uint64()
+ }
+ } else {
+ origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
+ }
} else {
origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number)
}
if origin == nil {
break
}
- number := origin.Number.Uint64()
headers = append(headers, origin)
bytes += estHeaderRlpSize
// Advance to the next header of the query
switch {
- case query.Origin.Hash != (common.Hash{}) && query.Reverse:
+ case hashMode && query.Reverse:
// Hash based traversal towards the genesis block
- for i := 0; i < int(query.Skip)+1; i++ {
- if header := pm.blockchain.GetHeader(query.Origin.Hash, number); header != nil {
- query.Origin.Hash = header.ParentHash
- number--
- } else {
- unknown = true
- break
- }
+ ancestor := query.Skip + 1
+ if ancestor == 0 {
+ unknown = true
+ } else {
+ query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
+ unknown = (query.Origin.Hash == common.Hash{})
}
- case query.Origin.Hash != (common.Hash{}) && !query.Reverse:
+ case hashMode && !query.Reverse:
// Hash based traversal towards the leaf block
var (
current = origin.Number.Uint64()
@@ -387,8 +394,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
unknown = true
} else {
if header := pm.blockchain.GetHeaderByNumber(next); header != nil {
- if pm.blockchain.GetBlockHashesFromHash(header.Hash(), query.Skip+1)[query.Skip] == query.Origin.Hash {
- query.Origin.Hash = header.Hash()
+ nextHash := header.Hash()
+ expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
+ if expOldHash == query.Origin.Hash {
+ query.Origin.Hash, query.Origin.Number = nextHash, next
} else {
unknown = true
}
diff --git a/les/handler.go b/les/handler.go
index 38f810d72..a1c16cb87 100644
--- a/les/handler.go
+++ b/les/handler.go
@@ -83,7 +83,7 @@ type BlockChain interface {
InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error)
Rollback(chain []common.Hash)
GetHeaderByNumber(number uint64) *types.Header
- GetBlockHashesFromHash(hash common.Hash, max uint64) []common.Hash
+ GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64)
Genesis() *types.Block
SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription
}
@@ -419,6 +419,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
}
hashMode := query.Origin.Hash != (common.Hash{})
+ first := true
+ maxNonCanonical := uint64(100)
// Gather headers until the fetch or network limits is reached
var (
@@ -430,14 +432,21 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
// Retrieve the next header satisfying the query
var origin *types.Header
if hashMode {
- origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash)
+ if first {
+ first = false
+ origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash)
+ if origin != nil {
+ query.Origin.Number = origin.Number.Uint64()
+ }
+ } else {
+ origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
+ }
} else {
origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number)
}
if origin == nil {
break
}
- number := origin.Number.Uint64()
headers = append(headers, origin)
bytes += estHeaderRlpSize
@@ -445,14 +454,12 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
switch {
case hashMode && query.Reverse:
// Hash based traversal towards the genesis block
- for i := 0; i < int(query.Skip)+1; i++ {
- if header := pm.blockchain.GetHeader(query.Origin.Hash, number); header != nil {
- query.Origin.Hash = header.ParentHash
- number--
- } else {
- unknown = true
- break
- }
+ ancestor := query.Skip + 1
+ if ancestor == 0 {
+ unknown = true
+ } else {
+ query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
+ unknown = (query.Origin.Hash == common.Hash{})
}
case hashMode && !query.Reverse:
// Hash based traversal towards the leaf block
@@ -466,8 +473,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
unknown = true
} else {
if header := pm.blockchain.GetHeaderByNumber(next); header != nil {
- if pm.blockchain.GetBlockHashesFromHash(header.Hash(), query.Skip+1)[query.Skip] == query.Origin.Hash {
- query.Origin.Hash = header.Hash()
+ nextHash := header.Hash()
+ expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
+ if expOldHash == query.Origin.Hash {
+ query.Origin.Hash, query.Origin.Number = nextHash, next
} else {
unknown = true
}
diff --git a/light/lightchain.go b/light/lightchain.go
index 9d0a4e4f7..30b9bd89a 100644
--- a/light/lightchain.go
+++ b/light/lightchain.go
@@ -433,6 +433,18 @@ func (self *LightChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []c
return self.hc.GetBlockHashesFromHash(hash, max)
}
+// GetAncestor retrieves the Nth ancestor of a given block. It assumes that either the given block or
+// a close ancestor of it is canonical. maxNonCanonical points to a downwards counter limiting the
+// number of blocks to be individually checked before we reach the canonical chain.
+//
+// Note: ancestor == 0 returns the same block, 1 returns its parent and so on.
+func (bc *LightChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) {
+ bc.chainmu.Lock()
+ defer bc.chainmu.Unlock()
+
+ return bc.hc.GetAncestor(hash, number, ancestor, maxNonCanonical)
+}
+
// GetHeaderByNumber retrieves a block header from the database by number,
// caching it (associated with its hash) if found.
func (self *LightChain) GetHeaderByNumber(number uint64) *types.Header {