aboutsummaryrefslogtreecommitdiffstats
path: root/swarm/pss/pss.go
diff options
context:
space:
mode:
Diffstat (limited to 'swarm/pss/pss.go')
-rw-r--r--swarm/pss/pss.go171
1 files changed, 131 insertions, 40 deletions
diff --git a/swarm/pss/pss.go b/swarm/pss/pss.go
index e1e24e1f5..d0986d280 100644
--- a/swarm/pss/pss.go
+++ b/swarm/pss/pss.go
@@ -23,11 +23,13 @@ import (
"crypto/rand"
"errors"
"fmt"
+ "hash"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/crypto/sha3"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
@@ -136,10 +138,10 @@ type Pss struct {
symKeyDecryptCacheCapacity int // max amount of symkeys to keep.
// message handling
- handlers map[Topic]map[*Handler]bool // topic and version based pss payload handlers. See pss.Handle()
- handlersMu sync.RWMutex
- allowRaw bool
- hashPool sync.Pool
+ handlers map[Topic]map[*handler]bool // topic and version based pss payload handlers. See pss.Handle()
+ handlersMu sync.RWMutex
+ hashPool sync.Pool
+ topicHandlerCaps map[Topic]*handlerCaps // caches capabilities of each topic's handlers (see handlerCap* consts in types.go)
// process
quitC chan struct{}
@@ -180,11 +182,12 @@ func NewPss(k *network.Kademlia, params *PssParams) (*Pss, error) {
symKeyDecryptCache: make([]*string, params.SymKeyCacheCapacity),
symKeyDecryptCacheCapacity: params.SymKeyCacheCapacity,
- handlers: make(map[Topic]map[*Handler]bool),
- allowRaw: params.AllowRaw,
+ handlers: make(map[Topic]map[*handler]bool),
+ topicHandlerCaps: make(map[Topic]*handlerCaps),
+
hashPool: sync.Pool{
New: func() interface{} {
- return storage.MakeHashFunc(storage.DefaultHash)()
+ return sha3.NewKeccak256()
},
},
}
@@ -313,30 +316,54 @@ func (p *Pss) PublicKey() *ecdsa.PublicKey {
//
// Returns a deregister function which needs to be called to
// deregister the handler,
-func (p *Pss) Register(topic *Topic, handler Handler) func() {
+func (p *Pss) Register(topic *Topic, hndlr *handler) func() {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
handlers := p.handlers[*topic]
if handlers == nil {
- handlers = make(map[*Handler]bool)
+ handlers = make(map[*handler]bool)
p.handlers[*topic] = handlers
+ log.Debug("registered handler", "caps", hndlr.caps)
+ }
+ if hndlr.caps == nil {
+ hndlr.caps = &handlerCaps{}
+ }
+ handlers[hndlr] = true
+ if _, ok := p.topicHandlerCaps[*topic]; !ok {
+ p.topicHandlerCaps[*topic] = &handlerCaps{}
}
- handlers[&handler] = true
- return func() { p.deregister(topic, &handler) }
+ if hndlr.caps.raw {
+ p.topicHandlerCaps[*topic].raw = true
+ }
+ if hndlr.caps.prox {
+ p.topicHandlerCaps[*topic].prox = true
+ }
+ return func() { p.deregister(topic, hndlr) }
}
-func (p *Pss) deregister(topic *Topic, h *Handler) {
+func (p *Pss) deregister(topic *Topic, hndlr *handler) {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
handlers := p.handlers[*topic]
- if len(handlers) == 1 {
+ if len(handlers) > 1 {
delete(p.handlers, *topic)
+ // topic caps might have changed now that a handler is gone
+ caps := &handlerCaps{}
+ for h := range handlers {
+ if h.caps.raw {
+ caps.raw = true
+ }
+ if h.caps.prox {
+ caps.prox = true
+ }
+ }
+ p.topicHandlerCaps[*topic] = caps
return
}
- delete(handlers, h)
+ delete(handlers, hndlr)
}
// get all registered handlers for respective topics
-func (p *Pss) getHandlers(topic Topic) map[*Handler]bool {
+func (p *Pss) getHandlers(topic Topic) map[*handler]bool {
p.handlersMu.RLock()
defer p.handlersMu.RUnlock()
return p.handlers[topic]
@@ -348,12 +375,11 @@ func (p *Pss) getHandlers(topic Topic) map[*Handler]bool {
// Only passes error to pss protocol handler if payload is not valid pssmsg
func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error {
metrics.GetOrRegisterCounter("pss.handlepssmsg", nil).Inc(1)
-
pssmsg, ok := msg.(*PssMsg)
-
if !ok {
return fmt.Errorf("invalid message type. Expected *PssMsg, got %T ", msg)
}
+ log.Trace("handler", "self", label(p.Kademlia.BaseAddr()), "topic", label(pssmsg.Payload.Topic[:]))
if int64(pssmsg.Expire) < time.Now().Unix() {
metrics.GetOrRegisterCounter("pss.expire", nil).Inc(1)
log.Warn("pss filtered expired message", "from", common.ToHex(p.Kademlia.BaseAddr()), "to", common.ToHex(pssmsg.To))
@@ -365,13 +391,34 @@ func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error {
}
p.addFwdCache(pssmsg)
- if !p.isSelfPossibleRecipient(pssmsg) {
- log.Trace("pss was for someone else :'( ... forwarding", "pss", common.ToHex(p.BaseAddr()))
+ psstopic := Topic(pssmsg.Payload.Topic)
+
+ // raw is simplest handler contingency to check, so check that first
+ var isRaw bool
+ if pssmsg.isRaw() {
+ if !p.topicHandlerCaps[psstopic].raw {
+ log.Debug("No handler for raw message", "topic", psstopic)
+ return nil
+ }
+ isRaw = true
+ }
+
+ // check if we can be recipient:
+ // - no prox handler on message and partial address matches
+ // - prox handler on message and we are in prox regardless of partial address match
+ // store this result so we don't calculate again on every handler
+ var isProx bool
+ if _, ok := p.topicHandlerCaps[psstopic]; ok {
+ isProx = p.topicHandlerCaps[psstopic].prox
+ }
+ isRecipient := p.isSelfPossibleRecipient(pssmsg, isProx)
+ if !isRecipient {
+ log.Trace("pss was for someone else :'( ... forwarding", "pss", common.ToHex(p.BaseAddr()), "prox", isProx)
return p.enqueue(pssmsg)
}
- log.Trace("pss for us, yay! ... let's process!", "pss", common.ToHex(p.BaseAddr()))
- if err := p.process(pssmsg); err != nil {
+ log.Trace("pss for us, yay! ... let's process!", "pss", common.ToHex(p.BaseAddr()), "prox", isProx, "raw", isRaw, "topic", label(pssmsg.Payload.Topic[:]))
+ if err := p.process(pssmsg, isRaw, isProx); err != nil {
qerr := p.enqueue(pssmsg)
if qerr != nil {
return fmt.Errorf("process fail: processerr %v, queueerr: %v", err, qerr)
@@ -384,7 +431,7 @@ func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error {
// Entry point to processing a message for which the current node can be the intended recipient.
// Attempts symmetric and asymmetric decryption with stored keys.
// Dispatches message to all handlers matching the message topic
-func (p *Pss) process(pssmsg *PssMsg) error {
+func (p *Pss) process(pssmsg *PssMsg, raw bool, prox bool) error {
metrics.GetOrRegisterCounter("pss.process", nil).Inc(1)
var err error
@@ -397,10 +444,8 @@ func (p *Pss) process(pssmsg *PssMsg) error {
envelope := pssmsg.Payload
psstopic := Topic(envelope.Topic)
- if pssmsg.isRaw() {
- if !p.allowRaw {
- return errors.New("raw message support disabled")
- }
+
+ if raw {
payload = pssmsg.Payload.Data
} else {
if pssmsg.isSym() {
@@ -422,19 +467,27 @@ func (p *Pss) process(pssmsg *PssMsg) error {
return err
}
}
- p.executeHandlers(psstopic, payload, from, asymmetric, keyid)
+ p.executeHandlers(psstopic, payload, from, raw, prox, asymmetric, keyid)
return nil
}
-func (p *Pss) executeHandlers(topic Topic, payload []byte, from *PssAddress, asymmetric bool, keyid string) {
+func (p *Pss) executeHandlers(topic Topic, payload []byte, from *PssAddress, raw bool, prox bool, asymmetric bool, keyid string) {
handlers := p.getHandlers(topic)
peer := p2p.NewPeer(enode.ID{}, fmt.Sprintf("%x", from), []p2p.Cap{})
- for f := range handlers {
- err := (*f)(payload, peer, asymmetric, keyid)
+ for h := range handlers {
+ if !h.caps.raw && raw {
+ log.Warn("norawhandler")
+ continue
+ }
+ if !h.caps.prox && prox {
+ log.Warn("noproxhandler")
+ continue
+ }
+ err := (h.f)(payload, peer, asymmetric, keyid)
if err != nil {
- log.Warn("Pss handler %p failed: %v", f, err)
+ log.Warn("Pss handler failed", "err", err)
}
}
}
@@ -445,9 +498,23 @@ func (p *Pss) isSelfRecipient(msg *PssMsg) bool {
}
// test match of leftmost bytes in given message to node's Kademlia address
-func (p *Pss) isSelfPossibleRecipient(msg *PssMsg) bool {
+func (p *Pss) isSelfPossibleRecipient(msg *PssMsg, prox bool) bool {
local := p.Kademlia.BaseAddr()
- return bytes.Equal(msg.To, local[:len(msg.To)])
+
+ // if a partial address matches we are possible recipient regardless of prox
+ // if not and prox is not set, we are surely not
+ if bytes.Equal(msg.To, local[:len(msg.To)]) {
+
+ return true
+ } else if !prox {
+ return false
+ }
+
+ depth := p.Kademlia.NeighbourhoodDepth()
+ po, _ := p.Kademlia.Pof(p.Kademlia.BaseAddr(), msg.To, 0)
+ log.Trace("selfpossible", "po", po, "depth", depth)
+
+ return depth <= po
}
/////////////////////////////////////////////////////////////////////
@@ -684,9 +751,6 @@ func (p *Pss) enqueue(msg *PssMsg) error {
//
// Will fail if raw messages are disallowed
func (p *Pss) SendRaw(address PssAddress, topic Topic, msg []byte) error {
- if !p.allowRaw {
- return errors.New("Raw messages not enabled")
- }
pssMsgParams := &msgParams{
raw: true,
}
@@ -699,7 +763,17 @@ func (p *Pss) SendRaw(address PssAddress, topic Topic, msg []byte) error {
pssMsg.Expire = uint32(time.Now().Add(p.msgTTL).Unix())
pssMsg.Payload = payload
p.addFwdCache(pssMsg)
- return p.enqueue(pssMsg)
+ err := p.enqueue(pssMsg)
+ if err != nil {
+ return err
+ }
+
+ // if we have a proxhandler on this topic
+ // also deliver message to ourselves
+ if p.isSelfPossibleRecipient(pssMsg, true) && p.topicHandlerCaps[topic].prox {
+ return p.process(pssMsg, true, true)
+ }
+ return nil
}
// Send a message using symmetric encryption
@@ -800,7 +874,16 @@ func (p *Pss) send(to []byte, topic Topic, msg []byte, asymmetric bool, key []by
pssMsg.To = to
pssMsg.Expire = uint32(time.Now().Add(p.msgTTL).Unix())
pssMsg.Payload = envelope
- return p.enqueue(pssMsg)
+ err = p.enqueue(pssMsg)
+ if err != nil {
+ return err
+ }
+ if _, ok := p.topicHandlerCaps[topic]; ok {
+ if p.isSelfPossibleRecipient(pssMsg, true) && p.topicHandlerCaps[topic].prox {
+ return p.process(pssMsg, true, true)
+ }
+ }
+ return nil
}
// Forwards a pss message to the peer(s) closest to the to recipient address in the PssMsg struct
@@ -895,6 +978,10 @@ func (p *Pss) cleanFwdCache() {
}
}
+func label(b []byte) string {
+ return fmt.Sprintf("%04x", b[:2])
+}
+
// add a message to the cache
func (p *Pss) addFwdCache(msg *PssMsg) error {
metrics.GetOrRegisterCounter("pss.addfwdcache", nil).Inc(1)
@@ -934,10 +1021,14 @@ func (p *Pss) checkFwdCache(msg *PssMsg) bool {
// Digest of message
func (p *Pss) digest(msg *PssMsg) pssDigest {
- hasher := p.hashPool.Get().(storage.SwarmHash)
+ return p.digestBytes(msg.serialize())
+}
+
+func (p *Pss) digestBytes(msg []byte) pssDigest {
+ hasher := p.hashPool.Get().(hash.Hash)
defer p.hashPool.Put(hasher)
hasher.Reset()
- hasher.Write(msg.serialize())
+ hasher.Write(msg)
digest := pssDigest{}
key := hasher.Sum(nil)
copy(digest[:], key[:digestLength])