diff options
56 files changed, 7630 insertions, 356 deletions
diff --git a/.gitignore b/.gitignore index e53e461dc..cb2c2d14d 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,6 @@ build/_vendor/pkg # travis profile.tmp profile.cov + +# IdeaIDE +.idea diff --git a/cmd/evm/json_logger.go b/cmd/evm/json_logger.go index 2cfeaa795..eb7b0c466 100644 --- a/cmd/evm/json_logger.go +++ b/cmd/evm/json_logger.go @@ -40,7 +40,7 @@ func (l *JSONLogger) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cos log := vm.StructLog{ Pc: pc, Op: op, - Gas: gas + cost, + Gas: gas, GasCost: cost, MemorySize: memory.Len(), Storage: nil, diff --git a/cmd/p2psim/main.go b/cmd/p2psim/main.go new file mode 100644 index 000000000..56b74d135 --- /dev/null +++ b/cmd/p2psim/main.go @@ -0,0 +1,414 @@ +// p2psim provides a command-line client for a simulation HTTP API. +// +// Here is an example of creating a 2 node network with the first node +// connected to the second: +// +// $ p2psim node create +// Created node01 +// +// $ p2psim node start node01 +// Started node01 +// +// $ p2psim node create +// Created node02 +// +// $ p2psim node start node02 +// Started node02 +// +// $ p2psim node connect node01 node02 +// Connected node01 to node02 +// +package main + +import ( + "context" + "encoding/json" + "fmt" + "io" + "os" + "strings" + "text/tabwriter" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + "github.com/ethereum/go-ethereum/rpc" + "gopkg.in/urfave/cli.v1" +) + +var client *simulations.Client + +func main() { + app := cli.NewApp() + app.Usage = "devp2p simulation command-line client" + app.Flags = []cli.Flag{ + cli.StringFlag{ + Name: "api", + Value: "http://localhost:8888", + Usage: "simulation API URL", + EnvVar: "P2PSIM_API_URL", + }, + } + app.Before = func(ctx *cli.Context) error { + client = simulations.NewClient(ctx.GlobalString("api")) + return nil + } + app.Commands = []cli.Command{ + { + Name: "show", + Usage: "show network information", + Action: showNetwork, + }, + { + Name: "events", + Usage: "stream network events", + Action: streamNetwork, + Flags: []cli.Flag{ + cli.BoolFlag{ + Name: "current", + Usage: "get existing nodes and conns first", + }, + cli.StringFlag{ + Name: "filter", + Value: "", + Usage: "message filter", + }, + }, + }, + { + Name: "snapshot", + Usage: "create a network snapshot to stdout", + Action: createSnapshot, + }, + { + Name: "load", + Usage: "load a network snapshot from stdin", + Action: loadSnapshot, + }, + { + Name: "node", + Usage: "manage simulation nodes", + Action: listNodes, + Subcommands: []cli.Command{ + { + Name: "list", + Usage: "list nodes", + Action: listNodes, + }, + { + Name: "create", + Usage: "create a node", + Action: createNode, + Flags: []cli.Flag{ + cli.StringFlag{ + Name: "name", + Value: "", + Usage: "node name", + }, + cli.StringFlag{ + Name: "services", + Value: "", + Usage: "node services (comma separated)", + }, + cli.StringFlag{ + Name: "key", + Value: "", + Usage: "node private key (hex encoded)", + }, + }, + }, + { + Name: "show", + ArgsUsage: "<node>", + Usage: "show node information", + Action: showNode, + }, + { + Name: "start", + ArgsUsage: "<node>", + Usage: "start a node", + Action: startNode, + }, + { + Name: "stop", + ArgsUsage: "<node>", + Usage: "stop a node", + Action: stopNode, + }, + { + Name: "connect", + ArgsUsage: "<node> <peer>", + Usage: "connect a node to a peer node", + Action: connectNode, + }, + { + Name: "disconnect", + ArgsUsage: "<node> <peer>", + Usage: "disconnect a node from a peer node", + Action: disconnectNode, + }, + { + Name: "rpc", + ArgsUsage: "<node> <method> [<args>]", + Usage: "call a node RPC method", + Action: rpcNode, + Flags: []cli.Flag{ + cli.BoolFlag{ + Name: "subscribe", + Usage: "method is a subscription", + }, + }, + }, + }, + }, + } + app.Run(os.Args) +} + +func showNetwork(ctx *cli.Context) error { + if len(ctx.Args()) != 0 { + return cli.ShowCommandHelp(ctx, ctx.Command.Name) + } + network, err := client.GetNetwork() + if err != nil { + return err + } + w := tabwriter.NewWriter(ctx.App.Writer, 1, 2, 2, ' ', 0) + defer w.Flush() + fmt.Fprintf(w, "NODES\t%d\n", len(network.Nodes)) + fmt.Fprintf(w, "CONNS\t%d\n", len(network.Conns)) + return nil +} + +func streamNetwork(ctx *cli.Context) error { + if len(ctx.Args()) != 0 { + return cli.ShowCommandHelp(ctx, ctx.Command.Name) + } + events := make(chan *simulations.Event) + sub, err := client.SubscribeNetwork(events, simulations.SubscribeOpts{ + Current: ctx.Bool("current"), + Filter: ctx.String("filter"), + }) + if err != nil { + return err + } + defer sub.Unsubscribe() + enc := json.NewEncoder(ctx.App.Writer) + for { + select { + case event := <-events: + if err := enc.Encode(event); err != nil { + return err + } + case err := <-sub.Err(): + return err + } + } +} + +func createSnapshot(ctx *cli.Context) error { + if len(ctx.Args()) != 0 { + return cli.ShowCommandHelp(ctx, ctx.Command.Name) + } + snap, err := client.CreateSnapshot() + if err != nil { + return err + } + return json.NewEncoder(os.Stdout).Encode(snap) +} + +func loadSnapshot(ctx *cli.Context) error { + if len(ctx.Args()) != 0 { + return cli.ShowCommandHelp(ctx, ctx.Command.Name) + } + snap := &simulations.Snapshot{} + if err := json.NewDecoder(os.Stdin).Decode(snap); err != nil { + return err + } + return client.LoadSnapshot(snap) +} + +func listNodes(ctx *cli.Context) error { + if len(ctx.Args()) != 0 { + return cli.ShowCommandHelp(ctx, ctx.Command.Name) + } + nodes, err := client.GetNodes() + if err != nil { + return err + } + w := tabwriter.NewWriter(ctx.App.Writer, 1, 2, 2, ' ', 0) + defer w.Flush() + fmt.Fprintf(w, "NAME\tPROTOCOLS\tID\n") + for _, node := range nodes { + fmt.Fprintf(w, "%s\t%s\t%s\n", node.Name, strings.Join(protocolList(node), ","), node.ID) + } + return nil +} + +func protocolList(node *p2p.NodeInfo) []string { + protos := make([]string, 0, len(node.Protocols)) + for name := range node.Protocols { + protos = append(protos, name) + } + return protos +} + +func createNode(ctx *cli.Context) error { + if len(ctx.Args()) != 0 { + return cli.ShowCommandHelp(ctx, ctx.Command.Name) + } + config := &adapters.NodeConfig{ + Name: ctx.String("name"), + } + if key := ctx.String("key"); key != "" { + privKey, err := crypto.HexToECDSA(key) + if err != nil { + return err + } + config.ID = discover.PubkeyID(&privKey.PublicKey) + config.PrivateKey = privKey + } + if services := ctx.String("services"); services != "" { + config.Services = strings.Split(services, ",") + } + node, err := client.CreateNode(config) + if err != nil { + return err + } + fmt.Fprintln(ctx.App.Writer, "Created", node.Name) + return nil +} + +func showNode(ctx *cli.Context) error { + args := ctx.Args() + if len(args) != 1 { + return cli.ShowCommandHelp(ctx, ctx.Command.Name) + } + nodeName := args[0] + node, err := client.GetNode(nodeName) + if err != nil { + return err + } + w := tabwriter.NewWriter(ctx.App.Writer, 1, 2, 2, ' ', 0) + defer w.Flush() + fmt.Fprintf(w, "NAME\t%s\n", node.Name) + fmt.Fprintf(w, "PROTOCOLS\t%s\n", strings.Join(protocolList(node), ",")) + fmt.Fprintf(w, "ID\t%s\n", node.ID) + fmt.Fprintf(w, "ENODE\t%s\n", node.Enode) + for name, proto := range node.Protocols { + fmt.Fprintln(w) + fmt.Fprintf(w, "--- PROTOCOL INFO: %s\n", name) + fmt.Fprintf(w, "%v\n", proto) + fmt.Fprintf(w, "---\n") + } + return nil +} + +func startNode(ctx *cli.Context) error { + args := ctx.Args() + if len(args) != 1 { + return cli.ShowCommandHelp(ctx, ctx.Command.Name) + } + nodeName := args[0] + if err := client.StartNode(nodeName); err != nil { + return err + } + fmt.Fprintln(ctx.App.Writer, "Started", nodeName) + return nil +} + +func stopNode(ctx *cli.Context) error { + args := ctx.Args() + if len(args) != 1 { + return cli.ShowCommandHelp(ctx, ctx.Command.Name) + } + nodeName := args[0] + if err := client.StopNode(nodeName); err != nil { + return err + } + fmt.Fprintln(ctx.App.Writer, "Stopped", nodeName) + return nil +} + +func connectNode(ctx *cli.Context) error { + args := ctx.Args() + if len(args) != 2 { + return cli.ShowCommandHelp(ctx, ctx.Command.Name) + } + nodeName := args[0] + peerName := args[1] + if err := client.ConnectNode(nodeName, peerName); err != nil { + return err + } + fmt.Fprintln(ctx.App.Writer, "Connected", nodeName, "to", peerName) + return nil +} + +func disconnectNode(ctx *cli.Context) error { + args := ctx.Args() + if len(args) != 2 { + return cli.ShowCommandHelp(ctx, ctx.Command.Name) + } + nodeName := args[0] + peerName := args[1] + if err := client.DisconnectNode(nodeName, peerName); err != nil { + return err + } + fmt.Fprintln(ctx.App.Writer, "Disconnected", nodeName, "from", peerName) + return nil +} + +func rpcNode(ctx *cli.Context) error { + args := ctx.Args() + if len(args) < 2 { + return cli.ShowCommandHelp(ctx, ctx.Command.Name) + } + nodeName := args[0] + method := args[1] + rpcClient, err := client.RPCClient(context.Background(), nodeName) + if err != nil { + return err + } + if ctx.Bool("subscribe") { + return rpcSubscribe(rpcClient, ctx.App.Writer, method, args[3:]...) + } + var result interface{} + params := make([]interface{}, len(args[3:])) + for i, v := range args[3:] { + params[i] = v + } + if err := rpcClient.Call(&result, method, params...); err != nil { + return err + } + return json.NewEncoder(ctx.App.Writer).Encode(result) +} + +func rpcSubscribe(client *rpc.Client, out io.Writer, method string, args ...string) error { + parts := strings.SplitN(method, "_", 2) + namespace := parts[0] + method = parts[1] + ch := make(chan interface{}) + subArgs := make([]interface{}, len(args)+1) + subArgs[0] = method + for i, v := range args { + subArgs[i+1] = v + } + sub, err := client.Subscribe(context.Background(), namespace, ch, subArgs...) + if err != nil { + return err + } + defer sub.Unsubscribe() + enc := json.NewEncoder(out) + for { + select { + case v := <-ch: + if err := enc.Encode(v); err != nil { + return err + } + case err := <-sub.Err(): + return err + } + } +} diff --git a/contracts/chequebook/contract/chequebook.sol b/contracts/chequebook/contract/chequebook.sol index eefe6c063..845ba464b 100644 --- a/contracts/chequebook/contract/chequebook.sol +++ b/contracts/chequebook/contract/chequebook.sol @@ -27,10 +27,11 @@ contract chequebook is mortal { if(owner != ecrecover(hash, sig_v, sig_r, sig_s)) return; // Attempt sending the difference between the cumulative amount on the cheque // and the cumulative amount on the last cashed cheque to beneficiary. - if (amount - sent[beneficiary] >= this.balance) { + uint256 diff = amount - sent[beneficiary]; + if (diff <= this.balance) { // update the cumulative amount before sending sent[beneficiary] = amount; - if (!beneficiary.send(amount - sent[beneficiary])) { + if (!beneficiary.send(diff)) { // Upon failure to execute send, revert everything throw; } diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index b0d796a44..94b922c79 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -137,12 +137,17 @@ func (in *Interpreter) Run(snapshot int, contract *Contract, input []byte) (ret // to be uint256. Practically much less so feasible. pc = uint64(0) // program counter cost uint64 + // copies used by tracer + stackCopy = newstack() // stackCopy needed for Tracer since stack is mutated by 63/64 gas rule + pcCopy uint64 // needed for the deferred Tracer + gasCopy uint64 // for Tracer to log gas remaining before execution + logged bool // deferred Tracer should ignore already logged steps ) contract.Input = input defer func() { - if err != nil && in.cfg.Debug { - in.cfg.Tracer.CaptureState(in.evm, pc, op, contract.Gas, cost, mem, stack, contract, in.evm.depth, err) + if err != nil && !logged && in.cfg.Debug { + in.cfg.Tracer.CaptureState(in.evm, pcCopy, op, gasCopy, cost, mem, stackCopy, contract, in.evm.depth, err) } }() @@ -154,6 +159,16 @@ func (in *Interpreter) Run(snapshot int, contract *Contract, input []byte) (ret // Get the memory location of pc op = contract.GetOp(pc) + if in.cfg.Debug { + logged = false + pcCopy = uint64(pc) + gasCopy = uint64(contract.Gas) + stackCopy = newstack() + for _, val := range stack.data { + stackCopy.push(val) + } + } + // get the operation from the jump table matching the opcode operation := in.cfg.JumpTable[op] if err := in.enforceRestrictions(op, operation, stack); err != nil { @@ -199,7 +214,8 @@ func (in *Interpreter) Run(snapshot int, contract *Contract, input []byte) (ret } if in.cfg.Debug { - in.cfg.Tracer.CaptureState(in.evm, pc, op, contract.Gas, cost, mem, stack, contract, in.evm.depth, err) + in.cfg.Tracer.CaptureState(in.evm, pc, op, gasCopy, cost, mem, stackCopy, contract, in.evm.depth, err) + logged = true } // execute the operation diff --git a/eth/api.go b/eth/api.go index a5b6e7076..d64e4e6c7 100644 --- a/eth/api.go +++ b/eth/api.go @@ -51,7 +51,7 @@ type PublicEthereumAPI struct { e *Ethereum } -// NewPublicEthereumAPI creates a new Etheruem protocol API for full nodes. +// NewPublicEthereumAPI creates a new Ethereum protocol API for full nodes. func NewPublicEthereumAPI(e *Ethereum) *PublicEthereumAPI { return &PublicEthereumAPI{e} } @@ -205,7 +205,7 @@ func (api *PrivateMinerAPI) GetHashrate() uint64 { return uint64(api.e.miner.HashRate()) } -// PrivateAdminAPI is the collection of Etheruem full node-related APIs +// PrivateAdminAPI is the collection of Ethereum full node-related APIs // exposed over the private admin endpoint. type PrivateAdminAPI struct { eth *Ethereum @@ -298,7 +298,7 @@ func (api *PrivateAdminAPI) ImportChain(file string) (bool, error) { return true, nil } -// PublicDebugAPI is the collection of Etheruem full node APIs exposed +// PublicDebugAPI is the collection of Ethereum full node APIs exposed // over the public debugging endpoint. type PublicDebugAPI struct { eth *Ethereum @@ -335,7 +335,7 @@ func (api *PublicDebugAPI) DumpBlock(blockNr rpc.BlockNumber) (state.Dump, error return stateDb.RawDump(), nil } -// PrivateDebugAPI is the collection of Etheruem full node APIs exposed over +// PrivateDebugAPI is the collection of Ethereum full node APIs exposed over // the private debugging endpoint. type PrivateDebugAPI struct { config *params.ChainConfig diff --git a/eth/bind.go b/eth/bind.go index 0385db1f9..d09977dbc 100644 --- a/eth/bind.go +++ b/eth/bind.go @@ -43,7 +43,7 @@ type ContractBackend struct { } // NewContractBackend creates a new native contract backend using an existing -// Etheruem object. +// Ethereum object. func NewContractBackend(apiBackend ethapi.Backend) *ContractBackend { return &ContractBackend{ eapi: ethapi.NewPublicEthereumAPI(apiBackend), diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 0775749e7..8d1a6f746 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -54,7 +54,7 @@ type PublicEthereumAPI struct { b Backend } -// NewPublicEthereumAPI creates a new Etheruem protocol API. +// NewPublicEthereumAPI creates a new Ethereum protocol API. func NewPublicEthereumAPI(b Backend) *PublicEthereumAPI { return &PublicEthereumAPI{b} } @@ -448,7 +448,7 @@ type PublicBlockChainAPI struct { b Backend } -// NewPublicBlockChainAPI creates a new Etheruem blockchain API. +// NewPublicBlockChainAPI creates a new Ethereum blockchain API. func NewPublicBlockChainAPI(b Backend) *PublicBlockChainAPI { return &PublicBlockChainAPI{b} } @@ -1081,7 +1081,10 @@ func submitTransaction(ctx context.Context, b Backend, tx *types.Transaction) (c } if tx.To() == nil { signer := types.MakeSigner(b.ChainConfig(), b.CurrentBlock().Number()) - from, _ := types.Sender(signer, tx) + from, err := types.Sender(signer, tx) + if err != nil { + return common.Hash{}, err + } addr := crypto.CreateAddress(from, tx.Nonce()) log.Info("Submitted contract creation", "fullhash", tx.Hash().Hex(), "contract", addr.Hex()) } else { @@ -1129,29 +1132,12 @@ func (s *PublicTransactionPoolAPI) SendTransaction(ctx context.Context, args Sen // SendRawTransaction will add the signed transaction to the transaction pool. // The sender is responsible for signing the transaction and using the correct nonce. -func (s *PublicTransactionPoolAPI) SendRawTransaction(ctx context.Context, encodedTx hexutil.Bytes) (string, error) { +func (s *PublicTransactionPoolAPI) SendRawTransaction(ctx context.Context, encodedTx hexutil.Bytes) (common.Hash, error) { tx := new(types.Transaction) if err := rlp.DecodeBytes(encodedTx, tx); err != nil { - return "", err - } - - if err := s.b.SendTx(ctx, tx); err != nil { - return "", err - } - - signer := types.MakeSigner(s.b.ChainConfig(), s.b.CurrentBlock().Number()) - if tx.To() == nil { - from, err := types.Sender(signer, tx) - if err != nil { - return "", err - } - addr := crypto.CreateAddress(from, tx.Nonce()) - log.Info("Submitted contract creation", "fullhash", tx.Hash().Hex(), "contract", addr.Hex()) - } else { - log.Info("Submitted transaction", "fullhash", tx.Hash().Hex(), "recipient", tx.To()) + return common.Hash{}, err } - - return tx.Hash().Hex(), nil + return submitTransaction(ctx, s.b, tx) } // Sign calculates an ECDSA signature for: @@ -1275,7 +1261,7 @@ func (s *PublicTransactionPoolAPI) Resend(ctx context.Context, sendArgs SendTxAr return common.Hash{}, fmt.Errorf("Transaction %#x not found", matchTx.Hash()) } -// PublicDebugAPI is the collection of Etheruem APIs exposed over the public +// PublicDebugAPI is the collection of Ethereum APIs exposed over the public // debugging endpoint. type PublicDebugAPI struct { b Backend @@ -1318,7 +1304,7 @@ func (api *PublicDebugAPI) SeedHash(ctx context.Context, number uint64) (string, return fmt.Sprintf("0x%x", ethash.SeedHash(number)), nil } -// PrivateDebugAPI is the collection of Etheruem APIs exposed over the private +// PrivateDebugAPI is the collection of Ethereum APIs exposed over the private // debugging endpoint. type PrivateDebugAPI struct { b Backend diff --git a/miner/agent.go b/miner/agent.go index 855892a07..e3cebbd2e 100644 --- a/miner/agent.go +++ b/miner/agent.go @@ -53,7 +53,19 @@ func (self *CpuAgent) Work() chan<- *Work { return self.workCh } func (self *CpuAgent) SetReturnCh(ch chan<- *Result) { self.returnCh = ch } func (self *CpuAgent) Stop() { + if !atomic.CompareAndSwapInt32(&self.isMining, 1, 0) { + return // agent already stopped + } self.stop <- struct{}{} +done: + // Empty work channel + for { + select { + case <-self.workCh: + default: + break done + } + } } func (self *CpuAgent) Start() { @@ -85,17 +97,6 @@ out: break out } } - -done: - // Empty work channel - for { - select { - case <-self.workCh: - default: - break done - } - } - atomic.StoreInt32(&self.isMining, 0) } func (self *CpuAgent) mine(work *Work, stop <-chan struct{}) { diff --git a/node/api.go b/node/api.go index 570cb9d98..1b04b7093 100644 --- a/node/api.go +++ b/node/api.go @@ -17,6 +17,7 @@ package node import ( + "context" "fmt" "strings" "time" @@ -25,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rpc" "github.com/rcrowley/go-metrics" ) @@ -73,6 +75,44 @@ func (api *PrivateAdminAPI) RemovePeer(url string) (bool, error) { return true, nil } +// PeerEvents creates an RPC subscription which receives peer events from the +// node's p2p.Server +func (api *PrivateAdminAPI) PeerEvents(ctx context.Context) (*rpc.Subscription, error) { + // Make sure the server is running, fail otherwise + server := api.node.Server() + if server == nil { + return nil, ErrNodeStopped + } + + // Create the subscription + notifier, supported := rpc.NotifierFromContext(ctx) + if !supported { + return nil, rpc.ErrNotificationsUnsupported + } + rpcSub := notifier.CreateSubscription() + + go func() { + events := make(chan *p2p.PeerEvent) + sub := server.SubscribeEvents(events) + defer sub.Unsubscribe() + + for { + select { + case event := <-events: + notifier.Notify(rpcSub.ID, event) + case <-sub.Err(): + return + case <-rpcSub.Err(): + return + case <-notifier.Closed(): + return + } + } + }() + + return rpcSub, nil +} + // StartRPC starts the HTTP RPC API server. func (api *PrivateAdminAPI) StartRPC(host *string, port *int, cors *string, apis *string) (bool, error) { api.node.lock.Lock() @@ -163,7 +203,7 @@ func (api *PrivateAdminAPI) StartWS(host *string, port *int, allowedOrigins *str } } - if err := api.node.startWS(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, origins); err != nil { + if err := api.node.startWS(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, origins, api.node.config.WSExposeAll); err != nil { return false, err } return true, nil diff --git a/node/config.go b/node/config.go index b9b5e5b92..be9e21b4f 100644 --- a/node/config.go +++ b/node/config.go @@ -128,6 +128,13 @@ type Config struct { // If the module list is empty, all RPC API endpoints designated public will be // exposed. WSModules []string `toml:",omitempty"` + + // WSExposeAll exposes all API modules via the WebSocket RPC interface rather + // than just the public ones. + // + // *WARNING* Only set this if the node is running in a trusted network, exposing + // private APIs to untrusted users is a major security risk. + WSExposeAll bool `toml:",omitempty"` } // IPCEndpoint resolves an IPC endpoint based on a configured value, taking into diff --git a/node/node.go b/node/node.go index 86cfb29ba..6f189d8fe 100644 --- a/node/node.go +++ b/node/node.go @@ -261,7 +261,7 @@ func (n *Node) startRPC(services map[reflect.Type]Service) error { n.stopInProc() return err } - if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins); err != nil { + if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil { n.stopHTTP() n.stopIPC() n.stopInProc() @@ -412,7 +412,7 @@ func (n *Node) stopHTTP() { } // startWS initializes and starts the websocket RPC endpoint. -func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrigins []string) error { +func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrigins []string, exposeAll bool) error { // Short circuit if the WS endpoint isn't being exposed if endpoint == "" { return nil @@ -425,7 +425,7 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig // Register all the APIs exposed by the services handler := rpc.NewServer() for _, api := range apis { - if whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { + if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { if err := handler.RegisterName(api.Namespace, api.Service); err != nil { return err } @@ -441,7 +441,7 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig return err } go rpc.NewWSServer(wsOrigins, handler).Serve(listener) - log.Info(fmt.Sprintf("WebSocket endpoint opened: ws://%s", endpoint)) + log.Info(fmt.Sprintf("WebSocket endpoint opened: ws://%s", listener.Addr())) // All listeners booted successfully n.wsEndpoint = endpoint @@ -556,6 +556,17 @@ func (n *Node) Attach() (*rpc.Client, error) { return rpc.DialInProc(n.inprocHandler), nil } +// RPCHandler returns the in-process RPC request handler. +func (n *Node) RPCHandler() (*rpc.Server, error) { + n.lock.RLock() + defer n.lock.RUnlock() + + if n.inprocHandler == nil { + return nil, ErrNodeStopped + } + return n.inprocHandler, nil +} + // Server retrieves the currently running P2P network layer. This method is meant // only to inspect fields of the currently running server, life cycle management // should be left to this Node entity. diff --git a/p2p/dial.go b/p2p/dial.go index b77971396..2d9e3a0ed 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -47,6 +47,24 @@ const ( maxResolveDelay = time.Hour ) +// NodeDialer is used to connect to nodes in the network, typically by using +// an underlying net.Dialer but also using net.Pipe in tests +type NodeDialer interface { + Dial(*discover.Node) (net.Conn, error) +} + +// TCPDialer implements the NodeDialer interface by using a net.Dialer to +// create TCP connections to nodes in the network +type TCPDialer struct { + *net.Dialer +} + +// Dial creates a TCP connection to the node +func (t TCPDialer) Dial(dest *discover.Node) (net.Conn, error) { + addr := &net.TCPAddr{IP: dest.IP, Port: int(dest.TCP)} + return t.Dialer.Dial("tcp", addr.String()) +} + // dialstate schedules dials and discovery lookups. // it get's a chance to compute new tasks on every iteration // of the main loop in Server.run. @@ -318,14 +336,13 @@ func (t *dialTask) resolve(srv *Server) bool { // dial performs the actual connection attempt. func (t *dialTask) dial(srv *Server, dest *discover.Node) bool { - addr := &net.TCPAddr{IP: dest.IP, Port: int(dest.TCP)} - fd, err := srv.Dialer.Dial("tcp", addr.String()) + fd, err := srv.Dialer.Dial(dest) if err != nil { log.Trace("Dial error", "task", t, "err", err) return false } mfd := newMeteredConn(fd, false) - srv.setupConn(mfd, t.flags, dest) + srv.SetupConn(mfd, t.flags, dest) return true } diff --git a/p2p/dial_test.go b/p2p/dial_test.go index 08e863bae..ad18ef9ab 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -597,7 +597,7 @@ func TestDialResolve(t *testing.T) { } // Now run the task, it should resolve the ID once. - config := Config{Dialer: &net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}} + config := Config{Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}} srv := &Server{ntab: table, Config: config} tasks[0].Do(srv) if !reflect.DeepEqual(table.resolveCalls, []discover.NodeID{dest.ID}) { diff --git a/p2p/discover/node.go b/p2p/discover/node.go index d9cbd9448..fc928a91a 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -225,6 +225,11 @@ func (n *Node) UnmarshalText(text []byte) error { // The node identifier is a marshaled elliptic curve public key. type NodeID [NodeIDBits / 8]byte +// Bytes returns a byte slice representation of the NodeID +func (n NodeID) Bytes() []byte { + return n[:] +} + // NodeID prints as a long hexadecimal number. func (n NodeID) String() string { return fmt.Sprintf("%x", n[:]) @@ -240,6 +245,41 @@ func (n NodeID) TerminalString() string { return hex.EncodeToString(n[:8]) } +// MarshalText implements the encoding.TextMarshaler interface. +func (n NodeID) MarshalText() ([]byte, error) { + return []byte(hex.EncodeToString(n[:])), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +func (n *NodeID) UnmarshalText(text []byte) error { + id, err := HexID(string(text)) + if err != nil { + return err + } + *n = id + return nil +} + +// BytesID converts a byte slice to a NodeID +func BytesID(b []byte) (NodeID, error) { + var id NodeID + if len(b) != len(id) { + return id, fmt.Errorf("wrong length, want %d bytes", len(id)) + } + copy(id[:], b) + return id, nil +} + +// MustBytesID converts a byte slice to a NodeID. +// It panics if the byte slice is not a valid NodeID. +func MustBytesID(b []byte) NodeID { + id, err := BytesID(b) + if err != nil { + panic(err) + } + return id +} + // HexID converts a hex string to a NodeID. // The string may be prefixed with 0x. func HexID(in string) (NodeID, error) { diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go index 3d1662d0b..ed8db4dc6 100644 --- a/p2p/discover/node_test.go +++ b/p2p/discover/node_test.go @@ -17,6 +17,7 @@ package discover import ( + "bytes" "fmt" "math/big" "math/rand" @@ -192,6 +193,35 @@ func TestHexID(t *testing.T) { } } +func TestNodeID_textEncoding(t *testing.T) { + ref := NodeID{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x30, + 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x40, + 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x50, + 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x60, + 0x61, 0x62, 0x63, 0x64, + } + hex := "01020304050607080910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364" + + text, err := ref.MarshalText() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(text, []byte(hex)) { + t.Fatalf("text encoding did not match\nexpected: %s\ngot: %s", hex, text) + } + + id := new(NodeID) + if err := id.UnmarshalText(text); err != nil { + t.Fatal(err) + } + if *id != ref { + t.Fatalf("text decoding did not match\nexpected: %s\ngot: %s", ref, id) + } +} + func TestNodeID_recover(t *testing.T) { prv := newkey() hash := make([]byte, 32) diff --git a/p2p/message.go b/p2p/message.go index 1292d2121..5690494bf 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -27,6 +27,8 @@ import ( "sync/atomic" "time" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/rlp" ) @@ -271,3 +273,67 @@ func ExpectMsg(r MsgReader, code uint64, content interface{}) error { } return nil } + +// msgEventer wraps a MsgReadWriter and sends events whenever a message is sent +// or received +type msgEventer struct { + MsgReadWriter + + feed *event.Feed + peerID discover.NodeID + Protocol string +} + +// newMsgEventer returns a msgEventer which sends message events to the given +// feed +func newMsgEventer(rw MsgReadWriter, feed *event.Feed, peerID discover.NodeID, proto string) *msgEventer { + return &msgEventer{ + MsgReadWriter: rw, + feed: feed, + peerID: peerID, + Protocol: proto, + } +} + +// ReadMsg reads a message from the underlying MsgReadWriter and emits a +// "message received" event +func (self *msgEventer) ReadMsg() (Msg, error) { + msg, err := self.MsgReadWriter.ReadMsg() + if err != nil { + return msg, err + } + self.feed.Send(&PeerEvent{ + Type: PeerEventTypeMsgRecv, + Peer: self.peerID, + Protocol: self.Protocol, + MsgCode: &msg.Code, + MsgSize: &msg.Size, + }) + return msg, nil +} + +// WriteMsg writes a message to the underlying MsgReadWriter and emits a +// "message sent" event +func (self *msgEventer) WriteMsg(msg Msg) error { + err := self.MsgReadWriter.WriteMsg(msg) + if err != nil { + return err + } + self.feed.Send(&PeerEvent{ + Type: PeerEventTypeMsgSend, + Peer: self.peerID, + Protocol: self.Protocol, + MsgCode: &msg.Code, + MsgSize: &msg.Size, + }) + return nil +} + +// Close closes the underlying MsgReadWriter if it implements the io.Closer +// interface +func (self *msgEventer) Close() error { + if v, ok := self.MsgReadWriter.(io.Closer); ok { + return v.Close() + } + return nil +} diff --git a/p2p/peer.go b/p2p/peer.go index fb4b39e95..1d2b726e8 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -25,16 +25,19 @@ import ( "time" "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/rlp" ) const ( - baseProtocolVersion = 4 + baseProtocolVersion = 5 baseProtocolLength = uint64(16) baseProtocolMaxMsgSize = 2 * 1024 + snappyProtocolVersion = 5 + pingInterval = 15 * time.Second ) @@ -60,6 +63,38 @@ type protoHandshake struct { Rest []rlp.RawValue `rlp:"tail"` } +// PeerEventType is the type of peer events emitted by a p2p.Server +type PeerEventType string + +const ( + // PeerEventTypeAdd is the type of event emitted when a peer is added + // to a p2p.Server + PeerEventTypeAdd PeerEventType = "add" + + // PeerEventTypeDrop is the type of event emitted when a peer is + // dropped from a p2p.Server + PeerEventTypeDrop PeerEventType = "drop" + + // PeerEventTypeMsgSend is the type of event emitted when a + // message is successfully sent to a peer + PeerEventTypeMsgSend PeerEventType = "msgsend" + + // PeerEventTypeMsgRecv is the type of event emitted when a + // message is received from a peer + PeerEventTypeMsgRecv PeerEventType = "msgrecv" +) + +// PeerEvent is an event emitted when peers are either added or dropped from +// a p2p.Server or when a message is sent or received on a peer connection +type PeerEvent struct { + Type PeerEventType `json:"type"` + Peer discover.NodeID `json:"peer"` + Error string `json:"error,omitempty"` + Protocol string `json:"protocol,omitempty"` + MsgCode *uint64 `json:"msg_code,omitempty"` + MsgSize *uint32 `json:"msg_size,omitempty"` +} + // Peer represents a connected remote node. type Peer struct { rw *conn @@ -71,6 +106,9 @@ type Peer struct { protoErr chan error closed chan struct{} disc chan DiscReason + + // events receives message send / receive events if set + events *event.Feed } // NewPeer returns a peer for testing purposes. @@ -297,9 +335,13 @@ func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error) proto.closed = p.closed proto.wstart = writeStart proto.werr = writeErr + var rw MsgReadWriter = proto + if p.events != nil { + rw = newMsgEventer(rw, p.events, p.ID(), proto.Name) + } p.log.Trace(fmt.Sprintf("Starting protocol %s/%d", proto.Name, proto.Version)) go func() { - err := proto.Run(p, proto) + err := proto.Run(p, rw) if err == nil { p.log.Trace(fmt.Sprintf("Protocol %s/%d returned", proto.Name, proto.Version)) err = errProtocolReturned diff --git a/p2p/rlpx.go b/p2p/rlpx.go index b2775cacd..24037ecc1 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -29,6 +29,7 @@ import ( "fmt" "hash" "io" + "io/ioutil" mrand "math/rand" "net" "sync" @@ -40,6 +41,7 @@ import ( "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/rlp" + "github.com/golang/snappy" ) const ( @@ -68,6 +70,10 @@ const ( discWriteTimeout = 1 * time.Second ) +// errPlainMessageTooLarge is returned if a decompressed message length exceeds +// the allowed 24 bits (i.e. length >= 16MB). +var errPlainMessageTooLarge = errors.New("message length >= 16MB") + // rlpx is the transport protocol used by actual (non-test) connections. // It wraps the frame encoder with locks and read/write deadlines. type rlpx struct { @@ -127,6 +133,9 @@ func (t *rlpx) doProtoHandshake(our *protoHandshake) (their *protoHandshake, err if err := <-werr; err != nil { return nil, fmt.Errorf("write error: %v", err) } + // If the protocol version supports Snappy encoding, upgrade immediately + t.rw.snappy = their.Version >= snappyProtocolVersion + return their, nil } @@ -556,6 +565,8 @@ type rlpxFrameRW struct { macCipher cipher.Block egressMAC hash.Hash ingressMAC hash.Hash + + snappy bool } func newRLPXFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW { @@ -583,6 +594,17 @@ func newRLPXFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW { func (rw *rlpxFrameRW) WriteMsg(msg Msg) error { ptype, _ := rlp.EncodeToBytes(msg.Code) + // if snappy is enabled, compress message now + if rw.snappy { + if msg.Size > maxUint24 { + return errPlainMessageTooLarge + } + payload, _ := ioutil.ReadAll(msg.Payload) + payload = snappy.Encode(nil, payload) + + msg.Payload = bytes.NewReader(payload) + msg.Size = uint32(len(payload)) + } // write header headbuf := make([]byte, 32) fsize := uint32(len(ptype)) + msg.Size @@ -668,6 +690,26 @@ func (rw *rlpxFrameRW) ReadMsg() (msg Msg, err error) { } msg.Size = uint32(content.Len()) msg.Payload = content + + // if snappy is enabled, verify and decompress message + if rw.snappy { + payload, err := ioutil.ReadAll(msg.Payload) + if err != nil { + return msg, err + } + size, err := snappy.DecodedLen(payload) + if err != nil { + return msg, err + } + if size > int(maxUint24) { + return msg, errPlainMessageTooLarge + } + payload, err = snappy.Decode(nil, payload) + if err != nil { + return msg, err + } + msg.Size, msg.Payload = uint32(size), bytes.NewReader(payload) + } return msg, nil } diff --git a/p2p/server.go b/p2p/server.go index d7909d53a..d1d578401 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -27,6 +27,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discv5" @@ -130,10 +131,14 @@ type Config struct { // If Dialer is set to a non-nil value, the given Dialer // is used to dial outbound peer connections. - Dialer *net.Dialer `toml:"-"` + Dialer NodeDialer `toml:"-"` // If NoDial is true, the server will not dial any peers. NoDial bool `toml:",omitempty"` + + // If EnableMsgEvents is set then the server will emit PeerEvents + // whenever a message is sent to or received from a peer + EnableMsgEvents bool } // Server manages all peer connections. @@ -166,6 +171,7 @@ type Server struct { addpeer chan *conn delpeer chan peerDrop loopWG sync.WaitGroup // loop, listenLoop + peerFeed event.Feed } type peerOpFunc func(map[discover.NodeID]*Peer) @@ -191,7 +197,7 @@ type conn struct { fd net.Conn transport flags connFlag - cont chan error // The run loop uses cont to signal errors to setupConn. + cont chan error // The run loop uses cont to signal errors to SetupConn. id discover.NodeID // valid after the encryption handshake caps []Cap // valid after the protocol handshake name string // valid after the protocol handshake @@ -291,6 +297,11 @@ func (srv *Server) RemovePeer(node *discover.Node) { } } +// SubscribePeers subscribes the given channel to peer events +func (srv *Server) SubscribeEvents(ch chan *PeerEvent) event.Subscription { + return srv.peerFeed.Subscribe(ch) +} + // Self returns the local node's endpoint information. func (srv *Server) Self() *discover.Node { srv.lock.Lock() @@ -358,7 +369,7 @@ func (srv *Server) Start() (err error) { srv.newTransport = newRLPX } if srv.Dialer == nil { - srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} + srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}} } srv.quit = make(chan struct{}) srv.addpeer = make(chan *conn) @@ -536,7 +547,11 @@ running: c.flags |= trustedConn } // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. - c.cont <- srv.encHandshakeChecks(peers, c) + select { + case c.cont <- srv.encHandshakeChecks(peers, c): + case <-srv.quit: + break running + } case c := <-srv.addpeer: // At this point the connection is past the protocol handshake. // Its capabilities are known and the remote identity is verified. @@ -544,6 +559,11 @@ running: if err == nil { // The handshakes are done and it passed all checks. p := newPeer(c, srv.Protocols) + // If message events are enabled, pass the peerFeed + // to the peer + if srv.EnableMsgEvents { + p.events = &srv.peerFeed + } name := truncateName(c.name) log.Debug("Adding p2p peer", "id", c.id, "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1) peers[c.id] = p @@ -552,7 +572,11 @@ running: // The dialer logic relies on the assumption that // dial tasks complete after the peer has been added or // discarded. Unblock the task last. - c.cont <- err + select { + case c.cont <- err: + case <-srv.quit: + break running + } case pd := <-srv.delpeer: // A peer disconnected. d := common.PrettyDuration(mclock.Now() - pd.created) @@ -665,16 +689,16 @@ func (srv *Server) listenLoop() { // Spawn the handler. It will give the slot back when the connection // has been established. go func() { - srv.setupConn(fd, inboundConn, nil) + srv.SetupConn(fd, inboundConn, nil) slots <- struct{}{} }() } } -// setupConn runs the handshakes and attempts to add the connection +// SetupConn runs the handshakes and attempts to add the connection // as a peer. It returns when the connection has been added as a peer // or the handshakes have failed. -func (srv *Server) setupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) { +func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) { // Prevent leftover pending conns from entering the handshake. srv.lock.Lock() running := srv.running @@ -755,7 +779,23 @@ func (srv *Server) runPeer(p *Peer) { if srv.newPeerHook != nil { srv.newPeerHook(p) } + + // broadcast peer add + srv.peerFeed.Send(&PeerEvent{ + Type: PeerEventTypeAdd, + Peer: p.ID(), + }) + + // run the protocol remoteRequested, err := p.run() + + // broadcast peer drop + srv.peerFeed.Send(&PeerEvent{ + Type: PeerEventTypeDrop, + Peer: p.ID(), + Error: err.Error(), + }) + // Note: run waits for existing peers to be sent on srv.delpeer // before returning, so this send should not select on srv.quit. srv.delpeer <- peerDrop{p, err, remoteRequested} diff --git a/p2p/server_test.go b/p2p/server_test.go index 971faf002..11dd83e5d 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -435,7 +435,7 @@ func TestServerSetupConn(t *testing.T) { } } p1, _ := net.Pipe() - srv.setupConn(p1, test.flags, test.dialDest) + srv.SetupConn(p1, test.flags, test.dialDest) if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) } diff --git a/p2p/simulations/README.md b/p2p/simulations/README.md new file mode 100644 index 000000000..d1f8649ea --- /dev/null +++ b/p2p/simulations/README.md @@ -0,0 +1,181 @@ +# devp2p Simulations + +The `p2p/simulations` package implements a simulation framework which supports +creating a collection of devp2p nodes, connecting them together to form a +simulation network, performing simulation actions in that network and then +extracting useful information. + +## Nodes + +Each node in a simulation network runs multiple services by wrapping a collection +of objects which implement the `node.Service` interface meaning they: + +* can be started and stopped +* run p2p protocols +* expose RPC APIs + +This means that any object which implements the `node.Service` interface can be +used to run a node in the simulation. + +## Services + +Before running a simulation, a set of service initializers must be registered +which can then be used to run nodes in the network. + +A service initializer is a function with the following signature: + +```go +func(ctx *adapters.ServiceContext) (node.Service, error) +``` + +These initializers should be registered by calling the `adapters.RegisterServices` +function in an `init()` hook: + +```go +func init() { + adapters.RegisterServices(adapters.Services{ + "service1": initService1, + "service2": initService2, + }) +} +``` + +## Node Adapters + +The simulation framework includes multiple "node adapters" which are +responsible for creating an environment in which a node runs. + +### SimAdapter + +The `SimAdapter` runs nodes in-memory, connecting them using an in-memory, +synchronous `net.Pipe` and connecting to their RPC server using an in-memory +`rpc.Client`. + +### ExecAdapter + +The `ExecAdapter` runs nodes as child processes of the running simulation. + +It does this by executing the binary which is running the simulation but +setting `argv[0]` (i.e. the program name) to `p2p-node` which is then +detected by an init hook in the child process which runs the `node.Service` +using the devp2p node stack rather than executing `main()`. + +The nodes listen for devp2p connections and WebSocket RPC clients on random +localhost ports. + +### DockerAdapter + +The `DockerAdapter` is similar to the `ExecAdapter` but executes `docker run` +to run the node in a Docker container using a Docker image containing the +simulation binary at `/bin/p2p-node`. + +The Docker image is built using `docker build` when the adapter is initialised, +meaning no prior setup is necessary other than having a working Docker client. + +Each node listens on the external IP of the container and the default p2p and +RPC ports (`30303` and `8546` respectively). + +## Network + +A simulation network is created with an ID and default service (which is used +if a node is created without an explicit service), exposes methods for +creating, starting, stopping, connecting and disconnecting nodes, and emits +events when certain actions occur. + +### Events + +A simulation network emits the following events: + +* node event - when nodes are created / started / stopped +* connection event - when nodes are connected / disconnected +* message event - when a protocol message is sent between two nodes + +The events have a "control" flag which when set indicates that the event is the +outcome of a controlled simulation action (e.g. creating a node or explicitly +connecting two nodes together). + +This is in contrast to a non-control event, otherwise called a "live" event, +which is the outcome of something happening in the network as a result of a +control event (e.g. a node actually started up or a connection was actually +established between two nodes). + +Live events are detected by the simulation network by subscribing to node peer +events via RPC when the nodes start up. + +## Testing Framework + +The `Simulation` type can be used in tests to perform actions in a simulation +network and then wait for expectations to be met. + +With a running simulation network, the `Simulation.Run` method can be called +with a `Step` which has the following fields: + +* `Action` - a function which performs some action in the network + +* `Expect` - an expectation function which returns whether or not a + given node meets the expectation + +* `Trigger` - a channel which receives node IDs which then trigger a check + of the expectation function to be performed against that node + +As a concrete example, consider a simulated network of Ethereum nodes. An +`Action` could be the sending of a transaction, `Expect` it being included in +a block, and `Trigger` a check for every block that is mined. + +On return, the `Simulation.Run` method returns a `StepResult` which can be used +to determine if all nodes met the expectation, how long it took them to meet +the expectation and what network events were emitted during the step run. + +## HTTP API + +The simulation framework includes a HTTP API which can be used to control the +simulation. + +The API is initialised with a particular node adapter and has the following +endpoints: + +``` +GET / Get network information +POST /start Start all nodes in the network +POST /stop Stop all nodes in the network +GET /events Stream network events +GET /snapshot Take a network snapshot +POST /snapshot Load a network snapshot +POST /nodes Create a node +GET /nodes Get all nodes in the network +GET /nodes/:nodeid Get node information +POST /nodes/:nodeid/start Start a node +POST /nodes/:nodeid/stop Stop a node +POST /nodes/:nodeid/conn/:peerid Connect two nodes +DELETE /nodes/:nodeid/conn/:peerid Disconnect two nodes +GET /nodes/:nodeid/rpc Make RPC requests to a node via WebSocket +``` + +For convenience, `nodeid` in the URL can be the name of a node rather than its +ID. + +## Command line client + +`p2psim` is a command line client for the HTTP API, located in +`cmd/p2psim`. + +It provides the following commands: + +``` +p2psim show +p2psim events [--current] [--filter=FILTER] +p2psim snapshot +p2psim load +p2psim node create [--name=NAME] [--services=SERVICES] [--key=KEY] +p2psim node list +p2psim node show <node> +p2psim node start <node> +p2psim node stop <node> +p2psim node connect <node> <peer> +p2psim node disconnect <node> <peer> +p2psim node rpc <node> <method> [<args>] [--subscribe] +``` + +## Example + +See [p2p/simulations/examples/README.md](examples/README.md). diff --git a/p2p/simulations/adapters/docker.go b/p2p/simulations/adapters/docker.go new file mode 100644 index 000000000..022314b3d --- /dev/null +++ b/p2p/simulations/adapters/docker.go @@ -0,0 +1,182 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package adapters + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + + "github.com/docker/docker/pkg/reexec" + "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/p2p/discover" +) + +// DockerAdapter is a NodeAdapter which runs simulation nodes inside Docker +// containers. +// +// A Docker image is built which contains the current binary at /bin/p2p-node +// which when executed runs the underlying service (see the description +// of the execP2PNode function for more details) +type DockerAdapter struct { + ExecAdapter +} + +// NewDockerAdapter builds the p2p-node Docker image containing the current +// binary and returns a DockerAdapter +func NewDockerAdapter() (*DockerAdapter, error) { + // Since Docker containers run on Linux and this adapter runs the + // current binary in the container, it must be compiled for Linux. + // + // It is reasonable to require this because the caller can just + // compile the current binary in a Docker container. + if runtime.GOOS != "linux" { + return nil, errors.New("DockerAdapter can only be used on Linux as it uses the current binary (which must be a Linux binary)") + } + + if err := buildDockerImage(); err != nil { + return nil, err + } + + return &DockerAdapter{ + ExecAdapter{ + nodes: make(map[discover.NodeID]*ExecNode), + }, + }, nil +} + +// Name returns the name of the adapter for logging purposes +func (d *DockerAdapter) Name() string { + return "docker-adapter" +} + +// NewNode returns a new DockerNode using the given config +func (d *DockerAdapter) NewNode(config *NodeConfig) (Node, error) { + if len(config.Services) == 0 { + return nil, errors.New("node must have at least one service") + } + for _, service := range config.Services { + if _, exists := serviceFuncs[service]; !exists { + return nil, fmt.Errorf("unknown node service %q", service) + } + } + + // generate the config + conf := &execNodeConfig{ + Stack: node.DefaultConfig, + Node: config, + } + conf.Stack.DataDir = "/data" + conf.Stack.WSHost = "0.0.0.0" + conf.Stack.WSOrigins = []string{"*"} + conf.Stack.WSExposeAll = true + conf.Stack.P2P.EnableMsgEvents = false + conf.Stack.P2P.NoDiscovery = true + conf.Stack.P2P.NAT = nil + conf.Stack.NoUSB = true + + node := &DockerNode{ + ExecNode: ExecNode{ + ID: config.ID, + Config: conf, + adapter: &d.ExecAdapter, + }, + } + node.newCmd = node.dockerCommand + d.ExecAdapter.nodes[node.ID] = &node.ExecNode + return node, nil +} + +// DockerNode wraps an ExecNode but exec's the current binary in a docker +// container rather than locally +type DockerNode struct { + ExecNode +} + +// dockerCommand returns a command which exec's the binary in a Docker +// container. +// +// It uses a shell so that we can pass the _P2P_NODE_CONFIG environment +// variable to the container using the --env flag. +func (n *DockerNode) dockerCommand() *exec.Cmd { + return exec.Command( + "sh", "-c", + fmt.Sprintf( + `exec docker run --interactive --env _P2P_NODE_CONFIG="${_P2P_NODE_CONFIG}" %s p2p-node %s %s`, + dockerImage, strings.Join(n.Config.Node.Services, ","), n.ID.String(), + ), + ) +} + +// dockerImage is the name of the Docker image which gets built to run the +// simulation node +const dockerImage = "p2p-node" + +// buildDockerImage builds the Docker image which is used to run the simulation +// node in a Docker container. +// +// It adds the current binary as "p2p-node" so that it runs execP2PNode +// when executed. +func buildDockerImage() error { + // create a directory to use as the build context + dir, err := ioutil.TempDir("", "p2p-docker") + if err != nil { + return err + } + defer os.RemoveAll(dir) + + // copy the current binary into the build context + bin, err := os.Open(reexec.Self()) + if err != nil { + return err + } + defer bin.Close() + dst, err := os.OpenFile(filepath.Join(dir, "self.bin"), os.O_WRONLY|os.O_CREATE, 0755) + if err != nil { + return err + } + defer dst.Close() + if _, err := io.Copy(dst, bin); err != nil { + return err + } + + // create the Dockerfile + dockerfile := []byte(` +FROM ubuntu:16.04 +RUN mkdir /data +ADD self.bin /bin/p2p-node + `) + if err := ioutil.WriteFile(filepath.Join(dir, "Dockerfile"), dockerfile, 0644); err != nil { + return err + } + + // run 'docker build' + cmd := exec.Command("docker", "build", "-t", dockerImage, dir) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("error building docker image: %s", err) + } + + return nil +} diff --git a/p2p/simulations/adapters/exec.go b/p2p/simulations/adapters/exec.go new file mode 100644 index 000000000..bdb92cc1d --- /dev/null +++ b/p2p/simulations/adapters/exec.go @@ -0,0 +1,504 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package adapters + +import ( + "bufio" + "context" + "crypto/ecdsa" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "os" + "os/exec" + "os/signal" + "path/filepath" + "regexp" + "strings" + "sync" + "syscall" + "time" + + "github.com/docker/docker/pkg/reexec" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rpc" + "golang.org/x/net/websocket" +) + +// ExecAdapter is a NodeAdapter which runs simulation nodes by executing the +// current binary as a child process. +// +// An init hook is used so that the child process executes the node services +// (rather than whataver the main() function would normally do), see the +// execP2PNode function for more information. +type ExecAdapter struct { + // BaseDir is the directory under which the data directories for each + // simulation node are created. + BaseDir string + + nodes map[discover.NodeID]*ExecNode +} + +// NewExecAdapter returns an ExecAdapter which stores node data in +// subdirectories of the given base directory +func NewExecAdapter(baseDir string) *ExecAdapter { + return &ExecAdapter{ + BaseDir: baseDir, + nodes: make(map[discover.NodeID]*ExecNode), + } +} + +// Name returns the name of the adapter for logging purposes +func (e *ExecAdapter) Name() string { + return "exec-adapter" +} + +// NewNode returns a new ExecNode using the given config +func (e *ExecAdapter) NewNode(config *NodeConfig) (Node, error) { + if len(config.Services) == 0 { + return nil, errors.New("node must have at least one service") + } + for _, service := range config.Services { + if _, exists := serviceFuncs[service]; !exists { + return nil, fmt.Errorf("unknown node service %q", service) + } + } + + // create the node directory using the first 12 characters of the ID + // as Unix socket paths cannot be longer than 256 characters + dir := filepath.Join(e.BaseDir, config.ID.String()[:12]) + if err := os.Mkdir(dir, 0755); err != nil { + return nil, fmt.Errorf("error creating node directory: %s", err) + } + + // generate the config + conf := &execNodeConfig{ + Stack: node.DefaultConfig, + Node: config, + } + conf.Stack.DataDir = filepath.Join(dir, "data") + conf.Stack.WSHost = "127.0.0.1" + conf.Stack.WSPort = 0 + conf.Stack.WSOrigins = []string{"*"} + conf.Stack.WSExposeAll = true + conf.Stack.P2P.EnableMsgEvents = false + conf.Stack.P2P.NoDiscovery = true + conf.Stack.P2P.NAT = nil + conf.Stack.NoUSB = true + + // listen on a random localhost port (we'll get the actual port after + // starting the node through the RPC admin.nodeInfo method) + conf.Stack.P2P.ListenAddr = "127.0.0.1:0" + + node := &ExecNode{ + ID: config.ID, + Dir: dir, + Config: conf, + adapter: e, + } + node.newCmd = node.execCommand + e.nodes[node.ID] = node + return node, nil +} + +// ExecNode starts a simulation node by exec'ing the current binary and +// running the configured services +type ExecNode struct { + ID discover.NodeID + Dir string + Config *execNodeConfig + Cmd *exec.Cmd + Info *p2p.NodeInfo + + adapter *ExecAdapter + client *rpc.Client + wsAddr string + newCmd func() *exec.Cmd + key *ecdsa.PrivateKey +} + +// Addr returns the node's enode URL +func (n *ExecNode) Addr() []byte { + if n.Info == nil { + return nil + } + return []byte(n.Info.Enode) +} + +// Client returns an rpc.Client which can be used to communicate with the +// underlying services (it is set once the node has started) +func (n *ExecNode) Client() (*rpc.Client, error) { + return n.client, nil +} + +// wsAddrPattern is a regex used to read the WebSocket address from the node's +// log +var wsAddrPattern = regexp.MustCompile(`ws://[\d.:]+`) + +// Start exec's the node passing the ID and service as command line arguments +// and the node config encoded as JSON in the _P2P_NODE_CONFIG environment +// variable +func (n *ExecNode) Start(snapshots map[string][]byte) (err error) { + if n.Cmd != nil { + return errors.New("already started") + } + defer func() { + if err != nil { + log.Error("node failed to start", "err", err) + n.Stop() + } + }() + + // encode a copy of the config containing the snapshot + confCopy := *n.Config + confCopy.Snapshots = snapshots + confCopy.PeerAddrs = make(map[string]string) + for id, node := range n.adapter.nodes { + confCopy.PeerAddrs[id.String()] = node.wsAddr + } + confData, err := json.Marshal(confCopy) + if err != nil { + return fmt.Errorf("error generating node config: %s", err) + } + + // use a pipe for stderr so we can both copy the node's stderr to + // os.Stderr and read the WebSocket address from the logs + stderrR, stderrW := io.Pipe() + stderr := io.MultiWriter(os.Stderr, stderrW) + + // start the node + cmd := n.newCmd() + cmd.Stdout = os.Stdout + cmd.Stderr = stderr + cmd.Env = append(os.Environ(), fmt.Sprintf("_P2P_NODE_CONFIG=%s", confData)) + if err := cmd.Start(); err != nil { + return fmt.Errorf("error starting node: %s", err) + } + n.Cmd = cmd + + // read the WebSocket address from the stderr logs + var wsAddr string + wsAddrC := make(chan string) + go func() { + s := bufio.NewScanner(stderrR) + for s.Scan() { + if strings.Contains(s.Text(), "WebSocket endpoint opened:") { + wsAddrC <- wsAddrPattern.FindString(s.Text()) + } + } + }() + select { + case wsAddr = <-wsAddrC: + if wsAddr == "" { + return errors.New("failed to read WebSocket address from stderr") + } + case <-time.After(10 * time.Second): + return errors.New("timed out waiting for WebSocket address on stderr") + } + + // create the RPC client and load the node info + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client, err := rpc.DialWebsocket(ctx, wsAddr, "") + if err != nil { + return fmt.Errorf("error dialing rpc websocket: %s", err) + } + var info p2p.NodeInfo + if err := client.CallContext(ctx, &info, "admin_nodeInfo"); err != nil { + return fmt.Errorf("error getting node info: %s", err) + } + n.client = client + n.wsAddr = wsAddr + n.Info = &info + + return nil +} + +// execCommand returns a command which runs the node locally by exec'ing +// the current binary but setting argv[0] to "p2p-node" so that the child +// runs execP2PNode +func (n *ExecNode) execCommand() *exec.Cmd { + return &exec.Cmd{ + Path: reexec.Self(), + Args: []string{"p2p-node", strings.Join(n.Config.Node.Services, ","), n.ID.String()}, + } +} + +// Stop stops the node by first sending SIGTERM and then SIGKILL if the node +// doesn't stop within 5s +func (n *ExecNode) Stop() error { + if n.Cmd == nil { + return nil + } + defer func() { + n.Cmd = nil + }() + + if n.client != nil { + n.client.Close() + n.client = nil + n.wsAddr = "" + n.Info = nil + } + + if err := n.Cmd.Process.Signal(syscall.SIGTERM); err != nil { + return n.Cmd.Process.Kill() + } + waitErr := make(chan error) + go func() { + waitErr <- n.Cmd.Wait() + }() + select { + case err := <-waitErr: + return err + case <-time.After(5 * time.Second): + return n.Cmd.Process.Kill() + } +} + +// NodeInfo returns information about the node +func (n *ExecNode) NodeInfo() *p2p.NodeInfo { + info := &p2p.NodeInfo{ + ID: n.ID.String(), + } + if n.client != nil { + n.client.Call(&info, "admin_nodeInfo") + } + return info +} + +// ServeRPC serves RPC requests over the given connection by dialling the +// node's WebSocket address and joining the two connections +func (n *ExecNode) ServeRPC(clientConn net.Conn) error { + conn, err := websocket.Dial(n.wsAddr, "", "http://localhost") + if err != nil { + return err + } + var wg sync.WaitGroup + wg.Add(2) + join := func(src, dst net.Conn) { + defer wg.Done() + io.Copy(dst, src) + // close the write end of the destination connection + if cw, ok := dst.(interface { + CloseWrite() error + }); ok { + cw.CloseWrite() + } else { + dst.Close() + } + } + go join(conn, clientConn) + go join(clientConn, conn) + wg.Wait() + return nil +} + +// Snapshots creates snapshots of the services by calling the +// simulation_snapshot RPC method +func (n *ExecNode) Snapshots() (map[string][]byte, error) { + if n.client == nil { + return nil, errors.New("RPC not started") + } + var snapshots map[string][]byte + return snapshots, n.client.Call(&snapshots, "simulation_snapshot") +} + +func init() { + // register a reexec function to start a devp2p node when the current + // binary is executed as "p2p-node" + reexec.Register("p2p-node", execP2PNode) +} + +// execNodeConfig is used to serialize the node configuration so it can be +// passed to the child process as a JSON encoded environment variable +type execNodeConfig struct { + Stack node.Config `json:"stack"` + Node *NodeConfig `json:"node"` + Snapshots map[string][]byte `json:"snapshots,omitempty"` + PeerAddrs map[string]string `json:"peer_addrs,omitempty"` +} + +// execP2PNode starts a devp2p node when the current binary is executed with +// argv[0] being "p2p-node", reading the service / ID from argv[1] / argv[2] +// and the node config from the _P2P_NODE_CONFIG environment variable +func execP2PNode() { + glogger := log.NewGlogHandler(log.StreamHandler(os.Stderr, log.LogfmtFormat())) + glogger.Verbosity(log.LvlInfo) + log.Root().SetHandler(glogger) + + // read the services from argv + serviceNames := strings.Split(os.Args[1], ",") + + // decode the config + confEnv := os.Getenv("_P2P_NODE_CONFIG") + if confEnv == "" { + log.Crit("missing _P2P_NODE_CONFIG") + } + var conf execNodeConfig + if err := json.Unmarshal([]byte(confEnv), &conf); err != nil { + log.Crit("error decoding _P2P_NODE_CONFIG", "err", err) + } + conf.Stack.P2P.PrivateKey = conf.Node.PrivateKey + + // use explicit IP address in ListenAddr so that Enode URL is usable + externalIP := func() string { + addrs, err := net.InterfaceAddrs() + if err != nil { + log.Crit("error getting IP address", "err", err) + } + for _, addr := range addrs { + if ip, ok := addr.(*net.IPNet); ok && !ip.IP.IsLoopback() { + return ip.IP.String() + } + } + log.Crit("unable to determine explicit IP address") + return "" + } + if strings.HasPrefix(conf.Stack.P2P.ListenAddr, ":") { + conf.Stack.P2P.ListenAddr = externalIP() + conf.Stack.P2P.ListenAddr + } + if conf.Stack.WSHost == "0.0.0.0" { + conf.Stack.WSHost = externalIP() + } + + // initialize the devp2p stack + stack, err := node.New(&conf.Stack) + if err != nil { + log.Crit("error creating node stack", "err", err) + } + + // register the services, collecting them into a map so we can wrap + // them in a snapshot service + services := make(map[string]node.Service, len(serviceNames)) + for _, name := range serviceNames { + serviceFunc, exists := serviceFuncs[name] + if !exists { + log.Crit("unknown node service", "name", name) + } + constructor := func(nodeCtx *node.ServiceContext) (node.Service, error) { + ctx := &ServiceContext{ + RPCDialer: &wsRPCDialer{addrs: conf.PeerAddrs}, + NodeContext: nodeCtx, + Config: conf.Node, + } + if conf.Snapshots != nil { + ctx.Snapshot = conf.Snapshots[name] + } + service, err := serviceFunc(ctx) + if err != nil { + return nil, err + } + services[name] = service + return service, nil + } + if err := stack.Register(constructor); err != nil { + log.Crit("error starting service", "name", name, "err", err) + } + } + + // register the snapshot service + if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { + return &snapshotService{services}, nil + }); err != nil { + log.Crit("error starting snapshot service", "err", err) + } + + // start the stack + if err := stack.Start(); err != nil { + log.Crit("error stating node stack", "err", err) + } + + // stop the stack if we get a SIGTERM signal + go func() { + sigc := make(chan os.Signal, 1) + signal.Notify(sigc, syscall.SIGTERM) + defer signal.Stop(sigc) + <-sigc + log.Info("Received SIGTERM, shutting down...") + stack.Stop() + }() + + // wait for the stack to exit + stack.Wait() +} + +// snapshotService is a node.Service which wraps a list of services and +// exposes an API to generate a snapshot of those services +type snapshotService struct { + services map[string]node.Service +} + +func (s *snapshotService) APIs() []rpc.API { + return []rpc.API{{ + Namespace: "simulation", + Version: "1.0", + Service: SnapshotAPI{s.services}, + }} +} + +func (s *snapshotService) Protocols() []p2p.Protocol { + return nil +} + +func (s *snapshotService) Start(*p2p.Server) error { + return nil +} + +func (s *snapshotService) Stop() error { + return nil +} + +// SnapshotAPI provides an RPC method to create snapshots of services +type SnapshotAPI struct { + services map[string]node.Service +} + +func (api SnapshotAPI) Snapshot() (map[string][]byte, error) { + snapshots := make(map[string][]byte) + for name, service := range api.services { + if s, ok := service.(interface { + Snapshot() ([]byte, error) + }); ok { + snap, err := s.Snapshot() + if err != nil { + return nil, err + } + snapshots[name] = snap + } + } + return snapshots, nil +} + +type wsRPCDialer struct { + addrs map[string]string +} + +// DialRPC implements the RPCDialer interface by creating a WebSocket RPC +// client of the given node +func (w *wsRPCDialer) DialRPC(id discover.NodeID) (*rpc.Client, error) { + addr, ok := w.addrs[id.String()] + if !ok { + return nil, fmt.Errorf("unknown node: %s", id) + } + return rpc.DialWebsocket(context.Background(), addr, "http://localhost") +} diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go new file mode 100644 index 000000000..c97188def --- /dev/null +++ b/p2p/simulations/adapters/inproc.go @@ -0,0 +1,314 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package adapters + +import ( + "errors" + "fmt" + "math" + "net" + "sync" + + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rpc" +) + +// SimAdapter is a NodeAdapter which creates in-memory simulation nodes and +// connects them using in-memory net.Pipe connections +type SimAdapter struct { + mtx sync.RWMutex + nodes map[discover.NodeID]*SimNode + services map[string]ServiceFunc +} + +// NewSimAdapter creates a SimAdapter which is capable of running in-memory +// simulation nodes running any of the given services (the services to run on a +// particular node are passed to the NewNode function in the NodeConfig) +func NewSimAdapter(services map[string]ServiceFunc) *SimAdapter { + return &SimAdapter{ + nodes: make(map[discover.NodeID]*SimNode), + services: services, + } +} + +// Name returns the name of the adapter for logging purposes +func (s *SimAdapter) Name() string { + return "sim-adapter" +} + +// NewNode returns a new SimNode using the given config +func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + + // check a node with the ID doesn't already exist + id := config.ID + if _, exists := s.nodes[id]; exists { + return nil, fmt.Errorf("node already exists: %s", id) + } + + // check the services are valid + if len(config.Services) == 0 { + return nil, errors.New("node must have at least one service") + } + for _, service := range config.Services { + if _, exists := s.services[service]; !exists { + return nil, fmt.Errorf("unknown node service %q", service) + } + } + + n, err := node.New(&node.Config{ + P2P: p2p.Config{ + PrivateKey: config.PrivateKey, + MaxPeers: math.MaxInt32, + NoDiscovery: true, + Dialer: s, + EnableMsgEvents: true, + }, + NoUSB: true, + }) + if err != nil { + return nil, err + } + + simNode := &SimNode{ + ID: id, + config: config, + node: n, + adapter: s, + running: make(map[string]node.Service), + } + s.nodes[id] = simNode + return simNode, nil +} + +// Dial implements the p2p.NodeDialer interface by connecting to the node using +// an in-memory net.Pipe connection +func (s *SimAdapter) Dial(dest *discover.Node) (conn net.Conn, err error) { + node, ok := s.GetNode(dest.ID) + if !ok { + return nil, fmt.Errorf("unknown node: %s", dest.ID) + } + srv := node.Server() + if srv == nil { + return nil, fmt.Errorf("node not running: %s", dest.ID) + } + pipe1, pipe2 := net.Pipe() + go srv.SetupConn(pipe1, 0, nil) + return pipe2, nil +} + +// DialRPC implements the RPCDialer interface by creating an in-memory RPC +// client of the given node +func (s *SimAdapter) DialRPC(id discover.NodeID) (*rpc.Client, error) { + node, ok := s.GetNode(id) + if !ok { + return nil, fmt.Errorf("unknown node: %s", id) + } + handler, err := node.node.RPCHandler() + if err != nil { + return nil, err + } + return rpc.DialInProc(handler), nil +} + +// GetNode returns the node with the given ID if it exists +func (s *SimAdapter) GetNode(id discover.NodeID) (*SimNode, bool) { + s.mtx.RLock() + defer s.mtx.RUnlock() + node, ok := s.nodes[id] + return node, ok +} + +// SimNode is an in-memory simulation node which connects to other nodes using +// an in-memory net.Pipe connection (see SimAdapter.Dial), running devp2p +// protocols directly over that pipe +type SimNode struct { + lock sync.RWMutex + ID discover.NodeID + config *NodeConfig + adapter *SimAdapter + node *node.Node + running map[string]node.Service + client *rpc.Client + registerOnce sync.Once +} + +// Addr returns the node's discovery address +func (self *SimNode) Addr() []byte { + return []byte(self.Node().String()) +} + +// Node returns a discover.Node representing the SimNode +func (self *SimNode) Node() *discover.Node { + return discover.NewNode(self.ID, net.IP{127, 0, 0, 1}, 30303, 30303) +} + +// Client returns an rpc.Client which can be used to communicate with the +// underlying services (it is set once the node has started) +func (self *SimNode) Client() (*rpc.Client, error) { + self.lock.RLock() + defer self.lock.RUnlock() + if self.client == nil { + return nil, errors.New("node not started") + } + return self.client, nil +} + +// ServeRPC serves RPC requests over the given connection by creating an +// in-memory client to the node's RPC server +func (self *SimNode) ServeRPC(conn net.Conn) error { + handler, err := self.node.RPCHandler() + if err != nil { + return err + } + handler.ServeCodec(rpc.NewJSONCodec(conn), rpc.OptionMethodInvocation|rpc.OptionSubscriptions) + return nil +} + +// Snapshots creates snapshots of the services by calling the +// simulation_snapshot RPC method +func (self *SimNode) Snapshots() (map[string][]byte, error) { + self.lock.RLock() + services := make(map[string]node.Service, len(self.running)) + for name, service := range self.running { + services[name] = service + } + self.lock.RUnlock() + if len(services) == 0 { + return nil, errors.New("no running services") + } + snapshots := make(map[string][]byte) + for name, service := range services { + if s, ok := service.(interface { + Snapshot() ([]byte, error) + }); ok { + snap, err := s.Snapshot() + if err != nil { + return nil, err + } + snapshots[name] = snap + } + } + return snapshots, nil +} + +// Start registers the services and starts the underlying devp2p node +func (self *SimNode) Start(snapshots map[string][]byte) error { + newService := func(name string) func(ctx *node.ServiceContext) (node.Service, error) { + return func(nodeCtx *node.ServiceContext) (node.Service, error) { + ctx := &ServiceContext{ + RPCDialer: self.adapter, + NodeContext: nodeCtx, + Config: self.config, + } + if snapshots != nil { + ctx.Snapshot = snapshots[name] + } + serviceFunc := self.adapter.services[name] + service, err := serviceFunc(ctx) + if err != nil { + return nil, err + } + self.running[name] = service + return service, nil + } + } + + // ensure we only register the services once in the case of the node + // being stopped and then started again + var regErr error + self.registerOnce.Do(func() { + for _, name := range self.config.Services { + if err := self.node.Register(newService(name)); err != nil { + regErr = err + return + } + } + }) + if regErr != nil { + return regErr + } + + if err := self.node.Start(); err != nil { + return err + } + + // create an in-process RPC client + handler, err := self.node.RPCHandler() + if err != nil { + return err + } + + self.lock.Lock() + self.client = rpc.DialInProc(handler) + self.lock.Unlock() + + return nil +} + +// Stop closes the RPC client and stops the underlying devp2p node +func (self *SimNode) Stop() error { + self.lock.Lock() + if self.client != nil { + self.client.Close() + self.client = nil + } + self.lock.Unlock() + return self.node.Stop() +} + +// Services returns a copy of the underlying services +func (self *SimNode) Services() []node.Service { + self.lock.RLock() + defer self.lock.RUnlock() + services := make([]node.Service, 0, len(self.running)) + for _, service := range self.running { + services = append(services, service) + } + return services +} + +// Server returns the underlying p2p.Server +func (self *SimNode) Server() *p2p.Server { + return self.node.Server() +} + +// SubscribeEvents subscribes the given channel to peer events from the +// underlying p2p.Server +func (self *SimNode) SubscribeEvents(ch chan *p2p.PeerEvent) event.Subscription { + srv := self.Server() + if srv == nil { + panic("node not running") + } + return srv.SubscribeEvents(ch) +} + +// NodeInfo returns information about the node +func (self *SimNode) NodeInfo() *p2p.NodeInfo { + server := self.Server() + if server == nil { + return &p2p.NodeInfo{ + ID: self.ID.String(), + Enode: self.Node().String(), + } + } + return server.NodeInfo() +} diff --git a/p2p/simulations/adapters/types.go b/p2p/simulations/adapters/types.go new file mode 100644 index 000000000..ed6cfc504 --- /dev/null +++ b/p2p/simulations/adapters/types.go @@ -0,0 +1,215 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package adapters + +import ( + "crypto/ecdsa" + "encoding/hex" + "encoding/json" + "fmt" + "net" + "os" + + "github.com/docker/docker/pkg/reexec" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rpc" +) + +// Node represents a node in a simulation network which is created by a +// NodeAdapter, for example: +// +// * SimNode - An in-memory node +// * ExecNode - A child process node +// * DockerNode - A Docker container node +// +type Node interface { + // Addr returns the node's address (e.g. an Enode URL) + Addr() []byte + + // Client returns the RPC client which is created once the node is + // up and running + Client() (*rpc.Client, error) + + // ServeRPC serves RPC requests over the given connection + ServeRPC(net.Conn) error + + // Start starts the node with the given snapshots + Start(snapshots map[string][]byte) error + + // Stop stops the node + Stop() error + + // NodeInfo returns information about the node + NodeInfo() *p2p.NodeInfo + + // Snapshots creates snapshots of the running services + Snapshots() (map[string][]byte, error) +} + +// NodeAdapter is used to create Nodes in a simulation network +type NodeAdapter interface { + // Name returns the name of the adapter for logging purposes + Name() string + + // NewNode creates a new node with the given configuration + NewNode(config *NodeConfig) (Node, error) +} + +// NodeConfig is the configuration used to start a node in a simulation +// network +type NodeConfig struct { + // ID is the node's ID which is used to identify the node in the + // simulation network + ID discover.NodeID + + // PrivateKey is the node's private key which is used by the devp2p + // stack to encrypt communications + PrivateKey *ecdsa.PrivateKey + + // Name is a human friendly name for the node like "node01" + Name string + + // Services are the names of the services which should be run when + // starting the node (for SimNodes it should be the names of services + // contained in SimAdapter.services, for other nodes it should be + // services registered by calling the RegisterService function) + Services []string +} + +// nodeConfigJSON is used to encode and decode NodeConfig as JSON by encoding +// all fields as strings +type nodeConfigJSON struct { + ID string `json:"id"` + PrivateKey string `json:"private_key"` + Name string `json:"name"` + Services []string `json:"services"` +} + +// MarshalJSON implements the json.Marshaler interface by encoding the config +// fields as strings +func (n *NodeConfig) MarshalJSON() ([]byte, error) { + confJSON := nodeConfigJSON{ + ID: n.ID.String(), + Name: n.Name, + Services: n.Services, + } + if n.PrivateKey != nil { + confJSON.PrivateKey = hex.EncodeToString(crypto.FromECDSA(n.PrivateKey)) + } + return json.Marshal(confJSON) +} + +// UnmarshalJSON implements the json.Unmarshaler interface by decoding the json +// string values into the config fields +func (n *NodeConfig) UnmarshalJSON(data []byte) error { + var confJSON nodeConfigJSON + if err := json.Unmarshal(data, &confJSON); err != nil { + return err + } + + if confJSON.ID != "" { + nodeID, err := discover.HexID(confJSON.ID) + if err != nil { + return err + } + n.ID = nodeID + } + + if confJSON.PrivateKey != "" { + key, err := hex.DecodeString(confJSON.PrivateKey) + if err != nil { + return err + } + privKey, err := crypto.ToECDSA(key) + if err != nil { + return err + } + n.PrivateKey = privKey + } + + n.Name = confJSON.Name + n.Services = confJSON.Services + + return nil +} + +// RandomNodeConfig returns node configuration with a randomly generated ID and +// PrivateKey +func RandomNodeConfig() *NodeConfig { + key, err := crypto.GenerateKey() + if err != nil { + panic("unable to generate key") + } + var id discover.NodeID + pubkey := crypto.FromECDSAPub(&key.PublicKey) + copy(id[:], pubkey[1:]) + return &NodeConfig{ + ID: id, + PrivateKey: key, + } +} + +// ServiceContext is a collection of options and methods which can be utilised +// when starting services +type ServiceContext struct { + RPCDialer + + NodeContext *node.ServiceContext + Config *NodeConfig + Snapshot []byte +} + +// RPCDialer is used when initialising services which need to connect to +// other nodes in the network (for example a simulated Swarm node which needs +// to connect to a Geth node to resolve ENS names) +type RPCDialer interface { + DialRPC(id discover.NodeID) (*rpc.Client, error) +} + +// Services is a collection of services which can be run in a simulation +type Services map[string]ServiceFunc + +// ServiceFunc returns a node.Service which can be used to boot a devp2p node +type ServiceFunc func(ctx *ServiceContext) (node.Service, error) + +// serviceFuncs is a map of registered services which are used to boot devp2p +// nodes +var serviceFuncs = make(Services) + +// RegisterServices registers the given Services which can then be used to +// start devp2p nodes using either the Exec or Docker adapters. +// +// It should be called in an init function so that it has the opportunity to +// execute the services before main() is called. +func RegisterServices(services Services) { + for name, f := range services { + if _, exists := serviceFuncs[name]; exists { + panic(fmt.Sprintf("node service already exists: %q", name)) + } + serviceFuncs[name] = f + } + + // now we have registered the services, run reexec.Init() which will + // potentially start one of the services if the current binary has + // been exec'd with argv[0] set to "p2p-node" + if reexec.Init() { + os.Exit(0) + } +} diff --git a/p2p/simulations/events.go b/p2p/simulations/events.go new file mode 100644 index 000000000..f17958c68 --- /dev/null +++ b/p2p/simulations/events.go @@ -0,0 +1,108 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package simulations + +import ( + "fmt" + "time" +) + +// EventType is the type of event emitted by a simulation network +type EventType string + +const ( + // EventTypeNode is the type of event emitted when a node is either + // created, started or stopped + EventTypeNode EventType = "node" + + // EventTypeConn is the type of event emitted when a connection is + // is either established or dropped between two nodes + EventTypeConn EventType = "conn" + + // EventTypeMsg is the type of event emitted when a p2p message it + // sent between two nodes + EventTypeMsg EventType = "msg" +) + +// Event is an event emitted by a simulation network +type Event struct { + // Type is the type of the event + Type EventType `json:"type"` + + // Time is the time the event happened + Time time.Time `json:"time"` + + // Control indicates whether the event is the result of a controlled + // action in the network + Control bool `json:"control"` + + // Node is set if the type is EventTypeNode + Node *Node `json:"node,omitempty"` + + // Conn is set if the type is EventTypeConn + Conn *Conn `json:"conn,omitempty"` + + // Msg is set if the type is EventTypeMsg + Msg *Msg `json:"msg,omitempty"` +} + +// NewEvent creates a new event for the given object which should be either a +// Node, Conn or Msg. +// +// The object is copied so that the event represents the state of the object +// when NewEvent is called. +func NewEvent(v interface{}) *Event { + event := &Event{Time: time.Now()} + switch v := v.(type) { + case *Node: + event.Type = EventTypeNode + node := *v + event.Node = &node + case *Conn: + event.Type = EventTypeConn + conn := *v + event.Conn = &conn + case *Msg: + event.Type = EventTypeMsg + msg := *v + event.Msg = &msg + default: + panic(fmt.Sprintf("invalid event type: %T", v)) + } + return event +} + +// ControlEvent creates a new control event +func ControlEvent(v interface{}) *Event { + event := NewEvent(v) + event.Control = true + return event +} + +// String returns the string representation of the event +func (e *Event) String() string { + switch e.Type { + case EventTypeNode: + return fmt.Sprintf("<node-event> id: %s up: %t", e.Node.ID().TerminalString(), e.Node.Up) + case EventTypeConn: + return fmt.Sprintf("<conn-event> nodes: %s->%s up: %t", e.Conn.One.TerminalString(), e.Conn.Other.TerminalString(), e.Conn.Up) + case EventTypeMsg: + return fmt.Sprintf("<msg-event> nodes: %s->%s proto: %s, code: %d, received: %t", e.Msg.One.TerminalString(), e.Msg.Other.TerminalString(), e.Msg.Protocol, e.Msg.Code, e.Msg.Received) + default: + return "" + } +} diff --git a/p2p/simulations/examples/README.md b/p2p/simulations/examples/README.md new file mode 100644 index 000000000..822a48dcb --- /dev/null +++ b/p2p/simulations/examples/README.md @@ -0,0 +1,39 @@ +# devp2p simulation examples + +## ping-pong + +`ping-pong.go` implements a simulation network which contains nodes running a +simple "ping-pong" protocol where nodes send a ping message to all their +connected peers every 10s and receive pong messages in return. + +To run the simulation, run `go run ping-pong.go` in one terminal to start the +simulation API and `./ping-pong.sh` in another to start and connect the nodes: + +``` +$ go run ping-pong.go +INFO [08-15|13:53:49] using sim adapter +INFO [08-15|13:53:49] starting simulation server on 0.0.0.0:8888... +``` + +``` +$ ./ping-pong.sh +---> 13:58:12 creating 10 nodes +Created node01 +Started node01 +... +Created node10 +Started node10 +---> 13:58:13 connecting node01 to all other nodes +Connected node01 to node02 +... +Connected node01 to node10 +---> 13:58:14 done +``` + +Use the `--adapter` flag to choose the adapter type: + +``` +$ go run ping-pong.go --adapter exec +INFO [08-15|14:01:14] using exec adapter tmpdir=/var/folders/k6/wpsgfg4n23ddbc6f5cnw5qg00000gn/T/p2p-example992833779 +INFO [08-15|14:01:14] starting simulation server on 0.0.0.0:8888... +``` diff --git a/p2p/simulations/examples/ping-pong.go b/p2p/simulations/examples/ping-pong.go new file mode 100644 index 000000000..6a0ead53a --- /dev/null +++ b/p2p/simulations/examples/ping-pong.go @@ -0,0 +1,184 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package main + +import ( + "flag" + "fmt" + "io/ioutil" + "net/http" + "os" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + "github.com/ethereum/go-ethereum/rpc" +) + +var adapterType = flag.String("adapter", "sim", `node adapter to use (one of "sim", "exec" or "docker")`) + +// main() starts a simulation network which contains nodes running a simple +// ping-pong protocol +func main() { + flag.Parse() + + // set the log level to Trace + log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(false)))) + + // register a single ping-pong service + services := map[string]adapters.ServiceFunc{ + "ping-pong": func(ctx *adapters.ServiceContext) (node.Service, error) { + return newPingPongService(ctx.Config.ID), nil + }, + } + adapters.RegisterServices(services) + + // create the NodeAdapter + var adapter adapters.NodeAdapter + + switch *adapterType { + + case "sim": + log.Info("using sim adapter") + adapter = adapters.NewSimAdapter(services) + + case "exec": + tmpdir, err := ioutil.TempDir("", "p2p-example") + if err != nil { + log.Crit("error creating temp dir", "err", err) + } + defer os.RemoveAll(tmpdir) + log.Info("using exec adapter", "tmpdir", tmpdir) + adapter = adapters.NewExecAdapter(tmpdir) + + case "docker": + log.Info("using docker adapter") + var err error + adapter, err = adapters.NewDockerAdapter() + if err != nil { + log.Crit("error creating docker adapter", "err", err) + } + + default: + log.Crit(fmt.Sprintf("unknown node adapter %q", *adapterType)) + } + + // start the HTTP API + log.Info("starting simulation server on 0.0.0.0:8888...") + network := simulations.NewNetwork(adapter, &simulations.NetworkConfig{ + DefaultService: "ping-pong", + }) + if err := http.ListenAndServe(":8888", simulations.NewServer(network)); err != nil { + log.Crit("error starting simulation server", "err", err) + } +} + +// pingPongService runs a ping-pong protocol between nodes where each node +// sends a ping to all its connected peers every 10s and receives a pong in +// return +type pingPongService struct { + id discover.NodeID + log log.Logger + received int64 +} + +func newPingPongService(id discover.NodeID) *pingPongService { + return &pingPongService{ + id: id, + log: log.New("node.id", id), + } +} + +func (p *pingPongService) Protocols() []p2p.Protocol { + return []p2p.Protocol{{ + Name: "ping-pong", + Version: 1, + Length: 2, + Run: p.Run, + NodeInfo: p.Info, + }} +} + +func (p *pingPongService) APIs() []rpc.API { + return nil +} + +func (p *pingPongService) Start(server *p2p.Server) error { + p.log.Info("ping-pong service starting") + return nil +} + +func (p *pingPongService) Stop() error { + p.log.Info("ping-pong service stopping") + return nil +} + +func (p *pingPongService) Info() interface{} { + return struct { + Received int64 `json:"received"` + }{ + atomic.LoadInt64(&p.received), + } +} + +const ( + pingMsgCode = iota + pongMsgCode +) + +// Run implements the ping-pong protocol which sends ping messages to the peer +// at 10s intervals, and responds to pings with pong messages. +func (p *pingPongService) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error { + log := p.log.New("peer.id", peer.ID()) + + errC := make(chan error) + go func() { + for range time.Tick(10 * time.Second) { + log.Info("sending ping") + if err := p2p.Send(rw, pingMsgCode, "PING"); err != nil { + errC <- err + return + } + } + }() + go func() { + for { + msg, err := rw.ReadMsg() + if err != nil { + errC <- err + return + } + payload, err := ioutil.ReadAll(msg.Payload) + if err != nil { + errC <- err + return + } + log.Info("received message", "msg.code", msg.Code, "msg.payload", string(payload)) + atomic.AddInt64(&p.received, 1) + if msg.Code == pingMsgCode { + log.Info("sending pong") + go p2p.Send(rw, pongMsgCode, "PONG") + } + } + }() + return <-errC +} diff --git a/p2p/simulations/examples/ping-pong.sh b/p2p/simulations/examples/ping-pong.sh new file mode 100755 index 000000000..47936bd9a --- /dev/null +++ b/p2p/simulations/examples/ping-pong.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# +# Boot a ping-pong network simulation using the HTTP API started by ping-pong.go + +set -e + +main() { + if ! which p2psim &>/dev/null; then + fail "missing p2psim binary (you need to build cmd/p2psim and put it in \$PATH)" + fi + + info "creating 10 nodes" + for i in $(seq 1 10); do + p2psim node create --name "$(node_name $i)" + p2psim node start "$(node_name $i)" + done + + info "connecting node01 to all other nodes" + for i in $(seq 2 10); do + p2psim node connect "node01" "$(node_name $i)" + done + + info "done" +} + +node_name() { + local num=$1 + echo "node$(printf '%02d' $num)" +} + +info() { + echo -e "\033[1;32m---> $(date +%H:%M:%S) ${@}\033[0m" +} + +fail() { + echo -e "\033[1;31mERROR: ${@}\033[0m" >&2 + exit 1 +} + +main "$@" diff --git a/p2p/simulations/http.go b/p2p/simulations/http.go new file mode 100644 index 000000000..3fa8b9292 --- /dev/null +++ b/p2p/simulations/http.go @@ -0,0 +1,680 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package simulations + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "strconv" + "strings" + + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + "github.com/ethereum/go-ethereum/rpc" + "github.com/julienschmidt/httprouter" + "golang.org/x/net/websocket" +) + +// DefaultClient is the default simulation API client which expects the API +// to be running at http://localhost:8888 +var DefaultClient = NewClient("http://localhost:8888") + +// Client is a client for the simulation HTTP API which supports creating +// and managing simulation networks +type Client struct { + URL string + + client *http.Client +} + +// NewClient returns a new simulation API client +func NewClient(url string) *Client { + return &Client{ + URL: url, + client: http.DefaultClient, + } +} + +// GetNetwork returns details of the network +func (c *Client) GetNetwork() (*Network, error) { + network := &Network{} + return network, c.Get("/", network) +} + +// StartNetwork starts all existing nodes in the simulation network +func (c *Client) StartNetwork() error { + return c.Post("/start", nil, nil) +} + +// StopNetwork stops all existing nodes in a simulation network +func (c *Client) StopNetwork() error { + return c.Post("/stop", nil, nil) +} + +// CreateSnapshot creates a network snapshot +func (c *Client) CreateSnapshot() (*Snapshot, error) { + snap := &Snapshot{} + return snap, c.Get("/snapshot", snap) +} + +// LoadSnapshot loads a snapshot into the network +func (c *Client) LoadSnapshot(snap *Snapshot) error { + return c.Post("/snapshot", snap, nil) +} + +// SubscribeOpts is a collection of options to use when subscribing to network +// events +type SubscribeOpts struct { + // Current instructs the server to send events for existing nodes and + // connections first + Current bool + + // Filter instructs the server to only send a subset of message events + Filter string +} + +// SubscribeNetwork subscribes to network events which are sent from the server +// as a server-sent-events stream, optionally receiving events for existing +// nodes and connections and filtering message events +func (c *Client) SubscribeNetwork(events chan *Event, opts SubscribeOpts) (event.Subscription, error) { + url := fmt.Sprintf("%s/events?current=%t&filter=%s", c.URL, opts.Current, opts.Filter) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "text/event-stream") + res, err := c.client.Do(req) + if err != nil { + return nil, err + } + if res.StatusCode != http.StatusOK { + response, _ := ioutil.ReadAll(res.Body) + res.Body.Close() + return nil, fmt.Errorf("unexpected HTTP status: %s: %s", res.Status, response) + } + + // define a producer function to pass to event.Subscription + // which reads server-sent events from res.Body and sends + // them to the events channel + producer := func(stop <-chan struct{}) error { + defer res.Body.Close() + + // read lines from res.Body in a goroutine so that we are + // always reading from the stop channel + lines := make(chan string) + errC := make(chan error, 1) + go func() { + s := bufio.NewScanner(res.Body) + for s.Scan() { + select { + case lines <- s.Text(): + case <-stop: + return + } + } + errC <- s.Err() + }() + + // detect any lines which start with "data:", decode the data + // into an event and send it to the events channel + for { + select { + case line := <-lines: + if !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + event := &Event{} + if err := json.Unmarshal([]byte(data), event); err != nil { + return fmt.Errorf("error decoding SSE event: %s", err) + } + select { + case events <- event: + case <-stop: + return nil + } + case err := <-errC: + return err + case <-stop: + return nil + } + } + } + + return event.NewSubscription(producer), nil +} + +// GetNodes returns all nodes which exist in the network +func (c *Client) GetNodes() ([]*p2p.NodeInfo, error) { + var nodes []*p2p.NodeInfo + return nodes, c.Get("/nodes", &nodes) +} + +// CreateNode creates a node in the network using the given configuration +func (c *Client) CreateNode(config *adapters.NodeConfig) (*p2p.NodeInfo, error) { + node := &p2p.NodeInfo{} + return node, c.Post("/nodes", config, node) +} + +// GetNode returns details of a node +func (c *Client) GetNode(nodeID string) (*p2p.NodeInfo, error) { + node := &p2p.NodeInfo{} + return node, c.Get(fmt.Sprintf("/nodes/%s", nodeID), node) +} + +// StartNode starts a node +func (c *Client) StartNode(nodeID string) error { + return c.Post(fmt.Sprintf("/nodes/%s/start", nodeID), nil, nil) +} + +// StopNode stops a node +func (c *Client) StopNode(nodeID string) error { + return c.Post(fmt.Sprintf("/nodes/%s/stop", nodeID), nil, nil) +} + +// ConnectNode connects a node to a peer node +func (c *Client) ConnectNode(nodeID, peerID string) error { + return c.Post(fmt.Sprintf("/nodes/%s/conn/%s", nodeID, peerID), nil, nil) +} + +// DisconnectNode disconnects a node from a peer node +func (c *Client) DisconnectNode(nodeID, peerID string) error { + return c.Delete(fmt.Sprintf("/nodes/%s/conn/%s", nodeID, peerID)) +} + +// RPCClient returns an RPC client connected to a node +func (c *Client) RPCClient(ctx context.Context, nodeID string) (*rpc.Client, error) { + baseURL := strings.Replace(c.URL, "http", "ws", 1) + return rpc.DialWebsocket(ctx, fmt.Sprintf("%s/nodes/%s/rpc", baseURL, nodeID), "") +} + +// Get performs a HTTP GET request decoding the resulting JSON response +// into "out" +func (c *Client) Get(path string, out interface{}) error { + return c.Send("GET", path, nil, out) +} + +// Post performs a HTTP POST request sending "in" as the JSON body and +// decoding the resulting JSON response into "out" +func (c *Client) Post(path string, in, out interface{}) error { + return c.Send("POST", path, in, out) +} + +// Delete performs a HTTP DELETE request +func (c *Client) Delete(path string) error { + return c.Send("DELETE", path, nil, nil) +} + +// Send performs a HTTP request, sending "in" as the JSON request body and +// decoding the JSON response into "out" +func (c *Client) Send(method, path string, in, out interface{}) error { + var body []byte + if in != nil { + var err error + body, err = json.Marshal(in) + if err != nil { + return err + } + } + req, err := http.NewRequest(method, c.URL+path, bytes.NewReader(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + res, err := c.client.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusCreated { + response, _ := ioutil.ReadAll(res.Body) + return fmt.Errorf("unexpected HTTP status: %s: %s", res.Status, response) + } + if out != nil { + if err := json.NewDecoder(res.Body).Decode(out); err != nil { + return err + } + } + return nil +} + +// Server is an HTTP server providing an API to manage a simulation network +type Server struct { + router *httprouter.Router + network *Network +} + +// NewServer returns a new simulation API server +func NewServer(network *Network) *Server { + s := &Server{ + router: httprouter.New(), + network: network, + } + + s.OPTIONS("/", s.Options) + s.GET("/", s.GetNetwork) + s.POST("/start", s.StartNetwork) + s.POST("/stop", s.StopNetwork) + s.GET("/events", s.StreamNetworkEvents) + s.GET("/snapshot", s.CreateSnapshot) + s.POST("/snapshot", s.LoadSnapshot) + s.POST("/nodes", s.CreateNode) + s.GET("/nodes", s.GetNodes) + s.GET("/nodes/:nodeid", s.GetNode) + s.POST("/nodes/:nodeid/start", s.StartNode) + s.POST("/nodes/:nodeid/stop", s.StopNode) + s.POST("/nodes/:nodeid/conn/:peerid", s.ConnectNode) + s.DELETE("/nodes/:nodeid/conn/:peerid", s.DisconnectNode) + s.GET("/nodes/:nodeid/rpc", s.NodeRPC) + + return s +} + +// GetNetwork returns details of the network +func (s *Server) GetNetwork(w http.ResponseWriter, req *http.Request) { + s.JSON(w, http.StatusOK, s.network) +} + +// StartNetwork starts all nodes in the network +func (s *Server) StartNetwork(w http.ResponseWriter, req *http.Request) { + if err := s.network.StartAll(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) +} + +// StopNetwork stops all nodes in the network +func (s *Server) StopNetwork(w http.ResponseWriter, req *http.Request) { + if err := s.network.StopAll(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) +} + +// StreamNetworkEvents streams network events as a server-sent-events stream +func (s *Server) StreamNetworkEvents(w http.ResponseWriter, req *http.Request) { + events := make(chan *Event) + sub := s.network.events.Subscribe(events) + defer sub.Unsubscribe() + + // stop the stream if the client goes away + var clientGone <-chan bool + if cn, ok := w.(http.CloseNotifier); ok { + clientGone = cn.CloseNotify() + } + + // write writes the given event and data to the stream like: + // + // event: <event> + // data: <data> + // + write := func(event, data string) { + fmt.Fprintf(w, "event: %s\n", event) + fmt.Fprintf(w, "data: %s\n\n", data) + if fw, ok := w.(http.Flusher); ok { + fw.Flush() + } + } + writeEvent := func(event *Event) error { + data, err := json.Marshal(event) + if err != nil { + return err + } + write("network", string(data)) + return nil + } + writeErr := func(err error) { + write("error", err.Error()) + } + + // check if filtering has been requested + var filters MsgFilters + if filterParam := req.URL.Query().Get("filter"); filterParam != "" { + var err error + filters, err = NewMsgFilters(filterParam) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + } + + w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "\n\n") + if fw, ok := w.(http.Flusher); ok { + fw.Flush() + } + + // optionally send the existing nodes and connections + if req.URL.Query().Get("current") == "true" { + snap, err := s.network.Snapshot() + if err != nil { + writeErr(err) + return + } + for _, node := range snap.Nodes { + event := NewEvent(&node.Node) + if err := writeEvent(event); err != nil { + writeErr(err) + return + } + } + for _, conn := range snap.Conns { + event := NewEvent(&conn) + if err := writeEvent(event); err != nil { + writeErr(err) + return + } + } + } + + for { + select { + case event := <-events: + // only send message events which match the filters + if event.Msg != nil && !filters.Match(event.Msg) { + continue + } + if err := writeEvent(event); err != nil { + writeErr(err) + return + } + case <-clientGone: + return + } + } +} + +// NewMsgFilters constructs a collection of message filters from a URL query +// parameter. +// +// The parameter is expected to be a dash-separated list of individual filters, +// each having the format '<proto>:<codes>', where <proto> is the name of a +// protocol and <codes> is a comma-separated list of message codes. +// +// A message code of '*' or '-1' is considered a wildcard and matches any code. +func NewMsgFilters(filterParam string) (MsgFilters, error) { + filters := make(MsgFilters) + for _, filter := range strings.Split(filterParam, "-") { + protoCodes := strings.SplitN(filter, ":", 2) + if len(protoCodes) != 2 || protoCodes[0] == "" || protoCodes[1] == "" { + return nil, fmt.Errorf("invalid message filter: %s", filter) + } + proto := protoCodes[0] + for _, code := range strings.Split(protoCodes[1], ",") { + if code == "*" || code == "-1" { + filters[MsgFilter{Proto: proto, Code: -1}] = struct{}{} + continue + } + n, err := strconv.ParseUint(code, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid message code: %s", code) + } + filters[MsgFilter{Proto: proto, Code: int64(n)}] = struct{}{} + } + } + return filters, nil +} + +// MsgFilters is a collection of filters which are used to filter message +// events +type MsgFilters map[MsgFilter]struct{} + +// Match checks if the given message matches any of the filters +func (m MsgFilters) Match(msg *Msg) bool { + // check if there is a wildcard filter for the message's protocol + if _, ok := m[MsgFilter{Proto: msg.Protocol, Code: -1}]; ok { + return true + } + + // check if there is a filter for the message's protocol and code + if _, ok := m[MsgFilter{Proto: msg.Protocol, Code: int64(msg.Code)}]; ok { + return true + } + + return false +} + +// MsgFilter is used to filter message events based on protocol and message +// code +type MsgFilter struct { + // Proto is matched against a message's protocol + Proto string + + // Code is matched against a message's code, with -1 matching all codes + Code int64 +} + +// CreateSnapshot creates a network snapshot +func (s *Server) CreateSnapshot(w http.ResponseWriter, req *http.Request) { + snap, err := s.network.Snapshot() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + s.JSON(w, http.StatusOK, snap) +} + +// LoadSnapshot loads a snapshot into the network +func (s *Server) LoadSnapshot(w http.ResponseWriter, req *http.Request) { + snap := &Snapshot{} + if err := json.NewDecoder(req.Body).Decode(snap); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if err := s.network.Load(snap); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + s.JSON(w, http.StatusOK, s.network) +} + +// CreateNode creates a node in the network using the given configuration +func (s *Server) CreateNode(w http.ResponseWriter, req *http.Request) { + config := adapters.RandomNodeConfig() + err := json.NewDecoder(req.Body).Decode(config) + if err != nil && err != io.EOF { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + node, err := s.network.NewNodeWithConfig(config) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + s.JSON(w, http.StatusCreated, node.NodeInfo()) +} + +// GetNodes returns all nodes which exist in the network +func (s *Server) GetNodes(w http.ResponseWriter, req *http.Request) { + nodes := s.network.GetNodes() + + infos := make([]*p2p.NodeInfo, len(nodes)) + for i, node := range nodes { + infos[i] = node.NodeInfo() + } + + s.JSON(w, http.StatusOK, infos) +} + +// GetNode returns details of a node +func (s *Server) GetNode(w http.ResponseWriter, req *http.Request) { + node := req.Context().Value("node").(*Node) + + s.JSON(w, http.StatusOK, node.NodeInfo()) +} + +// StartNode starts a node +func (s *Server) StartNode(w http.ResponseWriter, req *http.Request) { + node := req.Context().Value("node").(*Node) + + if err := s.network.Start(node.ID()); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + s.JSON(w, http.StatusOK, node.NodeInfo()) +} + +// StopNode stops a node +func (s *Server) StopNode(w http.ResponseWriter, req *http.Request) { + node := req.Context().Value("node").(*Node) + + if err := s.network.Stop(node.ID()); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + s.JSON(w, http.StatusOK, node.NodeInfo()) +} + +// ConnectNode connects a node to a peer node +func (s *Server) ConnectNode(w http.ResponseWriter, req *http.Request) { + node := req.Context().Value("node").(*Node) + peer := req.Context().Value("peer").(*Node) + + if err := s.network.Connect(node.ID(), peer.ID()); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + s.JSON(w, http.StatusOK, node.NodeInfo()) +} + +// DisconnectNode disconnects a node from a peer node +func (s *Server) DisconnectNode(w http.ResponseWriter, req *http.Request) { + node := req.Context().Value("node").(*Node) + peer := req.Context().Value("peer").(*Node) + + if err := s.network.Disconnect(node.ID(), peer.ID()); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + s.JSON(w, http.StatusOK, node.NodeInfo()) +} + +// Options responds to the OPTIONS HTTP method by returning a 200 OK response +// with the "Access-Control-Allow-Headers" header set to "Content-Type" +func (s *Server) Options(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + w.WriteHeader(http.StatusOK) +} + +// NodeRPC forwards RPC requests to a node in the network via a WebSocket +// connection +func (s *Server) NodeRPC(w http.ResponseWriter, req *http.Request) { + node := req.Context().Value("node").(*Node) + + handler := func(conn *websocket.Conn) { + node.ServeRPC(conn) + } + + websocket.Server{Handler: handler}.ServeHTTP(w, req) +} + +// ServeHTTP implements the http.Handler interface by delegating to the +// underlying httprouter.Router +func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + s.router.ServeHTTP(w, req) +} + +// GET registers a handler for GET requests to a particular path +func (s *Server) GET(path string, handle http.HandlerFunc) { + s.router.GET(path, s.wrapHandler(handle)) +} + +// POST registers a handler for POST requests to a particular path +func (s *Server) POST(path string, handle http.HandlerFunc) { + s.router.POST(path, s.wrapHandler(handle)) +} + +// DELETE registers a handler for DELETE requests to a particular path +func (s *Server) DELETE(path string, handle http.HandlerFunc) { + s.router.DELETE(path, s.wrapHandler(handle)) +} + +// OPTIONS registers a handler for OPTIONS requests to a particular path +func (s *Server) OPTIONS(path string, handle http.HandlerFunc) { + s.router.OPTIONS("/*path", s.wrapHandler(handle)) +} + +// JSON sends "data" as a JSON HTTP response +func (s *Server) JSON(w http.ResponseWriter, status int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(data) +} + +// wrapHandler returns a httprouter.Handle which wraps a http.HandlerFunc by +// populating request.Context with any objects from the URL params +func (s *Server) wrapHandler(handler http.HandlerFunc) httprouter.Handle { + return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + + ctx := context.Background() + + if id := params.ByName("nodeid"); id != "" { + var node *Node + if nodeID, err := discover.HexID(id); err == nil { + node = s.network.GetNode(nodeID) + } else { + node = s.network.GetNodeByName(id) + } + if node == nil { + http.NotFound(w, req) + return + } + ctx = context.WithValue(ctx, "node", node) + } + + if id := params.ByName("peerid"); id != "" { + var peer *Node + if peerID, err := discover.HexID(id); err == nil { + peer = s.network.GetNode(peerID) + } else { + peer = s.network.GetNodeByName(id) + } + if peer == nil { + http.NotFound(w, req) + return + } + ctx = context.WithValue(ctx, "peer", peer) + } + + handler(w, req.WithContext(ctx)) + } +} diff --git a/p2p/simulations/http_test.go b/p2p/simulations/http_test.go new file mode 100644 index 000000000..677a8fb14 --- /dev/null +++ b/p2p/simulations/http_test.go @@ -0,0 +1,823 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package simulations + +import ( + "context" + "fmt" + "math/rand" + "net/http/httptest" + "reflect" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/node" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" + "github.com/ethereum/go-ethereum/rpc" +) + +// testService implements the node.Service interface and provides protocols +// and APIs which are useful for testing nodes in a simulation network +type testService struct { + id discover.NodeID + + // peerCount is incremented once a peer handshake has been performed + peerCount int64 + + peers map[discover.NodeID]*testPeer + peersMtx sync.Mutex + + // state stores []byte which is used to test creating and loading + // snapshots + state atomic.Value +} + +func newTestService(ctx *adapters.ServiceContext) (node.Service, error) { + svc := &testService{ + id: ctx.Config.ID, + peers: make(map[discover.NodeID]*testPeer), + } + svc.state.Store(ctx.Snapshot) + return svc, nil +} + +type testPeer struct { + testReady chan struct{} + dumReady chan struct{} +} + +func (t *testService) peer(id discover.NodeID) *testPeer { + t.peersMtx.Lock() + defer t.peersMtx.Unlock() + if peer, ok := t.peers[id]; ok { + return peer + } + peer := &testPeer{ + testReady: make(chan struct{}), + dumReady: make(chan struct{}), + } + t.peers[id] = peer + return peer +} + +func (t *testService) Protocols() []p2p.Protocol { + return []p2p.Protocol{ + { + Name: "test", + Version: 1, + Length: 3, + Run: t.RunTest, + }, + { + Name: "dum", + Version: 1, + Length: 1, + Run: t.RunDum, + }, + { + Name: "prb", + Version: 1, + Length: 1, + Run: t.RunPrb, + }, + } +} + +func (t *testService) APIs() []rpc.API { + return []rpc.API{{ + Namespace: "test", + Version: "1.0", + Service: &TestAPI{ + state: &t.state, + peerCount: &t.peerCount, + }, + }} +} + +func (t *testService) Start(server *p2p.Server) error { + return nil +} + +func (t *testService) Stop() error { + return nil +} + +// handshake performs a peer handshake by sending and expecting an empty +// message with the given code +func (t *testService) handshake(rw p2p.MsgReadWriter, code uint64) error { + errc := make(chan error, 2) + go func() { errc <- p2p.Send(rw, code, struct{}{}) }() + go func() { errc <- p2p.ExpectMsg(rw, code, struct{}{}) }() + for i := 0; i < 2; i++ { + if err := <-errc; err != nil { + return err + } + } + return nil +} + +func (t *testService) RunTest(p *p2p.Peer, rw p2p.MsgReadWriter) error { + peer := t.peer(p.ID()) + + // perform three handshakes with three different message codes, + // used to test message sending and filtering + if err := t.handshake(rw, 2); err != nil { + return err + } + if err := t.handshake(rw, 1); err != nil { + return err + } + if err := t.handshake(rw, 0); err != nil { + return err + } + + // close the testReady channel so that other protocols can run + close(peer.testReady) + + // track the peer + atomic.AddInt64(&t.peerCount, 1) + defer atomic.AddInt64(&t.peerCount, -1) + + // block until the peer is dropped + for { + _, err := rw.ReadMsg() + if err != nil { + return err + } + } +} + +func (t *testService) RunDum(p *p2p.Peer, rw p2p.MsgReadWriter) error { + peer := t.peer(p.ID()) + + // wait for the test protocol to perform its handshake + <-peer.testReady + + // perform a handshake + if err := t.handshake(rw, 0); err != nil { + return err + } + + // close the dumReady channel so that other protocols can run + close(peer.dumReady) + + // block until the peer is dropped + for { + _, err := rw.ReadMsg() + if err != nil { + return err + } + } +} +func (t *testService) RunPrb(p *p2p.Peer, rw p2p.MsgReadWriter) error { + peer := t.peer(p.ID()) + + // wait for the dum protocol to perform its handshake + <-peer.dumReady + + // perform a handshake + if err := t.handshake(rw, 0); err != nil { + return err + } + + // block until the peer is dropped + for { + _, err := rw.ReadMsg() + if err != nil { + return err + } + } +} + +func (t *testService) Snapshot() ([]byte, error) { + return t.state.Load().([]byte), nil +} + +// TestAPI provides a test API to: +// * get the peer count +// * get and set an arbitrary state byte slice +// * get and increment a counter +// * subscribe to counter increment events +type TestAPI struct { + state *atomic.Value + peerCount *int64 + counter int64 + feed event.Feed +} + +func (t *TestAPI) PeerCount() int64 { + return atomic.LoadInt64(t.peerCount) +} + +func (t *TestAPI) Get() int64 { + return atomic.LoadInt64(&t.counter) +} + +func (t *TestAPI) Add(delta int64) { + atomic.AddInt64(&t.counter, delta) + t.feed.Send(delta) +} + +func (t *TestAPI) GetState() []byte { + return t.state.Load().([]byte) +} + +func (t *TestAPI) SetState(state []byte) { + t.state.Store(state) +} + +func (t *TestAPI) Events(ctx context.Context) (*rpc.Subscription, error) { + notifier, supported := rpc.NotifierFromContext(ctx) + if !supported { + return nil, rpc.ErrNotificationsUnsupported + } + + rpcSub := notifier.CreateSubscription() + + go func() { + events := make(chan int64) + sub := t.feed.Subscribe(events) + defer sub.Unsubscribe() + + for { + select { + case event := <-events: + notifier.Notify(rpcSub.ID, event) + case <-sub.Err(): + return + case <-rpcSub.Err(): + return + case <-notifier.Closed(): + return + } + } + }() + + return rpcSub, nil +} + +var testServices = adapters.Services{ + "test": newTestService, +} + +func testHTTPServer(t *testing.T) (*Network, *httptest.Server) { + adapter := adapters.NewSimAdapter(testServices) + network := NewNetwork(adapter, &NetworkConfig{ + DefaultService: "test", + }) + return network, httptest.NewServer(NewServer(network)) +} + +// TestHTTPNetwork tests interacting with a simulation network using the HTTP +// API +func TestHTTPNetwork(t *testing.T) { + // start the server + network, s := testHTTPServer(t) + defer s.Close() + + // subscribe to events so we can check them later + client := NewClient(s.URL) + events := make(chan *Event, 100) + var opts SubscribeOpts + sub, err := client.SubscribeNetwork(events, opts) + if err != nil { + t.Fatalf("error subscribing to network events: %s", err) + } + defer sub.Unsubscribe() + + // check we can retrieve details about the network + gotNetwork, err := client.GetNetwork() + if err != nil { + t.Fatalf("error getting network: %s", err) + } + if gotNetwork.ID != network.ID { + t.Fatalf("expected network to have ID %q, got %q", network.ID, gotNetwork.ID) + } + + // start a simulation network + nodeIDs := startTestNetwork(t, client) + + // check we got all the events + x := &expectEvents{t, events, sub} + x.expect( + x.nodeEvent(nodeIDs[0], false), + x.nodeEvent(nodeIDs[1], false), + x.nodeEvent(nodeIDs[0], true), + x.nodeEvent(nodeIDs[1], true), + x.connEvent(nodeIDs[0], nodeIDs[1], false), + x.connEvent(nodeIDs[0], nodeIDs[1], true), + ) + + // reconnect the stream and check we get the current nodes and conns + events = make(chan *Event, 100) + opts.Current = true + sub, err = client.SubscribeNetwork(events, opts) + if err != nil { + t.Fatalf("error subscribing to network events: %s", err) + } + defer sub.Unsubscribe() + x = &expectEvents{t, events, sub} + x.expect( + x.nodeEvent(nodeIDs[0], true), + x.nodeEvent(nodeIDs[1], true), + x.connEvent(nodeIDs[0], nodeIDs[1], true), + ) +} + +func startTestNetwork(t *testing.T, client *Client) []string { + // create two nodes + nodeCount := 2 + nodeIDs := make([]string, nodeCount) + for i := 0; i < nodeCount; i++ { + node, err := client.CreateNode(nil) + if err != nil { + t.Fatalf("error creating node: %s", err) + } + nodeIDs[i] = node.ID + } + + // check both nodes exist + nodes, err := client.GetNodes() + if err != nil { + t.Fatalf("error getting nodes: %s", err) + } + if len(nodes) != nodeCount { + t.Fatalf("expected %d nodes, got %d", nodeCount, len(nodes)) + } + for i, nodeID := range nodeIDs { + if nodes[i].ID != nodeID { + t.Fatalf("expected node %d to have ID %q, got %q", i, nodeID, nodes[i].ID) + } + node, err := client.GetNode(nodeID) + if err != nil { + t.Fatalf("error getting node %d: %s", i, err) + } + if node.ID != nodeID { + t.Fatalf("expected node %d to have ID %q, got %q", i, nodeID, node.ID) + } + } + + // start both nodes + for _, nodeID := range nodeIDs { + if err := client.StartNode(nodeID); err != nil { + t.Fatalf("error starting node %q: %s", nodeID, err) + } + } + + // connect the nodes + for i := 0; i < nodeCount-1; i++ { + peerId := i + 1 + if i == nodeCount-1 { + peerId = 0 + } + if err := client.ConnectNode(nodeIDs[i], nodeIDs[peerId]); err != nil { + t.Fatalf("error connecting nodes: %s", err) + } + } + + return nodeIDs +} + +type expectEvents struct { + *testing.T + + events chan *Event + sub event.Subscription +} + +func (t *expectEvents) nodeEvent(id string, up bool) *Event { + return &Event{ + Type: EventTypeNode, + Node: &Node{ + Config: &adapters.NodeConfig{ + ID: discover.MustHexID(id), + }, + Up: up, + }, + } +} + +func (t *expectEvents) connEvent(one, other string, up bool) *Event { + return &Event{ + Type: EventTypeConn, + Conn: &Conn{ + One: discover.MustHexID(one), + Other: discover.MustHexID(other), + Up: up, + }, + } +} + +func (t *expectEvents) expectMsgs(expected map[MsgFilter]int) { + actual := make(map[MsgFilter]int) + timeout := time.After(10 * time.Second) +loop: + for { + select { + case event := <-t.events: + t.Logf("received %s event: %s", event.Type, event) + + if event.Type != EventTypeMsg || event.Msg.Received { + continue loop + } + if event.Msg == nil { + t.Fatal("expected event.Msg to be set") + } + filter := MsgFilter{ + Proto: event.Msg.Protocol, + Code: int64(event.Msg.Code), + } + actual[filter]++ + if actual[filter] > expected[filter] { + t.Fatalf("received too many msgs for filter: %v", filter) + } + if reflect.DeepEqual(actual, expected) { + return + } + + case err := <-t.sub.Err(): + t.Fatalf("network stream closed unexpectedly: %s", err) + + case <-timeout: + t.Fatal("timed out waiting for expected events") + } + } +} + +func (t *expectEvents) expect(events ...*Event) { + timeout := time.After(10 * time.Second) + i := 0 + for { + select { + case event := <-t.events: + t.Logf("received %s event: %s", event.Type, event) + + expected := events[i] + if event.Type != expected.Type { + t.Fatalf("expected event %d to have type %q, got %q", i, expected.Type, event.Type) + } + + switch expected.Type { + + case EventTypeNode: + if event.Node == nil { + t.Fatal("expected event.Node to be set") + } + if event.Node.ID() != expected.Node.ID() { + t.Fatalf("expected node event %d to have id %q, got %q", i, expected.Node.ID().TerminalString(), event.Node.ID().TerminalString()) + } + if event.Node.Up != expected.Node.Up { + t.Fatalf("expected node event %d to have up=%t, got up=%t", i, expected.Node.Up, event.Node.Up) + } + + case EventTypeConn: + if event.Conn == nil { + t.Fatal("expected event.Conn to be set") + } + if event.Conn.One != expected.Conn.One { + t.Fatalf("expected conn event %d to have one=%q, got one=%q", i, expected.Conn.One.TerminalString(), event.Conn.One.TerminalString()) + } + if event.Conn.Other != expected.Conn.Other { + t.Fatalf("expected conn event %d to have other=%q, got other=%q", i, expected.Conn.Other.TerminalString(), event.Conn.Other.TerminalString()) + } + if event.Conn.Up != expected.Conn.Up { + t.Fatalf("expected conn event %d to have up=%t, got up=%t", i, expected.Conn.Up, event.Conn.Up) + } + + } + + i++ + if i == len(events) { + return + } + + case err := <-t.sub.Err(): + t.Fatalf("network stream closed unexpectedly: %s", err) + + case <-timeout: + t.Fatal("timed out waiting for expected events") + } + } +} + +// TestHTTPNodeRPC tests calling RPC methods on nodes via the HTTP API +func TestHTTPNodeRPC(t *testing.T) { + // start the server + _, s := testHTTPServer(t) + defer s.Close() + + // start a node in the network + client := NewClient(s.URL) + node, err := client.CreateNode(nil) + if err != nil { + t.Fatalf("error creating node: %s", err) + } + if err := client.StartNode(node.ID); err != nil { + t.Fatalf("error starting node: %s", err) + } + + // create two RPC clients + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + rpcClient1, err := client.RPCClient(ctx, node.ID) + if err != nil { + t.Fatalf("error getting node RPC client: %s", err) + } + rpcClient2, err := client.RPCClient(ctx, node.ID) + if err != nil { + t.Fatalf("error getting node RPC client: %s", err) + } + + // subscribe to events using client 1 + events := make(chan int64, 1) + sub, err := rpcClient1.Subscribe(ctx, "test", events, "events") + if err != nil { + t.Fatalf("error subscribing to events: %s", err) + } + defer sub.Unsubscribe() + + // call some RPC methods using client 2 + if err := rpcClient2.CallContext(ctx, nil, "test_add", 10); err != nil { + t.Fatalf("error calling RPC method: %s", err) + } + var result int64 + if err := rpcClient2.CallContext(ctx, &result, "test_get"); err != nil { + t.Fatalf("error calling RPC method: %s", err) + } + if result != 10 { + t.Fatalf("expected result to be 10, got %d", result) + } + + // check we got an event from client 1 + select { + case event := <-events: + if event != 10 { + t.Fatalf("expected event to be 10, got %d", event) + } + case <-ctx.Done(): + t.Fatal(ctx.Err()) + } +} + +// TestHTTPSnapshot tests creating and loading network snapshots +func TestHTTPSnapshot(t *testing.T) { + // start the server + _, s := testHTTPServer(t) + defer s.Close() + + // create a two-node network + client := NewClient(s.URL) + nodeCount := 2 + nodes := make([]*p2p.NodeInfo, nodeCount) + for i := 0; i < nodeCount; i++ { + node, err := client.CreateNode(nil) + if err != nil { + t.Fatalf("error creating node: %s", err) + } + if err := client.StartNode(node.ID); err != nil { + t.Fatalf("error starting node: %s", err) + } + nodes[i] = node + } + if err := client.ConnectNode(nodes[0].ID, nodes[1].ID); err != nil { + t.Fatalf("error connecting nodes: %s", err) + } + + // store some state in the test services + states := make([]string, nodeCount) + for i, node := range nodes { + rpc, err := client.RPCClient(context.Background(), node.ID) + if err != nil { + t.Fatalf("error getting RPC client: %s", err) + } + defer rpc.Close() + state := fmt.Sprintf("%x", rand.Int()) + if err := rpc.Call(nil, "test_setState", []byte(state)); err != nil { + t.Fatalf("error setting service state: %s", err) + } + states[i] = state + } + + // create a snapshot + snap, err := client.CreateSnapshot() + if err != nil { + t.Fatalf("error creating snapshot: %s", err) + } + for i, state := range states { + gotState := snap.Nodes[i].Snapshots["test"] + if string(gotState) != state { + t.Fatalf("expected snapshot state %q, got %q", state, gotState) + } + } + + // create another network + _, s = testHTTPServer(t) + defer s.Close() + client = NewClient(s.URL) + + // subscribe to events so we can check them later + events := make(chan *Event, 100) + var opts SubscribeOpts + sub, err := client.SubscribeNetwork(events, opts) + if err != nil { + t.Fatalf("error subscribing to network events: %s", err) + } + defer sub.Unsubscribe() + + // load the snapshot + if err := client.LoadSnapshot(snap); err != nil { + t.Fatalf("error loading snapshot: %s", err) + } + + // check the nodes and connection exists + net, err := client.GetNetwork() + if err != nil { + t.Fatalf("error getting network: %s", err) + } + if len(net.Nodes) != nodeCount { + t.Fatalf("expected network to have %d nodes, got %d", nodeCount, len(net.Nodes)) + } + for i, node := range nodes { + id := net.Nodes[i].ID().String() + if id != node.ID { + t.Fatalf("expected node %d to have ID %s, got %s", i, node.ID, id) + } + } + if len(net.Conns) != 1 { + t.Fatalf("expected network to have 1 connection, got %d", len(net.Conns)) + } + conn := net.Conns[0] + if conn.One.String() != nodes[0].ID { + t.Fatalf("expected connection to have one=%q, got one=%q", nodes[0].ID, conn.One) + } + if conn.Other.String() != nodes[1].ID { + t.Fatalf("expected connection to have other=%q, got other=%q", nodes[1].ID, conn.Other) + } + + // check the node states were restored + for i, node := range nodes { + rpc, err := client.RPCClient(context.Background(), node.ID) + if err != nil { + t.Fatalf("error getting RPC client: %s", err) + } + defer rpc.Close() + var state []byte + if err := rpc.Call(&state, "test_getState"); err != nil { + t.Fatalf("error getting service state: %s", err) + } + if string(state) != states[i] { + t.Fatalf("expected snapshot state %q, got %q", states[i], state) + } + } + + // check we got all the events + x := &expectEvents{t, events, sub} + x.expect( + x.nodeEvent(nodes[0].ID, false), + x.nodeEvent(nodes[0].ID, true), + x.nodeEvent(nodes[1].ID, false), + x.nodeEvent(nodes[1].ID, true), + x.connEvent(nodes[0].ID, nodes[1].ID, false), + x.connEvent(nodes[0].ID, nodes[1].ID, true), + ) +} + +// TestMsgFilterPassMultiple tests streaming message events using a filter +// with multiple protocols +func TestMsgFilterPassMultiple(t *testing.T) { + // start the server + _, s := testHTTPServer(t) + defer s.Close() + + // subscribe to events with a message filter + client := NewClient(s.URL) + events := make(chan *Event, 10) + opts := SubscribeOpts{ + Filter: "prb:0-test:0", + } + sub, err := client.SubscribeNetwork(events, opts) + if err != nil { + t.Fatalf("error subscribing to network events: %s", err) + } + defer sub.Unsubscribe() + + // start a simulation network + startTestNetwork(t, client) + + // check we got the expected events + x := &expectEvents{t, events, sub} + x.expectMsgs(map[MsgFilter]int{ + {"test", 0}: 2, + {"prb", 0}: 2, + }) +} + +// TestMsgFilterPassWildcard tests streaming message events using a filter +// with a code wildcard +func TestMsgFilterPassWildcard(t *testing.T) { + // start the server + _, s := testHTTPServer(t) + defer s.Close() + + // subscribe to events with a message filter + client := NewClient(s.URL) + events := make(chan *Event, 10) + opts := SubscribeOpts{ + Filter: "prb:0,2-test:*", + } + sub, err := client.SubscribeNetwork(events, opts) + if err != nil { + t.Fatalf("error subscribing to network events: %s", err) + } + defer sub.Unsubscribe() + + // start a simulation network + startTestNetwork(t, client) + + // check we got the expected events + x := &expectEvents{t, events, sub} + x.expectMsgs(map[MsgFilter]int{ + {"test", 2}: 2, + {"test", 1}: 2, + {"test", 0}: 2, + {"prb", 0}: 2, + }) +} + +// TestMsgFilterPassSingle tests streaming message events using a filter +// with a single protocol and code +func TestMsgFilterPassSingle(t *testing.T) { + // start the server + _, s := testHTTPServer(t) + defer s.Close() + + // subscribe to events with a message filter + client := NewClient(s.URL) + events := make(chan *Event, 10) + opts := SubscribeOpts{ + Filter: "dum:0", + } + sub, err := client.SubscribeNetwork(events, opts) + if err != nil { + t.Fatalf("error subscribing to network events: %s", err) + } + defer sub.Unsubscribe() + + // start a simulation network + startTestNetwork(t, client) + + // check we got the expected events + x := &expectEvents{t, events, sub} + x.expectMsgs(map[MsgFilter]int{ + {"dum", 0}: 2, + }) +} + +// TestMsgFilterPassSingle tests streaming message events using an invalid +// filter +func TestMsgFilterFailBadParams(t *testing.T) { + // start the server + _, s := testHTTPServer(t) + defer s.Close() + + client := NewClient(s.URL) + events := make(chan *Event, 10) + opts := SubscribeOpts{ + Filter: "foo:", + } + _, err := client.SubscribeNetwork(events, opts) + if err == nil { + t.Fatalf("expected event subscription to fail but succeeded!") + } + + opts.Filter = "bzz:aa" + _, err = client.SubscribeNetwork(events, opts) + if err == nil { + t.Fatalf("expected event subscription to fail but succeeded!") + } + + opts.Filter = "invalid" + _, err = client.SubscribeNetwork(events, opts) + if err == nil { + t.Fatalf("expected event subscription to fail but succeeded!") + } +} diff --git a/p2p/simulations/network.go b/p2p/simulations/network.go new file mode 100644 index 000000000..06890ffcf --- /dev/null +++ b/p2p/simulations/network.go @@ -0,0 +1,680 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package simulations + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" +) + +// NetworkConfig defines configuration options for starting a Network +type NetworkConfig struct { + ID string `json:"id"` + DefaultService string `json:"default_service,omitempty"` +} + +// Network models a p2p simulation network which consists of a collection of +// simulated nodes and the connections which exist between them. +// +// The Network has a single NodeAdapter which is responsible for actually +// starting nodes and connecting them together. +// +// The Network emits events when nodes are started and stopped, when they are +// connected and disconnected, and also when messages are sent between nodes. +type Network struct { + NetworkConfig + + Nodes []*Node `json:"nodes"` + nodeMap map[discover.NodeID]int + + Conns []*Conn `json:"conns"` + connMap map[string]int + + nodeAdapter adapters.NodeAdapter + events event.Feed + lock sync.RWMutex + quitc chan struct{} +} + +// NewNetwork returns a Network which uses the given NodeAdapter and NetworkConfig +func NewNetwork(nodeAdapter adapters.NodeAdapter, conf *NetworkConfig) *Network { + return &Network{ + NetworkConfig: *conf, + nodeAdapter: nodeAdapter, + nodeMap: make(map[discover.NodeID]int), + connMap: make(map[string]int), + quitc: make(chan struct{}), + } +} + +// Events returns the output event feed of the Network. +func (self *Network) Events() *event.Feed { + return &self.events +} + +// NewNode adds a new node to the network with a random ID +func (self *Network) NewNode() (*Node, error) { + conf := adapters.RandomNodeConfig() + conf.Services = []string{self.DefaultService} + return self.NewNodeWithConfig(conf) +} + +// NewNodeWithConfig adds a new node to the network with the given config, +// returning an error if a node with the same ID or name already exists +func (self *Network) NewNodeWithConfig(conf *adapters.NodeConfig) (*Node, error) { + self.lock.Lock() + defer self.lock.Unlock() + + // create a random ID and PrivateKey if not set + if conf.ID == (discover.NodeID{}) { + c := adapters.RandomNodeConfig() + conf.ID = c.ID + conf.PrivateKey = c.PrivateKey + } + id := conf.ID + + // assign a name to the node if not set + if conf.Name == "" { + conf.Name = fmt.Sprintf("node%02d", len(self.Nodes)+1) + } + + // check the node doesn't already exist + if node := self.getNode(id); node != nil { + return nil, fmt.Errorf("node with ID %q already exists", id) + } + if node := self.getNodeByName(conf.Name); node != nil { + return nil, fmt.Errorf("node with name %q already exists", conf.Name) + } + + // if no services are configured, use the default service + if len(conf.Services) == 0 { + conf.Services = []string{self.DefaultService} + } + + // use the NodeAdapter to create the node + adapterNode, err := self.nodeAdapter.NewNode(conf) + if err != nil { + return nil, err + } + node := &Node{ + Node: adapterNode, + Config: conf, + } + log.Trace(fmt.Sprintf("node %v created", id)) + self.nodeMap[id] = len(self.Nodes) + self.Nodes = append(self.Nodes, node) + + // emit a "control" event + self.events.Send(ControlEvent(node)) + + return node, nil +} + +// Config returns the network configuration +func (self *Network) Config() *NetworkConfig { + return &self.NetworkConfig +} + +// StartAll starts all nodes in the network +func (self *Network) StartAll() error { + for _, node := range self.Nodes { + if node.Up { + continue + } + if err := self.Start(node.ID()); err != nil { + return err + } + } + return nil +} + +// StopAll stops all nodes in the network +func (self *Network) StopAll() error { + for _, node := range self.Nodes { + if !node.Up { + continue + } + if err := self.Stop(node.ID()); err != nil { + return err + } + } + return nil +} + +// Start starts the node with the given ID +func (self *Network) Start(id discover.NodeID) error { + return self.startWithSnapshots(id, nil) +} + +// startWithSnapshots starts the node with the given ID using the give +// snapshots +func (self *Network) startWithSnapshots(id discover.NodeID, snapshots map[string][]byte) error { + node := self.GetNode(id) + if node == nil { + return fmt.Errorf("node %v does not exist", id) + } + if node.Up { + return fmt.Errorf("node %v already up", id) + } + log.Trace(fmt.Sprintf("starting node %v: %v using %v", id, node.Up, self.nodeAdapter.Name())) + if err := node.Start(snapshots); err != nil { + log.Warn(fmt.Sprintf("start up failed: %v", err)) + return err + } + node.Up = true + log.Info(fmt.Sprintf("started node %v: %v", id, node.Up)) + + self.events.Send(NewEvent(node)) + + // subscribe to peer events + client, err := node.Client() + if err != nil { + return fmt.Errorf("error getting rpc client for node %v: %s", id, err) + } + events := make(chan *p2p.PeerEvent) + sub, err := client.Subscribe(context.Background(), "admin", events, "peerEvents") + if err != nil { + return fmt.Errorf("error getting peer events for node %v: %s", id, err) + } + go self.watchPeerEvents(id, events, sub) + return nil +} + +// watchPeerEvents reads peer events from the given channel and emits +// corresponding network events +func (self *Network) watchPeerEvents(id discover.NodeID, events chan *p2p.PeerEvent, sub event.Subscription) { + defer func() { + sub.Unsubscribe() + + // assume the node is now down + self.lock.Lock() + node := self.getNode(id) + node.Up = false + self.lock.Unlock() + self.events.Send(NewEvent(node)) + }() + for { + select { + case event, ok := <-events: + if !ok { + return + } + peer := event.Peer + switch event.Type { + + case p2p.PeerEventTypeAdd: + self.DidConnect(id, peer) + + case p2p.PeerEventTypeDrop: + self.DidDisconnect(id, peer) + + case p2p.PeerEventTypeMsgSend: + self.DidSend(id, peer, event.Protocol, *event.MsgCode) + + case p2p.PeerEventTypeMsgRecv: + self.DidReceive(peer, id, event.Protocol, *event.MsgCode) + + } + + case err := <-sub.Err(): + if err != nil { + log.Error(fmt.Sprintf("error getting peer events for node %v", id), "err", err) + } + return + } + } +} + +// Stop stops the node with the given ID +func (self *Network) Stop(id discover.NodeID) error { + node := self.GetNode(id) + if node == nil { + return fmt.Errorf("node %v does not exist", id) + } + if !node.Up { + return fmt.Errorf("node %v already down", id) + } + if err := node.Stop(); err != nil { + return err + } + node.Up = false + log.Info(fmt.Sprintf("stop node %v: %v", id, node.Up)) + + self.events.Send(ControlEvent(node)) + return nil +} + +// Connect connects two nodes together by calling the "admin_addPeer" RPC +// method on the "one" node so that it connects to the "other" node +func (self *Network) Connect(oneID, otherID discover.NodeID) error { + log.Debug(fmt.Sprintf("connecting %s to %s", oneID, otherID)) + conn, err := self.GetOrCreateConn(oneID, otherID) + if err != nil { + return err + } + if conn.Up { + return fmt.Errorf("%v and %v already connected", oneID, otherID) + } + if err := conn.nodesUp(); err != nil { + return err + } + client, err := conn.one.Client() + if err != nil { + return err + } + self.events.Send(ControlEvent(conn)) + return client.Call(nil, "admin_addPeer", string(conn.other.Addr())) +} + +// Disconnect disconnects two nodes by calling the "admin_removePeer" RPC +// method on the "one" node so that it disconnects from the "other" node +func (self *Network) Disconnect(oneID, otherID discover.NodeID) error { + conn := self.GetConn(oneID, otherID) + if conn == nil { + return fmt.Errorf("connection between %v and %v does not exist", oneID, otherID) + } + if !conn.Up { + return fmt.Errorf("%v and %v already disconnected", oneID, otherID) + } + client, err := conn.one.Client() + if err != nil { + return err + } + self.events.Send(ControlEvent(conn)) + return client.Call(nil, "admin_removePeer", string(conn.other.Addr())) +} + +// DidConnect tracks the fact that the "one" node connected to the "other" node +func (self *Network) DidConnect(one, other discover.NodeID) error { + conn, err := self.GetOrCreateConn(one, other) + if err != nil { + return fmt.Errorf("connection between %v and %v does not exist", one, other) + } + if conn.Up { + return fmt.Errorf("%v and %v already connected", one, other) + } + conn.Up = true + self.events.Send(NewEvent(conn)) + return nil +} + +// DidDisconnect tracks the fact that the "one" node disconnected from the +// "other" node +func (self *Network) DidDisconnect(one, other discover.NodeID) error { + conn, err := self.GetOrCreateConn(one, other) + if err != nil { + return fmt.Errorf("connection between %v and %v does not exist", one, other) + } + if !conn.Up { + return fmt.Errorf("%v and %v already disconnected", one, other) + } + conn.Up = false + self.events.Send(NewEvent(conn)) + return nil +} + +// DidSend tracks the fact that "sender" sent a message to "receiver" +func (self *Network) DidSend(sender, receiver discover.NodeID, proto string, code uint64) error { + msg := &Msg{ + One: sender, + Other: receiver, + Protocol: proto, + Code: code, + Received: false, + } + self.events.Send(NewEvent(msg)) + return nil +} + +// DidReceive tracks the fact that "receiver" received a message from "sender" +func (self *Network) DidReceive(sender, receiver discover.NodeID, proto string, code uint64) error { + msg := &Msg{ + One: sender, + Other: receiver, + Protocol: proto, + Code: code, + Received: true, + } + self.events.Send(NewEvent(msg)) + return nil +} + +// GetNode gets the node with the given ID, returning nil if the node does not +// exist +func (self *Network) GetNode(id discover.NodeID) *Node { + self.lock.Lock() + defer self.lock.Unlock() + return self.getNode(id) +} + +// GetNode gets the node with the given name, returning nil if the node does +// not exist +func (self *Network) GetNodeByName(name string) *Node { + self.lock.Lock() + defer self.lock.Unlock() + return self.getNodeByName(name) +} + +func (self *Network) getNode(id discover.NodeID) *Node { + i, found := self.nodeMap[id] + if !found { + return nil + } + return self.Nodes[i] +} + +func (self *Network) getNodeByName(name string) *Node { + for _, node := range self.Nodes { + if node.Config.Name == name { + return node + } + } + return nil +} + +// GetNodes returns the existing nodes +func (self *Network) GetNodes() []*Node { + self.lock.Lock() + defer self.lock.Unlock() + return self.Nodes +} + +// GetConn returns the connection which exists between "one" and "other" +// regardless of which node initiated the connection +func (self *Network) GetConn(oneID, otherID discover.NodeID) *Conn { + self.lock.Lock() + defer self.lock.Unlock() + return self.getConn(oneID, otherID) +} + +// GetOrCreateConn is like GetConn but creates the connection if it doesn't +// already exist +func (self *Network) GetOrCreateConn(oneID, otherID discover.NodeID) (*Conn, error) { + self.lock.Lock() + defer self.lock.Unlock() + if conn := self.getConn(oneID, otherID); conn != nil { + return conn, nil + } + + one := self.getNode(oneID) + if one == nil { + return nil, fmt.Errorf("node %v does not exist", oneID) + } + other := self.getNode(otherID) + if other == nil { + return nil, fmt.Errorf("node %v does not exist", otherID) + } + conn := &Conn{ + One: oneID, + Other: otherID, + one: one, + other: other, + } + label := ConnLabel(oneID, otherID) + self.connMap[label] = len(self.Conns) + self.Conns = append(self.Conns, conn) + return conn, nil +} + +func (self *Network) getConn(oneID, otherID discover.NodeID) *Conn { + label := ConnLabel(oneID, otherID) + i, found := self.connMap[label] + if !found { + return nil + } + return self.Conns[i] +} + +// Shutdown stops all nodes in the network and closes the quit channel +func (self *Network) Shutdown() { + for _, node := range self.Nodes { + log.Debug(fmt.Sprintf("stopping node %s", node.ID().TerminalString())) + if err := node.Stop(); err != nil { + log.Warn(fmt.Sprintf("error stopping node %s", node.ID().TerminalString()), "err", err) + } + } + close(self.quitc) +} + +// Node is a wrapper around adapters.Node which is used to track the status +// of a node in the network +type Node struct { + adapters.Node `json:"-"` + + // Config if the config used to created the node + Config *adapters.NodeConfig `json:"config"` + + // Up tracks whether or not the node is running + Up bool `json:"up"` +} + +// ID returns the ID of the node +func (self *Node) ID() discover.NodeID { + return self.Config.ID +} + +// String returns a log-friendly string +func (self *Node) String() string { + return fmt.Sprintf("Node %v", self.ID().TerminalString()) +} + +// NodeInfo returns information about the node +func (self *Node) NodeInfo() *p2p.NodeInfo { + // avoid a panic if the node is not started yet + if self.Node == nil { + return nil + } + info := self.Node.NodeInfo() + info.Name = self.Config.Name + return info +} + +// MarshalJSON implements the json.Marshaler interface so that the encoded +// JSON includes the NodeInfo +func (self *Node) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Info *p2p.NodeInfo `json:"info,omitempty"` + Config *adapters.NodeConfig `json:"config,omitempty"` + Up bool `json:"up"` + }{ + Info: self.NodeInfo(), + Config: self.Config, + Up: self.Up, + }) +} + +// Conn represents a connection between two nodes in the network +type Conn struct { + // One is the node which initiated the connection + One discover.NodeID `json:"one"` + + // Other is the node which the connection was made to + Other discover.NodeID `json:"other"` + + // Up tracks whether or not the connection is active + Up bool `json:"up"` + + one *Node + other *Node +} + +// nodesUp returns whether both nodes are currently up +func (self *Conn) nodesUp() error { + if !self.one.Up { + return fmt.Errorf("one %v is not up", self.One) + } + if !self.other.Up { + return fmt.Errorf("other %v is not up", self.Other) + } + return nil +} + +// String returns a log-friendly string +func (self *Conn) String() string { + return fmt.Sprintf("Conn %v->%v", self.One.TerminalString(), self.Other.TerminalString()) +} + +// Msg represents a p2p message sent between two nodes in the network +type Msg struct { + One discover.NodeID `json:"one"` + Other discover.NodeID `json:"other"` + Protocol string `json:"protocol"` + Code uint64 `json:"code"` + Received bool `json:"received"` +} + +// String returns a log-friendly string +func (self *Msg) String() string { + return fmt.Sprintf("Msg(%d) %v->%v", self.Code, self.One.TerminalString(), self.Other.TerminalString()) +} + +// ConnLabel generates a deterministic string which represents a connection +// between two nodes, used to compare if two connections are between the same +// nodes +func ConnLabel(source, target discover.NodeID) string { + var first, second discover.NodeID + if bytes.Compare(source.Bytes(), target.Bytes()) > 0 { + first = target + second = source + } else { + first = source + second = target + } + return fmt.Sprintf("%v-%v", first, second) +} + +// Snapshot represents the state of a network at a single point in time and can +// be used to restore the state of a network +type Snapshot struct { + Nodes []NodeSnapshot `json:"nodes,omitempty"` + Conns []Conn `json:"conns,omitempty"` +} + +// NodeSnapshot represents the state of a node in the network +type NodeSnapshot struct { + Node Node `json:"node,omitempty"` + + // Snapshots is arbitrary data gathered from calling node.Snapshots() + Snapshots map[string][]byte `json:"snapshots,omitempty"` +} + +// Snapshot creates a network snapshot +func (self *Network) Snapshot() (*Snapshot, error) { + self.lock.Lock() + defer self.lock.Unlock() + snap := &Snapshot{ + Nodes: make([]NodeSnapshot, len(self.Nodes)), + Conns: make([]Conn, len(self.Conns)), + } + for i, node := range self.Nodes { + snap.Nodes[i] = NodeSnapshot{Node: *node} + if !node.Up { + continue + } + snapshots, err := node.Snapshots() + if err != nil { + return nil, err + } + snap.Nodes[i].Snapshots = snapshots + } + for i, conn := range self.Conns { + snap.Conns[i] = *conn + } + return snap, nil +} + +// Load loads a network snapshot +func (self *Network) Load(snap *Snapshot) error { + for _, n := range snap.Nodes { + if _, err := self.NewNodeWithConfig(n.Node.Config); err != nil { + return err + } + if !n.Node.Up { + continue + } + if err := self.startWithSnapshots(n.Node.Config.ID, n.Snapshots); err != nil { + return err + } + } + for _, conn := range snap.Conns { + if err := self.Connect(conn.One, conn.Other); err != nil { + return err + } + } + return nil +} + +// Subscribe reads control events from a channel and executes them +func (self *Network) Subscribe(events chan *Event) { + for { + select { + case event, ok := <-events: + if !ok { + return + } + if event.Control { + self.executeControlEvent(event) + } + case <-self.quitc: + return + } + } +} + +func (self *Network) executeControlEvent(event *Event) { + log.Trace("execute control event", "type", event.Type, "event", event) + switch event.Type { + case EventTypeNode: + if err := self.executeNodeEvent(event); err != nil { + log.Error("error executing node event", "event", event, "err", err) + } + case EventTypeConn: + if err := self.executeConnEvent(event); err != nil { + log.Error("error executing conn event", "event", event, "err", err) + } + case EventTypeMsg: + log.Warn("ignoring control msg event") + } +} + +func (self *Network) executeNodeEvent(e *Event) error { + if !e.Node.Up { + return self.Stop(e.Node.ID()) + } + + if _, err := self.NewNodeWithConfig(e.Node.Config); err != nil { + return err + } + return self.Start(e.Node.ID()) +} + +func (self *Network) executeConnEvent(e *Event) error { + if e.Conn.Up { + return self.Connect(e.Conn.One, e.Conn.Other) + } else { + return self.Disconnect(e.Conn.One, e.Conn.Other) + } +} diff --git a/p2p/simulations/network_test.go b/p2p/simulations/network_test.go new file mode 100644 index 000000000..2a062121b --- /dev/null +++ b/p2p/simulations/network_test.go @@ -0,0 +1,159 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package simulations + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/simulations/adapters" +) + +// TestNetworkSimulation creates a multi-node simulation network with each node +// connected in a ring topology, checks that all nodes successfully handshake +// with each other and that a snapshot fully represents the desired topology +func TestNetworkSimulation(t *testing.T) { + // create simulation network with 20 testService nodes + adapter := adapters.NewSimAdapter(adapters.Services{ + "test": newTestService, + }) + network := NewNetwork(adapter, &NetworkConfig{ + DefaultService: "test", + }) + defer network.Shutdown() + nodeCount := 20 + ids := make([]discover.NodeID, nodeCount) + for i := 0; i < nodeCount; i++ { + node, err := network.NewNode() + if err != nil { + t.Fatalf("error creating node: %s", err) + } + if err := network.Start(node.ID()); err != nil { + t.Fatalf("error starting node: %s", err) + } + ids[i] = node.ID() + } + + // perform a check which connects the nodes in a ring (so each node is + // connected to exactly two peers) and then checks that all nodes + // performed two handshakes by checking their peerCount + action := func(_ context.Context) error { + for i, id := range ids { + peerID := ids[(i+1)%len(ids)] + if err := network.Connect(id, peerID); err != nil { + return err + } + } + return nil + } + check := func(ctx context.Context, id discover.NodeID) (bool, error) { + // check we haven't run out of time + select { + case <-ctx.Done(): + return false, ctx.Err() + default: + } + + // get the node + node := network.GetNode(id) + if node == nil { + return false, fmt.Errorf("unknown node: %s", id) + } + + // check it has exactly two peers + client, err := node.Client() + if err != nil { + return false, err + } + var peerCount int64 + if err := client.CallContext(ctx, &peerCount, "test_peerCount"); err != nil { + return false, err + } + switch { + case peerCount < 2: + return false, nil + case peerCount == 2: + return true, nil + default: + return false, fmt.Errorf("unexpected peerCount: %d", peerCount) + } + } + + timeout := 30 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // trigger a check every 100ms + trigger := make(chan discover.NodeID) + go triggerChecks(ctx, ids, trigger, 100*time.Millisecond) + + result := NewSimulation(network).Run(ctx, &Step{ + Action: action, + Trigger: trigger, + Expect: &Expectation{ + Nodes: ids, + Check: check, + }, + }) + if result.Error != nil { + t.Fatalf("simulation failed: %s", result.Error) + } + + // take a network snapshot and check it contains the correct topology + snap, err := network.Snapshot() + if err != nil { + t.Fatal(err) + } + if len(snap.Nodes) != nodeCount { + t.Fatalf("expected snapshot to contain %d nodes, got %d", nodeCount, len(snap.Nodes)) + } + if len(snap.Conns) != nodeCount { + t.Fatalf("expected snapshot to contain %d connections, got %d", nodeCount, len(snap.Conns)) + } + for i, id := range ids { + conn := snap.Conns[i] + if conn.One != id { + t.Fatalf("expected conn[%d].One to be %s, got %s", i, id, conn.One) + } + peerID := ids[(i+1)%len(ids)] + if conn.Other != peerID { + t.Fatalf("expected conn[%d].Other to be %s, got %s", i, peerID, conn.Other) + } + } +} + +func triggerChecks(ctx context.Context, ids []discover.NodeID, trigger chan discover.NodeID, interval time.Duration) { + tick := time.NewTicker(interval) + defer tick.Stop() + for { + select { + case <-tick.C: + for _, id := range ids { + select { + case trigger <- id: + case <-ctx.Done(): + return + } + } + case <-ctx.Done(): + return + } + } +} diff --git a/p2p/simulations/simulation.go b/p2p/simulations/simulation.go new file mode 100644 index 000000000..28886e924 --- /dev/null +++ b/p2p/simulations/simulation.go @@ -0,0 +1,157 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package simulations + +import ( + "context" + "time" + + "github.com/ethereum/go-ethereum/p2p/discover" +) + +// Simulation provides a framework for running actions in a simulated network +// and then waiting for expectations to be met +type Simulation struct { + network *Network +} + +// NewSimulation returns a new simulation which runs in the given network +func NewSimulation(network *Network) *Simulation { + return &Simulation{ + network: network, + } +} + +// Run performs a step of the simulation by performing the step's action and +// then waiting for the step's expectation to be met +func (s *Simulation) Run(ctx context.Context, step *Step) (result *StepResult) { + result = newStepResult() + + result.StartedAt = time.Now() + defer func() { result.FinishedAt = time.Now() }() + + // watch network events for the duration of the step + stop := s.watchNetwork(result) + defer stop() + + // perform the action + if err := step.Action(ctx); err != nil { + result.Error = err + return + } + + // wait for all node expectations to either pass, error or timeout + nodes := make(map[discover.NodeID]struct{}, len(step.Expect.Nodes)) + for _, id := range step.Expect.Nodes { + nodes[id] = struct{}{} + } + for len(result.Passes) < len(nodes) { + select { + case id := <-step.Trigger: + // skip if we aren't checking the node + if _, ok := nodes[id]; !ok { + continue + } + + // skip if the node has already passed + if _, ok := result.Passes[id]; ok { + continue + } + + // run the node expectation check + pass, err := step.Expect.Check(ctx, id) + if err != nil { + result.Error = err + return + } + if pass { + result.Passes[id] = time.Now() + } + case <-ctx.Done(): + result.Error = ctx.Err() + return + } + } + + return +} + +func (s *Simulation) watchNetwork(result *StepResult) func() { + stop := make(chan struct{}) + done := make(chan struct{}) + events := make(chan *Event) + sub := s.network.Events().Subscribe(events) + go func() { + defer close(done) + defer sub.Unsubscribe() + for { + select { + case event := <-events: + result.NetworkEvents = append(result.NetworkEvents, event) + case <-stop: + return + } + } + }() + return func() { + close(stop) + <-done + } +} + +type Step struct { + // Action is the action to perform for this step + Action func(context.Context) error + + // Trigger is a channel which receives node ids and triggers an + // expectation check for that node + Trigger chan discover.NodeID + + // Expect is the expectation to wait for when performing this step + Expect *Expectation +} + +type Expectation struct { + // Nodes is a list of nodes to check + Nodes []discover.NodeID + + // Check checks whether a given node meets the expectation + Check func(context.Context, discover.NodeID) (bool, error) +} + +func newStepResult() *StepResult { + return &StepResult{ + Passes: make(map[discover.NodeID]time.Time), + } +} + +type StepResult struct { + // Error is the error encountered whilst running the step + Error error + + // StartedAt is the time the step started + StartedAt time.Time + + // FinishedAt is the time the step finished + FinishedAt time.Time + + // Passes are the timestamps of the successful node expectations + Passes map[discover.NodeID]time.Time + + // NetworkEvents are the network events which occurred during the step + NetworkEvents []*Event +} diff --git a/params/config.go b/params/config.go index 5c6f0b6bf..fff168a8c 100644 --- a/params/config.go +++ b/params/config.go @@ -197,7 +197,7 @@ func (c *ChainConfig) GasTable(num *big.Int) GasTable { case c.IsEIP158(num): return GasTableEIP158 case c.IsEIP150(num): - return GasTableHomesteadGasRepriceFork + return GasTableEIP150 default: return GasTableHomestead } diff --git a/params/gas_table.go b/params/gas_table.go index a06053904..4969382b1 100644 --- a/params/gas_table.go +++ b/params/gas_table.go @@ -49,9 +49,7 @@ var ( // GasTableHomestead contain the gas re-prices for // the homestead phase. - // - // TODO rename to GasTableEIP150 - GasTableHomesteadGasRepriceFork = GasTable{ + GasTableEIP150 = GasTable{ ExtcodeSize: 700, ExtcodeCopy: 700, Balance: 400, diff --git a/rpc/client.go b/rpc/client.go index f02366a39..8aa84ec98 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -349,85 +349,49 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { return err } -// ShhSubscribe calls the "shh_subscribe" method with the given arguments, -// registering a subscription. Server notifications for the subscription are -// sent to the given channel. The element type of the channel must match the -// expected type of content returned by the subscription. -// -// The context argument cancels the RPC request that sets up the subscription but has no -// effect on the subscription after ShhSubscribe has returned. -// -// Slow subscribers will be dropped eventually. Client buffers up to 8000 notifications -// before considering the subscriber dead. The subscription Err channel will receive -// ErrSubscriptionQueueOverflow. Use a sufficiently large buffer on the channel or ensure -// that the channel usually has at least one reader to prevent this issue. -func (c *Client) ShhSubscribe(ctx context.Context, channel interface{}, args ...interface{}) (*ClientSubscription, error) { - // Check type of channel first. - chanVal := reflect.ValueOf(channel) - if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 { - panic("first argument to ShhSubscribe must be a writable channel") - } - if chanVal.IsNil() { - panic("channel given to ShhSubscribe must not be nil") - } - if c.isHTTP { - return nil, ErrNotificationsUnsupported - } - - msg, err := c.newMessage("shh"+subscribeMethodSuffix, args...) - if err != nil { - return nil, err - } - op := &requestOp{ - ids: []json.RawMessage{msg.ID}, - resp: make(chan *jsonrpcMessage), - sub: newClientSubscription(c, "shh", chanVal), - } +// EthSubscribe registers a subscripion under the "eth" namespace. +func (c *Client) EthSubscribe(ctx context.Context, channel interface{}, args ...interface{}) (*ClientSubscription, error) { + return c.Subscribe(ctx, "eth", channel, args...) +} - // Send the subscription request. - // The arrival and validity of the response is signaled on sub.quit. - if err := c.send(ctx, op, msg); err != nil { - return nil, err - } - if _, err := op.wait(ctx); err != nil { - return nil, err - } - return op.sub, nil +// ShhSubscribe registers a subscripion under the "shh" namespace. +func (c *Client) ShhSubscribe(ctx context.Context, channel interface{}, args ...interface{}) (*ClientSubscription, error) { + return c.Subscribe(ctx, "shh", channel, args...) } -// EthSubscribe calls the "eth_subscribe" method with the given arguments, +// Subscribe calls the "<namespace>_subscribe" method with the given arguments, // registering a subscription. Server notifications for the subscription are // sent to the given channel. The element type of the channel must match the // expected type of content returned by the subscription. // // The context argument cancels the RPC request that sets up the subscription but has no -// effect on the subscription after EthSubscribe has returned. +// effect on the subscription after Subscribe has returned. // // Slow subscribers will be dropped eventually. Client buffers up to 8000 notifications // before considering the subscriber dead. The subscription Err channel will receive // ErrSubscriptionQueueOverflow. Use a sufficiently large buffer on the channel or ensure // that the channel usually has at least one reader to prevent this issue. -func (c *Client) EthSubscribe(ctx context.Context, channel interface{}, args ...interface{}) (*ClientSubscription, error) { +func (c *Client) Subscribe(ctx context.Context, namespace string, channel interface{}, args ...interface{}) (*ClientSubscription, error) { // Check type of channel first. chanVal := reflect.ValueOf(channel) if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 { - panic("first argument to EthSubscribe must be a writable channel") + panic("first argument to Subscribe must be a writable channel") } if chanVal.IsNil() { - panic("channel given to EthSubscribe must not be nil") + panic("channel given to Subscribe must not be nil") } if c.isHTTP { return nil, ErrNotificationsUnsupported } - msg, err := c.newMessage("eth"+subscribeMethodSuffix, args...) + msg, err := c.newMessage(namespace+subscribeMethodSuffix, args...) if err != nil { return nil, err } op := &requestOp{ ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage), - sub: newClientSubscription(c, "eth", chanVal), + sub: newClientSubscription(c, namespace, chanVal), } // Send the subscription request. diff --git a/rpc/client_test.go b/rpc/client_test.go index 10d74670b..4f354d389 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -251,6 +251,38 @@ func TestClientSubscribe(t *testing.T) { } } +func TestClientSubscribeCustomNamespace(t *testing.T) { + namespace := "custom" + server := newTestServer(namespace, new(NotificationTestService)) + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + nc := make(chan int) + count := 10 + sub, err := client.Subscribe(context.Background(), namespace, nc, "someSubscription", count, 0) + if err != nil { + t.Fatal("can't subscribe:", err) + } + for i := 0; i < count; i++ { + if val := <-nc; val != i { + t.Fatalf("value mismatch: got %d, want %d", val, i) + } + } + + sub.Unsubscribe() + select { + case v := <-nc: + t.Fatal("received value after unsubscribe:", v) + case err := <-sub.Err(): + if err != nil { + t.Fatalf("Err returned a non-nil error after explicit unsubscribe: %q", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("subscription not closed within 1s after unsubscribe") + } +} + // In this test, the connection drops while EthSubscribe is // waiting for a response. func TestClientSubscribeClose(t *testing.T) { diff --git a/swarm/network/depo.go b/swarm/network/depo.go index e76bfa66c..8695bf5d9 100644 --- a/swarm/network/depo.go +++ b/swarm/network/depo.go @@ -29,12 +29,12 @@ import ( // Handler for storage/retrieval related protocol requests // implements the StorageHandler interface used by the bzz protocol type Depo struct { - hashfunc storage.Hasher + hashfunc storage.SwarmHasher localStore storage.ChunkStore netStore storage.ChunkStore } -func NewDepo(hash storage.Hasher, localStore, remoteStore storage.ChunkStore) *Depo { +func NewDepo(hash storage.SwarmHasher, localStore, remoteStore storage.ChunkStore) *Depo { return &Depo{ hashfunc: hash, localStore: localStore, diff --git a/swarm/storage/chunker.go b/swarm/storage/chunker.go index ca85e4333..0454828b9 100644 --- a/swarm/storage/chunker.go +++ b/swarm/storage/chunker.go @@ -20,9 +20,9 @@ import ( "encoding/binary" "errors" "fmt" - "hash" "io" "sync" + "time" ) /* @@ -50,14 +50,6 @@ data_{i} := size(subtree_{i}) || key_{j} || key_{j+1} .... || key_{j+n-1} The underlying hash function is configurable */ -const ( - defaultHash = "SHA3" - // defaultHash = "BMTSHA3" // http://golang.org/pkg/hash/#Hash - // defaultHash = "SHA256" // http://golang.org/pkg/hash/#Hash - defaultBranches int64 = 128 - // hashSize int64 = hasherfunc.New().Size() // hasher knows about its own length in bytes - // chunksize int64 = branches * hashSize // chunk is defined as this -) /* Tree chunker is a concrete implementation of data chunking. @@ -67,25 +59,19 @@ If all is well it is possible to implement this by simply composing readers so t The hashing itself does use extra copies and allocation though, since it does need it. */ -type ChunkerParams struct { - Branches int64 - Hash string -} - -func NewChunkerParams() *ChunkerParams { - return &ChunkerParams{ - Branches: defaultBranches, - Hash: defaultHash, - } -} +var ( + errAppendOppNotSuported = errors.New("Append operation not supported") + errOperationTimedOut = errors.New("operation timed out") +) type TreeChunker struct { branches int64 - hashFunc Hasher + hashFunc SwarmHasher // calculated hashSize int64 // self.hashFunc.New().Size() chunkSize int64 // hashSize* branches - workerCount int + workerCount int64 // the number of worker routines used + workerLock sync.RWMutex // lock for the worker count } func NewTreeChunker(params *ChunkerParams) (self *TreeChunker) { @@ -94,7 +80,8 @@ func NewTreeChunker(params *ChunkerParams) (self *TreeChunker) { self.branches = params.Branches self.hashSize = int64(self.hashFunc().Size()) self.chunkSize = self.hashSize * self.branches - self.workerCount = 1 + self.workerCount = 0 + return } @@ -114,13 +101,31 @@ type hashJob struct { parentWg *sync.WaitGroup } -func (self *TreeChunker) Split(data io.Reader, size int64, chunkC chan *Chunk, swg, wwg *sync.WaitGroup) (Key, error) { +func (self *TreeChunker) incrementWorkerCount() { + self.workerLock.Lock() + defer self.workerLock.Unlock() + self.workerCount += 1 +} + +func (self *TreeChunker) getWorkerCount() int64 { + self.workerLock.RLock() + defer self.workerLock.RUnlock() + return self.workerCount +} +func (self *TreeChunker) decrementWorkerCount() { + self.workerLock.Lock() + defer self.workerLock.Unlock() + self.workerCount -= 1 +} + +func (self *TreeChunker) Split(data io.Reader, size int64, chunkC chan *Chunk, swg, wwg *sync.WaitGroup) (Key, error) { if self.chunkSize <= 0 { panic("chunker must be initialised") } - jobC := make(chan *hashJob, 2*processors) + + jobC := make(chan *hashJob, 2*ChunkProcessors) wg := &sync.WaitGroup{} errC := make(chan error) quitC := make(chan bool) @@ -129,6 +134,8 @@ func (self *TreeChunker) Split(data io.Reader, size int64, chunkC chan *Chunk, s if wwg != nil { wwg.Add(1) } + + self.incrementWorkerCount() go self.hashWorker(jobC, chunkC, errC, quitC, swg, wwg) depth := 0 @@ -157,10 +164,15 @@ func (self *TreeChunker) Split(data io.Reader, size int64, chunkC chan *Chunk, s close(errC) }() - //TODO: add a timeout - if err := <-errC; err != nil { - close(quitC) - return nil, err + + defer close(quitC) + select { + case err := <-errC: + if err != nil { + return nil, err + } + case <-time.NewTimer(splitTimeout).C: + return nil,errOperationTimedOut } return key, nil @@ -168,6 +180,8 @@ func (self *TreeChunker) Split(data io.Reader, size int64, chunkC chan *Chunk, s func (self *TreeChunker) split(depth int, treeSize int64, key Key, data io.Reader, size int64, jobC chan *hashJob, chunkC chan *Chunk, errC chan error, quitC chan bool, parentWg, swg, wwg *sync.WaitGroup) { + // + for depth > 0 && size < treeSize { treeSize /= self.branches depth-- @@ -223,12 +237,15 @@ func (self *TreeChunker) split(depth int, treeSize int64, key Key, data io.Reade // parentWg.Add(1) // go func() { childrenWg.Wait() - if len(jobC) > self.workerCount && self.workerCount < processors { + + worker := self.getWorkerCount() + if int64(len(jobC)) > worker && worker < ChunkProcessors { if wwg != nil { wwg.Add(1) } - self.workerCount++ + self.incrementWorkerCount() go self.hashWorker(jobC, chunkC, errC, quitC, swg, wwg) + } select { case jobC <- &hashJob{key, chunk, size, parentWg}: @@ -237,6 +254,8 @@ func (self *TreeChunker) split(depth int, treeSize int64, key Key, data io.Reade } func (self *TreeChunker) hashWorker(jobC chan *hashJob, chunkC chan *Chunk, errC chan error, quitC chan bool, swg, wwg *sync.WaitGroup) { + defer self.decrementWorkerCount() + hasher := self.hashFunc() if wwg != nil { defer wwg.Done() @@ -249,7 +268,6 @@ func (self *TreeChunker) hashWorker(jobC chan *hashJob, chunkC chan *Chunk, errC return } // now we got the hashes in the chunk, then hash the chunks - hasher.Reset() self.hashChunk(hasher, job, chunkC, swg) case <-quitC: return @@ -260,9 +278,11 @@ func (self *TreeChunker) hashWorker(jobC chan *hashJob, chunkC chan *Chunk, errC // The treeChunkers own Hash hashes together // - the size (of the subtree encoded in the Chunk) // - the Chunk, ie. the contents read from the input reader -func (self *TreeChunker) hashChunk(hasher hash.Hash, job *hashJob, chunkC chan *Chunk, swg *sync.WaitGroup) { - hasher.Write(job.chunk) +func (self *TreeChunker) hashChunk(hasher SwarmHash, job *hashJob, chunkC chan *Chunk, swg *sync.WaitGroup) { + hasher.ResetWithLength(job.chunk[:8]) // 8 bytes of length + hasher.Write(job.chunk[8:]) // minus 8 []byte length h := hasher.Sum(nil) + newChunk := &Chunk{ Key: h, SData: job.chunk, @@ -285,6 +305,10 @@ func (self *TreeChunker) hashChunk(hasher hash.Hash, job *hashJob, chunkC chan * } } +func (self *TreeChunker) Append(key Key, data io.Reader, chunkC chan *Chunk, swg, wwg *sync.WaitGroup) (Key, error) { + return nil, errAppendOppNotSuported +} + // LazyChunkReader implements LazySectionReader type LazyChunkReader struct { key Key // root key @@ -298,7 +322,6 @@ type LazyChunkReader struct { // implements the Joiner interface func (self *TreeChunker) Join(key Key, chunkC chan *Chunk) LazySectionReader { - return &LazyChunkReader{ key: key, chunkC: chunkC, diff --git a/swarm/storage/chunker_test.go b/swarm/storage/chunker_test.go index 426074e59..b41d7dd33 100644 --- a/swarm/storage/chunker_test.go +++ b/swarm/storage/chunker_test.go @@ -20,12 +20,14 @@ import ( "bytes" "crypto/rand" "encoding/binary" + "errors" "fmt" "io" - "runtime" "sync" "testing" "time" + + "github.com/ethereum/go-ethereum/crypto/sha3" ) /* @@ -43,7 +45,7 @@ type chunkerTester struct { t test } -func (self *chunkerTester) Split(chunker Splitter, data io.Reader, size int64, chunkC chan *Chunk, swg *sync.WaitGroup, expectedError error) (key Key) { +func (self *chunkerTester) Split(chunker Splitter, data io.Reader, size int64, chunkC chan *Chunk, swg *sync.WaitGroup, expectedError error) (key Key, err error) { // reset self.chunks = make(map[string]*Chunk) @@ -54,13 +56,13 @@ func (self *chunkerTester) Split(chunker Splitter, data io.Reader, size int64, c quitC := make(chan bool) timeout := time.After(600 * time.Second) if chunkC != nil { - go func() { + go func() error { for { select { case <-timeout: - self.t.Fatalf("Join timeout error") + return errors.New(("Split timeout error")) case <-quitC: - return + return nil case chunk := <-chunkC: // self.chunks = append(self.chunks, chunk) self.chunks[chunk.Key.String()] = chunk @@ -68,22 +70,69 @@ func (self *chunkerTester) Split(chunker Splitter, data io.Reader, size int64, c chunk.wg.Done() } } + } }() } - key, err := chunker.Split(data, size, chunkC, swg, nil) + + key, err = chunker.Split(data, size, chunkC, swg, nil) if err != nil && expectedError == nil { - self.t.Fatalf("Split error: %v", err) - } else if expectedError != nil && (err == nil || err.Error() != expectedError.Error()) { - self.t.Fatalf("Not receiving the correct error! Expected %v, received %v", expectedError, err) + err = errors.New(fmt.Sprintf("Split error: %v", err)) } + if chunkC != nil { if swg != nil { swg.Wait() } close(quitC) } - return + return key, err +} + +func (self *chunkerTester) Append(chunker Splitter, rootKey Key, data io.Reader, chunkC chan *Chunk, swg *sync.WaitGroup, expectedError error) (key Key, err error) { + quitC := make(chan bool) + timeout := time.After(60 * time.Second) + if chunkC != nil { + go func() error { + for { + select { + case <-timeout: + return errors.New(("Append timeout error")) + case <-quitC: + return nil + case chunk := <-chunkC: + if chunk != nil { + stored, success := self.chunks[chunk.Key.String()] + if !success { + // Requesting data + self.chunks[chunk.Key.String()] = chunk + if chunk.wg != nil { + chunk.wg.Done() + } + } else { + // getting data + chunk.SData = stored.SData + chunk.Size = int64(binary.LittleEndian.Uint64(chunk.SData[0:8])) + close(chunk.C) + } + } + } + } + }() + } + + key, err = chunker.Append(rootKey, data, chunkC, swg, nil) + if err != nil && expectedError == nil { + err = errors.New(fmt.Sprintf("Append error: %v", err)) + } + + if chunkC != nil { + if swg != nil { + swg.Wait() + } + close(quitC) + } + return key, err } func (self *chunkerTester) Join(chunker Chunker, key Key, c int, chunkC chan *Chunk, quitC chan bool) LazySectionReader { @@ -93,22 +142,20 @@ func (self *chunkerTester) Join(chunker Chunker, key Key, c int, chunkC chan *Ch timeout := time.After(600 * time.Second) i := 0 - go func() { + go func() error { for { select { case <-timeout: - self.t.Fatalf("Join timeout error") - + return errors.New(("Join timeout error")) case chunk, ok := <-chunkC: if !ok { close(quitC) - return + return nil } // this just mocks the behaviour of a chunk store retrieval stored, success := self.chunks[chunk.Key.String()] if !success { - self.t.Fatalf("not found") - return + return errors.New(("Not found")) } chunk.SData = stored.SData chunk.Size = int64(binary.LittleEndian.Uint64(chunk.SData[0:8])) @@ -136,11 +183,15 @@ func testRandomBrokenData(splitter Splitter, n int, tester *chunkerTester) { chunkC := make(chan *Chunk, 1000) swg := &sync.WaitGroup{} - key := tester.Split(splitter, brokendata, int64(n), chunkC, swg, fmt.Errorf("Broken reader")) + expectedError := fmt.Errorf("Broken reader") + key, err := tester.Split(splitter, brokendata, int64(n), chunkC, swg, expectedError) + if err == nil || err.Error() != expectedError.Error() { + tester.t.Fatalf("Not receiving the correct error! Expected %v, received %v", expectedError, err) + } tester.t.Logf(" Key = %v\n", key) } -func testRandomData(splitter Splitter, n int, tester *chunkerTester) { +func testRandomData(splitter Splitter, n int, tester *chunkerTester) Key { if tester.inputs == nil { tester.inputs = make(map[uint64][]byte) } @@ -156,7 +207,10 @@ func testRandomData(splitter Splitter, n int, tester *chunkerTester) { chunkC := make(chan *Chunk, 1000) swg := &sync.WaitGroup{} - key := tester.Split(splitter, data, int64(n), chunkC, swg, nil) + key, err := tester.Split(splitter, data, int64(n), chunkC, swg, nil) + if err != nil { + tester.t.Fatalf(err.Error()) + } tester.t.Logf(" Key = %v\n", key) chunkC = make(chan *Chunk, 1000) @@ -176,29 +230,145 @@ func testRandomData(splitter Splitter, n int, tester *chunkerTester) { } close(chunkC) <-quitC + + return key +} + +func testRandomDataAppend(splitter Splitter, n, m int, tester *chunkerTester) { + if tester.inputs == nil { + tester.inputs = make(map[uint64][]byte) + } + input, found := tester.inputs[uint64(n)] + var data io.Reader + if !found { + data, input = testDataReaderAndSlice(n) + tester.inputs[uint64(n)] = input + } else { + data = io.LimitReader(bytes.NewReader(input), int64(n)) + } + + chunkC := make(chan *Chunk, 1000) + swg := &sync.WaitGroup{} + + key, err := tester.Split(splitter, data, int64(n), chunkC, swg, nil) + if err != nil { + tester.t.Fatalf(err.Error()) + } + tester.t.Logf(" Key = %v\n", key) + + //create a append data stream + appendInput, found := tester.inputs[uint64(m)] + var appendData io.Reader + if !found { + appendData, appendInput = testDataReaderAndSlice(m) + tester.inputs[uint64(m)] = appendInput + } else { + appendData = io.LimitReader(bytes.NewReader(appendInput), int64(m)) + } + + chunkC = make(chan *Chunk, 1000) + swg = &sync.WaitGroup{} + + newKey, err := tester.Append(splitter, key, appendData, chunkC, swg, nil) + if err != nil { + tester.t.Fatalf(err.Error()) + } + tester.t.Logf(" NewKey = %v\n", newKey) + + chunkC = make(chan *Chunk, 1000) + quitC := make(chan bool) + + chunker := NewTreeChunker(NewChunkerParams()) + reader := tester.Join(chunker, newKey, 0, chunkC, quitC) + newOutput := make([]byte, n+m) + r, err := reader.Read(newOutput) + if r != (n + m) { + tester.t.Fatalf("read error read: %v n = %v err = %v\n", r, n, err) + } + + newInput := append(input, appendInput...) + if !bytes.Equal(newOutput, newInput) { + tester.t.Fatalf("input and output mismatch\n IN: %v\nOUT: %v\n", newInput, newOutput) + } + + close(chunkC) +} + +func TestSha3ForCorrectness(t *testing.T) { + tester := &chunkerTester{t: t} + + size := 4096 + input := make([]byte, size+8) + binary.LittleEndian.PutUint64(input[:8], uint64(size)) + + io.LimitReader(bytes.NewReader(input[8:]), int64(size)) + + rawSha3 := sha3.NewKeccak256() + rawSha3.Reset() + rawSha3.Write(input) + rawSha3Output := rawSha3.Sum(nil) + + sha3FromMakeFunc := MakeHashFunc(SHA3Hash)() + sha3FromMakeFunc.ResetWithLength(input[:8]) + sha3FromMakeFunc.Write(input[8:]) + sha3FromMakeFuncOutput := sha3FromMakeFunc.Sum(nil) + + if len(rawSha3Output) != len(sha3FromMakeFuncOutput) { + tester.t.Fatalf("Original SHA3 and abstracted Sha3 has different length %v:%v\n", len(rawSha3Output), len(sha3FromMakeFuncOutput)) + } + + if !bytes.Equal(rawSha3Output, sha3FromMakeFuncOutput) { + tester.t.Fatalf("Original SHA3 and abstracted Sha3 mismatch %v:%v\n", rawSha3Output, sha3FromMakeFuncOutput) + } + +} + +func TestDataAppend(t *testing.T) { + sizes := []int{1, 1, 1, 4095, 4096, 4097, 1, 1, 1, 123456, 2345678, 2345678} + appendSizes := []int{4095, 4096, 4097, 1, 1, 1, 8191, 8192, 8193, 9000, 3000, 5000} + + tester := &chunkerTester{t: t} + chunker := NewPyramidChunker(NewChunkerParams()) + for i, s := range sizes { + testRandomDataAppend(chunker, s, appendSizes[i], tester) + + } } func TestRandomData(t *testing.T) { - // sizes := []int{123456} - sizes := []int{1, 60, 83, 179, 253, 1024, 4095, 4096, 4097, 8191, 8192, 8193, 123456, 2345678} + sizes := []int{1, 60, 83, 179, 253, 1024, 4095, 4096, 4097, 8191, 8192, 8193, 12287, 12288, 12289, 123456, 2345678} tester := &chunkerTester{t: t} + chunker := NewTreeChunker(NewChunkerParams()) + pyramid := NewPyramidChunker(NewChunkerParams()) for _, s := range sizes { - testRandomData(chunker, s, tester) + treeChunkerKey := testRandomData(chunker, s, tester) + pyramidChunkerKey := testRandomData(pyramid, s, tester) + if treeChunkerKey.String() != pyramidChunkerKey.String() { + tester.t.Fatalf("tree chunker and pyramid chunker key mismatch for size %v\n TC: %v\n PC: %v\n", s, treeChunkerKey.String(), pyramidChunkerKey.String()) + } } - pyramid := NewPyramidChunker(NewChunkerParams()) + + cp := NewChunkerParams() + cp.Hash = BMTHash + chunker = NewTreeChunker(cp) + pyramid = NewPyramidChunker(cp) for _, s := range sizes { - testRandomData(pyramid, s, tester) + treeChunkerKey := testRandomData(chunker, s, tester) + pyramidChunkerKey := testRandomData(pyramid, s, tester) + if treeChunkerKey.String() != pyramidChunkerKey.String() { + tester.t.Fatalf("tree chunker BMT and pyramid chunker BMT key mismatch for size %v \n TC: %v\n PC: %v\n", s, treeChunkerKey.String(), pyramidChunkerKey.String()) + } } + } func TestRandomBrokenData(t *testing.T) { - sizes := []int{1, 60, 83, 179, 253, 1024, 4095, 4096, 4097, 8191, 8192, 8193, 123456, 2345678} + sizes := []int{1, 60, 83, 179, 253, 1024, 4095, 4096, 4097, 8191, 8192, 8193, 12287, 12288, 12289, 123456, 2345678} tester := &chunkerTester{t: t} chunker := NewTreeChunker(NewChunkerParams()) for _, s := range sizes { testRandomBrokenData(chunker, s, tester) - t.Logf("done size: %v", s) } } @@ -220,45 +390,100 @@ func benchmarkJoin(n int, t *testing.B) { chunkC := make(chan *Chunk, 1000) swg := &sync.WaitGroup{} - key := tester.Split(chunker, data, int64(n), chunkC, swg, nil) - // t.StartTimer() + key, err := tester.Split(chunker, data, int64(n), chunkC, swg, nil) + if err != nil { + tester.t.Fatalf(err.Error()) + } chunkC = make(chan *Chunk, 1000) quitC := make(chan bool) reader := tester.Join(chunker, key, i, chunkC, quitC) benchReadAll(reader) close(chunkC) <-quitC - // t.StopTimer() } - stats := new(runtime.MemStats) - runtime.ReadMemStats(stats) - fmt.Println(stats.Sys) } -func benchmarkSplitTree(n int, t *testing.B) { +func benchmarkSplitTreeSHA3(n int, t *testing.B) { t.ReportAllocs() for i := 0; i < t.N; i++ { chunker := NewTreeChunker(NewChunkerParams()) tester := &chunkerTester{t: t} data := testDataReader(n) - tester.Split(chunker, data, int64(n), nil, nil, nil) + _, err := tester.Split(chunker, data, int64(n), nil, nil, nil) + if err != nil { + tester.t.Fatalf(err.Error()) + } } - stats := new(runtime.MemStats) - runtime.ReadMemStats(stats) - fmt.Println(stats.Sys) } -func benchmarkSplitPyramid(n int, t *testing.B) { +func benchmarkSplitTreeBMT(n int, t *testing.B) { + t.ReportAllocs() + for i := 0; i < t.N; i++ { + cp := NewChunkerParams() + cp.Hash = BMTHash + chunker := NewTreeChunker(cp) + tester := &chunkerTester{t: t} + data := testDataReader(n) + _, err := tester.Split(chunker, data, int64(n), nil, nil, nil) + if err != nil { + tester.t.Fatalf(err.Error()) + } + } +} + +func benchmarkSplitPyramidSHA3(n int, t *testing.B) { t.ReportAllocs() for i := 0; i < t.N; i++ { splitter := NewPyramidChunker(NewChunkerParams()) tester := &chunkerTester{t: t} data := testDataReader(n) - tester.Split(splitter, data, int64(n), nil, nil, nil) + _, err := tester.Split(splitter, data, int64(n), nil, nil, nil) + if err != nil { + tester.t.Fatalf(err.Error()) + } + } +} + +func benchmarkSplitPyramidBMT(n int, t *testing.B) { + t.ReportAllocs() + for i := 0; i < t.N; i++ { + cp := NewChunkerParams() + cp.Hash = BMTHash + splitter := NewPyramidChunker(cp) + tester := &chunkerTester{t: t} + data := testDataReader(n) + _, err := tester.Split(splitter, data, int64(n), nil, nil, nil) + if err != nil { + tester.t.Fatalf(err.Error()) + } + } +} + +func benchmarkAppendPyramid(n, m int, t *testing.B) { + t.ReportAllocs() + for i := 0; i < t.N; i++ { + chunker := NewPyramidChunker(NewChunkerParams()) + tester := &chunkerTester{t: t} + data := testDataReader(n) + data1 := testDataReader(m) + + chunkC := make(chan *Chunk, 1000) + swg := &sync.WaitGroup{} + key, err := tester.Split(chunker, data, int64(n), chunkC, swg, nil) + if err != nil { + tester.t.Fatalf(err.Error()) + } + + chunkC = make(chan *Chunk, 1000) + swg = &sync.WaitGroup{} + + _, err = tester.Append(chunker, key, data1, chunkC, swg, nil) + if err != nil { + tester.t.Fatalf(err.Error()) + } + + close(chunkC) } - stats := new(runtime.MemStats) - runtime.ReadMemStats(stats) - fmt.Println(stats.Sys) } func BenchmarkJoin_2(t *testing.B) { benchmarkJoin(100, t) } @@ -269,26 +494,59 @@ func BenchmarkJoin_6(t *testing.B) { benchmarkJoin(1000000, t) } func BenchmarkJoin_7(t *testing.B) { benchmarkJoin(10000000, t) } func BenchmarkJoin_8(t *testing.B) { benchmarkJoin(100000000, t) } -func BenchmarkSplitTree_2(t *testing.B) { benchmarkSplitTree(100, t) } -func BenchmarkSplitTree_2h(t *testing.B) { benchmarkSplitTree(500, t) } -func BenchmarkSplitTree_3(t *testing.B) { benchmarkSplitTree(1000, t) } -func BenchmarkSplitTree_3h(t *testing.B) { benchmarkSplitTree(5000, t) } -func BenchmarkSplitTree_4(t *testing.B) { benchmarkSplitTree(10000, t) } -func BenchmarkSplitTree_4h(t *testing.B) { benchmarkSplitTree(50000, t) } -func BenchmarkSplitTree_5(t *testing.B) { benchmarkSplitTree(100000, t) } -func BenchmarkSplitTree_6(t *testing.B) { benchmarkSplitTree(1000000, t) } -func BenchmarkSplitTree_7(t *testing.B) { benchmarkSplitTree(10000000, t) } -func BenchmarkSplitTree_8(t *testing.B) { benchmarkSplitTree(100000000, t) } - -func BenchmarkSplitPyramid_2(t *testing.B) { benchmarkSplitPyramid(100, t) } -func BenchmarkSplitPyramid_2h(t *testing.B) { benchmarkSplitPyramid(500, t) } -func BenchmarkSplitPyramid_3(t *testing.B) { benchmarkSplitPyramid(1000, t) } -func BenchmarkSplitPyramid_3h(t *testing.B) { benchmarkSplitPyramid(5000, t) } -func BenchmarkSplitPyramid_4(t *testing.B) { benchmarkSplitPyramid(10000, t) } -func BenchmarkSplitPyramid_4h(t *testing.B) { benchmarkSplitPyramid(50000, t) } -func BenchmarkSplitPyramid_5(t *testing.B) { benchmarkSplitPyramid(100000, t) } -func BenchmarkSplitPyramid_6(t *testing.B) { benchmarkSplitPyramid(1000000, t) } -func BenchmarkSplitPyramid_7(t *testing.B) { benchmarkSplitPyramid(10000000, t) } -func BenchmarkSplitPyramid_8(t *testing.B) { benchmarkSplitPyramid(100000000, t) } - -// godep go test -bench ./swarm/storage -cpuprofile cpu.out -memprofile mem.out +func BenchmarkSplitTreeSHA3_2(t *testing.B) { benchmarkSplitTreeSHA3(100, t) } +func BenchmarkSplitTreeSHA3_2h(t *testing.B) { benchmarkSplitTreeSHA3(500, t) } +func BenchmarkSplitTreeSHA3_3(t *testing.B) { benchmarkSplitTreeSHA3(1000, t) } +func BenchmarkSplitTreeSHA3_3h(t *testing.B) { benchmarkSplitTreeSHA3(5000, t) } +func BenchmarkSplitTreeSHA3_4(t *testing.B) { benchmarkSplitTreeSHA3(10000, t) } +func BenchmarkSplitTreeSHA3_4h(t *testing.B) { benchmarkSplitTreeSHA3(50000, t) } +func BenchmarkSplitTreeSHA3_5(t *testing.B) { benchmarkSplitTreeSHA3(100000, t) } +func BenchmarkSplitTreeSHA3_6(t *testing.B) { benchmarkSplitTreeSHA3(1000000, t) } +func BenchmarkSplitTreeSHA3_7(t *testing.B) { benchmarkSplitTreeSHA3(10000000, t) } +func BenchmarkSplitTreeSHA3_8(t *testing.B) { benchmarkSplitTreeSHA3(100000000, t) } + +func BenchmarkSplitTreeBMT_2(t *testing.B) { benchmarkSplitTreeBMT(100, t) } +func BenchmarkSplitTreeBMT_2h(t *testing.B) { benchmarkSplitTreeBMT(500, t) } +func BenchmarkSplitTreeBMT_3(t *testing.B) { benchmarkSplitTreeBMT(1000, t) } +func BenchmarkSplitTreeBMT_3h(t *testing.B) { benchmarkSplitTreeBMT(5000, t) } +func BenchmarkSplitTreeBMT_4(t *testing.B) { benchmarkSplitTreeBMT(10000, t) } +func BenchmarkSplitTreeBMT_4h(t *testing.B) { benchmarkSplitTreeBMT(50000, t) } +func BenchmarkSplitTreeBMT_5(t *testing.B) { benchmarkSplitTreeBMT(100000, t) } +func BenchmarkSplitTreeBMT_6(t *testing.B) { benchmarkSplitTreeBMT(1000000, t) } +func BenchmarkSplitTreeBMT_7(t *testing.B) { benchmarkSplitTreeBMT(10000000, t) } +func BenchmarkSplitTreeBMT_8(t *testing.B) { benchmarkSplitTreeBMT(100000000, t) } + +func BenchmarkSplitPyramidSHA3_2(t *testing.B) { benchmarkSplitPyramidSHA3(100, t) } +func BenchmarkSplitPyramidSHA3_2h(t *testing.B) { benchmarkSplitPyramidSHA3(500, t) } +func BenchmarkSplitPyramidSHA3_3(t *testing.B) { benchmarkSplitPyramidSHA3(1000, t) } +func BenchmarkSplitPyramidSHA3_3h(t *testing.B) { benchmarkSplitPyramidSHA3(5000, t) } +func BenchmarkSplitPyramidSHA3_4(t *testing.B) { benchmarkSplitPyramidSHA3(10000, t) } +func BenchmarkSplitPyramidSHA3_4h(t *testing.B) { benchmarkSplitPyramidSHA3(50000, t) } +func BenchmarkSplitPyramidSHA3_5(t *testing.B) { benchmarkSplitPyramidSHA3(100000, t) } +func BenchmarkSplitPyramidSHA3_6(t *testing.B) { benchmarkSplitPyramidSHA3(1000000, t) } +func BenchmarkSplitPyramidSHA3_7(t *testing.B) { benchmarkSplitPyramidSHA3(10000000, t) } +func BenchmarkSplitPyramidSHA3_8(t *testing.B) { benchmarkSplitPyramidSHA3(100000000, t) } + +func BenchmarkSplitPyramidBMT_2(t *testing.B) { benchmarkSplitPyramidBMT(100, t) } +func BenchmarkSplitPyramidBMT_2h(t *testing.B) { benchmarkSplitPyramidBMT(500, t) } +func BenchmarkSplitPyramidBMT_3(t *testing.B) { benchmarkSplitPyramidBMT(1000, t) } +func BenchmarkSplitPyramidBMT_3h(t *testing.B) { benchmarkSplitPyramidBMT(5000, t) } +func BenchmarkSplitPyramidBMT_4(t *testing.B) { benchmarkSplitPyramidBMT(10000, t) } +func BenchmarkSplitPyramidBMT_4h(t *testing.B) { benchmarkSplitPyramidBMT(50000, t) } +func BenchmarkSplitPyramidBMT_5(t *testing.B) { benchmarkSplitPyramidBMT(100000, t) } +func BenchmarkSplitPyramidBMT_6(t *testing.B) { benchmarkSplitPyramidBMT(1000000, t) } +func BenchmarkSplitPyramidBMT_7(t *testing.B) { benchmarkSplitPyramidBMT(10000000, t) } +func BenchmarkSplitPyramidBMT_8(t *testing.B) { benchmarkSplitPyramidBMT(100000000, t) } + +func BenchmarkAppendPyramid_2(t *testing.B) { benchmarkAppendPyramid(100, 1000, t) } +func BenchmarkAppendPyramid_2h(t *testing.B) { benchmarkAppendPyramid(500, 1000, t) } +func BenchmarkAppendPyramid_3(t *testing.B) { benchmarkAppendPyramid(1000, 1000, t) } +func BenchmarkAppendPyramid_4(t *testing.B) { benchmarkAppendPyramid(10000, 1000, t) } +func BenchmarkAppendPyramid_4h(t *testing.B) { benchmarkAppendPyramid(50000, 1000, t) } +func BenchmarkAppendPyramid_5(t *testing.B) { benchmarkAppendPyramid(1000000, 1000, t) } +func BenchmarkAppendPyramid_6(t *testing.B) { benchmarkAppendPyramid(1000000, 1000, t) } +func BenchmarkAppendPyramid_7(t *testing.B) { benchmarkAppendPyramid(10000000, 1000, t) } +func BenchmarkAppendPyramid_8(t *testing.B) { benchmarkAppendPyramid(100000000, 1000, t) } + +// go test -timeout 20m -cpu 4 -bench=./swarm/storage -run no +// If you dont add the timeout argument above .. the benchmark will timeout and dump diff --git a/swarm/storage/common_test.go b/swarm/storage/common_test.go index 44d1dd1f7..cd4c2ef13 100644 --- a/swarm/storage/common_test.go +++ b/swarm/storage/common_test.go @@ -76,7 +76,7 @@ func testStore(m ChunkStore, l int64, branches int64, t *testing.T) { }() chunker := NewTreeChunker(&ChunkerParams{ Branches: branches, - Hash: defaultHash, + Hash: SHA3Hash, }) swg := &sync.WaitGroup{} key, _ := chunker.Split(rand.Reader, l, chunkC, swg, nil) diff --git a/swarm/storage/dbstore.go b/swarm/storage/dbstore.go index cbeddb8cb..46a5c16cc 100644 --- a/swarm/storage/dbstore.go +++ b/swarm/storage/dbstore.go @@ -72,12 +72,12 @@ type DbStore struct { gcPos, gcStartPos []byte gcArray []*gcItem - hashfunc Hasher + hashfunc SwarmHasher lock sync.Mutex } -func NewDbStore(path string, hash Hasher, capacity uint64, radius int) (s *DbStore, err error) { +func NewDbStore(path string, hash SwarmHasher, capacity uint64, radius int) (s *DbStore, err error) { s = new(DbStore) s.hashfunc = hash diff --git a/swarm/storage/dbstore_test.go b/swarm/storage/dbstore_test.go index ddce7ccfe..dd165b576 100644 --- a/swarm/storage/dbstore_test.go +++ b/swarm/storage/dbstore_test.go @@ -29,7 +29,7 @@ func initDbStore(t *testing.T) *DbStore { if err != nil { t.Fatal(err) } - m, err := NewDbStore(dir, MakeHashFunc(defaultHash), defaultDbCapacity, defaultRadius) + m, err := NewDbStore(dir, MakeHashFunc(SHA3Hash), defaultDbCapacity, defaultRadius) if err != nil { t.Fatal("can't create store:", err) } diff --git a/swarm/storage/localstore.go b/swarm/storage/localstore.go index 58f59d0a2..b442e6cc5 100644 --- a/swarm/storage/localstore.go +++ b/swarm/storage/localstore.go @@ -28,7 +28,7 @@ type LocalStore struct { } // This constructor uses MemStore and DbStore as components -func NewLocalStore(hash Hasher, params *StoreParams) (*LocalStore, error) { +func NewLocalStore(hash SwarmHasher, params *StoreParams) (*LocalStore, error) { dbStore, err := NewDbStore(params.ChunkDbPath, hash, params.DbCapacity, params.Radius) if err != nil { return nil, err diff --git a/swarm/storage/netstore.go b/swarm/storage/netstore.go index 746dd85f6..7b0612edc 100644 --- a/swarm/storage/netstore.go +++ b/swarm/storage/netstore.go @@ -36,7 +36,7 @@ NetStore falls back to a backend (CloudStorage interface) implemented by bzz/network/forwarder. forwarder or IPFS or IPΞS */ type NetStore struct { - hashfunc Hasher + hashfunc SwarmHasher localStore *LocalStore cloud CloudStore } @@ -69,7 +69,7 @@ func NewStoreParams(path string) (self *StoreParams) { // netstore contructor, takes path argument that is used to initialise dbStore, // the persistent (disk) storage component of LocalStore // the second argument is the hive, the connection/logistics manager for the node -func NewNetStore(hash Hasher, lstore *LocalStore, cloud CloudStore, params *StoreParams) *NetStore { +func NewNetStore(hash SwarmHasher, lstore *LocalStore, cloud CloudStore, params *StoreParams) *NetStore { return &NetStore{ hashfunc: hash, localStore: lstore, diff --git a/swarm/storage/pyramid.go b/swarm/storage/pyramid.go index 74e00a497..e3be2a987 100644 --- a/swarm/storage/pyramid.go +++ b/swarm/storage/pyramid.go @@ -18,53 +18,112 @@ package storage import ( "encoding/binary" - "fmt" + "errors" "io" - "math" - "strings" "sync" + "time" +) + +/* + The main idea of a pyramid chunker is to process the input data without knowing the entire size apriori. + For this to be achieved, the chunker tree is built from the ground up until the data is exhausted. + This opens up new aveneus such as easy append and other sort of modifications to the tree therby avoiding + duplication of data chunks. + + + Below is an example of a two level chunks tree. The leaf chunks are called data chunks and all the above + chunks are called tree chunks. The tree chunk above data chunks is level 0 and so on until it reaches + the root tree chunk. + + + + T10 <- Tree chunk lvl1 + | + __________________________|_____________________________ + / | | \ + / | \ \ + __T00__ ___T01__ ___T02__ ___T03__ <- Tree chunks lvl 0 + / / \ / / \ / / \ / / \ + / / \ / / \ / / \ / / \ + D1 D2 ... D128 D1 D2 ... D128 D1 D2 ... D128 D1 D2 ... D128 <- Data Chunks + + + The split function continuously read the data and creates data chunks and send them to storage. + When certain no of data chunks are created (defaultBranches), a signal is sent to create a tree + entry. When the level 0 tree entries reaches certain threshold (defaultBranches), another signal + is sent to a tree entry one level up.. and so on... until only the data is exhausted AND only one + tree entry is present in certain level. The key of tree entry is given out as the rootKey of the file. + +*/ + +var ( + errLoadingTreeRootChunk = errors.New("LoadTree Error: Could not load root chunk") + errLoadingTreeChunk = errors.New("LoadTree Error: Could not load chunk") +) - "github.com/ethereum/go-ethereum/common" +const ( + ChunkProcessors = 8 + DefaultBranches int64 = 128 + splitTimeout = time.Minute * 5 ) const ( - processors = 8 + DataChunk = 0 + TreeChunk = 1 ) -type Tree struct { - Chunks int64 - Levels []map[int64]*Node - Lock sync.RWMutex +type ChunkerParams struct { + Branches int64 + Hash string +} + +func NewChunkerParams() *ChunkerParams { + return &ChunkerParams{ + Branches: DefaultBranches, + Hash: SHA3Hash, + } } -type Node struct { - Pending int64 - Size uint64 - Children []common.Hash - Last bool +// Entry to create a tree node +type TreeEntry struct { + level int + branchCount int64 + subtreeSize uint64 + chunk []byte + key []byte + index int // used in append to indicate the index of existing tree entry + updatePending bool // indicates if the entry is loaded from existing tree } -func (self *Node) String() string { - var children []string - for _, node := range self.Children { - children = append(children, node.Hex()) +func NewTreeEntry(pyramid *PyramidChunker) *TreeEntry { + return &TreeEntry{ + level: 0, + branchCount: 0, + subtreeSize: 0, + chunk: make([]byte, pyramid.chunkSize+8), + key: make([]byte, pyramid.hashSize), + index: 0, + updatePending: false, } - return fmt.Sprintf("pending: %v, size: %v, last :%v, children: %v", self.Pending, self.Size, self.Last, strings.Join(children, ", ")) } -type Task struct { - Index int64 // Index of the chunk being processed - Size uint64 - Data []byte // Binary blob of the chunk - Last bool +// Used by the hash processor to create a data/tree chunk and send to storage +type chunkJob struct { + key Key + chunk []byte + size int64 + parentWg *sync.WaitGroup + chunkType int // used to identify the tree related chunks for debugging + chunkLvl int // leaf-1 is level 0 and goes upwards until it reaches root } type PyramidChunker struct { - hashFunc Hasher + hashFunc SwarmHasher chunkSize int64 hashSize int64 branches int64 - workerCount int + workerCount int64 + workerLock sync.RWMutex } func NewPyramidChunker(params *ChunkerParams) (self *PyramidChunker) { @@ -73,128 +132,506 @@ func NewPyramidChunker(params *ChunkerParams) (self *PyramidChunker) { self.branches = params.Branches self.hashSize = int64(self.hashFunc().Size()) self.chunkSize = self.hashSize * self.branches - self.workerCount = 1 + self.workerCount = 0 return } -func (self *PyramidChunker) Split(data io.Reader, size int64, chunkC chan *Chunk, swg, wwg *sync.WaitGroup) (Key, error) { +func (self *PyramidChunker) Join(key Key, chunkC chan *Chunk) LazySectionReader { + return &LazyChunkReader{ + key: key, + chunkC: chunkC, + chunkSize: self.chunkSize, + branches: self.branches, + hashSize: self.hashSize, + } +} - chunks := (size + self.chunkSize - 1) / self.chunkSize - depth := int(math.Ceil(math.Log(float64(chunks))/math.Log(float64(self.branches)))) + 1 +func (self *PyramidChunker) incrementWorkerCount() { + self.workerLock.Lock() + defer self.workerLock.Unlock() + self.workerCount += 1 +} - results := Tree{ - Chunks: chunks, - Levels: make([]map[int64]*Node, depth), +func (self *PyramidChunker) getWorkerCount() int64 { + self.workerLock.Lock() + defer self.workerLock.Unlock() + return self.workerCount +} + +func (self *PyramidChunker) decrementWorkerCount() { + self.workerLock.Lock() + defer self.workerLock.Unlock() + self.workerCount -= 1 +} + +func (self *PyramidChunker) Split(data io.Reader, size int64, chunkC chan *Chunk, storageWG, processorWG *sync.WaitGroup) (Key, error) { + jobC := make(chan *chunkJob, 2*ChunkProcessors) + wg := &sync.WaitGroup{} + errC := make(chan error) + quitC := make(chan bool) + rootKey := make([]byte, self.hashSize) + chunkLevel := make([][]*TreeEntry, self.branches) + + wg.Add(1) + go self.prepareChunks(false, chunkLevel, data, rootKey, quitC, wg, jobC, processorWG, chunkC, errC, storageWG) + + // closes internal error channel if all subprocesses in the workgroup finished + go func() { + + // waiting for all chunks to finish + wg.Wait() + + // if storage waitgroup is non-nil, we wait for storage to finish too + if storageWG != nil { + storageWG.Wait() + } + //We close errC here because this is passed down to 8 parallel routines underneath. + // if a error happens in one of them.. that particular routine raises error... + // once they all complete successfully, the control comes back and we can safely close this here. + close(errC) + }() + + defer close(quitC) + + select { + case err := <-errC: + if err != nil { + return nil, err + } + case <-time.NewTimer(splitTimeout).C: } - for i := 0; i < depth; i++ { - results.Levels[i] = make(map[int64]*Node) + return rootKey, nil + +} + +func (self *PyramidChunker) Append(key Key, data io.Reader, chunkC chan *Chunk, storageWG, processorWG *sync.WaitGroup) (Key, error) { + quitC := make(chan bool) + rootKey := make([]byte, self.hashSize) + chunkLevel := make([][]*TreeEntry, self.branches) + + // Load the right most unfinished tree chunks in every level + self.loadTree(chunkLevel, key, chunkC, quitC) + + jobC := make(chan *chunkJob, 2*ChunkProcessors) + wg := &sync.WaitGroup{} + errC := make(chan error) + + wg.Add(1) + go self.prepareChunks(true, chunkLevel, data, rootKey, quitC, wg, jobC, processorWG, chunkC, errC, storageWG) + + // closes internal error channel if all subprocesses in the workgroup finished + go func() { + + // waiting for all chunks to finish + wg.Wait() + + // if storage waitgroup is non-nil, we wait for storage to finish too + if storageWG != nil { + storageWG.Wait() + } + close(errC) + }() + + defer close(quitC) + + select { + case err := <-errC: + if err != nil { + return nil, err + } + case <-time.NewTimer(splitTimeout).C: } - // Create a pool of workers to crunch through the file - tasks := make(chan *Task, 2*processors) - pend := new(sync.WaitGroup) - abortC := make(chan bool) - for i := 0; i < processors; i++ { - pend.Add(1) - go self.processor(pend, swg, tasks, chunkC, &results) + return rootKey, nil + +} + +func (self *PyramidChunker) processor(id int64, jobC chan *chunkJob, chunkC chan *Chunk, errC chan error, quitC chan bool, swg, wwg *sync.WaitGroup) { + defer self.decrementWorkerCount() + + hasher := self.hashFunc() + if wwg != nil { + defer wwg.Done() } - // Feed the chunks into the task pool - read := 0 - for index := 0; ; index++ { - buffer := make([]byte, self.chunkSize+8) - n, err := data.Read(buffer[8:]) - read += n - last := int64(read) == size || err == io.ErrUnexpectedEOF || err == io.EOF - if err != nil && !last { - close(abortC) - break - } - binary.LittleEndian.PutUint64(buffer[:8], uint64(n)) - pend.Add(1) + for { select { - case tasks <- &Task{Index: int64(index), Size: uint64(n), Data: buffer[:n+8], Last: last}: - case <-abortC: - return nil, err + + case job, ok := <-jobC: + if !ok { + return + } + self.processChunk(id, hasher, job, chunkC, swg) + case <-quitC: + return } - if last { - break + } +} + +func (self *PyramidChunker) processChunk(id int64, hasher SwarmHash, job *chunkJob, chunkC chan *Chunk, swg *sync.WaitGroup) { + hasher.ResetWithLength(job.chunk[:8]) // 8 bytes of length + hasher.Write(job.chunk[8:]) // minus 8 []byte length + h := hasher.Sum(nil) + + newChunk := &Chunk{ + Key: h, + SData: job.chunk, + Size: job.size, + wg: swg, + } + + // report hash of this chunk one level up (keys corresponds to the proper subslice of the parent chunk) + copy(job.key, h) + + // send off new chunk to storage + if chunkC != nil { + if swg != nil { + swg.Add(1) } } - // Wait for the workers and return - close(tasks) - pend.Wait() + job.parentWg.Done() - key := results.Levels[0][0].Children[0][:] - return key, nil + if chunkC != nil { + chunkC <- newChunk + } } -func (self *PyramidChunker) processor(pend, swg *sync.WaitGroup, tasks chan *Task, chunkC chan *Chunk, results *Tree) { - defer pend.Done() +func (self *PyramidChunker) loadTree(chunkLevel [][]*TreeEntry, key Key, chunkC chan *Chunk, quitC chan bool) error { + // Get the root chunk to get the total size + chunk := retrieve(key, chunkC, quitC) + if chunk == nil { + return errLoadingTreeRootChunk + } - // Start processing leaf chunks ad infinitum - hasher := self.hashFunc() - for task := range tasks { - depth, pow := len(results.Levels)-1, self.branches - size := task.Size - data := task.Data - var node *Node - for depth >= 0 { - // New chunk received, reset the hasher and start processing - hasher.Reset() - if node == nil { // Leaf node, hash the data chunk - hasher.Write(task.Data) - } else { // Internal node, hash the children - size = node.Size - data = make([]byte, hasher.Size()*len(node.Children)+8) - binary.LittleEndian.PutUint64(data[:8], size) - - hasher.Write(data[:8]) - for i, hash := range node.Children { - copy(data[i*hasher.Size()+8:], hash[:]) - hasher.Write(hash[:]) + //if data size is less than a chunk... add a parent with update as pending + if chunk.Size <= self.chunkSize { + newEntry := &TreeEntry{ + level: 0, + branchCount: 1, + subtreeSize: uint64(chunk.Size), + chunk: make([]byte, self.chunkSize+8), + key: make([]byte, self.hashSize), + index: 0, + updatePending: true, + } + copy(newEntry.chunk[8:], chunk.Key) + chunkLevel[0] = append(chunkLevel[0], newEntry) + return nil + } + + var treeSize int64 + var depth int + treeSize = self.chunkSize + for ; treeSize < chunk.Size; treeSize *= self.branches { + depth++ + } + + // Add the root chunk entry + branchCount := int64(len(chunk.SData)-8) / self.hashSize + newEntry := &TreeEntry{ + level: int(depth - 1), + branchCount: branchCount, + subtreeSize: uint64(chunk.Size), + chunk: chunk.SData, + key: key, + index: 0, + updatePending: true, + } + chunkLevel[depth-1] = append(chunkLevel[depth-1], newEntry) + + // Add the rest of the tree + for lvl := (depth - 1); lvl >= 1; lvl-- { + + //TODO(jmozah): instead of loading finished branches and then trim in the end, + //avoid loading them in the first place + for _, ent := range chunkLevel[lvl] { + branchCount = int64(len(ent.chunk)-8) / self.hashSize + for i := int64(0); i < branchCount; i++ { + key := ent.chunk[8+(i*self.hashSize) : 8+((i+1)*self.hashSize)] + newChunk := retrieve(key, chunkC, quitC) + if newChunk == nil { + return errLoadingTreeChunk } - } - hash := hasher.Sum(nil) - last := task.Last || (node != nil) && node.Last - // Insert the subresult into the memoization tree - results.Lock.Lock() - if node = results.Levels[depth][task.Index/pow]; node == nil { - // Figure out the pending tasks - pending := self.branches - if task.Index/pow == results.Chunks/pow { - pending = (results.Chunks + pow/self.branches - 1) / (pow / self.branches) % self.branches + bewBranchCount := int64(len(newChunk.SData)-8) / self.hashSize + newEntry := &TreeEntry{ + level: int(lvl - 1), + branchCount: bewBranchCount, + subtreeSize: uint64(newChunk.Size), + chunk: newChunk.SData, + key: key, + index: 0, + updatePending: true, } - node = &Node{pending, 0, make([]common.Hash, pending), last} - results.Levels[depth][task.Index/pow] = node + chunkLevel[lvl-1] = append(chunkLevel[lvl-1], newEntry) + } - node.Pending-- - i := task.Index / (pow / self.branches) % self.branches - if last { - node.Last = true + + // We need to get only the right most unfinished branch.. so trim all finished branches + if int64(len(chunkLevel[lvl-1])) >= self.branches { + chunkLevel[lvl-1] = nil } - copy(node.Children[i][:], hash) - node.Size += size - left := node.Pending - if chunkC != nil { - if swg != nil { - swg.Add(1) - } + } + } + + return nil +} + +func (self *PyramidChunker) prepareChunks(isAppend bool, chunkLevel [][]*TreeEntry, data io.Reader, rootKey []byte, quitC chan bool, wg *sync.WaitGroup, jobC chan *chunkJob, processorWG *sync.WaitGroup, chunkC chan *Chunk, errC chan error, storageWG *sync.WaitGroup) { + defer wg.Done() + + chunkWG := &sync.WaitGroup{} + totalDataSize := 0 - chunkC <- &Chunk{Key: hash, SData: data, wg: swg} - // TODO: consider selecting on self.quitC to avoid blocking forever on shutdown + // processorWG keeps track of workers spawned for hashing chunks + if processorWG != nil { + processorWG.Add(1) + } + + self.incrementWorkerCount() + go self.processor(self.workerCount, jobC, chunkC, errC, quitC, storageWG, processorWG) + + parent := NewTreeEntry(self) + var unFinishedChunk *Chunk + + if isAppend == true && len(chunkLevel[0]) != 0 { + + lastIndex := len(chunkLevel[0]) - 1 + ent := chunkLevel[0][lastIndex] + + if ent.branchCount < self.branches { + parent = &TreeEntry{ + level: 0, + branchCount: ent.branchCount, + subtreeSize: ent.subtreeSize, + chunk: ent.chunk, + key: ent.key, + index: lastIndex, + updatePending: true, } - if depth+1 < len(results.Levels) { - delete(results.Levels[depth+1], task.Index/(pow/self.branches)) + + lastBranch := parent.branchCount - 1 + lastKey := parent.chunk[8+lastBranch*self.hashSize : 8+(lastBranch+1)*self.hashSize] + + unFinishedChunk = retrieve(lastKey, chunkC, quitC) + if unFinishedChunk.Size < self.chunkSize { + + parent.subtreeSize = parent.subtreeSize - uint64(unFinishedChunk.Size) + parent.branchCount = parent.branchCount - 1 + } else { + unFinishedChunk = nil } + } + } - results.Lock.Unlock() - // If there's more work to be done, leave for others - if left > 0 { + for index := 0; ; index++ { + + var n int + var err error + chunkData := make([]byte, self.chunkSize+8) + if unFinishedChunk != nil { + copy(chunkData, unFinishedChunk.SData) + n, err = data.Read(chunkData[8+unFinishedChunk.Size:]) + n += int(unFinishedChunk.Size) + unFinishedChunk = nil + } else { + n, err = data.Read(chunkData[8:]) + } + + totalDataSize += n + if err != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + if parent.branchCount == 1 { + // Data is exactly one chunk.. pick the last chunk key as root + chunkWG.Wait() + lastChunksKey := parent.chunk[8 : 8+self.hashSize] + copy(rootKey, lastChunksKey) + break + } + } else { + close(quitC) break } - // We're the last ones in this batch, merge the children together - depth-- - pow *= self.branches } - pend.Done() + + // Data ended in chunk boundry.. just signal to start bulding tree + if n == 0 { + self.buildTree(isAppend, chunkLevel, parent, chunkWG, jobC, quitC, true, rootKey) + break + } else { + + pkey := self.enqueueDataChunk(chunkData, uint64(n), parent, chunkWG, jobC, quitC) + + // update tree related parent data structures + parent.subtreeSize += uint64(n) + parent.branchCount++ + + // Data got exhausted... signal to send any parent tree related chunks + if int64(n) < self.chunkSize { + + // only one data chunk .. so dont add any parent chunk + if parent.branchCount <= 1 { + chunkWG.Wait() + copy(rootKey, pkey) + break + } + + self.buildTree(isAppend, chunkLevel, parent, chunkWG, jobC, quitC, true, rootKey) + break + } + + if parent.branchCount == self.branches { + self.buildTree(isAppend, chunkLevel, parent, chunkWG, jobC, quitC, false, rootKey) + parent = NewTreeEntry(self) + } + + } + + workers := self.getWorkerCount() + if int64(len(jobC)) > workers && workers < ChunkProcessors { + if processorWG != nil { + processorWG.Add(1) + } + self.incrementWorkerCount() + go self.processor(self.workerCount, jobC, chunkC, errC, quitC, storageWG, processorWG) + } + + } + +} + +func (self *PyramidChunker) buildTree(isAppend bool, chunkLevel [][]*TreeEntry, ent *TreeEntry, chunkWG *sync.WaitGroup, jobC chan *chunkJob, quitC chan bool, last bool, rootKey []byte) { + chunkWG.Wait() + self.enqueueTreeChunk(chunkLevel, ent, chunkWG, jobC, quitC, last) + + compress := false + endLvl := self.branches + for lvl := int64(0); lvl < self.branches; lvl++ { + lvlCount := int64(len(chunkLevel[lvl])) + if lvlCount >= self.branches { + endLvl = lvl + 1 + compress = true + break + } + } + + if compress == false && last == false { + return + } + + // Wait for all the keys to be processed before compressing the tree + chunkWG.Wait() + + for lvl := int64(ent.level); lvl < endLvl; lvl++ { + + lvlCount := int64(len(chunkLevel[lvl])) + if lvlCount == 1 && last == true { + copy(rootKey, chunkLevel[lvl][0].key) + return + } + + for startCount := int64(0); startCount < lvlCount; startCount += self.branches { + + endCount := startCount + self.branches + if endCount > lvlCount { + endCount = lvlCount + } + + var nextLvlCount int64 + var tempEntry *TreeEntry + if len(chunkLevel[lvl+1]) > 0 { + nextLvlCount = int64(len(chunkLevel[lvl+1]) - 1) + tempEntry = chunkLevel[lvl+1][nextLvlCount] + } + if isAppend == true && tempEntry != nil && tempEntry.updatePending == true { + updateEntry := &TreeEntry{ + level: int(lvl + 1), + branchCount: 0, + subtreeSize: 0, + chunk: make([]byte, self.chunkSize+8), + key: make([]byte, self.hashSize), + index: int(nextLvlCount), + updatePending: true, + } + for index := int64(0); index < lvlCount; index++ { + updateEntry.branchCount++ + updateEntry.subtreeSize += chunkLevel[lvl][index].subtreeSize + copy(updateEntry.chunk[8+(index*self.hashSize):8+((index+1)*self.hashSize)], chunkLevel[lvl][index].key[:self.hashSize]) + } + + self.enqueueTreeChunk(chunkLevel, updateEntry, chunkWG, jobC, quitC, last) + + } else { + + noOfBranches := endCount - startCount + newEntry := &TreeEntry{ + level: int(lvl + 1), + branchCount: noOfBranches, + subtreeSize: 0, + chunk: make([]byte, (noOfBranches*self.hashSize)+8), + key: make([]byte, self.hashSize), + index: int(nextLvlCount), + updatePending: false, + } + + index := int64(0) + for i := startCount; i < endCount; i++ { + entry := chunkLevel[lvl][i] + newEntry.subtreeSize += entry.subtreeSize + copy(newEntry.chunk[8+(index*self.hashSize):8+((index+1)*self.hashSize)], entry.key[:self.hashSize]) + index++ + } + + self.enqueueTreeChunk(chunkLevel, newEntry, chunkWG, jobC, quitC, last) + + } + + } + + if isAppend == false { + chunkWG.Wait() + if compress == true { + chunkLevel[lvl] = nil + } + } } + } + +func (self *PyramidChunker) enqueueTreeChunk(chunkLevel [][]*TreeEntry, ent *TreeEntry, chunkWG *sync.WaitGroup, jobC chan *chunkJob, quitC chan bool, last bool) { + if ent != nil { + + // wait for data chunks to get over before processing the tree chunk + if last == true { + chunkWG.Wait() + } + + binary.LittleEndian.PutUint64(ent.chunk[:8], ent.subtreeSize) + ent.key = make([]byte, self.hashSize) + chunkWG.Add(1) + select { + case jobC <- &chunkJob{ent.key, ent.chunk[:ent.branchCount*self.hashSize+8], int64(ent.subtreeSize), chunkWG, TreeChunk, 0}: + case <-quitC: + } + + // Update or append based on weather it is a new entry or being reused + if ent.updatePending == true { + chunkWG.Wait() + chunkLevel[ent.level][ent.index] = ent + } else { + chunkLevel[ent.level] = append(chunkLevel[ent.level], ent) + } + + } +} + +func (self *PyramidChunker) enqueueDataChunk(chunkData []byte, size uint64, parent *TreeEntry, chunkWG *sync.WaitGroup, jobC chan *chunkJob, quitC chan bool) Key { + binary.LittleEndian.PutUint64(chunkData[:8], size) + pkey := parent.chunk[8+parent.branchCount*self.hashSize : 8+(parent.branchCount+1)*self.hashSize] + + chunkWG.Add(1) + select { + case jobC <- &chunkJob{pkey, chunkData[:size+8], int64(size), chunkWG, DataChunk, -1}: + case <-quitC: + } + + return pkey + +}
\ No newline at end of file diff --git a/swarm/storage/swarmhasher.go b/swarm/storage/swarmhasher.go new file mode 100644 index 000000000..38b86373c --- /dev/null +++ b/swarm/storage/swarmhasher.go @@ -0,0 +1,40 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package storage + +import ( + "hash" +) + +const ( + BMTHash = "BMT" + SHA3Hash = "SHA3" // http://golang.org/pkg/hash/#Hash +) + +type SwarmHash interface { + hash.Hash + ResetWithLength([]byte) +} + +type HashWithLength struct { + hash.Hash +} + +func (self *HashWithLength) ResetWithLength(length []byte) { + self.Reset() + self.Write(length) +} diff --git a/swarm/storage/types.go b/swarm/storage/types.go index a9de23c93..d35f1f929 100644 --- a/swarm/storage/types.go +++ b/swarm/storage/types.go @@ -24,12 +24,13 @@ import ( "io" "sync" - // "github.com/ethereum/go-ethereum/bmt" + "github.com/ethereum/go-ethereum/bmt" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto/sha3" ) type Hasher func() hash.Hash +type SwarmHasher func() SwarmHash // Peer is the recorded as Source on the chunk // should probably not be here? but network should wrap chunk object @@ -78,12 +79,18 @@ func IsZeroKey(key Key) bool { var ZeroKey = Key(common.Hash{}.Bytes()) -func MakeHashFunc(hash string) Hasher { +func MakeHashFunc(hash string) SwarmHasher { switch hash { case "SHA256": - return crypto.SHA256.New + return func() SwarmHash { return &HashWithLength{crypto.SHA256.New()} } case "SHA3": - return sha3.NewKeccak256 + return func() SwarmHash { return &HashWithLength{sha3.NewKeccak256()} } + case "BMT": + return func() SwarmHash { + hasher := sha3.NewKeccak256 + pool := bmt.NewTreePool(hasher, bmt.DefaultSegmentCount, bmt.DefaultPoolSize) + return bmt.New(pool) + } } return nil } @@ -192,6 +199,13 @@ type Splitter interface { A closed error signals process completion at which point the key can be considered final if there were no errors. */ Split(io.Reader, int64, chan *Chunk, *sync.WaitGroup, *sync.WaitGroup) (Key, error) + + /* This is the first step in making files mutable (not chunks).. + Append allows adding more data chunks to the end of the already existsing file. + The key for the root chunk is supplied to load the respective tree. + Rest of the parameters behave like Split. + */ + Append(Key, io.Reader, chan *Chunk, *sync.WaitGroup, *sync.WaitGroup) (Key, error) } type Joiner interface { diff --git a/vendor/github.com/julienschmidt/httprouter/LICENSE b/vendor/github.com/julienschmidt/httprouter/LICENSE new file mode 100644 index 000000000..b829abc8a --- /dev/null +++ b/vendor/github.com/julienschmidt/httprouter/LICENSE @@ -0,0 +1,24 @@ +Copyright (c) 2013 Julien Schmidt. All rights reserved. + + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * The names of the contributors may not be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL JULIEN SCHMIDT BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file diff --git a/vendor/github.com/julienschmidt/httprouter/README.md b/vendor/github.com/julienschmidt/httprouter/README.md new file mode 100644 index 000000000..92885470b --- /dev/null +++ b/vendor/github.com/julienschmidt/httprouter/README.md @@ -0,0 +1,266 @@ +# HttpRouter [![Build Status](https://travis-ci.org/julienschmidt/httprouter.svg?branch=master)](https://travis-ci.org/julienschmidt/httprouter) [![Coverage Status](https://coveralls.io/repos/github/julienschmidt/httprouter/badge.svg?branch=master)](https://coveralls.io/github/julienschmidt/httprouter?branch=master) [![GoDoc](https://godoc.org/github.com/julienschmidt/httprouter?status.svg)](http://godoc.org/github.com/julienschmidt/httprouter) + +HttpRouter is a lightweight high performance HTTP request router (also called *multiplexer* or just *mux* for short) for [Go](https://golang.org/). + +In contrast to the [default mux](https://golang.org/pkg/net/http/#ServeMux) of Go's `net/http` package, this router supports variables in the routing pattern and matches against the request method. It also scales better. + +The router is optimized for high performance and a small memory footprint. It scales well even with very long paths and a large number of routes. A compressing dynamic trie (radix tree) structure is used for efficient matching. + +## Features + +**Only explicit matches:** With other routers, like [`http.ServeMux`](https://golang.org/pkg/net/http/#ServeMux), a requested URL path could match multiple patterns. Therefore they have some awkward pattern priority rules, like *longest match* or *first registered, first matched*. By design of this router, a request can only match exactly one or no route. As a result, there are also no unintended matches, which makes it great for SEO and improves the user experience. + +**Stop caring about trailing slashes:** Choose the URL style you like, the router automatically redirects the client if a trailing slash is missing or if there is one extra. Of course it only does so, if the new path has a handler. If you don't like it, you can [turn off this behavior](https://godoc.org/github.com/julienschmidt/httprouter#Router.RedirectTrailingSlash). + +**Path auto-correction:** Besides detecting the missing or additional trailing slash at no extra cost, the router can also fix wrong cases and remove superfluous path elements (like `../` or `//`). Is [CAPTAIN CAPS LOCK](http://www.urbandictionary.com/define.php?term=Captain+Caps+Lock) one of your users? HttpRouter can help him by making a case-insensitive look-up and redirecting him to the correct URL. + +**Parameters in your routing pattern:** Stop parsing the requested URL path, just give the path segment a name and the router delivers the dynamic value to you. Because of the design of the router, path parameters are very cheap. + +**Zero Garbage:** The matching and dispatching process generates zero bytes of garbage. In fact, the only heap allocations that are made, is by building the slice of the key-value pairs for path parameters. If the request path contains no parameters, not a single heap allocation is necessary. + +**Best Performance:** [Benchmarks speak for themselves](https://github.com/julienschmidt/go-http-routing-benchmark). See below for technical details of the implementation. + +**No more server crashes:** You can set a [Panic handler](https://godoc.org/github.com/julienschmidt/httprouter#Router.PanicHandler) to deal with panics occurring during handling a HTTP request. The router then recovers and lets the `PanicHandler` log what happened and deliver a nice error page. + +**Perfect for APIs:** The router design encourages to build sensible, hierarchical RESTful APIs. Moreover it has builtin native support for [OPTIONS requests](http://zacstewart.com/2012/04/14/http-options-method.html) and `405 Method Not Allowed` replies. + +Of course you can also set **custom [`NotFound`](https://godoc.org/github.com/julienschmidt/httprouter#Router.NotFound) and [`MethodNotAllowed`](https://godoc.org/github.com/julienschmidt/httprouter#Router.MethodNotAllowed) handlers** and [**serve static files**](https://godoc.org/github.com/julienschmidt/httprouter#Router.ServeFiles). + +## Usage + +This is just a quick introduction, view the [GoDoc](http://godoc.org/github.com/julienschmidt/httprouter) for details. + +Let's start with a trivial example: + +```go +package main + +import ( + "fmt" + "github.com/julienschmidt/httprouter" + "net/http" + "log" +) + +func Index(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + fmt.Fprint(w, "Welcome!\n") +} + +func Hello(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + fmt.Fprintf(w, "hello, %s!\n", ps.ByName("name")) +} + +func main() { + router := httprouter.New() + router.GET("/", Index) + router.GET("/hello/:name", Hello) + + log.Fatal(http.ListenAndServe(":8080", router)) +} +``` + +### Named parameters + +As you can see, `:name` is a *named parameter*. The values are accessible via `httprouter.Params`, which is just a slice of `httprouter.Param`s. You can get the value of a parameter either by its index in the slice, or by using the `ByName(name)` method: `:name` can be retrived by `ByName("name")`. + +Named parameters only match a single path segment: + +``` +Pattern: /user/:user + + /user/gordon match + /user/you match + /user/gordon/profile no match + /user/ no match +``` + +**Note:** Since this router has only explicit matches, you can not register static routes and parameters for the same path segment. For example you can not register the patterns `/user/new` and `/user/:user` for the same request method at the same time. The routing of different request methods is independent from each other. + +### Catch-All parameters + +The second type are *catch-all* parameters and have the form `*name`. Like the name suggests, they match everything. Therefore they must always be at the **end** of the pattern: + +``` +Pattern: /src/*filepath + + /src/ match + /src/somefile.go match + /src/subdir/somefile.go match +``` + +## How does it work? + +The router relies on a tree structure which makes heavy use of *common prefixes*, it is basically a *compact* [*prefix tree*](https://en.wikipedia.org/wiki/Trie) (or just [*Radix tree*](https://en.wikipedia.org/wiki/Radix_tree)). Nodes with a common prefix also share a common parent. Here is a short example what the routing tree for the `GET` request method could look like: + +``` +Priority Path Handle +9 \ *<1> +3 ├s nil +2 |├earch\ *<2> +1 |└upport\ *<3> +2 ├blog\ *<4> +1 | └:post nil +1 | └\ *<5> +2 ├about-us\ *<6> +1 | └team\ *<7> +1 └contact\ *<8> +``` + +Every `*<num>` represents the memory address of a handler function (a pointer). If you follow a path trough the tree from the root to the leaf, you get the complete route path, e.g `\blog\:post\`, where `:post` is just a placeholder ([*parameter*](#named-parameters)) for an actual post name. Unlike hash-maps, a tree structure also allows us to use dynamic parts like the `:post` parameter, since we actually match against the routing patterns instead of just comparing hashes. [As benchmarks show](https://github.com/julienschmidt/go-http-routing-benchmark), this works very well and efficient. + +Since URL paths have a hierarchical structure and make use only of a limited set of characters (byte values), it is very likely that there are a lot of common prefixes. This allows us to easily reduce the routing into ever smaller problems. Moreover the router manages a separate tree for every request method. For one thing it is more space efficient than holding a method->handle map in every single node, for another thing is also allows us to greatly reduce the routing problem before even starting the look-up in the prefix-tree. + +For even better scalability, the child nodes on each tree level are ordered by priority, where the priority is just the number of handles registered in sub nodes (children, grandchildren, and so on..). This helps in two ways: + +1. Nodes which are part of the most routing paths are evaluated first. This helps to make as much routes as possible to be reachable as fast as possible. +2. It is some sort of cost compensation. The longest reachable path (highest cost) can always be evaluated first. The following scheme visualizes the tree structure. Nodes are evaluated from top to bottom and from left to right. + +``` +├------------ +├--------- +├----- +├---- +├-- +├-- +└- +``` + +## Why doesn't this work with `http.Handler`? + +**It does!** The router itself implements the `http.Handler` interface. Moreover the router provides convenient [adapters for `http.Handler`](https://godoc.org/github.com/julienschmidt/httprouter#Router.Handler)s and [`http.HandlerFunc`](https://godoc.org/github.com/julienschmidt/httprouter#Router.HandlerFunc)s which allows them to be used as a [`httprouter.Handle`](https://godoc.org/github.com/julienschmidt/httprouter#Router.Handle) when registering a route. The only disadvantage is, that no parameter values can be retrieved when a `http.Handler` or `http.HandlerFunc` is used, since there is no efficient way to pass the values with the existing function parameters. Therefore [`httprouter.Handle`](https://godoc.org/github.com/julienschmidt/httprouter#Router.Handle) has a third function parameter. + +Just try it out for yourself, the usage of HttpRouter is very straightforward. The package is compact and minimalistic, but also probably one of the easiest routers to set up. + +## Where can I find Middleware *X*? + +This package just provides a very efficient request router with a few extra features. The router is just a [`http.Handler`](https://golang.org/pkg/net/http/#Handler), you can chain any http.Handler compatible middleware before the router, for example the [Gorilla handlers](http://www.gorillatoolkit.org/pkg/handlers). Or you could [just write your own](https://justinas.org/writing-http-middleware-in-go/), it's very easy! + +Alternatively, you could try [a web framework based on HttpRouter](#web-frameworks-based-on-httprouter). + +### Multi-domain / Sub-domains + +Here is a quick example: Does your server serve multiple domains / hosts? +You want to use sub-domains? +Define a router per host! + +```go +// We need an object that implements the http.Handler interface. +// Therefore we need a type for which we implement the ServeHTTP method. +// We just use a map here, in which we map host names (with port) to http.Handlers +type HostSwitch map[string]http.Handler + +// Implement the ServerHTTP method on our new type +func (hs HostSwitch) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Check if a http.Handler is registered for the given host. + // If yes, use it to handle the request. + if handler := hs[r.Host]; handler != nil { + handler.ServeHTTP(w, r) + } else { + // Handle host names for wich no handler is registered + http.Error(w, "Forbidden", 403) // Or Redirect? + } +} + +func main() { + // Initialize a router as usual + router := httprouter.New() + router.GET("/", Index) + router.GET("/hello/:name", Hello) + + // Make a new HostSwitch and insert the router (our http handler) + // for example.com and port 12345 + hs := make(HostSwitch) + hs["example.com:12345"] = router + + // Use the HostSwitch to listen and serve on port 12345 + log.Fatal(http.ListenAndServe(":12345", hs)) +} +``` + +### Basic Authentication + +Another quick example: Basic Authentication (RFC 2617) for handles: + +```go +package main + +import ( + "fmt" + "log" + "net/http" + + "github.com/julienschmidt/httprouter" +) + +func BasicAuth(h httprouter.Handle, requiredUser, requiredPassword string) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + // Get the Basic Authentication credentials + user, password, hasAuth := r.BasicAuth() + + if hasAuth && user == requiredUser && password == requiredPassword { + // Delegate request to the given handle + h(w, r, ps) + } else { + // Request Basic Authentication otherwise + w.Header().Set("WWW-Authenticate", "Basic realm=Restricted") + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + } + } +} + +func Index(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + fmt.Fprint(w, "Not protected!\n") +} + +func Protected(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + fmt.Fprint(w, "Protected!\n") +} + +func main() { + user := "gordon" + pass := "secret!" + + router := httprouter.New() + router.GET("/", Index) + router.GET("/protected/", BasicAuth(Protected, user, pass)) + + log.Fatal(http.ListenAndServe(":8080", router)) +} +``` + +## Chaining with the NotFound handler + +**NOTE: It might be required to set [`Router.HandleMethodNotAllowed`](https://godoc.org/github.com/julienschmidt/httprouter#Router.HandleMethodNotAllowed) to `false` to avoid problems.** + +You can use another [`http.Handler`](https://golang.org/pkg/net/http/#Handler), for example another router, to handle requests which could not be matched by this router by using the [`Router.NotFound`](https://godoc.org/github.com/julienschmidt/httprouter#Router.NotFound) handler. This allows chaining. + +### Static files + +The `NotFound` handler can for example be used to serve static files from the root path `/` (like an `index.html` file along with other assets): + +```go +// Serve static files from the ./public directory +router.NotFound = http.FileServer(http.Dir("public")) +``` + +But this approach sidesteps the strict core rules of this router to avoid routing problems. A cleaner approach is to use a distinct sub-path for serving files, like `/static/*filepath` or `/files/*filepath`. + +## Web Frameworks based on HttpRouter + +If the HttpRouter is a bit too minimalistic for you, you might try one of the following more high-level 3rd-party web frameworks building upon the HttpRouter package: + +* [Ace](https://github.com/plimble/ace): Blazing fast Go Web Framework +* [api2go](https://github.com/manyminds/api2go): A JSON API Implementation for Go +* [Gin](https://github.com/gin-gonic/gin): Features a martini-like API with much better performance +* [Goat](https://github.com/bahlo/goat): A minimalistic REST API server in Go +* [goMiddlewareChain](https://github.com/TobiEiss/goMiddlewareChain): An express.js-like-middleware-chain +* [Hikaru](https://github.com/najeira/hikaru): Supports standalone and Google AppEngine +* [Hitch](https://github.com/nbio/hitch): Hitch ties httprouter, [httpcontext](https://github.com/nbio/httpcontext), and middleware up in a bow +* [httpway](https://github.com/corneldamian/httpway): Simple middleware extension with context for httprouter and a server with gracefully shutdown support +* [kami](https://github.com/guregu/kami): A tiny web framework using x/net/context +* [Medeina](https://github.com/imdario/medeina): Inspired by Ruby's Roda and Cuba +* [Neko](https://github.com/rocwong/neko): A lightweight web application framework for Golang +* [River](https://github.com/abiosoft/river): River is a simple and lightweight REST server +* [Roxanna](https://github.com/iamthemuffinman/Roxanna): An amalgamation of httprouter, better logging, and hot reload +* [siesta](https://github.com/VividCortex/siesta): Composable HTTP handlers with contexts +* [xmux](https://github.com/rs/xmux): xmux is a httprouter fork on top of xhandler (net/context aware) diff --git a/vendor/github.com/julienschmidt/httprouter/path.go b/vendor/github.com/julienschmidt/httprouter/path.go new file mode 100644 index 000000000..486134db3 --- /dev/null +++ b/vendor/github.com/julienschmidt/httprouter/path.go @@ -0,0 +1,123 @@ +// Copyright 2013 Julien Schmidt. All rights reserved. +// Based on the path package, Copyright 2009 The Go Authors. +// Use of this source code is governed by a BSD-style license that can be found +// in the LICENSE file. + +package httprouter + +// CleanPath is the URL version of path.Clean, it returns a canonical URL path +// for p, eliminating . and .. elements. +// +// The following rules are applied iteratively until no further processing can +// be done: +// 1. Replace multiple slashes with a single slash. +// 2. Eliminate each . path name element (the current directory). +// 3. Eliminate each inner .. path name element (the parent directory) +// along with the non-.. element that precedes it. +// 4. Eliminate .. elements that begin a rooted path: +// that is, replace "/.." by "/" at the beginning of a path. +// +// If the result of this process is an empty string, "/" is returned +func CleanPath(p string) string { + // Turn empty string into "/" + if p == "" { + return "/" + } + + n := len(p) + var buf []byte + + // Invariants: + // reading from path; r is index of next byte to process. + // writing to buf; w is index of next byte to write. + + // path must start with '/' + r := 1 + w := 1 + + if p[0] != '/' { + r = 0 + buf = make([]byte, n+1) + buf[0] = '/' + } + + trailing := n > 2 && p[n-1] == '/' + + // A bit more clunky without a 'lazybuf' like the path package, but the loop + // gets completely inlined (bufApp). So in contrast to the path package this + // loop has no expensive function calls (except 1x make) + + for r < n { + switch { + case p[r] == '/': + // empty path element, trailing slash is added after the end + r++ + + case p[r] == '.' && r+1 == n: + trailing = true + r++ + + case p[r] == '.' && p[r+1] == '/': + // . element + r++ + + case p[r] == '.' && p[r+1] == '.' && (r+2 == n || p[r+2] == '/'): + // .. element: remove to last / + r += 2 + + if w > 1 { + // can backtrack + w-- + + if buf == nil { + for w > 1 && p[w] != '/' { + w-- + } + } else { + for w > 1 && buf[w] != '/' { + w-- + } + } + } + + default: + // real path element. + // add slash if needed + if w > 1 { + bufApp(&buf, p, w, '/') + w++ + } + + // copy element + for r < n && p[r] != '/' { + bufApp(&buf, p, w, p[r]) + w++ + r++ + } + } + } + + // re-append trailing slash + if trailing && w > 1 { + bufApp(&buf, p, w, '/') + w++ + } + + if buf == nil { + return p[:w] + } + return string(buf[:w]) +} + +// internal helper to lazily create a buffer if necessary +func bufApp(buf *[]byte, s string, w int, c byte) { + if *buf == nil { + if s[w] == c { + return + } + + *buf = make([]byte, len(s)) + copy(*buf, s[:w]) + } + (*buf)[w] = c +} diff --git a/vendor/github.com/julienschmidt/httprouter/router.go b/vendor/github.com/julienschmidt/httprouter/router.go new file mode 100644 index 000000000..bb1733005 --- /dev/null +++ b/vendor/github.com/julienschmidt/httprouter/router.go @@ -0,0 +1,411 @@ +// Copyright 2013 Julien Schmidt. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be found +// in the LICENSE file. + +// Package httprouter is a trie based high performance HTTP request router. +// +// A trivial example is: +// +// package main +// +// import ( +// "fmt" +// "github.com/julienschmidt/httprouter" +// "net/http" +// "log" +// ) +// +// func Index(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { +// fmt.Fprint(w, "Welcome!\n") +// } +// +// func Hello(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +// fmt.Fprintf(w, "hello, %s!\n", ps.ByName("name")) +// } +// +// func main() { +// router := httprouter.New() +// router.GET("/", Index) +// router.GET("/hello/:name", Hello) +// +// log.Fatal(http.ListenAndServe(":8080", router)) +// } +// +// The router matches incoming requests by the request method and the path. +// If a handle is registered for this path and method, the router delegates the +// request to that function. +// For the methods GET, POST, PUT, PATCH and DELETE shortcut functions exist to +// register handles, for all other methods router.Handle can be used. +// +// The registered path, against which the router matches incoming requests, can +// contain two types of parameters: +// Syntax Type +// :name named parameter +// *name catch-all parameter +// +// Named parameters are dynamic path segments. They match anything until the +// next '/' or the path end: +// Path: /blog/:category/:post +// +// Requests: +// /blog/go/request-routers match: category="go", post="request-routers" +// /blog/go/request-routers/ no match, but the router would redirect +// /blog/go/ no match +// /blog/go/request-routers/comments no match +// +// Catch-all parameters match anything until the path end, including the +// directory index (the '/' before the catch-all). Since they match anything +// until the end, catch-all parameters must always be the final path element. +// Path: /files/*filepath +// +// Requests: +// /files/ match: filepath="/" +// /files/LICENSE match: filepath="/LICENSE" +// /files/templates/article.html match: filepath="/templates/article.html" +// /files no match, but the router would redirect +// +// The value of parameters is saved as a slice of the Param struct, consisting +// each of a key and a value. The slice is passed to the Handle func as a third +// parameter. +// There are two ways to retrieve the value of a parameter: +// // by the name of the parameter +// user := ps.ByName("user") // defined by :user or *user +// +// // by the index of the parameter. This way you can also get the name (key) +// thirdKey := ps[2].Key // the name of the 3rd parameter +// thirdValue := ps[2].Value // the value of the 3rd parameter +package httprouter + +import ( + "net/http" +) + +// Handle is a function that can be registered to a route to handle HTTP +// requests. Like http.HandlerFunc, but has a third parameter for the values of +// wildcards (variables). +type Handle func(http.ResponseWriter, *http.Request, Params) + +// Param is a single URL parameter, consisting of a key and a value. +type Param struct { + Key string + Value string +} + +// Params is a Param-slice, as returned by the router. +// The slice is ordered, the first URL parameter is also the first slice value. +// It is therefore safe to read values by the index. +type Params []Param + +// ByName returns the value of the first Param which key matches the given name. +// If no matching Param is found, an empty string is returned. +func (ps Params) ByName(name string) string { + for i := range ps { + if ps[i].Key == name { + return ps[i].Value + } + } + return "" +} + +// Router is a http.Handler which can be used to dispatch requests to different +// handler functions via configurable routes +type Router struct { + trees map[string]*node + + // Enables automatic redirection if the current route can't be matched but a + // handler for the path with (without) the trailing slash exists. + // For example if /foo/ is requested but a route only exists for /foo, the + // client is redirected to /foo with http status code 301 for GET requests + // and 307 for all other request methods. + RedirectTrailingSlash bool + + // If enabled, the router tries to fix the current request path, if no + // handle is registered for it. + // First superfluous path elements like ../ or // are removed. + // Afterwards the router does a case-insensitive lookup of the cleaned path. + // If a handle can be found for this route, the router makes a redirection + // to the corrected path with status code 301 for GET requests and 307 for + // all other request methods. + // For example /FOO and /..//Foo could be redirected to /foo. + // RedirectTrailingSlash is independent of this option. + RedirectFixedPath bool + + // If enabled, the router checks if another method is allowed for the + // current route, if the current request can not be routed. + // If this is the case, the request is answered with 'Method Not Allowed' + // and HTTP status code 405. + // If no other Method is allowed, the request is delegated to the NotFound + // handler. + HandleMethodNotAllowed bool + + // If enabled, the router automatically replies to OPTIONS requests. + // Custom OPTIONS handlers take priority over automatic replies. + HandleOPTIONS bool + + // Configurable http.Handler which is called when no matching route is + // found. If it is not set, http.NotFound is used. + NotFound http.Handler + + // Configurable http.Handler which is called when a request + // cannot be routed and HandleMethodNotAllowed is true. + // If it is not set, http.Error with http.StatusMethodNotAllowed is used. + // The "Allow" header with allowed request methods is set before the handler + // is called. + MethodNotAllowed http.Handler + + // Function to handle panics recovered from http handlers. + // It should be used to generate a error page and return the http error code + // 500 (Internal Server Error). + // The handler can be used to keep your server from crashing because of + // unrecovered panics. + PanicHandler func(http.ResponseWriter, *http.Request, interface{}) +} + +// Make sure the Router conforms with the http.Handler interface +var _ http.Handler = New() + +// New returns a new initialized Router. +// Path auto-correction, including trailing slashes, is enabled by default. +func New() *Router { + return &Router{ + RedirectTrailingSlash: true, + RedirectFixedPath: true, + HandleMethodNotAllowed: true, + HandleOPTIONS: true, + } +} + +// GET is a shortcut for router.Handle("GET", path, handle) +func (r *Router) GET(path string, handle Handle) { + r.Handle("GET", path, handle) +} + +// HEAD is a shortcut for router.Handle("HEAD", path, handle) +func (r *Router) HEAD(path string, handle Handle) { + r.Handle("HEAD", path, handle) +} + +// OPTIONS is a shortcut for router.Handle("OPTIONS", path, handle) +func (r *Router) OPTIONS(path string, handle Handle) { + r.Handle("OPTIONS", path, handle) +} + +// POST is a shortcut for router.Handle("POST", path, handle) +func (r *Router) POST(path string, handle Handle) { + r.Handle("POST", path, handle) +} + +// PUT is a shortcut for router.Handle("PUT", path, handle) +func (r *Router) PUT(path string, handle Handle) { + r.Handle("PUT", path, handle) +} + +// PATCH is a shortcut for router.Handle("PATCH", path, handle) +func (r *Router) PATCH(path string, handle Handle) { + r.Handle("PATCH", path, handle) +} + +// DELETE is a shortcut for router.Handle("DELETE", path, handle) +func (r *Router) DELETE(path string, handle Handle) { + r.Handle("DELETE", path, handle) +} + +// Handle registers a new request handle with the given path and method. +// +// For GET, POST, PUT, PATCH and DELETE requests the respective shortcut +// functions can be used. +// +// This function is intended for bulk loading and to allow the usage of less +// frequently used, non-standardized or custom methods (e.g. for internal +// communication with a proxy). +func (r *Router) Handle(method, path string, handle Handle) { + if path[0] != '/' { + panic("path must begin with '/' in path '" + path + "'") + } + + if r.trees == nil { + r.trees = make(map[string]*node) + } + + root := r.trees[method] + if root == nil { + root = new(node) + r.trees[method] = root + } + + root.addRoute(path, handle) +} + +// Handler is an adapter which allows the usage of an http.Handler as a +// request handle. +func (r *Router) Handler(method, path string, handler http.Handler) { + r.Handle(method, path, + func(w http.ResponseWriter, req *http.Request, _ Params) { + handler.ServeHTTP(w, req) + }, + ) +} + +// HandlerFunc is an adapter which allows the usage of an http.HandlerFunc as a +// request handle. +func (r *Router) HandlerFunc(method, path string, handler http.HandlerFunc) { + r.Handler(method, path, handler) +} + +// ServeFiles serves files from the given file system root. +// The path must end with "/*filepath", files are then served from the local +// path /defined/root/dir/*filepath. +// For example if root is "/etc" and *filepath is "passwd", the local file +// "/etc/passwd" would be served. +// Internally a http.FileServer is used, therefore http.NotFound is used instead +// of the Router's NotFound handler. +// To use the operating system's file system implementation, +// use http.Dir: +// router.ServeFiles("/src/*filepath", http.Dir("/var/www")) +func (r *Router) ServeFiles(path string, root http.FileSystem) { + if len(path) < 10 || path[len(path)-10:] != "/*filepath" { + panic("path must end with /*filepath in path '" + path + "'") + } + + fileServer := http.FileServer(root) + + r.GET(path, func(w http.ResponseWriter, req *http.Request, ps Params) { + req.URL.Path = ps.ByName("filepath") + fileServer.ServeHTTP(w, req) + }) +} + +func (r *Router) recv(w http.ResponseWriter, req *http.Request) { + if rcv := recover(); rcv != nil { + r.PanicHandler(w, req, rcv) + } +} + +// Lookup allows the manual lookup of a method + path combo. +// This is e.g. useful to build a framework around this router. +// If the path was found, it returns the handle function and the path parameter +// values. Otherwise the third return value indicates whether a redirection to +// the same path with an extra / without the trailing slash should be performed. +func (r *Router) Lookup(method, path string) (Handle, Params, bool) { + if root := r.trees[method]; root != nil { + return root.getValue(path) + } + return nil, nil, false +} + +func (r *Router) allowed(path, reqMethod string) (allow string) { + if path == "*" { // server-wide + for method := range r.trees { + if method == "OPTIONS" { + continue + } + + // add request method to list of allowed methods + if len(allow) == 0 { + allow = method + } else { + allow += ", " + method + } + } + } else { // specific path + for method := range r.trees { + // Skip the requested method - we already tried this one + if method == reqMethod || method == "OPTIONS" { + continue + } + + handle, _, _ := r.trees[method].getValue(path) + if handle != nil { + // add request method to list of allowed methods + if len(allow) == 0 { + allow = method + } else { + allow += ", " + method + } + } + } + } + if len(allow) > 0 { + allow += ", OPTIONS" + } + return +} + +// ServeHTTP makes the router implement the http.Handler interface. +func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if r.PanicHandler != nil { + defer r.recv(w, req) + } + + path := req.URL.Path + + if root := r.trees[req.Method]; root != nil { + if handle, ps, tsr := root.getValue(path); handle != nil { + handle(w, req, ps) + return + } else if req.Method != "CONNECT" && path != "/" { + code := 301 // Permanent redirect, request with GET method + if req.Method != "GET" { + // Temporary redirect, request with same method + // As of Go 1.3, Go does not support status code 308. + code = 307 + } + + if tsr && r.RedirectTrailingSlash { + if len(path) > 1 && path[len(path)-1] == '/' { + req.URL.Path = path[:len(path)-1] + } else { + req.URL.Path = path + "/" + } + http.Redirect(w, req, req.URL.String(), code) + return + } + + // Try to fix the request path + if r.RedirectFixedPath { + fixedPath, found := root.findCaseInsensitivePath( + CleanPath(path), + r.RedirectTrailingSlash, + ) + if found { + req.URL.Path = string(fixedPath) + http.Redirect(w, req, req.URL.String(), code) + return + } + } + } + } + + if req.Method == "OPTIONS" { + // Handle OPTIONS requests + if r.HandleOPTIONS { + if allow := r.allowed(path, req.Method); len(allow) > 0 { + w.Header().Set("Allow", allow) + return + } + } + } else { + // Handle 405 + if r.HandleMethodNotAllowed { + if allow := r.allowed(path, req.Method); len(allow) > 0 { + w.Header().Set("Allow", allow) + if r.MethodNotAllowed != nil { + r.MethodNotAllowed.ServeHTTP(w, req) + } else { + http.Error(w, + http.StatusText(http.StatusMethodNotAllowed), + http.StatusMethodNotAllowed, + ) + } + return + } + } + } + + // Handle 404 + if r.NotFound != nil { + r.NotFound.ServeHTTP(w, req) + } else { + http.NotFound(w, req) + } +} diff --git a/vendor/github.com/julienschmidt/httprouter/tree.go b/vendor/github.com/julienschmidt/httprouter/tree.go new file mode 100644 index 000000000..a8fa98b04 --- /dev/null +++ b/vendor/github.com/julienschmidt/httprouter/tree.go @@ -0,0 +1,656 @@ +// Copyright 2013 Julien Schmidt. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be found +// in the LICENSE file. + +package httprouter + +import ( + "strings" + "unicode" + "unicode/utf8" +) + +func min(a, b int) int { + if a <= b { + return a + } + return b +} + +func countParams(path string) uint8 { + var n uint + for i := 0; i < len(path); i++ { + if path[i] != ':' && path[i] != '*' { + continue + } + n++ + } + if n >= 255 { + return 255 + } + return uint8(n) +} + +type nodeType uint8 + +const ( + static nodeType = iota // default + root + param + catchAll +) + +type node struct { + path string + wildChild bool + nType nodeType + maxParams uint8 + indices string + children []*node + handle Handle + priority uint32 +} + +// increments priority of the given child and reorders if necessary +func (n *node) incrementChildPrio(pos int) int { + n.children[pos].priority++ + prio := n.children[pos].priority + + // adjust position (move to front) + newPos := pos + for newPos > 0 && n.children[newPos-1].priority < prio { + // swap node positions + n.children[newPos-1], n.children[newPos] = n.children[newPos], n.children[newPos-1] + + newPos-- + } + + // build new index char string + if newPos != pos { + n.indices = n.indices[:newPos] + // unchanged prefix, might be empty + n.indices[pos:pos+1] + // the index char we move + n.indices[newPos:pos] + n.indices[pos+1:] // rest without char at 'pos' + } + + return newPos +} + +// addRoute adds a node with the given handle to the path. +// Not concurrency-safe! +func (n *node) addRoute(path string, handle Handle) { + fullPath := path + n.priority++ + numParams := countParams(path) + + // non-empty tree + if len(n.path) > 0 || len(n.children) > 0 { + walk: + for { + // Update maxParams of the current node + if numParams > n.maxParams { + n.maxParams = numParams + } + + // Find the longest common prefix. + // This also implies that the common prefix contains no ':' or '*' + // since the existing key can't contain those chars. + i := 0 + max := min(len(path), len(n.path)) + for i < max && path[i] == n.path[i] { + i++ + } + + // Split edge + if i < len(n.path) { + child := node{ + path: n.path[i:], + wildChild: n.wildChild, + nType: static, + indices: n.indices, + children: n.children, + handle: n.handle, + priority: n.priority - 1, + } + + // Update maxParams (max of all children) + for i := range child.children { + if child.children[i].maxParams > child.maxParams { + child.maxParams = child.children[i].maxParams + } + } + + n.children = []*node{&child} + // []byte for proper unicode char conversion, see #65 + n.indices = string([]byte{n.path[i]}) + n.path = path[:i] + n.handle = nil + n.wildChild = false + } + + // Make new node a child of this node + if i < len(path) { + path = path[i:] + + if n.wildChild { + n = n.children[0] + n.priority++ + + // Update maxParams of the child node + if numParams > n.maxParams { + n.maxParams = numParams + } + numParams-- + + // Check if the wildcard matches + if len(path) >= len(n.path) && n.path == path[:len(n.path)] && + // Check for longer wildcard, e.g. :name and :names + (len(n.path) >= len(path) || path[len(n.path)] == '/') { + continue walk + } else { + // Wildcard conflict + var pathSeg string + if n.nType == catchAll { + pathSeg = path + } else { + pathSeg = strings.SplitN(path, "/", 2)[0] + } + prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path + panic("'" + pathSeg + + "' in new path '" + fullPath + + "' conflicts with existing wildcard '" + n.path + + "' in existing prefix '" + prefix + + "'") + } + } + + c := path[0] + + // slash after param + if n.nType == param && c == '/' && len(n.children) == 1 { + n = n.children[0] + n.priority++ + continue walk + } + + // Check if a child with the next path byte exists + for i := 0; i < len(n.indices); i++ { + if c == n.indices[i] { + i = n.incrementChildPrio(i) + n = n.children[i] + continue walk + } + } + + // Otherwise insert it + if c != ':' && c != '*' { + // []byte for proper unicode char conversion, see #65 + n.indices += string([]byte{c}) + child := &node{ + maxParams: numParams, + } + n.children = append(n.children, child) + n.incrementChildPrio(len(n.indices) - 1) + n = child + } + n.insertChild(numParams, path, fullPath, handle) + return + + } else if i == len(path) { // Make node a (in-path) leaf + if n.handle != nil { + panic("a handle is already registered for path '" + fullPath + "'") + } + n.handle = handle + } + return + } + } else { // Empty tree + n.insertChild(numParams, path, fullPath, handle) + n.nType = root + } +} + +func (n *node) insertChild(numParams uint8, path, fullPath string, handle Handle) { + var offset int // already handled bytes of the path + + // find prefix until first wildcard (beginning with ':'' or '*'') + for i, max := 0, len(path); numParams > 0; i++ { + c := path[i] + if c != ':' && c != '*' { + continue + } + + // find wildcard end (either '/' or path end) + end := i + 1 + for end < max && path[end] != '/' { + switch path[end] { + // the wildcard name must not contain ':' and '*' + case ':', '*': + panic("only one wildcard per path segment is allowed, has: '" + + path[i:] + "' in path '" + fullPath + "'") + default: + end++ + } + } + + // check if this Node existing children which would be + // unreachable if we insert the wildcard here + if len(n.children) > 0 { + panic("wildcard route '" + path[i:end] + + "' conflicts with existing children in path '" + fullPath + "'") + } + + // check if the wildcard has a name + if end-i < 2 { + panic("wildcards must be named with a non-empty name in path '" + fullPath + "'") + } + + if c == ':' { // param + // split path at the beginning of the wildcard + if i > 0 { + n.path = path[offset:i] + offset = i + } + + child := &node{ + nType: param, + maxParams: numParams, + } + n.children = []*node{child} + n.wildChild = true + n = child + n.priority++ + numParams-- + + // if the path doesn't end with the wildcard, then there + // will be another non-wildcard subpath starting with '/' + if end < max { + n.path = path[offset:end] + offset = end + + child := &node{ + maxParams: numParams, + priority: 1, + } + n.children = []*node{child} + n = child + } + + } else { // catchAll + if end != max || numParams > 1 { + panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'") + } + + if len(n.path) > 0 && n.path[len(n.path)-1] == '/' { + panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'") + } + + // currently fixed width 1 for '/' + i-- + if path[i] != '/' { + panic("no / before catch-all in path '" + fullPath + "'") + } + + n.path = path[offset:i] + + // first node: catchAll node with empty path + child := &node{ + wildChild: true, + nType: catchAll, + maxParams: 1, + } + n.children = []*node{child} + n.indices = string(path[i]) + n = child + n.priority++ + + // second node: node holding the variable + child = &node{ + path: path[i:], + nType: catchAll, + maxParams: 1, + handle: handle, + priority: 1, + } + n.children = []*node{child} + + return + } + } + + // insert remaining path part and handle to the leaf + n.path = path[offset:] + n.handle = handle +} + +// Returns the handle registered with the given path (key). The values of +// wildcards are saved to a map. +// If no handle can be found, a TSR (trailing slash redirect) recommendation is +// made if a handle exists with an extra (without the) trailing slash for the +// given path. +func (n *node) getValue(path string) (handle Handle, p Params, tsr bool) { +walk: // outer loop for walking the tree + for { + if len(path) > len(n.path) { + if path[:len(n.path)] == n.path { + path = path[len(n.path):] + // If this node does not have a wildcard (param or catchAll) + // child, we can just look up the next child node and continue + // to walk down the tree + if !n.wildChild { + c := path[0] + for i := 0; i < len(n.indices); i++ { + if c == n.indices[i] { + n = n.children[i] + continue walk + } + } + + // Nothing found. + // We can recommend to redirect to the same URL without a + // trailing slash if a leaf exists for that path. + tsr = (path == "/" && n.handle != nil) + return + + } + + // handle wildcard child + n = n.children[0] + switch n.nType { + case param: + // find param end (either '/' or path end) + end := 0 + for end < len(path) && path[end] != '/' { + end++ + } + + // save param value + if p == nil { + // lazy allocation + p = make(Params, 0, n.maxParams) + } + i := len(p) + p = p[:i+1] // expand slice within preallocated capacity + p[i].Key = n.path[1:] + p[i].Value = path[:end] + + // we need to go deeper! + if end < len(path) { + if len(n.children) > 0 { + path = path[end:] + n = n.children[0] + continue walk + } + + // ... but we can't + tsr = (len(path) == end+1) + return + } + + if handle = n.handle; handle != nil { + return + } else if len(n.children) == 1 { + // No handle found. Check if a handle for this path + a + // trailing slash exists for TSR recommendation + n = n.children[0] + tsr = (n.path == "/" && n.handle != nil) + } + + return + + case catchAll: + // save param value + if p == nil { + // lazy allocation + p = make(Params, 0, n.maxParams) + } + i := len(p) + p = p[:i+1] // expand slice within preallocated capacity + p[i].Key = n.path[2:] + p[i].Value = path + + handle = n.handle + return + + default: + panic("invalid node type") + } + } + } else if path == n.path { + // We should have reached the node containing the handle. + // Check if this node has a handle registered. + if handle = n.handle; handle != nil { + return + } + + if path == "/" && n.wildChild && n.nType != root { + tsr = true + return + } + + // No handle found. Check if a handle for this path + a + // trailing slash exists for trailing slash recommendation + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { + n = n.children[i] + tsr = (len(n.path) == 1 && n.handle != nil) || + (n.nType == catchAll && n.children[0].handle != nil) + return + } + } + + return + } + + // Nothing found. We can recommend to redirect to the same URL with an + // extra trailing slash if a leaf exists for that path + tsr = (path == "/") || + (len(n.path) == len(path)+1 && n.path[len(path)] == '/' && + path == n.path[:len(n.path)-1] && n.handle != nil) + return + } +} + +// Makes a case-insensitive lookup of the given path and tries to find a handler. +// It can optionally also fix trailing slashes. +// It returns the case-corrected path and a bool indicating whether the lookup +// was successful. +func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPath []byte, found bool) { + return n.findCaseInsensitivePathRec( + path, + strings.ToLower(path), + make([]byte, 0, len(path)+1), // preallocate enough memory for new path + [4]byte{}, // empty rune buffer + fixTrailingSlash, + ) +} + +// shift bytes in array by n bytes left +func shiftNRuneBytes(rb [4]byte, n int) [4]byte { + switch n { + case 0: + return rb + case 1: + return [4]byte{rb[1], rb[2], rb[3], 0} + case 2: + return [4]byte{rb[2], rb[3]} + case 3: + return [4]byte{rb[3]} + default: + return [4]byte{} + } +} + +// recursive case-insensitive lookup function used by n.findCaseInsensitivePath +func (n *node) findCaseInsensitivePathRec(path, loPath string, ciPath []byte, rb [4]byte, fixTrailingSlash bool) ([]byte, bool) { + loNPath := strings.ToLower(n.path) + +walk: // outer loop for walking the tree + for len(loPath) >= len(loNPath) && (len(loNPath) == 0 || loPath[1:len(loNPath)] == loNPath[1:]) { + // add common path to result + ciPath = append(ciPath, n.path...) + + if path = path[len(n.path):]; len(path) > 0 { + loOld := loPath + loPath = loPath[len(loNPath):] + + // If this node does not have a wildcard (param or catchAll) child, + // we can just look up the next child node and continue to walk down + // the tree + if !n.wildChild { + // skip rune bytes already processed + rb = shiftNRuneBytes(rb, len(loNPath)) + + if rb[0] != 0 { + // old rune not finished + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == rb[0] { + // continue with child node + n = n.children[i] + loNPath = strings.ToLower(n.path) + continue walk + } + } + } else { + // process a new rune + var rv rune + + // find rune start + // runes are up to 4 byte long, + // -4 would definitely be another rune + var off int + for max := min(len(loNPath), 3); off < max; off++ { + if i := len(loNPath) - off; utf8.RuneStart(loOld[i]) { + // read rune from cached lowercase path + rv, _ = utf8.DecodeRuneInString(loOld[i:]) + break + } + } + + // calculate lowercase bytes of current rune + utf8.EncodeRune(rb[:], rv) + // skipp already processed bytes + rb = shiftNRuneBytes(rb, off) + + for i := 0; i < len(n.indices); i++ { + // lowercase matches + if n.indices[i] == rb[0] { + // must use a recursive approach since both the + // uppercase byte and the lowercase byte might exist + // as an index + if out, found := n.children[i].findCaseInsensitivePathRec( + path, loPath, ciPath, rb, fixTrailingSlash, + ); found { + return out, true + } + break + } + } + + // same for uppercase rune, if it differs + if up := unicode.ToUpper(rv); up != rv { + utf8.EncodeRune(rb[:], up) + rb = shiftNRuneBytes(rb, off) + + for i := 0; i < len(n.indices); i++ { + // uppercase matches + if n.indices[i] == rb[0] { + // continue with child node + n = n.children[i] + loNPath = strings.ToLower(n.path) + continue walk + } + } + } + } + + // Nothing found. We can recommend to redirect to the same URL + // without a trailing slash if a leaf exists for that path + return ciPath, (fixTrailingSlash && path == "/" && n.handle != nil) + } + + n = n.children[0] + switch n.nType { + case param: + // find param end (either '/' or path end) + k := 0 + for k < len(path) && path[k] != '/' { + k++ + } + + // add param value to case insensitive path + ciPath = append(ciPath, path[:k]...) + + // we need to go deeper! + if k < len(path) { + if len(n.children) > 0 { + // continue with child node + n = n.children[0] + loNPath = strings.ToLower(n.path) + loPath = loPath[k:] + path = path[k:] + continue + } + + // ... but we can't + if fixTrailingSlash && len(path) == k+1 { + return ciPath, true + } + return ciPath, false + } + + if n.handle != nil { + return ciPath, true + } else if fixTrailingSlash && len(n.children) == 1 { + // No handle found. Check if a handle for this path + a + // trailing slash exists + n = n.children[0] + if n.path == "/" && n.handle != nil { + return append(ciPath, '/'), true + } + } + return ciPath, false + + case catchAll: + return append(ciPath, path...), true + + default: + panic("invalid node type") + } + } else { + // We should have reached the node containing the handle. + // Check if this node has a handle registered. + if n.handle != nil { + return ciPath, true + } + + // No handle found. + // Try to fix the path by adding a trailing slash + if fixTrailingSlash { + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { + n = n.children[i] + if (len(n.path) == 1 && n.handle != nil) || + (n.nType == catchAll && n.children[0].handle != nil) { + return append(ciPath, '/'), true + } + return ciPath, false + } + } + } + return ciPath, false + } + } + + // Nothing found. + // Try to fix the path by adding / removing a trailing slash + if fixTrailingSlash { + if path == "/" { + return ciPath, true + } + if len(loPath)+1 == len(loNPath) && loNPath[len(loPath)] == '/' && + loPath[1:] == loNPath[1:len(loPath)] && n.handle != nil { + return append(ciPath, n.path...), true + } + } + return ciPath, false +} diff --git a/vendor/vendor.json b/vendor/vendor.json index a9de0ec72..15eda5209 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -183,6 +183,12 @@ "revisionTime": "2016-06-03T03:41:37Z" }, { + "checksumSHA1": "gKyBj05YkfuLFruAyPZ4KV9nFp8=", + "path": "github.com/julienschmidt/httprouter", + "revision": "975b5c4c7c21c0e3d2764200bf2aa8e34657ae6e", + "revisionTime": "2017-04-30T22:20:11Z" + }, + { "checksumSHA1": "UpjhOUZ1+0zNt+iIvdtECSHXmTs=", "path": "github.com/karalabe/hid", "revision": "f00545f9f3748e591590be3732d913c77525b10f", |