diff options
47 files changed, 2484 insertions, 478 deletions
diff --git a/cmd/puppeth/wizard_netstats.go b/cmd/puppeth/wizard_netstats.go index 89b38e262..99ca11bb1 100644 --- a/cmd/puppeth/wizard_netstats.go +++ b/cmd/puppeth/wizard_netstats.go @@ -82,7 +82,6 @@ func (w *wizard) gatherStats(server string, pubkey []byte, client *sshClient) *s logger.Info("Starting remote server health-check") stat := &serverStat{ - address: client.address, services: make(map[string]map[string]string), } if client == nil { @@ -94,6 +93,8 @@ func (w *wizard) gatherStats(server string, pubkey []byte, client *sshClient) *s } client = conn } + stat.address = client.address + // Client connected one way or another, run health-checks logger.Debug("Checking for nginx availability") if infos, err := checkNginx(client, w.network); err != nil { @@ -214,6 +215,9 @@ func (stats serverStats) render() { if len(stat.address) > len(separator[1]) { separator[1] = strings.Repeat("-", len(stat.address)) } + if len(stat.failure) > len(separator[1]) { + separator[1] = strings.Repeat("-", len(stat.failure)) + } for service, configs := range stat.services { if len(service) > len(separator[2]) { separator[2] = strings.Repeat("-", len(service)) @@ -250,7 +254,11 @@ func (stats serverStats) render() { sort.Strings(services) if len(services) == 0 { - table.Append([]string{server, stats[server].address, "", "", ""}) + if stats[server].failure != "" { + table.Append([]string{server, stats[server].failure, "", "", ""}) + } else { + table.Append([]string{server, stats[server].address, "", "", ""}) + } } for j, service := range services { // Add an empty line between all services diff --git a/cmd/swarm/access.go b/cmd/swarm/access.go new file mode 100644 index 000000000..12cfbfc1a --- /dev/null +++ b/cmd/swarm/access.go @@ -0,0 +1,219 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>. +package main + +import ( + "crypto/rand" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "strings" + + "github.com/ethereum/go-ethereum/cmd/utils" + "github.com/ethereum/go-ethereum/swarm/api" + "github.com/ethereum/go-ethereum/swarm/api/client" + "gopkg.in/urfave/cli.v1" +) + +var salt = make([]byte, 32) + +func init() { + if _, err := io.ReadFull(rand.Reader, salt); err != nil { + panic("reading from crypto/rand failed: " + err.Error()) + } +} + +func accessNewPass(ctx *cli.Context) { + args := ctx.Args() + if len(args) != 1 { + utils.Fatalf("Expected 1 argument - the ref") + } + + var ( + ae *api.AccessEntry + accessKey []byte + err error + ref = args[0] + password = getPassPhrase("", 0, makePasswordList(ctx)) + dryRun = ctx.Bool(SwarmDryRunFlag.Name) + ) + accessKey, ae, err = api.DoPasswordNew(ctx, password, salt) + if err != nil { + utils.Fatalf("error getting session key: %v", err) + } + m, err := api.GenerateAccessControlManifest(ctx, ref, accessKey, ae) + if dryRun { + err = printManifests(m, nil) + if err != nil { + utils.Fatalf("had an error printing the manifests: %v", err) + } + } else { + utils.Fatalf("uploading manifests") + err = uploadManifests(ctx, m, nil) + if err != nil { + utils.Fatalf("had an error uploading the manifests: %v", err) + } + } +} + +func accessNewPK(ctx *cli.Context) { + args := ctx.Args() + if len(args) != 1 { + utils.Fatalf("Expected 1 argument - the ref") + } + + var ( + ae *api.AccessEntry + sessionKey []byte + err error + ref = args[0] + privateKey = getPrivKey(ctx) + granteePublicKey = ctx.String(SwarmAccessGrantKeyFlag.Name) + dryRun = ctx.Bool(SwarmDryRunFlag.Name) + ) + sessionKey, ae, err = api.DoPKNew(ctx, privateKey, granteePublicKey, salt) + if err != nil { + utils.Fatalf("error getting session key: %v", err) + } + m, err := api.GenerateAccessControlManifest(ctx, ref, sessionKey, ae) + if dryRun { + err = printManifests(m, nil) + if err != nil { + utils.Fatalf("had an error printing the manifests: %v", err) + } + } else { + err = uploadManifests(ctx, m, nil) + if err != nil { + utils.Fatalf("had an error uploading the manifests: %v", err) + } + } +} + +func accessNewACT(ctx *cli.Context) { + args := ctx.Args() + if len(args) != 1 { + utils.Fatalf("Expected 1 argument - the ref") + } + + var ( + ae *api.AccessEntry + actManifest *api.Manifest + accessKey []byte + err error + ref = args[0] + grantees = []string{} + actFilename = ctx.String(SwarmAccessGrantKeysFlag.Name) + privateKey = getPrivKey(ctx) + dryRun = ctx.Bool(SwarmDryRunFlag.Name) + ) + + bytes, err := ioutil.ReadFile(actFilename) + if err != nil { + utils.Fatalf("had an error reading the grantee public key list") + } + grantees = strings.Split(string(bytes), "\n") + accessKey, ae, actManifest, err = api.DoACTNew(ctx, privateKey, salt, grantees) + if err != nil { + utils.Fatalf("error generating ACT manifest: %v", err) + } + + if err != nil { + utils.Fatalf("error getting session key: %v", err) + } + m, err := api.GenerateAccessControlManifest(ctx, ref, accessKey, ae) + if err != nil { + utils.Fatalf("error generating root access manifest: %v", err) + } + + if dryRun { + err = printManifests(m, actManifest) + if err != nil { + utils.Fatalf("had an error printing the manifests: %v", err) + } + } else { + err = uploadManifests(ctx, m, actManifest) + if err != nil { + utils.Fatalf("had an error uploading the manifests: %v", err) + } + } +} + +func printManifests(rootAccessManifest, actManifest *api.Manifest) error { + js, err := json.Marshal(rootAccessManifest) + if err != nil { + return err + } + fmt.Println(string(js)) + + if actManifest != nil { + js, err := json.Marshal(actManifest) + if err != nil { + return err + } + fmt.Println(string(js)) + } + return nil +} + +func uploadManifests(ctx *cli.Context, rootAccessManifest, actManifest *api.Manifest) error { + bzzapi := strings.TrimRight(ctx.GlobalString(SwarmApiFlag.Name), "/") + client := client.NewClient(bzzapi) + + var ( + key string + err error + ) + if actManifest != nil { + key, err = client.UploadManifest(actManifest, false) + if err != nil { + return err + } + + rootAccessManifest.Entries[0].Access.Act = key + } + key, err = client.UploadManifest(rootAccessManifest, false) + if err != nil { + return err + } + fmt.Println(key) + return nil +} + +// makePasswordList reads password lines from the file specified by the global --password flag +// and also by the same subcommand --password flag. +// This function ia a fork of utils.MakePasswordList to lookup cli context for subcommand. +// Function ctx.SetGlobal is not setting the global flag value that can be accessed +// by ctx.GlobalString using the current version of cli package. +func makePasswordList(ctx *cli.Context) []string { + path := ctx.GlobalString(utils.PasswordFileFlag.Name) + if path == "" { + path = ctx.String(utils.PasswordFileFlag.Name) + if path == "" { + return nil + } + } + text, err := ioutil.ReadFile(path) + if err != nil { + utils.Fatalf("Failed to read password file: %v", err) + } + lines := strings.Split(string(text), "\n") + // Sanitise DOS line endings. + for i := range lines { + lines[i] = strings.TrimRight(lines[i], "\r") + } + return lines +} diff --git a/cmd/swarm/access_test.go b/cmd/swarm/access_test.go new file mode 100644 index 000000000..163eb2b4d --- /dev/null +++ b/cmd/swarm/access_test.go @@ -0,0 +1,581 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>. +package main + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "encoding/json" + "io" + "io/ioutil" + gorand "math/rand" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/swarm/api" + swarm "github.com/ethereum/go-ethereum/swarm/api/client" +) + +// TestAccessPassword tests for the correct creation of an ACT manifest protected by a password. +// The test creates bogus content, uploads it encrypted, then creates the wrapping manifest with the Access entry +// The parties participating - node (publisher), uploads to second node then disappears. Content which was uploaded +// is then fetched through 2nd node. since the tested code is not key-aware - we can just +// fetch from the 2nd node using HTTP BasicAuth +func TestAccessPassword(t *testing.T) { + cluster := newTestCluster(t, 1) + defer cluster.Shutdown() + proxyNode := cluster.Nodes[0] + + // create a tmp file + tmp, err := ioutil.TempDir("", "swarm-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmp) + + // write data to file + data := "notsorandomdata" + dataFilename := filepath.Join(tmp, "data.txt") + + err = ioutil.WriteFile(dataFilename, []byte(data), 0666) + if err != nil { + t.Fatal(err) + } + + hashRegexp := `[a-f\d]{128}` + + // upload the file with 'swarm up' and expect a hash + up := runSwarm(t, + "--bzzapi", + proxyNode.URL, //it doesn't matter through which node we upload content + "up", + "--encrypt", + dataFilename) + _, matches := up.ExpectRegexp(hashRegexp) + up.ExpectExit() + + if len(matches) < 1 { + t.Fatal("no matches found") + } + + ref := matches[0] + + password := "smth" + passwordFilename := filepath.Join(tmp, "password.txt") + + err = ioutil.WriteFile(passwordFilename, []byte(password), 0666) + if err != nil { + t.Fatal(err) + } + + up = runSwarm(t, + "access", + "new", + "pass", + "--dry-run", + "--password", + passwordFilename, + ref, + ) + + _, matches = up.ExpectRegexp(".+") + up.ExpectExit() + + if len(matches) == 0 { + t.Fatalf("stdout not matched") + } + + var m api.Manifest + + err = json.Unmarshal([]byte(matches[0]), &m) + if err != nil { + t.Fatalf("unmarshal manifest: %v", err) + } + + if len(m.Entries) != 1 { + t.Fatalf("expected one manifest entry, got %v", len(m.Entries)) + } + + e := m.Entries[0] + + ct := "application/bzz-manifest+json" + if e.ContentType != ct { + t.Errorf("expected %q content type, got %q", ct, e.ContentType) + } + + if e.Access == nil { + t.Fatal("manifest access is nil") + } + + a := e.Access + + if a.Type != "pass" { + t.Errorf(`got access type %q, expected "pass"`, a.Type) + } + if len(a.Salt) < 32 { + t.Errorf(`got salt with length %v, expected not less the 32 bytes`, len(a.Salt)) + } + if a.KdfParams == nil { + t.Fatal("manifest access kdf params is nil") + } + + client := swarm.NewClient(cluster.Nodes[0].URL) + + hash, err := client.UploadManifest(&m, false) + if err != nil { + t.Fatal(err) + } + + httpClient := &http.Client{} + + url := cluster.Nodes[0].URL + "/" + "bzz:/" + hash + response, err := httpClient.Get(url) + if err != nil { + t.Fatal(err) + } + if response.StatusCode != http.StatusUnauthorized { + t.Fatal("should be a 401") + } + authHeader := response.Header.Get("WWW-Authenticate") + if authHeader == "" { + t.Fatal("should be something here") + } + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + t.Fatal(err) + } + req.SetBasicAuth("", password) + + response, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + t.Errorf("expected status %v, got %v", http.StatusOK, response.StatusCode) + } + d, err := ioutil.ReadAll(response.Body) + if err != nil { + t.Fatal(err) + } + if string(d) != data { + t.Errorf("expected decrypted data %q, got %q", data, string(d)) + } + + wrongPasswordFilename := filepath.Join(tmp, "password-wrong.txt") + + err = ioutil.WriteFile(wrongPasswordFilename, []byte("just wr0ng"), 0666) + if err != nil { + t.Fatal(err) + } + + //download file with 'swarm down' with wrong password + up = runSwarm(t, + "--bzzapi", + proxyNode.URL, + "down", + "bzz:/"+hash, + tmp, + "--password", + wrongPasswordFilename) + + _, matches = up.ExpectRegexp("unauthorized") + if len(matches) != 1 && matches[0] != "unauthorized" { + t.Fatal(`"unauthorized" not found in output"`) + } + up.ExpectExit() +} + +// TestAccessPK tests for the correct creation of an ACT manifest between two parties (publisher and grantee). +// The test creates bogus content, uploads it encrypted, then creates the wrapping manifest with the Access entry +// The parties participating - node (publisher), uploads to second node (which is also the grantee) then disappears. +// Content which was uploaded is then fetched through the grantee's http proxy. Since the tested code is private-key aware, +// the test will fail if the proxy's given private key is not granted on the ACT. +func TestAccessPK(t *testing.T) { + // Setup Swarm and upload a test file to it + cluster := newTestCluster(t, 1) + defer cluster.Shutdown() + + // create a tmp file + tmp, err := ioutil.TempFile("", "swarm-test") + if err != nil { + t.Fatal(err) + } + defer tmp.Close() + defer os.Remove(tmp.Name()) + + // write data to file + data := "notsorandomdata" + _, err = io.WriteString(tmp, data) + if err != nil { + t.Fatal(err) + } + + hashRegexp := `[a-f\d]{128}` + + // upload the file with 'swarm up' and expect a hash + up := runSwarm(t, + "--bzzapi", + cluster.Nodes[0].URL, + "up", + "--encrypt", + tmp.Name()) + _, matches := up.ExpectRegexp(hashRegexp) + up.ExpectExit() + + if len(matches) < 1 { + t.Fatal("no matches found") + } + + ref := matches[0] + + pk := cluster.Nodes[0].PrivateKey + granteePubKey := crypto.CompressPubkey(&pk.PublicKey) + + publisherDir, err := ioutil.TempDir("", "swarm-account-dir-temp") + if err != nil { + t.Fatal(err) + } + + passFile, err := ioutil.TempFile("", "swarm-test") + if err != nil { + t.Fatal(err) + } + defer passFile.Close() + defer os.Remove(passFile.Name()) + _, err = io.WriteString(passFile, testPassphrase) + if err != nil { + t.Fatal(err) + } + _, publisherAccount := getTestAccount(t, publisherDir) + up = runSwarm(t, + "--bzzaccount", + publisherAccount.Address.String(), + "--password", + passFile.Name(), + "--datadir", + publisherDir, + "--bzzapi", + cluster.Nodes[0].URL, + "access", + "new", + "pk", + "--dry-run", + "--grant-key", + hex.EncodeToString(granteePubKey), + ref, + ) + + _, matches = up.ExpectRegexp(".+") + up.ExpectExit() + + if len(matches) == 0 { + t.Fatalf("stdout not matched") + } + + var m api.Manifest + + err = json.Unmarshal([]byte(matches[0]), &m) + if err != nil { + t.Fatalf("unmarshal manifest: %v", err) + } + + if len(m.Entries) != 1 { + t.Fatalf("expected one manifest entry, got %v", len(m.Entries)) + } + + e := m.Entries[0] + + ct := "application/bzz-manifest+json" + if e.ContentType != ct { + t.Errorf("expected %q content type, got %q", ct, e.ContentType) + } + + if e.Access == nil { + t.Fatal("manifest access is nil") + } + + a := e.Access + + if a.Type != "pk" { + t.Errorf(`got access type %q, expected "pk"`, a.Type) + } + if len(a.Salt) < 32 { + t.Errorf(`got salt with length %v, expected not less the 32 bytes`, len(a.Salt)) + } + if a.KdfParams != nil { + t.Fatal("manifest access kdf params should be nil") + } + + client := swarm.NewClient(cluster.Nodes[0].URL) + + hash, err := client.UploadManifest(&m, false) + if err != nil { + t.Fatal(err) + } + + httpClient := &http.Client{} + + url := cluster.Nodes[0].URL + "/" + "bzz:/" + hash + response, err := httpClient.Get(url) + if err != nil { + t.Fatal(err) + } + if response.StatusCode != http.StatusOK { + t.Fatal("should be a 200") + } + d, err := ioutil.ReadAll(response.Body) + if err != nil { + t.Fatal(err) + } + if string(d) != data { + t.Errorf("expected decrypted data %q, got %q", data, string(d)) + } +} + +// TestAccessACT tests the e2e creation, uploading and downloading of an ACT type access control +// the test fires up a 3 node cluster, then randomly picks 2 nodes which will be acting as grantees to the data +// set. the third node should fail decoding the reference as it will not be granted access. the publisher uploads through +// one of the nodes then disappears. +func TestAccessACT(t *testing.T) { + // Setup Swarm and upload a test file to it + cluster := newTestCluster(t, 3) + defer cluster.Shutdown() + + var uploadThroughNode = cluster.Nodes[0] + client := swarm.NewClient(uploadThroughNode.URL) + + r1 := gorand.New(gorand.NewSource(time.Now().UnixNano())) + nodeToSkip := r1.Intn(3) // a number between 0 and 2 (node indices in `cluster`) + // create a tmp file + tmp, err := ioutil.TempFile("", "swarm-test") + if err != nil { + t.Fatal(err) + } + defer tmp.Close() + defer os.Remove(tmp.Name()) + + // write data to file + data := "notsorandomdata" + _, err = io.WriteString(tmp, data) + if err != nil { + t.Fatal(err) + } + + hashRegexp := `[a-f\d]{128}` + + // upload the file with 'swarm up' and expect a hash + up := runSwarm(t, + "--bzzapi", + cluster.Nodes[0].URL, + "up", + "--encrypt", + tmp.Name()) + _, matches := up.ExpectRegexp(hashRegexp) + up.ExpectExit() + + if len(matches) < 1 { + t.Fatal("no matches found") + } + + ref := matches[0] + grantees := []string{} + for i, v := range cluster.Nodes { + if i == nodeToSkip { + continue + } + pk := v.PrivateKey + granteePubKey := crypto.CompressPubkey(&pk.PublicKey) + grantees = append(grantees, hex.EncodeToString(granteePubKey)) + } + + granteesPubkeyListFile, err := ioutil.TempFile("", "grantees-pubkey-list.csv") + if err != nil { + t.Fatal(err) + } + + _, err = granteesPubkeyListFile.WriteString(strings.Join(grantees, "\n")) + if err != nil { + t.Fatal(err) + } + + defer granteesPubkeyListFile.Close() + defer os.Remove(granteesPubkeyListFile.Name()) + + publisherDir, err := ioutil.TempDir("", "swarm-account-dir-temp") + if err != nil { + t.Fatal(err) + } + + passFile, err := ioutil.TempFile("", "swarm-test") + if err != nil { + t.Fatal(err) + } + defer passFile.Close() + defer os.Remove(passFile.Name()) + _, err = io.WriteString(passFile, testPassphrase) + if err != nil { + t.Fatal(err) + } + + _, publisherAccount := getTestAccount(t, publisherDir) + up = runSwarm(t, + "--bzzaccount", + publisherAccount.Address.String(), + "--password", + passFile.Name(), + "--datadir", + publisherDir, + "--bzzapi", + cluster.Nodes[0].URL, + "access", + "new", + "act", + "--grant-keys", + granteesPubkeyListFile.Name(), + ref, + ) + + _, matches = up.ExpectRegexp(`[a-f\d]{64}`) + up.ExpectExit() + + if len(matches) == 0 { + t.Fatalf("stdout not matched") + } + hash := matches[0] + m, _, err := client.DownloadManifest(hash) + if err != nil { + t.Fatalf("unmarshal manifest: %v", err) + } + + if len(m.Entries) != 1 { + t.Fatalf("expected one manifest entry, got %v", len(m.Entries)) + } + + e := m.Entries[0] + + ct := "application/bzz-manifest+json" + if e.ContentType != ct { + t.Errorf("expected %q content type, got %q", ct, e.ContentType) + } + + if e.Access == nil { + t.Fatal("manifest access is nil") + } + + a := e.Access + + if a.Type != "act" { + t.Fatalf(`got access type %q, expected "act"`, a.Type) + } + if len(a.Salt) < 32 { + t.Fatalf(`got salt with length %v, expected not less the 32 bytes`, len(a.Salt)) + } + if a.KdfParams != nil { + t.Fatal("manifest access kdf params should be nil") + } + + httpClient := &http.Client{} + + // all nodes except the skipped node should be able to decrypt the content + for i, node := range cluster.Nodes { + log.Debug("trying to fetch from node", "node index", i) + + url := node.URL + "/" + "bzz:/" + hash + response, err := httpClient.Get(url) + if err != nil { + t.Fatal(err) + } + log.Debug("got response from node", "response code", response.StatusCode) + + if i == nodeToSkip { + log.Debug("reached node to skip", "status code", response.StatusCode) + + if response.StatusCode != http.StatusUnauthorized { + t.Fatalf("should be a 401") + } + + continue + } + + if response.StatusCode != http.StatusOK { + t.Fatal("should be a 200") + } + d, err := ioutil.ReadAll(response.Body) + if err != nil { + t.Fatal(err) + } + if string(d) != data { + t.Errorf("expected decrypted data %q, got %q", data, string(d)) + } + } +} + +// TestKeypairSanity is a sanity test for the crypto scheme for ACT. it asserts the correct shared secret according to +// the specs at https://github.com/ethersphere/swarm-docs/blob/eb857afda906c6e7bb90d37f3f334ccce5eef230/act.md +func TestKeypairSanity(t *testing.T) { + salt := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, salt); err != nil { + t.Fatalf("reading from crypto/rand failed: %v", err.Error()) + } + sharedSecret := "a85586744a1ddd56a7ed9f33fa24f40dd745b3a941be296a0d60e329dbdb896d" + + for i, v := range []struct { + publisherPriv string + granteePub string + }{ + { + publisherPriv: "ec5541555f3bc6376788425e9d1a62f55a82901683fd7062c5eddcc373a73459", + granteePub: "0226f213613e843a413ad35b40f193910d26eb35f00154afcde9ded57479a6224a", + }, + { + publisherPriv: "70c7a73011aa56584a0009ab874794ee7e5652fd0c6911cd02f8b6267dd82d2d", + granteePub: "02e6f8d5e28faaa899744972bb847b6eb805a160494690c9ee7197ae9f619181db", + }, + } { + b, _ := hex.DecodeString(v.granteePub) + granteePub, _ := crypto.DecompressPubkey(b) + publisherPrivate, _ := crypto.HexToECDSA(v.publisherPriv) + + ssKey, err := api.NewSessionKeyPK(publisherPrivate, granteePub, salt) + if err != nil { + t.Fatal(err) + } + + hasher := sha3.NewKeccak256() + hasher.Write(salt) + shared, err := hex.DecodeString(sharedSecret) + if err != nil { + t.Fatal(err) + } + hasher.Write(shared) + sum := hasher.Sum(nil) + + if !bytes.Equal(ssKey, sum) { + t.Fatalf("%d: got a session key mismatch", i) + } + } +} diff --git a/cmd/swarm/config.go b/cmd/swarm/config.go index cda8c41c3..1183f8bc8 100644 --- a/cmd/swarm/config.go +++ b/cmd/swarm/config.go @@ -78,6 +78,7 @@ const ( SWARM_ENV_STORE_PATH = "SWARM_STORE_PATH" SWARM_ENV_STORE_CAPACITY = "SWARM_STORE_CAPACITY" SWARM_ENV_STORE_CACHE_CAPACITY = "SWARM_STORE_CACHE_CAPACITY" + SWARM_ACCESS_PASSWORD = "SWARM_ACCESS_PASSWORD" GETH_ENV_DATADIR = "GETH_DATADIR" ) diff --git a/cmd/swarm/download.go b/cmd/swarm/download.go index c2418f744..91bc2c93a 100644 --- a/cmd/swarm/download.go +++ b/cmd/swarm/download.go @@ -68,18 +68,36 @@ func download(ctx *cli.Context) { utils.Fatalf("could not parse uri argument: %v", err) } - // assume behaviour according to --recursive switch - if isRecursive { - if err := client.DownloadDirectory(uri.Addr, uri.Path, dest); err != nil { - utils.Fatalf("encoutered an error while downloading directory: %v", err) - } - } else { - // we are downloading a file - log.Debug(fmt.Sprintf("downloading file/path from a manifest. hash: %s, path:%s", uri.Addr, uri.Path)) + dl := func(credentials string) error { + // assume behaviour according to --recursive switch + if isRecursive { + if err := client.DownloadDirectory(uri.Addr, uri.Path, dest, credentials); err != nil { + if err == swarm.ErrUnauthorized { + return err + } + return fmt.Errorf("directory %s: %v", uri.Path, err) + } + } else { + // we are downloading a file + log.Debug("downloading file/path from a manifest", "uri.Addr", uri.Addr, "uri.Path", uri.Path) - err := client.DownloadFile(uri.Addr, uri.Path, dest) - if err != nil { - utils.Fatalf("could not download %s from given address: %s. error: %v", uri.Path, uri.Addr, err) + err := client.DownloadFile(uri.Addr, uri.Path, dest, credentials) + if err != nil { + if err == swarm.ErrUnauthorized { + return err + } + return fmt.Errorf("file %s from address: %s: %v", uri.Path, uri.Addr, err) + } } + return nil + } + if passwords := makePasswordList(ctx); passwords != nil { + password := getPassPhrase(fmt.Sprintf("Downloading %s is restricted", uri), 0, passwords) + err = dl(password) + } else { + err = dl("") + } + if err != nil { + utils.Fatalf("download: %v", err) } } diff --git a/cmd/swarm/list.go b/cmd/swarm/list.go index 57b5517c6..01b3f4ab6 100644 --- a/cmd/swarm/list.go +++ b/cmd/swarm/list.go @@ -44,7 +44,7 @@ func list(ctx *cli.Context) { bzzapi := strings.TrimRight(ctx.GlobalString(SwarmApiFlag.Name), "/") client := swarm.NewClient(bzzapi) - list, err := client.List(manifest, prefix) + list, err := client.List(manifest, prefix, "") if err != nil { utils.Fatalf("Failed to generate file and directory list: %s", err) } diff --git a/cmd/swarm/main.go b/cmd/swarm/main.go index ac09ae998..76be60cb6 100644 --- a/cmd/swarm/main.go +++ b/cmd/swarm/main.go @@ -155,6 +155,14 @@ var ( Name: "defaultpath", Usage: "path to file served for empty url path (none)", } + SwarmAccessGrantKeyFlag = cli.StringFlag{ + Name: "grant-key", + Usage: "grants a given public key access to an ACT", + } + SwarmAccessGrantKeysFlag = cli.StringFlag{ + Name: "grant-keys", + Usage: "grants a given list of public keys in the following file (separated by line breaks) access to an ACT", + } SwarmUpFromStdinFlag = cli.BoolFlag{ Name: "stdin", Usage: "reads data to be uploaded from stdin", @@ -167,6 +175,15 @@ var ( Name: "encrypt", Usage: "use encrypted upload", } + SwarmAccessPasswordFlag = cli.StringFlag{ + Name: "password", + Usage: "Password", + EnvVar: SWARM_ACCESS_PASSWORD, + } + SwarmDryRunFlag = cli.BoolFlag{ + Name: "dry-run", + Usage: "dry-run", + } CorsStringFlag = cli.StringFlag{ Name: "corsdomain", Usage: "Domain on which to send Access-Control-Allow-Origin header (multiple domains can be supplied separated by a ',')", @@ -254,6 +271,61 @@ func init() { }, { CustomHelpTemplate: helpTemplate, + Name: "access", + Usage: "encrypts a reference and embeds it into a root manifest", + ArgsUsage: "<ref>", + Description: "encrypts a reference and embeds it into a root manifest", + Subcommands: []cli.Command{ + { + CustomHelpTemplate: helpTemplate, + Name: "new", + Usage: "encrypts a reference and embeds it into a root manifest", + ArgsUsage: "<ref>", + Description: "encrypts a reference and embeds it into a root access manifest and prints the resulting manifest", + Subcommands: []cli.Command{ + { + Action: accessNewPass, + CustomHelpTemplate: helpTemplate, + Flags: []cli.Flag{ + utils.PasswordFileFlag, + SwarmDryRunFlag, + }, + Name: "pass", + Usage: "encrypts a reference with a password and embeds it into a root manifest", + ArgsUsage: "<ref>", + Description: "encrypts a reference and embeds it into a root access manifest and prints the resulting manifest", + }, + { + Action: accessNewPK, + CustomHelpTemplate: helpTemplate, + Flags: []cli.Flag{ + utils.PasswordFileFlag, + SwarmDryRunFlag, + SwarmAccessGrantKeyFlag, + }, + Name: "pk", + Usage: "encrypts a reference with the node's private key and a given grantee's public key and embeds it into a root manifest", + ArgsUsage: "<ref>", + Description: "encrypts a reference and embeds it into a root access manifest and prints the resulting manifest", + }, + { + Action: accessNewACT, + CustomHelpTemplate: helpTemplate, + Flags: []cli.Flag{ + SwarmAccessGrantKeysFlag, + SwarmDryRunFlag, + }, + Name: "act", + Usage: "encrypts a reference with the node's private key and a given grantee's public key and embeds it into a root manifest", + ArgsUsage: "<ref>", + Description: "encrypts a reference and embeds it into a root access manifest and prints the resulting manifest", + }, + }, + }, + }, + }, + { + CustomHelpTemplate: helpTemplate, Name: "resource", Usage: "(Advanced) Create and update Mutable Resources", ArgsUsage: "<create|update|info>", @@ -304,16 +376,13 @@ func init() { Description: "Prints the swarm hash of file or directory", }, { - Action: download, - Name: "down", - Flags: []cli.Flag{SwarmRecursiveFlag}, - Usage: "downloads a swarm manifest or a file inside a manifest", - ArgsUsage: " <uri> [<dir>]", - Description: ` -Downloads a swarm bzz uri to the given dir. When no dir is provided, working directory is assumed. --recursive flag is expected when downloading a manifest with multiple entries. -`, + Action: download, + Name: "down", + Flags: []cli.Flag{SwarmRecursiveFlag, SwarmAccessPasswordFlag}, + Usage: "downloads a swarm manifest or a file inside a manifest", + ArgsUsage: " <uri> [<dir>]", + Description: `Downloads a swarm bzz uri to the given dir. When no dir is provided, working directory is assumed. --recursive flag is expected when downloading a manifest with multiple entries.`, }, - { Name: "manifest", CustomHelpTemplate: helpTemplate, @@ -413,16 +482,14 @@ pv(1) tool to get a progress bar: Name: "import", Usage: "import chunks from a tar archive into a local chunk database (use - to read from stdin)", ArgsUsage: "<chunkdb> <file>", - Description: ` -Import chunks from a tar archive into a local chunk database (use - to read from stdin). + Description: `Import chunks from a tar archive into a local chunk database (use - to read from stdin). swarm db import ~/.ethereum/swarm/bzz-KEY/chunks chunks.tar The import may be quite large, consider piping the input through the Unix pv(1) tool to get a progress bar: - pv chunks.tar | swarm db import ~/.ethereum/swarm/bzz-KEY/chunks - -`, + pv chunks.tar | swarm db import ~/.ethereum/swarm/bzz-KEY/chunks -`, }, { Action: dbClean, @@ -535,6 +602,7 @@ func version(ctx *cli.Context) error { func bzzd(ctx *cli.Context) error { //build a valid bzzapi.Config from all available sources: //default config, file config, command line and env vars + bzzconfig, err := buildConfig(ctx) if err != nil { utils.Fatalf("unable to configure swarm: %v", err) @@ -557,6 +625,7 @@ func bzzd(ctx *cli.Context) error { if err != nil { utils.Fatalf("can't create node: %v", err) } + //a few steps need to be done after the config phase is completed, //due to overriding behavior initSwarmNode(bzzconfig, stack, ctx) diff --git a/cmd/swarm/run_test.go b/cmd/swarm/run_test.go index 90d3c98ba..3e766dc10 100644 --- a/cmd/swarm/run_test.go +++ b/cmd/swarm/run_test.go @@ -18,10 +18,12 @@ package main import ( "context" + "crypto/ecdsa" "fmt" "io/ioutil" "net" "os" + "path" "path/filepath" "runtime" "sync" @@ -175,14 +177,15 @@ func (c *testCluster) Cleanup() { } type testNode struct { - Name string - Addr string - URL string - Enode string - Dir string - IpcPath string - Client *rpc.Client - Cmd *cmdtest.TestCmd + Name string + Addr string + URL string + Enode string + Dir string + IpcPath string + PrivateKey *ecdsa.PrivateKey + Client *rpc.Client + Cmd *cmdtest.TestCmd } const testPassphrase = "swarm-test-passphrase" @@ -289,7 +292,11 @@ func existingTestNode(t *testing.T, dir string, bzzaccount string) *testNode { func newTestNode(t *testing.T, dir string) *testNode { conf, account := getTestAccount(t, dir) - node := &testNode{Dir: dir} + ks := keystore.NewKeyStore(path.Join(dir, "keystore"), 1<<18, 1) + + pk := decryptStoreAccount(ks, account.Address.Hex(), []string{testPassphrase}) + + node := &testNode{Dir: dir, PrivateKey: pk} // assign ports ports, err := getAvailableTCPPorts(2) diff --git a/core/chain_indexer.go b/core/chain_indexer.go index 0b927116d..11a7c96fa 100644 --- a/core/chain_indexer.go +++ b/core/chain_indexer.go @@ -17,6 +17,7 @@ package core import ( + "context" "encoding/binary" "fmt" "sync" @@ -37,11 +38,11 @@ import ( type ChainIndexerBackend interface { // Reset initiates the processing of a new chain segment, potentially terminating // any partially completed operations (in case of a reorg). - Reset(section uint64, prevHead common.Hash) error + Reset(ctx context.Context, section uint64, prevHead common.Hash) error // Process crunches through the next header in the chain segment. The caller // will ensure a sequential order of headers. - Process(header *types.Header) + Process(ctx context.Context, header *types.Header) error // Commit finalizes the section metadata and stores it into the database. Commit() error @@ -71,9 +72,11 @@ type ChainIndexer struct { backend ChainIndexerBackend // Background processor generating the index data content children []*ChainIndexer // Child indexers to cascade chain updates to - active uint32 // Flag whether the event loop was started - update chan struct{} // Notification channel that headers should be processed - quit chan chan error // Quit channel to tear down running goroutines + active uint32 // Flag whether the event loop was started + update chan struct{} // Notification channel that headers should be processed + quit chan chan error // Quit channel to tear down running goroutines + ctx context.Context + ctxCancel func() sectionSize uint64 // Number of blocks in a single chain segment to process confirmsReq uint64 // Number of confirmations before processing a completed segment @@ -105,6 +108,8 @@ func NewChainIndexer(chainDb, indexDb ethdb.Database, backend ChainIndexerBacken } // Initialize database dependent fields and start the updater c.loadValidSections() + c.ctx, c.ctxCancel = context.WithCancel(context.Background()) + go c.updateLoop() return c @@ -138,6 +143,8 @@ func (c *ChainIndexer) Start(chain ChainIndexerChain) { func (c *ChainIndexer) Close() error { var errs []error + c.ctxCancel() + // Tear down the primary update loop errc := make(chan error) c.quit <- errc @@ -297,6 +304,12 @@ func (c *ChainIndexer) updateLoop() { c.lock.Unlock() newHead, err := c.processSection(section, oldHead) if err != nil { + select { + case <-c.ctx.Done(): + <-c.quit <- nil + return + default: + } c.log.Error("Section processing failed", "error", err) } c.lock.Lock() @@ -344,7 +357,7 @@ func (c *ChainIndexer) processSection(section uint64, lastHead common.Hash) (com // Reset and partial processing - if err := c.backend.Reset(section, lastHead); err != nil { + if err := c.backend.Reset(c.ctx, section, lastHead); err != nil { c.setValidSections(0) return common.Hash{}, err } @@ -360,11 +373,12 @@ func (c *ChainIndexer) processSection(section uint64, lastHead common.Hash) (com } else if header.ParentHash != lastHead { return common.Hash{}, fmt.Errorf("chain reorged during section processing") } - c.backend.Process(header) + if err := c.backend.Process(c.ctx, header); err != nil { + return common.Hash{}, err + } lastHead = header.Hash() } if err := c.backend.Commit(); err != nil { - c.log.Error("Section commit failed", "error", err) return common.Hash{}, err } return lastHead, nil diff --git a/core/chain_indexer_test.go b/core/chain_indexer_test.go index 550caf556..a029dec62 100644 --- a/core/chain_indexer_test.go +++ b/core/chain_indexer_test.go @@ -17,6 +17,7 @@ package core import ( + "context" "fmt" "math/big" "math/rand" @@ -210,13 +211,13 @@ func (b *testChainIndexBackend) reorg(headNum uint64) uint64 { return b.stored * b.indexer.sectionSize } -func (b *testChainIndexBackend) Reset(section uint64, prevHead common.Hash) error { +func (b *testChainIndexBackend) Reset(ctx context.Context, section uint64, prevHead common.Hash) error { b.section = section b.headerCnt = 0 return nil } -func (b *testChainIndexBackend) Process(header *types.Header) { +func (b *testChainIndexBackend) Process(ctx context.Context, header *types.Header) error { b.headerCnt++ if b.headerCnt > b.indexer.sectionSize { b.t.Error("Processing too many headers") @@ -227,6 +228,7 @@ func (b *testChainIndexBackend) Process(header *types.Header) { b.t.Fatal("Unexpected call to Process") case b.processCh <- header.Number.Uint64(): } + return nil } func (b *testChainIndexBackend) Commit() error { diff --git a/core/events.go b/core/events.go index 8d200f2a2..710bdb589 100644 --- a/core/events.go +++ b/core/events.go @@ -29,9 +29,6 @@ type PendingLogsEvent struct { Logs []*types.Log } -// PendingStateEvent is posted pre mining and notifies of pending state changes. -type PendingStateEvent struct{} - // NewMinedBlockEvent is posted when a block has been imported. type NewMinedBlockEvent struct{ Block *types.Block } diff --git a/eth/backend.go b/eth/backend.go index 865534b19..6549cb8a3 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -130,7 +130,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { gasPrice: config.GasPrice, etherbase: config.Etherbase, bloomRequests: make(chan chan *bloombits.Retrieval), - bloomIndexer: NewBloomIndexer(chainDb, params.BloomBitsBlocks), + bloomIndexer: NewBloomIndexer(chainDb, params.BloomBitsBlocks, bloomConfirms), } log.Info("Initialising Ethereum protocol", "versions", ProtocolVersions, "network", config.NetworkId) diff --git a/eth/bloombits.go b/eth/bloombits.go index 954239d14..eb18565e2 100644 --- a/eth/bloombits.go +++ b/eth/bloombits.go @@ -17,6 +17,7 @@ package eth import ( + "context" "time" "github.com/ethereum/go-ethereum/common" @@ -92,30 +93,28 @@ const ( // BloomIndexer implements a core.ChainIndexer, building up a rotated bloom bits index // for the Ethereum header bloom filters, permitting blazing fast filtering. type BloomIndexer struct { - size uint64 // section size to generate bloombits for - - db ethdb.Database // database instance to write index data and metadata into - gen *bloombits.Generator // generator to rotate the bloom bits crating the bloom index - - section uint64 // Section is the section number being processed currently - head common.Hash // Head is the hash of the last header processed + size uint64 // section size to generate bloombits for + db ethdb.Database // database instance to write index data and metadata into + gen *bloombits.Generator // generator to rotate the bloom bits crating the bloom index + section uint64 // Section is the section number being processed currently + head common.Hash // Head is the hash of the last header processed } // NewBloomIndexer returns a chain indexer that generates bloom bits data for the // canonical chain for fast logs filtering. -func NewBloomIndexer(db ethdb.Database, size uint64) *core.ChainIndexer { +func NewBloomIndexer(db ethdb.Database, size, confReq uint64) *core.ChainIndexer { backend := &BloomIndexer{ db: db, size: size, } table := ethdb.NewTable(db, string(rawdb.BloomBitsIndexPrefix)) - return core.NewChainIndexer(db, table, backend, size, bloomConfirms, bloomThrottling, "bloombits") + return core.NewChainIndexer(db, table, backend, size, confReq, bloomThrottling, "bloombits") } // Reset implements core.ChainIndexerBackend, starting a new bloombits index // section. -func (b *BloomIndexer) Reset(section uint64, lastSectionHead common.Hash) error { +func (b *BloomIndexer) Reset(ctx context.Context, section uint64, lastSectionHead common.Hash) error { gen, err := bloombits.NewGenerator(uint(b.size)) b.gen, b.section, b.head = gen, section, common.Hash{} return err @@ -123,16 +122,16 @@ func (b *BloomIndexer) Reset(section uint64, lastSectionHead common.Hash) error // Process implements core.ChainIndexerBackend, adding a new header's bloom into // the index. -func (b *BloomIndexer) Process(header *types.Header) { +func (b *BloomIndexer) Process(ctx context.Context, header *types.Header) error { b.gen.AddBloom(uint(header.Number.Uint64()-b.section*b.size), header.Bloom) b.head = header.Hash() + return nil } // Commit implements core.ChainIndexerBackend, finalizing the bloom section and // writing it out into the database. func (b *BloomIndexer) Commit() error { batch := b.db.NewBatch() - for i := 0; i < types.BloomBitLength; i++ { bits, err := b.gen.Bitset(uint(i)) if err != nil { diff --git a/les/backend.go b/les/backend.go index 178bc1e0e..9b8cc1828 100644 --- a/les/backend.go +++ b/les/backend.go @@ -95,29 +95,35 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { quitSync := make(chan struct{}) leth := &LightEthereum{ - config: config, - chainConfig: chainConfig, - chainDb: chainDb, - eventMux: ctx.EventMux, - peers: peers, - reqDist: newRequestDistributor(peers, quitSync), - accountManager: ctx.AccountManager, - engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, chainDb), - shutdownChan: make(chan bool), - networkId: config.NetworkId, - bloomRequests: make(chan chan *bloombits.Retrieval), - bloomIndexer: eth.NewBloomIndexer(chainDb, light.BloomTrieFrequency), - chtIndexer: light.NewChtIndexer(chainDb, true), - bloomTrieIndexer: light.NewBloomTrieIndexer(chainDb, true), + config: config, + chainConfig: chainConfig, + chainDb: chainDb, + eventMux: ctx.EventMux, + peers: peers, + reqDist: newRequestDistributor(peers, quitSync), + accountManager: ctx.AccountManager, + engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, chainDb), + shutdownChan: make(chan bool), + networkId: config.NetworkId, + bloomRequests: make(chan chan *bloombits.Retrieval), + bloomIndexer: eth.NewBloomIndexer(chainDb, light.BloomTrieFrequency, light.HelperTrieConfirmations), } leth.relay = NewLesTxRelay(peers, leth.reqDist) leth.serverPool = newServerPool(chainDb, quitSync, &leth.wg) leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool) - leth.odr = NewLesOdr(chainDb, leth.chtIndexer, leth.bloomTrieIndexer, leth.bloomIndexer, leth.retriever) + leth.odr = NewLesOdr(chainDb, leth.retriever) + leth.chtIndexer = light.NewChtIndexer(chainDb, true, leth.odr) + leth.bloomTrieIndexer = light.NewBloomTrieIndexer(chainDb, true, leth.odr) + leth.odr.SetIndexers(leth.chtIndexer, leth.bloomTrieIndexer, leth.bloomIndexer) + // Note: NewLightChain adds the trusted checkpoint so it needs an ODR with + // indexers already set but not started yet if leth.blockchain, err = light.NewLightChain(leth.odr, leth.chainConfig, leth.engine); err != nil { return nil, err } + // Note: AddChildIndexer starts the update process for the child + leth.bloomIndexer.AddChildIndexer(leth.bloomTrieIndexer) + leth.chtIndexer.Start(leth.blockchain) leth.bloomIndexer.Start(leth.blockchain) // Rewind the chain in case of an incompatible config upgrade. if compat, ok := genesisErr.(*params.ConfigCompatError); ok { @@ -242,9 +248,6 @@ func (s *LightEthereum) Stop() error { if s.chtIndexer != nil { s.chtIndexer.Close() } - if s.bloomTrieIndexer != nil { - s.bloomTrieIndexer.Close() - } s.blockchain.Stop() s.protocolManager.Stop() s.txPool.Stop() diff --git a/les/distributor.go b/les/distributor.go index 159fa4c73..d3f6b21d1 100644 --- a/les/distributor.go +++ b/les/distributor.go @@ -20,14 +20,10 @@ package les import ( "container/list" - "errors" "sync" "time" ) -// ErrNoPeers is returned if no peers capable of serving a queued request are available -var ErrNoPeers = errors.New("no suitable peers available") - // requestDistributor implements a mechanism that distributes requests to // suitable peers, obeying flow control rules and prioritizing them in creation // order (even when a resend is necessary). diff --git a/les/handler.go b/les/handler.go index 91a235bf0..ccb4a8844 100644 --- a/les/handler.go +++ b/les/handler.go @@ -1206,11 +1206,12 @@ func (pm *ProtocolManager) txStatus(hashes []common.Hash) []txStatus { // NodeInfo represents a short summary of the Ethereum sub-protocol metadata // known about the host peer. type NodeInfo struct { - Network uint64 `json:"network"` // Ethereum network ID (1=Frontier, 2=Morden, Ropsten=3, Rinkeby=4) - Difficulty *big.Int `json:"difficulty"` // Total difficulty of the host's blockchain - Genesis common.Hash `json:"genesis"` // SHA3 hash of the host's genesis block - Config *params.ChainConfig `json:"config"` // Chain configuration for the fork rules - Head common.Hash `json:"head"` // SHA3 hash of the host's best owned block + Network uint64 `json:"network"` // Ethereum network ID (1=Frontier, 2=Morden, Ropsten=3, Rinkeby=4) + Difficulty *big.Int `json:"difficulty"` // Total difficulty of the host's blockchain + Genesis common.Hash `json:"genesis"` // SHA3 hash of the host's genesis block + Config *params.ChainConfig `json:"config"` // Chain configuration for the fork rules + Head common.Hash `json:"head"` // SHA3 hash of the host's best owned block + CHT light.TrustedCheckpoint `json:"cht"` // Trused CHT checkpoint for fast catchup } // NodeInfo retrieves some protocol metadata about the running host node. @@ -1218,12 +1219,31 @@ func (self *ProtocolManager) NodeInfo() *NodeInfo { head := self.blockchain.CurrentHeader() hash := head.Hash() + var cht light.TrustedCheckpoint + + sections, _, sectionHead := self.odr.ChtIndexer().Sections() + sections2, _, sectionHead2 := self.odr.BloomTrieIndexer().Sections() + if sections2 < sections { + sections = sections2 + sectionHead = sectionHead2 + } + if sections > 0 { + sectionIndex := sections - 1 + cht = light.TrustedCheckpoint{ + SectionIdx: sectionIndex, + SectionHead: sectionHead, + CHTRoot: light.GetChtRoot(self.chainDb, sectionIndex, sectionHead), + BloomRoot: light.GetBloomTrieRoot(self.chainDb, sectionIndex, sectionHead), + } + } + return &NodeInfo{ Network: self.networkId, Difficulty: self.blockchain.GetTd(hash, head.Number.Uint64()), Genesis: self.blockchain.Genesis().Hash(), Config: self.blockchain.Config(), Head: hash, + CHT: cht, } } @@ -1258,7 +1278,7 @@ func (pc *peerConnection) RequestHeadersByHash(origin common.Hash, amount int, s } _, ok := <-pc.manager.reqDist.queue(rq) if !ok { - return ErrNoPeers + return light.ErrNoPeers } return nil } @@ -1282,7 +1302,7 @@ func (pc *peerConnection) RequestHeadersByNumber(origin uint64, amount int, skip } _, ok := <-pc.manager.reqDist.queue(rq) if !ok { - return ErrNoPeers + return light.ErrNoPeers } return nil } diff --git a/les/helper_test.go b/les/helper_test.go index 8fd01a39e..50c97e06e 100644 --- a/les/helper_test.go +++ b/les/helper_test.go @@ -156,12 +156,12 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor } else { blockchain, _ := core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}) - chtIndexer := light.NewChtIndexer(db, false) + chtIndexer := light.NewChtIndexer(db, false, nil) chtIndexer.Start(blockchain) - bbtIndexer := light.NewBloomTrieIndexer(db, false) + bbtIndexer := light.NewBloomTrieIndexer(db, false, nil) - bloomIndexer := eth.NewBloomIndexer(db, params.BloomBitsBlocks) + bloomIndexer := eth.NewBloomIndexer(db, params.BloomBitsBlocks, light.HelperTrieProcessConfirmations) bloomIndexer.AddChildIndexer(bbtIndexer) bloomIndexer.Start(blockchain) diff --git a/les/odr.go b/les/odr.go index f8412aaad..2ad28d5d9 100644 --- a/les/odr.go +++ b/les/odr.go @@ -33,14 +33,11 @@ type LesOdr struct { stop chan struct{} } -func NewLesOdr(db ethdb.Database, chtIndexer, bloomTrieIndexer, bloomIndexer *core.ChainIndexer, retriever *retrieveManager) *LesOdr { +func NewLesOdr(db ethdb.Database, retriever *retrieveManager) *LesOdr { return &LesOdr{ - db: db, - chtIndexer: chtIndexer, - bloomTrieIndexer: bloomTrieIndexer, - bloomIndexer: bloomIndexer, - retriever: retriever, - stop: make(chan struct{}), + db: db, + retriever: retriever, + stop: make(chan struct{}), } } @@ -54,6 +51,13 @@ func (odr *LesOdr) Database() ethdb.Database { return odr.db } +// SetIndexers adds the necessary chain indexers to the ODR backend +func (odr *LesOdr) SetIndexers(chtIndexer, bloomTrieIndexer, bloomIndexer *core.ChainIndexer) { + odr.chtIndexer = chtIndexer + odr.bloomTrieIndexer = bloomTrieIndexer + odr.bloomIndexer = bloomIndexer +} + // ChtIndexer returns the CHT chain indexer func (odr *LesOdr) ChtIndexer() *core.ChainIndexer { return odr.chtIndexer diff --git a/les/odr_test.go b/les/odr_test.go index 983f7262b..c7c25cbe4 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -167,7 +167,8 @@ func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { rm := newRetrieveManager(peers, dist, nil) db := ethdb.NewMemDatabase() ldb := ethdb.NewMemDatabase() - odr := NewLesOdr(ldb, light.NewChtIndexer(db, true), light.NewBloomTrieIndexer(db, true), eth.NewBloomIndexer(db, light.BloomTrieFrequency), rm) + odr := NewLesOdr(ldb, rm) + odr.SetIndexers(light.NewChtIndexer(db, true, nil), light.NewBloomTrieIndexer(db, true, nil), eth.NewBloomIndexer(db, light.BloomTrieFrequency, light.HelperTrieConfirmations)) pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db) lpm := newTestProtocolManagerMust(t, true, 0, nil, peers, odr, ldb) _, err1, lpeer, err2 := newTestPeerPair("peer", protocol, pm, lpm) diff --git a/les/request_test.go b/les/request_test.go index ba2f603d8..db576798b 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -89,7 +89,8 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) { rm := newRetrieveManager(peers, dist, nil) db := ethdb.NewMemDatabase() ldb := ethdb.NewMemDatabase() - odr := NewLesOdr(ldb, light.NewChtIndexer(db, true), light.NewBloomTrieIndexer(db, true), eth.NewBloomIndexer(db, light.BloomTrieFrequency), rm) + odr := NewLesOdr(ldb, rm) + odr.SetIndexers(light.NewChtIndexer(db, true, nil), light.NewBloomTrieIndexer(db, true, nil), eth.NewBloomIndexer(db, light.BloomTrieFrequency, light.HelperTrieConfirmations)) pm := newTestProtocolManagerMust(t, false, 4, testChainGen, nil, nil, db) lpm := newTestProtocolManagerMust(t, true, 0, nil, peers, odr, ldb) diff --git a/les/retrieve.go b/les/retrieve.go index a9037a38e..8ae36d82c 100644 --- a/les/retrieve.go +++ b/les/retrieve.go @@ -27,6 +27,7 @@ import ( "time" "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/light" ) var ( @@ -207,7 +208,7 @@ func (r *sentReq) stateRequesting() reqStateFn { return r.stateNoMorePeers } // nothing to wait for, no more peers to ask, return with error - r.stop(ErrNoPeers) + r.stop(light.ErrNoPeers) // no need to go to stopped state because waiting() already returned false return nil } diff --git a/les/server.go b/les/server.go index fca6124c9..a934fbf26 100644 --- a/les/server.go +++ b/les/server.go @@ -67,8 +67,8 @@ func NewLesServer(eth *eth.Ethereum, config *eth.Config) (*LesServer, error) { protocolManager: pm, quitSync: quitSync, lesTopics: lesTopics, - chtIndexer: light.NewChtIndexer(eth.ChainDb(), false), - bloomTrieIndexer: light.NewBloomTrieIndexer(eth.ChainDb(), false), + chtIndexer: light.NewChtIndexer(eth.ChainDb(), false, nil), + bloomTrieIndexer: light.NewBloomTrieIndexer(eth.ChainDb(), false, nil), } logger := log.New() diff --git a/light/lightchain.go b/light/lightchain.go index 30b9bd89a..b7e629e88 100644 --- a/light/lightchain.go +++ b/light/lightchain.go @@ -116,19 +116,19 @@ func NewLightChain(odr OdrBackend, config *params.ChainConfig, engine consensus. } // addTrustedCheckpoint adds a trusted checkpoint to the blockchain -func (self *LightChain) addTrustedCheckpoint(cp trustedCheckpoint) { +func (self *LightChain) addTrustedCheckpoint(cp TrustedCheckpoint) { if self.odr.ChtIndexer() != nil { - StoreChtRoot(self.chainDb, cp.sectionIdx, cp.sectionHead, cp.chtRoot) - self.odr.ChtIndexer().AddKnownSectionHead(cp.sectionIdx, cp.sectionHead) + StoreChtRoot(self.chainDb, cp.SectionIdx, cp.SectionHead, cp.CHTRoot) + self.odr.ChtIndexer().AddKnownSectionHead(cp.SectionIdx, cp.SectionHead) } if self.odr.BloomTrieIndexer() != nil { - StoreBloomTrieRoot(self.chainDb, cp.sectionIdx, cp.sectionHead, cp.bloomTrieRoot) - self.odr.BloomTrieIndexer().AddKnownSectionHead(cp.sectionIdx, cp.sectionHead) + StoreBloomTrieRoot(self.chainDb, cp.SectionIdx, cp.SectionHead, cp.BloomRoot) + self.odr.BloomTrieIndexer().AddKnownSectionHead(cp.SectionIdx, cp.SectionHead) } if self.odr.BloomIndexer() != nil { - self.odr.BloomIndexer().AddKnownSectionHead(cp.sectionIdx, cp.sectionHead) + self.odr.BloomIndexer().AddKnownSectionHead(cp.SectionIdx, cp.SectionHead) } - log.Info("Added trusted checkpoint", "chain", cp.name, "block", (cp.sectionIdx+1)*CHTFrequencyClient-1, "hash", cp.sectionHead) + log.Info("Added trusted checkpoint", "chain", cp.name, "block", (cp.SectionIdx+1)*CHTFrequencyClient-1, "hash", cp.SectionHead) } func (self *LightChain) getProcInterrupt() bool { diff --git a/light/odr.go b/light/odr.go index 8f1e50b81..83c64055a 100644 --- a/light/odr.go +++ b/light/odr.go @@ -20,6 +20,7 @@ package light import ( "context" + "errors" "math/big" "github.com/ethereum/go-ethereum/common" @@ -33,6 +34,9 @@ import ( // service is not required. var NoOdr = context.Background() +// ErrNoPeers is returned if no peers capable of serving a queued request are available +var ErrNoPeers = errors.New("no suitable peers available") + // OdrBackend is an interface to a backend service that handles ODR retrievals type type OdrBackend interface { Database() ethdb.Database diff --git a/light/postprocess.go b/light/postprocess.go index 2090a9d04..0b25e1d88 100644 --- a/light/postprocess.go +++ b/light/postprocess.go @@ -17,8 +17,10 @@ package light import ( + "context" "encoding/binary" "errors" + "fmt" "math/big" "time" @@ -47,35 +49,35 @@ const ( HelperTrieProcessConfirmations = 256 // number of confirmations before a HelperTrie is generated ) -// trustedCheckpoint represents a set of post-processed trie roots (CHT and BloomTrie) associated with +// TrustedCheckpoint represents a set of post-processed trie roots (CHT and BloomTrie) associated with // the appropriate section index and head hash. It is used to start light syncing from this checkpoint // and avoid downloading the entire header chain while still being able to securely access old headers/logs. -type trustedCheckpoint struct { - name string - sectionIdx uint64 - sectionHead, chtRoot, bloomTrieRoot common.Hash +type TrustedCheckpoint struct { + name string + SectionIdx uint64 + SectionHead, CHTRoot, BloomRoot common.Hash } var ( - mainnetCheckpoint = trustedCheckpoint{ - name: "mainnet", - sectionIdx: 179, - sectionHead: common.HexToHash("ae778e455492db1183e566fa0c67f954d256fdd08618f6d5a393b0e24576d0ea"), - chtRoot: common.HexToHash("646b338f9ca74d936225338916be53710ec84020b89946004a8605f04c817f16"), - bloomTrieRoot: common.HexToHash("d0f978f5dbc86e5bf931d8dd5b2ecbebbda6dc78f8896af6a27b46a3ced0ac25"), + mainnetCheckpoint = TrustedCheckpoint{ + name: "mainnet", + SectionIdx: 179, + SectionHead: common.HexToHash("ae778e455492db1183e566fa0c67f954d256fdd08618f6d5a393b0e24576d0ea"), + CHTRoot: common.HexToHash("646b338f9ca74d936225338916be53710ec84020b89946004a8605f04c817f16"), + BloomRoot: common.HexToHash("d0f978f5dbc86e5bf931d8dd5b2ecbebbda6dc78f8896af6a27b46a3ced0ac25"), } - ropstenCheckpoint = trustedCheckpoint{ - name: "ropsten", - sectionIdx: 107, - sectionHead: common.HexToHash("e1988f95399debf45b873e065e5cd61b416ef2e2e5deec5a6f87c3127086e1ce"), - chtRoot: common.HexToHash("15cba18e4de0ab1e95e202625199ba30147aec8b0b70384b66ebea31ba6a18e0"), - bloomTrieRoot: common.HexToHash("e00fa6389b2e597d9df52172cd8e936879eed0fca4fa59db99e2c8ed682562f2"), + ropstenCheckpoint = TrustedCheckpoint{ + name: "ropsten", + SectionIdx: 107, + SectionHead: common.HexToHash("e1988f95399debf45b873e065e5cd61b416ef2e2e5deec5a6f87c3127086e1ce"), + CHTRoot: common.HexToHash("15cba18e4de0ab1e95e202625199ba30147aec8b0b70384b66ebea31ba6a18e0"), + BloomRoot: common.HexToHash("e00fa6389b2e597d9df52172cd8e936879eed0fca4fa59db99e2c8ed682562f2"), } ) // trustedCheckpoints associates each known checkpoint with the genesis hash of the chain it belongs to -var trustedCheckpoints = map[common.Hash]trustedCheckpoint{ +var trustedCheckpoints = map[common.Hash]TrustedCheckpoint{ params.MainnetGenesisHash: mainnetCheckpoint, params.TestnetGenesisHash: ropstenCheckpoint, } @@ -119,7 +121,8 @@ func StoreChtRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root common // ChtIndexerBackend implements core.ChainIndexerBackend type ChtIndexerBackend struct { - diskdb ethdb.Database + diskdb, trieTable ethdb.Database + odr OdrBackend triedb *trie.Database section, sectionSize uint64 lastHash common.Hash @@ -127,7 +130,7 @@ type ChtIndexerBackend struct { } // NewBloomTrieIndexer creates a BloomTrie chain indexer -func NewChtIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer { +func NewChtIndexer(db ethdb.Database, clientMode bool, odr OdrBackend) *core.ChainIndexer { var sectionSize, confirmReq uint64 if clientMode { sectionSize = CHTFrequencyClient @@ -137,28 +140,64 @@ func NewChtIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer { confirmReq = HelperTrieProcessConfirmations } idb := ethdb.NewTable(db, "chtIndex-") + trieTable := ethdb.NewTable(db, ChtTablePrefix) backend := &ChtIndexerBackend{ diskdb: db, - triedb: trie.NewDatabase(ethdb.NewTable(db, ChtTablePrefix)), + odr: odr, + trieTable: trieTable, + triedb: trie.NewDatabase(trieTable), sectionSize: sectionSize, } return core.NewChainIndexer(db, idb, backend, sectionSize, confirmReq, time.Millisecond*100, "cht") } +// fetchMissingNodes tries to retrieve the last entry of the latest trusted CHT from the +// ODR backend in order to be able to add new entries and calculate subsequent root hashes +func (c *ChtIndexerBackend) fetchMissingNodes(ctx context.Context, section uint64, root common.Hash) error { + batch := c.trieTable.NewBatch() + r := &ChtRequest{ChtRoot: root, ChtNum: section - 1, BlockNum: section*c.sectionSize - 1} + for { + err := c.odr.Retrieve(ctx, r) + switch err { + case nil: + r.Proof.Store(batch) + return batch.Write() + case ErrNoPeers: + // if there are no peers to serve, retry later + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(time.Second * 10): + // stay in the loop and try again + } + default: + return err + } + } +} + // Reset implements core.ChainIndexerBackend -func (c *ChtIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error { +func (c *ChtIndexerBackend) Reset(ctx context.Context, section uint64, lastSectionHead common.Hash) error { var root common.Hash if section > 0 { root = GetChtRoot(c.diskdb, section-1, lastSectionHead) } var err error c.trie, err = trie.New(root, c.triedb) + + if err != nil && c.odr != nil { + err = c.fetchMissingNodes(ctx, section, root) + if err == nil { + c.trie, err = trie.New(root, c.triedb) + } + } + c.section = section return err } // Process implements core.ChainIndexerBackend -func (c *ChtIndexerBackend) Process(header *types.Header) { +func (c *ChtIndexerBackend) Process(ctx context.Context, header *types.Header) error { hash, num := header.Hash(), header.Number.Uint64() c.lastHash = hash @@ -170,6 +209,7 @@ func (c *ChtIndexerBackend) Process(header *types.Header) { binary.BigEndian.PutUint64(encNumber[:], num) data, _ := rlp.EncodeToBytes(ChtNode{hash, td}) c.trie.Update(encNumber[:], data) + return nil } // Commit implements core.ChainIndexerBackend @@ -181,16 +221,15 @@ func (c *ChtIndexerBackend) Commit() error { c.triedb.Commit(root, false) if ((c.section+1)*c.sectionSize)%CHTFrequencyClient == 0 { - log.Info("Storing CHT", "section", c.section*c.sectionSize/CHTFrequencyClient, "head", c.lastHash, "root", root) + log.Info("Storing CHT", "section", c.section*c.sectionSize/CHTFrequencyClient, "head", fmt.Sprintf("%064x", c.lastHash), "root", fmt.Sprintf("%064x", root)) } StoreChtRoot(c.diskdb, c.section, c.lastHash, root) return nil } const ( - BloomTrieFrequency = 32768 - ethBloomBitsSection = 4096 - ethBloomBitsConfirmations = 256 + BloomTrieFrequency = 32768 + ethBloomBitsSection = 4096 ) var ( @@ -215,7 +254,8 @@ func StoreBloomTrieRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root // BloomTrieIndexerBackend implements core.ChainIndexerBackend type BloomTrieIndexerBackend struct { - diskdb ethdb.Database + diskdb, trieTable ethdb.Database + odr OdrBackend triedb *trie.Database section, parentSectionSize, bloomTrieRatio uint64 trie *trie.Trie @@ -223,44 +263,98 @@ type BloomTrieIndexerBackend struct { } // NewBloomTrieIndexer creates a BloomTrie chain indexer -func NewBloomTrieIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer { +func NewBloomTrieIndexer(db ethdb.Database, clientMode bool, odr OdrBackend) *core.ChainIndexer { + trieTable := ethdb.NewTable(db, BloomTrieTablePrefix) backend := &BloomTrieIndexerBackend{ - diskdb: db, - triedb: trie.NewDatabase(ethdb.NewTable(db, BloomTrieTablePrefix)), + diskdb: db, + odr: odr, + trieTable: trieTable, + triedb: trie.NewDatabase(trieTable), } idb := ethdb.NewTable(db, "bltIndex-") - var confirmReq uint64 if clientMode { backend.parentSectionSize = BloomTrieFrequency - confirmReq = HelperTrieConfirmations } else { backend.parentSectionSize = ethBloomBitsSection - confirmReq = HelperTrieProcessConfirmations } backend.bloomTrieRatio = BloomTrieFrequency / backend.parentSectionSize backend.sectionHeads = make([]common.Hash, backend.bloomTrieRatio) - return core.NewChainIndexer(db, idb, backend, BloomTrieFrequency, confirmReq-ethBloomBitsConfirmations, time.Millisecond*100, "bloomtrie") + return core.NewChainIndexer(db, idb, backend, BloomTrieFrequency, 0, time.Millisecond*100, "bloomtrie") +} + +// fetchMissingNodes tries to retrieve the last entries of the latest trusted bloom trie from the +// ODR backend in order to be able to add new entries and calculate subsequent root hashes +func (b *BloomTrieIndexerBackend) fetchMissingNodes(ctx context.Context, section uint64, root common.Hash) error { + indexCh := make(chan uint, types.BloomBitLength) + type res struct { + nodes *NodeSet + err error + } + resCh := make(chan res, types.BloomBitLength) + for i := 0; i < 20; i++ { + go func() { + for bitIndex := range indexCh { + r := &BloomRequest{BloomTrieRoot: root, BloomTrieNum: section - 1, BitIdx: bitIndex, SectionIdxList: []uint64{section - 1}} + for { + if err := b.odr.Retrieve(ctx, r); err == ErrNoPeers { + // if there are no peers to serve, retry later + select { + case <-ctx.Done(): + resCh <- res{nil, ctx.Err()} + return + case <-time.After(time.Second * 10): + // stay in the loop and try again + } + } else { + resCh <- res{r.Proofs, err} + break + } + } + } + }() + } + + for i := uint(0); i < types.BloomBitLength; i++ { + indexCh <- i + } + close(indexCh) + batch := b.trieTable.NewBatch() + for i := uint(0); i < types.BloomBitLength; i++ { + res := <-resCh + if res.err != nil { + return res.err + } + res.nodes.Store(batch) + } + return batch.Write() } // Reset implements core.ChainIndexerBackend -func (b *BloomTrieIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error { +func (b *BloomTrieIndexerBackend) Reset(ctx context.Context, section uint64, lastSectionHead common.Hash) error { var root common.Hash if section > 0 { root = GetBloomTrieRoot(b.diskdb, section-1, lastSectionHead) } var err error b.trie, err = trie.New(root, b.triedb) + if err != nil && b.odr != nil { + err = b.fetchMissingNodes(ctx, section, root) + if err == nil { + b.trie, err = trie.New(root, b.triedb) + } + } b.section = section return err } // Process implements core.ChainIndexerBackend -func (b *BloomTrieIndexerBackend) Process(header *types.Header) { +func (b *BloomTrieIndexerBackend) Process(ctx context.Context, header *types.Header) error { num := header.Number.Uint64() - b.section*BloomTrieFrequency if (num+1)%b.parentSectionSize == 0 { b.sectionHeads[num/b.parentSectionSize] = header.Hash() } + return nil } // Commit implements core.ChainIndexerBackend @@ -300,7 +394,7 @@ func (b *BloomTrieIndexerBackend) Commit() error { b.triedb.Commit(root, false) sectionHead := b.sectionHeads[b.bloomTrieRatio-1] - log.Info("Storing bloom trie", "section", b.section, "head", sectionHead, "root", root, "compression", float64(compSize)/float64(decompSize)) + log.Info("Storing bloom trie", "section", b.section, "head", fmt.Sprintf("%064x", sectionHead), "root", fmt.Sprintf("%064x", root), "compression", float64(compSize)/float64(decompSize)) StoreBloomTrieRoot(b.diskdb, b.section, sectionHead, root) return nil diff --git a/miner/worker.go b/miner/worker.go index 81a63c29a..e7e279645 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -40,19 +40,27 @@ import ( const ( // resultQueueSize is the size of channel listening to sealing result. resultQueueSize = 10 + // txChanSize is the size of channel listening to NewTxsEvent. // The number is referenced from the size of tx pool. txChanSize = 4096 + // chainHeadChanSize is the size of channel listening to ChainHeadEvent. chainHeadChanSize = 10 + // chainSideChanSize is the size of channel listening to ChainSideEvent. chainSideChanSize = 10 - miningLogAtDepth = 5 + + // miningLogAtDepth is the number of confirmations before logging successful mining. + miningLogAtDepth = 5 + + // blockRecommitInterval is the time interval to recreate the mining block with + // any newly arrived transactions. + blockRecommitInterval = 3 * time.Second ) -// Env is the worker's current environment and holds all of the current state information. -type Env struct { - config *params.ChainConfig +// environment is the worker's current environment and holds all of the current state information. +type environment struct { signer types.Signer state *state.StateDB // apply state changes here @@ -67,105 +75,6 @@ type Env struct { receipts []*types.Receipt } -func (env *Env) commitTransaction(tx *types.Transaction, bc *core.BlockChain, coinbase common.Address, gp *core.GasPool) (error, []*types.Log) { - snap := env.state.Snapshot() - - receipt, _, err := core.ApplyTransaction(env.config, bc, &coinbase, gp, env.state, env.header, tx, &env.header.GasUsed, vm.Config{}) - if err != nil { - env.state.RevertToSnapshot(snap) - return err, nil - } - env.txs = append(env.txs, tx) - env.receipts = append(env.receipts, receipt) - - return nil, receipt.Logs -} - -func (env *Env) commitTransactions(mux *event.TypeMux, txs *types.TransactionsByPriceAndNonce, bc *core.BlockChain, coinbase common.Address) { - if env.gasPool == nil { - env.gasPool = new(core.GasPool).AddGas(env.header.GasLimit) - } - - var coalescedLogs []*types.Log - - for { - // If we don't have enough gas for any further transactions then we're done - if env.gasPool.Gas() < params.TxGas { - log.Trace("Not enough gas for further transactions", "have", env.gasPool, "want", params.TxGas) - break - } - // Retrieve the next transaction and abort if all done - tx := txs.Peek() - if tx == nil { - break - } - // Error may be ignored here. The error has already been checked - // during transaction acceptance is the transaction pool. - // - // We use the eip155 signer regardless of the current hf. - from, _ := types.Sender(env.signer, tx) - // Check whether the tx is replay protected. If we're not in the EIP155 hf - // phase, start ignoring the sender until we do. - if tx.Protected() && !env.config.IsEIP155(env.header.Number) { - log.Trace("Ignoring reply protected transaction", "hash", tx.Hash(), "eip155", env.config.EIP155Block) - - txs.Pop() - continue - } - // Start executing the transaction - env.state.Prepare(tx.Hash(), common.Hash{}, env.tcount) - - err, logs := env.commitTransaction(tx, bc, coinbase, env.gasPool) - switch err { - case core.ErrGasLimitReached: - // Pop the current out-of-gas transaction without shifting in the next from the account - log.Trace("Gas limit exceeded for current block", "sender", from) - txs.Pop() - - case core.ErrNonceTooLow: - // New head notification data race between the transaction pool and miner, shift - log.Trace("Skipping transaction with low nonce", "sender", from, "nonce", tx.Nonce()) - txs.Shift() - - case core.ErrNonceTooHigh: - // Reorg notification data race between the transaction pool and miner, skip account = - log.Trace("Skipping account with hight nonce", "sender", from, "nonce", tx.Nonce()) - txs.Pop() - - case nil: - // Everything ok, collect the logs and shift in the next transaction from the same account - coalescedLogs = append(coalescedLogs, logs...) - env.tcount++ - txs.Shift() - - default: - // Strange error, discard the transaction and get the next in line (note, the - // nonce-too-high clause will prevent us from executing in vain). - log.Debug("Transaction failed, account skipped", "hash", tx.Hash(), "err", err) - txs.Shift() - } - } - - if len(coalescedLogs) > 0 || env.tcount > 0 { - // make a copy, the state caches the logs and these logs get "upgraded" from pending to mined - // logs by filling in the block hash when the block was mined by the local miner. This can - // cause a race condition if a log was "upgraded" before the PendingLogsEvent is processed. - cpy := make([]*types.Log, len(coalescedLogs)) - for i, l := range coalescedLogs { - cpy[i] = new(types.Log) - *cpy[i] = *l - } - go func(logs []*types.Log, tcount int) { - if len(logs) > 0 { - mux.Post(core.PendingLogsEvent{Logs: logs}) - } - if tcount > 0 { - mux.Post(core.PendingStateEvent{}) - } - }(cpy, env.tcount) - } -} - // task contains all information for consensus engine sealing and result submitting. type task struct { receipts []*types.Receipt @@ -174,6 +83,17 @@ type task struct { createdAt time.Time } +const ( + commitInterruptNone int32 = iota + commitInterruptNewHead + commitInterruptResubmit +) + +type newWorkReq struct { + interrupt *int32 + noempty bool +} + // worker is the main object which takes care of submitting new work to consensus engine // and gathering the sealing result. type worker struct { @@ -192,12 +112,13 @@ type worker struct { chainSideSub event.Subscription // Channels - newWork chan struct{} - taskCh chan *task - resultCh chan *task - exitCh chan struct{} + newWorkCh chan *newWorkReq + taskCh chan *task + resultCh chan *task + startCh chan struct{} + exitCh chan struct{} - current *Env // An environment for current running cycle. + current *environment // An environment for current running cycle. possibleUncles map[common.Hash]*types.Block // A set of side blocks as the possible uncle blocks. unconfirmed *unconfirmedBlocks // A set of locally mined blocks pending canonicalness confirmations. @@ -213,8 +134,9 @@ type worker struct { running int32 // The indicator whether the consensus engine is running or not. // Test hooks - newTaskHook func(*task) // Method to call upon receiving a new sealing task - fullTaskInterval func() // Method to call before pushing the full sealing task + newTaskHook func(*task) // Method to call upon receiving a new sealing task + skipSealHook func(*task) bool // Method to decide whether skipping the sealing. + fullTaskHook func() // Method to call before pushing the full sealing task } func newWorker(config *params.ChainConfig, engine consensus.Engine, eth Backend, mux *event.TypeMux) *worker { @@ -229,10 +151,11 @@ func newWorker(config *params.ChainConfig, engine consensus.Engine, eth Backend, txsCh: make(chan core.NewTxsEvent, txChanSize), chainHeadCh: make(chan core.ChainHeadEvent, chainHeadChanSize), chainSideCh: make(chan core.ChainSideEvent, chainSideChanSize), - newWork: make(chan struct{}, 1), + newWorkCh: make(chan *newWorkReq), taskCh: make(chan *task), resultCh: make(chan *task, resultQueueSize), exitCh: make(chan struct{}), + startCh: make(chan struct{}, 1), } // Subscribe NewTxsEvent for tx pool worker.txsSub = eth.TxPool().SubscribeNewTxsEvent(worker.txsCh) @@ -241,11 +164,13 @@ func newWorker(config *params.ChainConfig, engine consensus.Engine, eth Backend, worker.chainSideSub = eth.BlockChain().SubscribeChainSideEvent(worker.chainSideCh) go worker.mainLoop() + go worker.newWorkLoop() go worker.resultLoop() go worker.taskLoop() // Submit first work to initialize pending state. - worker.newWork <- struct{}{} + worker.startCh <- struct{}{} + return worker } @@ -285,7 +210,7 @@ func (w *worker) pendingBlock() *types.Block { // start sets the running status as 1 and triggers new work submitting. func (w *worker) start() { atomic.StoreInt32(&w.running, 1) - w.newWork <- struct{}{} + w.startCh <- struct{}{} } // stop sets the running status as 0. @@ -312,6 +237,44 @@ func (w *worker) close() { } } +// newWorkLoop is a standalone goroutine to submit new mining work upon received events. +func (w *worker) newWorkLoop() { + var interrupt *int32 + + timer := time.NewTimer(0) + <-timer.C // discard the initial tick + + // recommit aborts in-flight transaction execution with given signal and resubmits a new one. + recommit := func(noempty bool, s int32) { + if interrupt != nil { + atomic.StoreInt32(interrupt, s) + } + interrupt = new(int32) + w.newWorkCh <- &newWorkReq{interrupt: interrupt, noempty: noempty} + timer.Reset(blockRecommitInterval) + } + + for { + select { + case <-w.startCh: + recommit(false, commitInterruptNewHead) + + case <-w.chainHeadCh: + recommit(false, commitInterruptNewHead) + + case <-timer.C: + // If mining is running resubmit a new work cycle periodically to pull in + // higher priced transactions. Disable this overhead for pending blocks. + if w.isRunning() && (w.config.Clique == nil || w.config.Clique.Period > 0) { + recommit(true, commitInterruptResubmit) + } + + case <-w.exitCh: + return + } + } +} + // mainLoop is a standalone goroutine to regenerate the sealing task based on the received event. func (w *worker) mainLoop() { defer w.txsSub.Unsubscribe() @@ -320,17 +283,36 @@ func (w *worker) mainLoop() { for { select { - case <-w.newWork: - // Submit a work when the worker is created or started. - w.commitNewWork() - - case <-w.chainHeadCh: - // Resubmit a work for new cycle once worker receives chain head event. - w.commitNewWork() + case req := <-w.newWorkCh: + w.commitNewWork(req.interrupt, req.noempty) case ev := <-w.chainSideCh: + if _, exist := w.possibleUncles[ev.Block.Hash()]; exist { + continue + } // Add side block to possible uncle block set. w.possibleUncles[ev.Block.Hash()] = ev.Block + // If our mining block contains less than 2 uncle blocks, + // add the new uncle block if valid and regenerate a mining block. + if w.isRunning() && w.current != nil && w.current.uncles.Cardinality() < 2 { + start := time.Now() + if err := w.commitUncle(w.current, ev.Block.Header()); err == nil { + var uncles []*types.Header + w.current.uncles.Each(func(item interface{}) bool { + hash, ok := item.(common.Hash) + if !ok { + return false + } + uncle, exist := w.possibleUncles[hash] + if !exist { + return false + } + uncles = append(uncles, uncle.Header()) + return true + }) + w.commit(uncles, nil, true, start) + } + } case ev := <-w.txsCh: // Apply transactions to the pending state if we're not mining. @@ -339,9 +321,9 @@ func (w *worker) mainLoop() { // already included in the current mining block. These transactions will // be automatically eliminated. if !w.isRunning() && w.current != nil { - w.mu.Lock() + w.mu.RLock() coinbase := w.coinbase - w.mu.Unlock() + w.mu.RUnlock() txs := make(map[common.Address]types.Transactions) for _, tx := range ev.Txs { @@ -349,12 +331,12 @@ func (w *worker) mainLoop() { txs[acc] = append(txs[acc], tx) } txset := types.NewTransactionsByPriceAndNonce(w.current.signer, txs) - w.current.commitTransactions(w.mux, txset, w.chain, coinbase) + w.commitTransactions(txset, coinbase, nil) w.updateSnapshot() } else { // If we're mining, but nothing is being processed, wake on new transactions if w.config.Clique != nil && w.config.Clique.Period == 0 { - w.commitNewWork() + w.commitNewWork(nil, false) } } @@ -378,6 +360,10 @@ func (w *worker) seal(t *task, stop <-chan struct{}) { res *task ) + if w.skipSealHook != nil && w.skipSealHook(t) { + return + } + if t.block, err = w.engine.Seal(w.chain, t.block, stop); t.block != nil { log.Info("Successfully sealed new block", "number", t.block.Number(), "hash", t.block.Hash(), "elapsed", common.PrettyDuration(time.Since(t.createdAt))) @@ -479,8 +465,7 @@ func (w *worker) makeCurrent(parent *types.Block, header *types.Header) error { if err != nil { return err } - env := &Env{ - config: w.config, + env := &environment{ signer: types.NewEIP155Signer(w.config.ChainID), state: state, ancestors: mapset.NewSet(), @@ -505,7 +490,7 @@ func (w *worker) makeCurrent(parent *types.Block, header *types.Header) error { } // commitUncle adds the given block to uncle block set, returns error if failed to add. -func (w *worker) commitUncle(env *Env, uncle *types.Header) error { +func (w *worker) commitUncle(env *environment, uncle *types.Header) error { hash := uncle.Hash() if env.uncles.Contains(hash) { return fmt.Errorf("uncle not unique") @@ -550,8 +535,120 @@ func (w *worker) updateSnapshot() { w.snapshotState = w.current.state.Copy() } +func (w *worker) commitTransaction(tx *types.Transaction, coinbase common.Address) ([]*types.Log, error) { + snap := w.current.state.Snapshot() + + receipt, _, err := core.ApplyTransaction(w.config, w.chain, &coinbase, w.current.gasPool, w.current.state, w.current.header, tx, &w.current.header.GasUsed, vm.Config{}) + if err != nil { + w.current.state.RevertToSnapshot(snap) + return nil, err + } + w.current.txs = append(w.current.txs, tx) + w.current.receipts = append(w.current.receipts, receipt) + + return receipt.Logs, nil +} + +func (w *worker) commitTransactions(txs *types.TransactionsByPriceAndNonce, coinbase common.Address, interrupt *int32) bool { + // Short circuit if current is nil + if w.current == nil { + return true + } + + if w.current.gasPool == nil { + w.current.gasPool = new(core.GasPool).AddGas(w.current.header.GasLimit) + } + + var coalescedLogs []*types.Log + + for { + // In the following three cases, we will interrupt the execution of the transaction. + // (1) new head block event arrival, the interrupt signal is 1 + // (2) worker start or restart, the interrupt signal is 1 + // (3) worker recreate the mining block with any newly arrived transactions, the interrupt signal is 2. + // For the first two cases, the semi-finished work will be discarded. + // For the third case, the semi-finished work will be submitted to the consensus engine. + // TODO(rjl493456442) give feedback to newWorkLoop to adjust resubmit interval if it is too short. + if interrupt != nil && atomic.LoadInt32(interrupt) != commitInterruptNone { + return atomic.LoadInt32(interrupt) == commitInterruptNewHead + } + // If we don't have enough gas for any further transactions then we're done + if w.current.gasPool.Gas() < params.TxGas { + log.Trace("Not enough gas for further transactions", "have", w.current.gasPool, "want", params.TxGas) + break + } + // Retrieve the next transaction and abort if all done + tx := txs.Peek() + if tx == nil { + break + } + // Error may be ignored here. The error has already been checked + // during transaction acceptance is the transaction pool. + // + // We use the eip155 signer regardless of the current hf. + from, _ := types.Sender(w.current.signer, tx) + // Check whether the tx is replay protected. If we're not in the EIP155 hf + // phase, start ignoring the sender until we do. + if tx.Protected() && !w.config.IsEIP155(w.current.header.Number) { + log.Trace("Ignoring reply protected transaction", "hash", tx.Hash(), "eip155", w.config.EIP155Block) + + txs.Pop() + continue + } + // Start executing the transaction + w.current.state.Prepare(tx.Hash(), common.Hash{}, w.current.tcount) + + logs, err := w.commitTransaction(tx, coinbase) + switch err { + case core.ErrGasLimitReached: + // Pop the current out-of-gas transaction without shifting in the next from the account + log.Trace("Gas limit exceeded for current block", "sender", from) + txs.Pop() + + case core.ErrNonceTooLow: + // New head notification data race between the transaction pool and miner, shift + log.Trace("Skipping transaction with low nonce", "sender", from, "nonce", tx.Nonce()) + txs.Shift() + + case core.ErrNonceTooHigh: + // Reorg notification data race between the transaction pool and miner, skip account = + log.Trace("Skipping account with hight nonce", "sender", from, "nonce", tx.Nonce()) + txs.Pop() + + case nil: + // Everything ok, collect the logs and shift in the next transaction from the same account + coalescedLogs = append(coalescedLogs, logs...) + w.current.tcount++ + txs.Shift() + + default: + // Strange error, discard the transaction and get the next in line (note, the + // nonce-too-high clause will prevent us from executing in vain). + log.Debug("Transaction failed, account skipped", "hash", tx.Hash(), "err", err) + txs.Shift() + } + } + + if !w.isRunning() && len(coalescedLogs) > 0 { + // We don't push the pendingLogsEvent while we are mining. The reason is that + // when we are mining, the worker will regenerate a mining block every 3 seconds. + // In order to avoid pushing the repeated pendingLog, we disable the pending log pushing. + + // make a copy, the state caches the logs and these logs get "upgraded" from pending to mined + // logs by filling in the block hash when the block was mined by the local miner. This can + // cause a race condition if a log was "upgraded" before the PendingLogsEvent is processed. + cpy := make([]*types.Log, len(coalescedLogs)) + for i, l := range coalescedLogs { + cpy[i] = new(types.Log) + *cpy[i] = *l + } + go w.mux.Post(core.PendingLogsEvent{Logs: cpy}) + } + return false +} + // commitNewWork generates several new sealing tasks based on the parent block. -func (w *worker) commitNewWork() { +func (w *worker) commitNewWork(interrupt *int32, noempty bool) { w.mu.RLock() defer w.mu.RUnlock() @@ -637,29 +734,10 @@ func (w *worker) commitNewWork() { delete(w.possibleUncles, hash) } - var ( - emptyBlock, fullBlock *types.Block - emptyState, fullState *state.StateDB - ) - - // Create an empty block based on temporary copied state for sealing in advance without waiting block - // execution finished. - emptyState = env.state.Copy() - if emptyBlock, err = w.engine.Finalize(w.chain, header, emptyState, nil, uncles, nil); err != nil { - log.Error("Failed to finalize block for temporary sealing", "err", err) - } else { - // Push empty work in advance without applying pending transaction. - // The reason is transactions execution can cost a lot and sealer need to - // take advantage of this part time. - if w.isRunning() { - select { - case w.taskCh <- &task{receipts: nil, state: emptyState, block: emptyBlock, createdAt: time.Now()}: - log.Info("Commit new empty mining work", "number", emptyBlock.Number(), "uncles", len(uncles)) - case <-w.exitCh: - log.Info("Worker has exited") - return - } - } + if !noempty { + // Create an empty block based on temporary copied state for sealing in advance without waiting block + // execution finished. + w.commit(uncles, nil, false, tstart) } // Fill the block with all available pending transactions. @@ -674,33 +752,50 @@ func (w *worker) commitNewWork() { return } txs := types.NewTransactionsByPriceAndNonce(w.current.signer, pending) - env.commitTransactions(w.mux, txs, w.chain, w.coinbase) - - // Create the full block to seal with the consensus engine - fullState = env.state.Copy() - if fullBlock, err = w.engine.Finalize(w.chain, header, fullState, env.txs, uncles, env.receipts); err != nil { - log.Error("Failed to finalize block for sealing", "err", err) + if w.commitTransactions(txs, w.coinbase, interrupt) { return } + + w.commit(uncles, w.fullTaskHook, true, tstart) +} + +// commit runs any post-transaction state modifications, assembles the final block +// and commits new work if consensus engine is running. +func (w *worker) commit(uncles []*types.Header, interval func(), update bool, start time.Time) error { // Deep copy receipts here to avoid interaction between different tasks. - cpy := make([]*types.Receipt, len(env.receipts)) - for i, l := range env.receipts { - cpy[i] = new(types.Receipt) - *cpy[i] = *l + receipts := make([]*types.Receipt, len(w.current.receipts)) + for i, l := range w.current.receipts { + receipts[i] = new(types.Receipt) + *receipts[i] = *l + } + s := w.current.state.Copy() + block, err := w.engine.Finalize(w.chain, w.current.header, s, w.current.txs, uncles, w.current.receipts) + if err != nil { + return err } - // We only care about logging if we're actually mining. if w.isRunning() { - if w.fullTaskInterval != nil { - w.fullTaskInterval() + if interval != nil { + interval() } - select { - case w.taskCh <- &task{receipts: cpy, state: fullState, block: fullBlock, createdAt: time.Now()}: - w.unconfirmed.Shift(fullBlock.NumberU64() - 1) - log.Info("Commit new full mining work", "number", fullBlock.Number(), "txs", env.tcount, "uncles", len(uncles), "elapsed", common.PrettyDuration(time.Since(tstart))) + case w.taskCh <- &task{receipts: receipts, state: s, block: block, createdAt: time.Now()}: + w.unconfirmed.Shift(block.NumberU64() - 1) + + feesWei := new(big.Int) + for _, tx := range block.Transactions() { + feesWei.Add(feesWei, new(big.Int).Mul(new(big.Int).SetUint64(tx.Gas()), tx.GasPrice())) + } + feesEth := new(big.Float).Quo(new(big.Float).SetInt(feesWei), new(big.Float).SetInt(big.NewInt(params.Ether))) + + log.Info("Commit new mining work", "number", block.Number(), "uncles", len(uncles), "txs", w.current.tcount, + "gas", block.GasUsed(), "fees", feesEth, "elapsed", common.PrettyDuration(time.Since(start))) + case <-w.exitCh: log.Info("Worker has exited") } } - w.updateSnapshot() + if update { + w.updateSnapshot() + } + return nil } diff --git a/miner/worker_test.go b/miner/worker_test.go index 5823a608e..34bb7f5f3 100644 --- a/miner/worker_test.go +++ b/miner/worker_test.go @@ -59,7 +59,7 @@ func init() { ethashChainConfig = params.TestChainConfig cliqueChainConfig = params.TestChainConfig cliqueChainConfig.Clique = ¶ms.CliqueConfig{ - Period: 1, + Period: 10, Epoch: 30000, } tx1, _ := types.SignTx(types.NewTransaction(0, acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), types.HomesteadSigner{}, testBankKey) @@ -74,6 +74,7 @@ type testWorkerBackend struct { txPool *core.TxPool chain *core.BlockChain testTxFeed event.Feed + uncleBlock *types.Block } func newTestWorkerBackend(t *testing.T, chainConfig *params.ChainConfig, engine consensus.Engine) *testWorkerBackend { @@ -93,15 +94,19 @@ func newTestWorkerBackend(t *testing.T, chainConfig *params.ChainConfig, engine default: t.Fatal("unexpect consensus engine type") } - gspec.MustCommit(db) + genesis := gspec.MustCommit(db) chain, _ := core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}) txpool := core.NewTxPool(testTxPoolConfig, chainConfig, chain) + blocks, _ := core.GenerateChain(chainConfig, genesis, engine, db, 1, func(i int, gen *core.BlockGen) { + gen.SetCoinbase(acc1Addr) + }) return &testWorkerBackend{ - db: db, - chain: chain, - txPool: txpool, + db: db, + chain: chain, + txPool: txpool, + uncleBlock: blocks[0], } } @@ -188,7 +193,7 @@ func testEmptyWork(t *testing.T, chainConfig *params.ChainConfig, engine consens taskCh <- struct{}{} } } - w.fullTaskInterval = func() { + w.fullTaskHook = func() { time.Sleep(100 * time.Millisecond) } @@ -202,11 +207,131 @@ func testEmptyWork(t *testing.T, chainConfig *params.ChainConfig, engine consens w.start() for i := 0; i < 2; i += 1 { - to := time.NewTimer(time.Second) select { case <-taskCh: - case <-to.C: + case <-time.NewTimer(time.Second).C: t.Error("new task timeout") } } } + +func TestStreamUncleBlock(t *testing.T) { + ethash := ethash.NewFaker() + defer ethash.Close() + + w, b := newTestWorker(t, ethashChainConfig, ethash) + defer w.close() + + var taskCh = make(chan struct{}) + + taskIndex := 0 + w.newTaskHook = func(task *task) { + if task.block.NumberU64() == 1 { + if taskIndex == 2 { + has := task.block.Header().UncleHash + want := types.CalcUncleHash([]*types.Header{b.uncleBlock.Header()}) + if has != want { + t.Errorf("uncle hash mismatch, has %s, want %s", has.Hex(), want.Hex()) + } + } + taskCh <- struct{}{} + taskIndex += 1 + } + } + w.skipSealHook = func(task *task) bool { + return true + } + w.fullTaskHook = func() { + time.Sleep(100 * time.Millisecond) + } + + // Ensure worker has finished initialization + for { + b := w.pendingBlock() + if b != nil && b.NumberU64() == 1 { + break + } + } + + w.start() + // Ignore the first two works + for i := 0; i < 2; i += 1 { + select { + case <-taskCh: + case <-time.NewTimer(time.Second).C: + t.Error("new task timeout") + } + } + b.PostChainEvents([]interface{}{core.ChainSideEvent{Block: b.uncleBlock}}) + + select { + case <-taskCh: + case <-time.NewTimer(time.Second).C: + t.Error("new task timeout") + } +} + +func TestRegenerateMiningBlockEthash(t *testing.T) { + testRegenerateMiningBlock(t, ethashChainConfig, ethash.NewFaker()) +} + +func TestRegenerateMiningBlockClique(t *testing.T) { + testRegenerateMiningBlock(t, cliqueChainConfig, clique.New(cliqueChainConfig.Clique, ethdb.NewMemDatabase())) +} + +func testRegenerateMiningBlock(t *testing.T, chainConfig *params.ChainConfig, engine consensus.Engine) { + defer engine.Close() + + w, b := newTestWorker(t, chainConfig, engine) + defer w.close() + + var taskCh = make(chan struct{}) + + taskIndex := 0 + w.newTaskHook = func(task *task) { + if task.block.NumberU64() == 1 { + if taskIndex == 2 { + receiptLen, balance := 2, big.NewInt(2000) + if len(task.receipts) != receiptLen { + t.Errorf("receipt number mismatch has %d, want %d", len(task.receipts), receiptLen) + } + if task.state.GetBalance(acc1Addr).Cmp(balance) != 0 { + t.Errorf("account balance mismatch has %d, want %d", task.state.GetBalance(acc1Addr), balance) + } + } + taskCh <- struct{}{} + taskIndex += 1 + } + } + w.skipSealHook = func(task *task) bool { + return true + } + w.fullTaskHook = func() { + time.Sleep(100 * time.Millisecond) + } + // Ensure worker has finished initialization + for { + b := w.pendingBlock() + if b != nil && b.NumberU64() == 1 { + break + } + } + + w.start() + // Ignore the first two works + for i := 0; i < 2; i += 1 { + select { + case <-taskCh: + case <-time.NewTimer(time.Second).C: + t.Error("new task timeout") + } + } + b.txPool.AddLocals(newTxs) + time.Sleep(3 * time.Second) + + select { + case <-taskCh: + case <-time.NewTimer(time.Second).C: + t.Error("new task timeout") + } +} diff --git a/swarm/api/act.go b/swarm/api/act.go new file mode 100644 index 000000000..b1a594783 --- /dev/null +++ b/swarm/api/act.go @@ -0,0 +1,468 @@ +package api + +import ( + "context" + "crypto/ecdsa" + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "strings" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/crypto/ecies" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/swarm/log" + "github.com/ethereum/go-ethereum/swarm/sctx" + "github.com/ethereum/go-ethereum/swarm/storage" + "golang.org/x/crypto/scrypt" + cli "gopkg.in/urfave/cli.v1" +) + +var ( + ErrDecrypt = errors.New("cant decrypt - forbidden") + ErrUnknownAccessType = errors.New("unknown access type (or not implemented)") + ErrDecryptDomainForbidden = errors.New("decryption request domain forbidden - can only decrypt on localhost") + AllowedDecryptDomains = []string{ + "localhost", + "127.0.0.1", + } +) + +const EMPTY_CREDENTIALS = "" + +type AccessEntry struct { + Type AccessType + Publisher string + Salt []byte + Act string + KdfParams *KdfParams +} + +type DecryptFunc func(*ManifestEntry) error + +func (a *AccessEntry) MarshalJSON() (out []byte, err error) { + + return json.Marshal(struct { + Type AccessType `json:"type,omitempty"` + Publisher string `json:"publisher,omitempty"` + Salt string `json:"salt,omitempty"` + Act string `json:"act,omitempty"` + KdfParams *KdfParams `json:"kdf_params,omitempty"` + }{ + Type: a.Type, + Publisher: a.Publisher, + Salt: hex.EncodeToString(a.Salt), + Act: a.Act, + KdfParams: a.KdfParams, + }) + +} + +func (a *AccessEntry) UnmarshalJSON(value []byte) error { + v := struct { + Type AccessType `json:"type,omitempty"` + Publisher string `json:"publisher,omitempty"` + Salt string `json:"salt,omitempty"` + Act string `json:"act,omitempty"` + KdfParams *KdfParams `json:"kdf_params,omitempty"` + }{} + + err := json.Unmarshal(value, &v) + if err != nil { + return err + } + a.Act = v.Act + a.KdfParams = v.KdfParams + a.Publisher = v.Publisher + a.Salt, err = hex.DecodeString(v.Salt) + if err != nil { + return err + } + if len(a.Salt) != 32 { + return errors.New("salt should be 32 bytes long") + } + a.Type = v.Type + return nil +} + +type KdfParams struct { + N int `json:"n"` + P int `json:"p"` + R int `json:"r"` +} + +type AccessType string + +const AccessTypePass = AccessType("pass") +const AccessTypePK = AccessType("pk") +const AccessTypeACT = AccessType("act") + +func NewAccessEntryPassword(salt []byte, kdfParams *KdfParams) (*AccessEntry, error) { + if len(salt) != 32 { + return nil, fmt.Errorf("salt should be 32 bytes long") + } + return &AccessEntry{ + Type: AccessTypePass, + Salt: salt, + KdfParams: kdfParams, + }, nil +} + +func NewAccessEntryPK(publisher string, salt []byte) (*AccessEntry, error) { + if len(publisher) != 66 { + return nil, fmt.Errorf("publisher should be 66 characters long, got %d", len(publisher)) + } + if len(salt) != 32 { + return nil, fmt.Errorf("salt should be 32 bytes long") + } + return &AccessEntry{ + Type: AccessTypePK, + Publisher: publisher, + Salt: salt, + }, nil +} + +func NewAccessEntryACT(publisher string, salt []byte, act string) (*AccessEntry, error) { + if len(salt) != 32 { + return nil, fmt.Errorf("salt should be 32 bytes long") + } + if len(publisher) != 66 { + return nil, fmt.Errorf("publisher should be 66 characters long") + } + + return &AccessEntry{ + Type: AccessTypeACT, + Publisher: publisher, + Salt: salt, + Act: act, + }, nil +} + +func NOOPDecrypt(*ManifestEntry) error { + return nil +} + +var DefaultKdfParams = NewKdfParams(262144, 1, 8) + +func NewKdfParams(n, p, r int) *KdfParams { + + return &KdfParams{ + N: n, + P: p, + R: r, + } +} + +// NewSessionKeyPassword creates a session key based on a shared secret (password) and the given salt +// and kdf parameters in the access entry +func NewSessionKeyPassword(password string, accessEntry *AccessEntry) ([]byte, error) { + if accessEntry.Type != AccessTypePass { + return nil, errors.New("incorrect access entry type") + } + return scrypt.Key( + []byte(password), + accessEntry.Salt, + accessEntry.KdfParams.N, + accessEntry.KdfParams.R, + accessEntry.KdfParams.P, + 32, + ) +} + +// NewSessionKeyPK creates a new ACT Session Key using an ECDH shared secret for the given key pair and the given salt value +func NewSessionKeyPK(private *ecdsa.PrivateKey, public *ecdsa.PublicKey, salt []byte) ([]byte, error) { + granteePubEcies := ecies.ImportECDSAPublic(public) + privateKey := ecies.ImportECDSA(private) + + bytes, err := privateKey.GenerateShared(granteePubEcies, 16, 16) + if err != nil { + return nil, err + } + bytes = append(salt, bytes...) + sessionKey := crypto.Keccak256(bytes) + return sessionKey, nil +} + +func (a *API) NodeSessionKey(privateKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey, salt []byte) ([]byte, error) { + return NewSessionKeyPK(privateKey, publicKey, salt) +} +func (a *API) doDecrypt(ctx context.Context, credentials string, pk *ecdsa.PrivateKey) DecryptFunc { + return func(m *ManifestEntry) error { + if m.Access == nil { + return nil + } + + allowed := false + requestDomain := sctx.GetHost(ctx) + for _, v := range AllowedDecryptDomains { + if strings.Contains(requestDomain, v) { + allowed = true + } + } + + if !allowed { + return ErrDecryptDomainForbidden + } + + switch m.Access.Type { + case "pass": + if credentials != "" { + key, err := NewSessionKeyPassword(credentials, m.Access) + if err != nil { + return err + } + + ref, err := hex.DecodeString(m.Hash) + if err != nil { + return err + } + + enc := NewRefEncryption(len(ref) - 8) + decodedRef, err := enc.Decrypt(ref, key) + if err != nil { + return ErrDecrypt + } + + m.Hash = hex.EncodeToString(decodedRef) + m.Access = nil + return nil + } + return ErrDecrypt + case "pk": + publisherBytes, err := hex.DecodeString(m.Access.Publisher) + if err != nil { + return ErrDecrypt + } + publisher, err := crypto.DecompressPubkey(publisherBytes) + if err != nil { + return ErrDecrypt + } + key, err := a.NodeSessionKey(pk, publisher, m.Access.Salt) + if err != nil { + return ErrDecrypt + } + ref, err := hex.DecodeString(m.Hash) + if err != nil { + return err + } + + enc := NewRefEncryption(len(ref) - 8) + decodedRef, err := enc.Decrypt(ref, key) + if err != nil { + return ErrDecrypt + } + + m.Hash = hex.EncodeToString(decodedRef) + m.Access = nil + return nil + case "act": + publisherBytes, err := hex.DecodeString(m.Access.Publisher) + if err != nil { + return ErrDecrypt + } + publisher, err := crypto.DecompressPubkey(publisherBytes) + if err != nil { + return ErrDecrypt + } + + sessionKey, err := a.NodeSessionKey(pk, publisher, m.Access.Salt) + if err != nil { + return ErrDecrypt + } + + hasher := sha3.NewKeccak256() + hasher.Write(append(sessionKey, 0)) + lookupKey := hasher.Sum(nil) + + hasher.Reset() + + hasher.Write(append(sessionKey, 1)) + accessKeyDecryptionKey := hasher.Sum(nil) + + lk := hex.EncodeToString(lookupKey) + list, err := a.GetManifestList(ctx, NOOPDecrypt, storage.Address(common.Hex2Bytes(m.Access.Act)), lk) + + found := "" + for _, v := range list.Entries { + if v.Path == lk { + found = v.Hash + } + } + + if found == "" { + return ErrDecrypt + } + + v, err := hex.DecodeString(found) + if err != nil { + return err + } + enc := NewRefEncryption(len(v) - 8) + decodedRef, err := enc.Decrypt(v, accessKeyDecryptionKey) + if err != nil { + return ErrDecrypt + } + + ref, err := hex.DecodeString(m.Hash) + if err != nil { + return err + } + + enc = NewRefEncryption(len(ref) - 8) + decodedMainRef, err := enc.Decrypt(ref, decodedRef) + if err != nil { + return ErrDecrypt + } + m.Hash = hex.EncodeToString(decodedMainRef) + m.Access = nil + return nil + } + return ErrUnknownAccessType + } +} + +func GenerateAccessControlManifest(ctx *cli.Context, ref string, accessKey []byte, ae *AccessEntry) (*Manifest, error) { + refBytes, err := hex.DecodeString(ref) + if err != nil { + return nil, err + } + // encrypt ref with accessKey + enc := NewRefEncryption(len(refBytes)) + encrypted, err := enc.Encrypt(refBytes, accessKey) + if err != nil { + return nil, err + } + + m := &Manifest{ + Entries: []ManifestEntry{ + { + Hash: hex.EncodeToString(encrypted), + ContentType: ManifestType, + ModTime: time.Now(), + Access: ae, + }, + }, + } + + return m, nil +} + +func DoPKNew(ctx *cli.Context, privateKey *ecdsa.PrivateKey, granteePublicKey string, salt []byte) (sessionKey []byte, ae *AccessEntry, err error) { + if granteePublicKey == "" { + return nil, nil, errors.New("need a grantee Public Key") + } + b, err := hex.DecodeString(granteePublicKey) + if err != nil { + log.Error("error decoding grantee public key", "err", err) + return nil, nil, err + } + + granteePub, err := crypto.DecompressPubkey(b) + if err != nil { + log.Error("error decompressing grantee public key", "err", err) + return nil, nil, err + } + + sessionKey, err = NewSessionKeyPK(privateKey, granteePub, salt) + if err != nil { + log.Error("error getting session key", "err", err) + return nil, nil, err + } + + ae, err = NewAccessEntryPK(hex.EncodeToString(crypto.CompressPubkey(&privateKey.PublicKey)), salt) + if err != nil { + log.Error("error generating access entry", "err", err) + return nil, nil, err + } + + return sessionKey, ae, nil +} + +func DoACTNew(ctx *cli.Context, privateKey *ecdsa.PrivateKey, salt []byte, grantees []string) (accessKey []byte, ae *AccessEntry, actManifest *Manifest, err error) { + if len(grantees) == 0 { + return nil, nil, nil, errors.New("did not get any grantee public keys") + } + + publisherPub := hex.EncodeToString(crypto.CompressPubkey(&privateKey.PublicKey)) + grantees = append(grantees, publisherPub) + + accessKey = make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, salt); err != nil { + panic("reading from crypto/rand failed: " + err.Error()) + } + if _, err := io.ReadFull(rand.Reader, accessKey); err != nil { + panic("reading from crypto/rand failed: " + err.Error()) + } + + lookupPathEncryptedAccessKeyMap := make(map[string]string) + i := 0 + for _, v := range grantees { + i++ + if v == "" { + return nil, nil, nil, errors.New("need a grantee Public Key") + } + b, err := hex.DecodeString(v) + if err != nil { + log.Error("error decoding grantee public key", "err", err) + return nil, nil, nil, err + } + + granteePub, err := crypto.DecompressPubkey(b) + if err != nil { + log.Error("error decompressing grantee public key", "err", err) + return nil, nil, nil, err + } + sessionKey, err := NewSessionKeyPK(privateKey, granteePub, salt) + + hasher := sha3.NewKeccak256() + hasher.Write(append(sessionKey, 0)) + lookupKey := hasher.Sum(nil) + + hasher.Reset() + hasher.Write(append(sessionKey, 1)) + + accessKeyEncryptionKey := hasher.Sum(nil) + + enc := NewRefEncryption(len(accessKey)) + encryptedAccessKey, err := enc.Encrypt(accessKey, accessKeyEncryptionKey) + + lookupPathEncryptedAccessKeyMap[hex.EncodeToString(lookupKey)] = hex.EncodeToString(encryptedAccessKey) + } + + m := &Manifest{ + Entries: []ManifestEntry{}, + } + + for k, v := range lookupPathEncryptedAccessKeyMap { + m.Entries = append(m.Entries, ManifestEntry{ + Path: k, + Hash: v, + ContentType: "text/plain", + }) + } + + ae, err = NewAccessEntryACT(hex.EncodeToString(crypto.CompressPubkey(&privateKey.PublicKey)), salt, "") + if err != nil { + return nil, nil, nil, err + } + + return accessKey, ae, m, nil +} + +func DoPasswordNew(ctx *cli.Context, password string, salt []byte) (sessionKey []byte, ae *AccessEntry, err error) { + ae, err = NewAccessEntryPassword(salt, DefaultKdfParams) + if err != nil { + return nil, nil, err + } + + sessionKey, err = NewSessionKeyPassword(password, ae) + if err != nil { + return nil, nil, err + } + return sessionKey, ae, nil +} diff --git a/swarm/api/api.go b/swarm/api/api.go index 99d971b10..adf469cfa 100644 --- a/swarm/api/api.go +++ b/swarm/api/api.go @@ -19,6 +19,9 @@ package api import ( "archive/tar" "context" + "crypto/ecdsa" + "encoding/hex" + "errors" "fmt" "io" "math/big" @@ -44,6 +47,10 @@ import ( ) var ( + ErrNotFound = errors.New("not found") +) + +var ( apiResolveCount = metrics.NewRegisteredCounter("api.resolve.count", nil) apiResolveFail = metrics.NewRegisteredCounter("api.resolve.fail", nil) apiPutCount = metrics.NewRegisteredCounter("api.put.count", nil) @@ -227,14 +234,18 @@ type API struct { resource *mru.Handler fileStore *storage.FileStore dns Resolver + Decryptor func(context.Context, string) DecryptFunc } // NewAPI the api constructor initialises a new API instance. -func NewAPI(fileStore *storage.FileStore, dns Resolver, resourceHandler *mru.Handler) (self *API) { +func NewAPI(fileStore *storage.FileStore, dns Resolver, resourceHandler *mru.Handler, pk *ecdsa.PrivateKey) (self *API) { self = &API{ fileStore: fileStore, dns: dns, resource: resourceHandler, + Decryptor: func(ctx context.Context, credentials string) DecryptFunc { + return self.doDecrypt(ctx, credentials, pk) + }, } return } @@ -260,8 +271,30 @@ func (a *API) Store(ctx context.Context, data io.Reader, size int64, toEncrypt b // ErrResolve is returned when an URI cannot be resolved from ENS. type ErrResolve error +// Resolve a name into a content-addressed hash +// where address could be an ENS name, or a content addressed hash +func (a *API) Resolve(ctx context.Context, address string) (storage.Address, error) { + // if DNS is not configured, return an error + if a.dns == nil { + if hashMatcher.MatchString(address) { + return common.Hex2Bytes(address), nil + } + apiResolveFail.Inc(1) + return nil, fmt.Errorf("no DNS to resolve name: %q", address) + } + // try and resolve the address + resolved, err := a.dns.Resolve(address) + if err != nil { + if hashMatcher.MatchString(address) { + return common.Hex2Bytes(address), nil + } + return nil, err + } + return resolved[:], nil +} + // Resolve resolves a URI to an Address using the MultiResolver. -func (a *API) Resolve(ctx context.Context, uri *URI) (storage.Address, error) { +func (a *API) ResolveURI(ctx context.Context, uri *URI, credentials string) (storage.Address, error) { apiResolveCount.Inc(1) log.Trace("resolving", "uri", uri.Addr) @@ -280,28 +313,44 @@ func (a *API) Resolve(ctx context.Context, uri *URI) (storage.Address, error) { return key, nil } - // if DNS is not configured, check if the address is a hash - if a.dns == nil { - key := uri.Address() - if key == nil { - apiResolveFail.Inc(1) - return nil, fmt.Errorf("no DNS to resolve name: %q", uri.Addr) - } - return key, nil + addr, err := a.Resolve(ctx, uri.Addr) + if err != nil { + return nil, err } - // try and resolve the address - resolved, err := a.dns.Resolve(uri.Addr) - if err == nil { - return resolved[:], nil + if uri.Path == "" { + return addr, nil } - - key := uri.Address() - if key == nil { - apiResolveFail.Inc(1) + walker, err := a.NewManifestWalker(ctx, addr, a.Decryptor(ctx, credentials), nil) + if err != nil { return nil, err } - return key, nil + var entry *ManifestEntry + walker.Walk(func(e *ManifestEntry) error { + // if the entry matches the path, set entry and stop + // the walk + if e.Path == uri.Path { + entry = e + // return an error to cancel the walk + return errors.New("found") + } + // ignore non-manifest files + if e.ContentType != ManifestType { + return nil + } + // if the manifest's path is a prefix of the + // requested path, recurse into it by returning + // nil and continuing the walk + if strings.HasPrefix(uri.Path, e.Path) { + return nil + } + return ErrSkipManifest + }) + if entry == nil { + return nil, errors.New("not found") + } + addr = storage.Address(common.Hex2Bytes(entry.Hash)) + return addr, nil } // Put provides singleton manifest creation on top of FileStore store @@ -332,10 +381,10 @@ func (a *API) Put(ctx context.Context, content string, contentType string, toEnc // Get uses iterative manifest retrieval and prefix matching // to resolve basePath to content using FileStore retrieve // it returns a section reader, mimeType, status, the key of the actual content and an error -func (a *API) Get(ctx context.Context, manifestAddr storage.Address, path string) (reader storage.LazySectionReader, mimeType string, status int, contentAddr storage.Address, err error) { +func (a *API) Get(ctx context.Context, decrypt DecryptFunc, manifestAddr storage.Address, path string) (reader storage.LazySectionReader, mimeType string, status int, contentAddr storage.Address, err error) { log.Debug("api.get", "key", manifestAddr, "path", path) apiGetCount.Inc(1) - trie, err := loadManifest(ctx, a.fileStore, manifestAddr, nil) + trie, err := loadManifest(ctx, a.fileStore, manifestAddr, nil, decrypt) if err != nil { apiGetNotFound.Inc(1) status = http.StatusNotFound @@ -347,6 +396,16 @@ func (a *API) Get(ctx context.Context, manifestAddr storage.Address, path string if entry != nil { log.Debug("trie got entry", "key", manifestAddr, "path", path, "entry.Hash", entry.Hash) + + if entry.ContentType == ManifestType { + log.Debug("entry is manifest", "key", manifestAddr, "new key", entry.Hash) + adr, err := hex.DecodeString(entry.Hash) + if err != nil { + return nil, "", 0, nil, err + } + return a.Get(ctx, decrypt, adr, entry.Path) + } + // we need to do some extra work if this is a mutable resource manifest if entry.ContentType == ResourceContentType { @@ -398,7 +457,7 @@ func (a *API) Get(ctx context.Context, manifestAddr storage.Address, path string log.Trace("resource is multihash", "key", manifestAddr) // get the manifest the multihash digest points to - trie, err := loadManifest(ctx, a.fileStore, manifestAddr, nil) + trie, err := loadManifest(ctx, a.fileStore, manifestAddr, nil, decrypt) if err != nil { apiGetNotFound.Inc(1) status = http.StatusNotFound @@ -451,7 +510,7 @@ func (a *API) Delete(ctx context.Context, addr string, path string) (storage.Add apiDeleteFail.Inc(1) return nil, err } - key, err := a.Resolve(ctx, uri) + key, err := a.ResolveURI(ctx, uri, EMPTY_CREDENTIALS) if err != nil { return nil, err @@ -470,13 +529,13 @@ func (a *API) Delete(ctx context.Context, addr string, path string) (storage.Add // GetDirectoryTar fetches a requested directory as a tarstream // it returns an io.Reader and an error. Do not forget to Close() the returned ReadCloser -func (a *API) GetDirectoryTar(ctx context.Context, uri *URI) (io.ReadCloser, error) { +func (a *API) GetDirectoryTar(ctx context.Context, decrypt DecryptFunc, uri *URI) (io.ReadCloser, error) { apiGetTarCount.Inc(1) - addr, err := a.Resolve(ctx, uri) + addr, err := a.Resolve(ctx, uri.Addr) if err != nil { return nil, err } - walker, err := a.NewManifestWalker(ctx, addr, nil) + walker, err := a.NewManifestWalker(ctx, addr, decrypt, nil) if err != nil { apiGetTarFail.Inc(1) return nil, err @@ -542,9 +601,9 @@ func (a *API) GetDirectoryTar(ctx context.Context, uri *URI) (io.ReadCloser, err // GetManifestList lists the manifest entries for the specified address and prefix // and returns it as a ManifestList -func (a *API) GetManifestList(ctx context.Context, addr storage.Address, prefix string) (list ManifestList, err error) { +func (a *API) GetManifestList(ctx context.Context, decryptor DecryptFunc, addr storage.Address, prefix string) (list ManifestList, err error) { apiManifestListCount.Inc(1) - walker, err := a.NewManifestWalker(ctx, addr, nil) + walker, err := a.NewManifestWalker(ctx, addr, decryptor, nil) if err != nil { apiManifestListFail.Inc(1) return ManifestList{}, err @@ -631,7 +690,7 @@ func (a *API) UpdateManifest(ctx context.Context, addr storage.Address, update f func (a *API) Modify(ctx context.Context, addr storage.Address, path, contentHash, contentType string) (storage.Address, error) { apiModifyCount.Inc(1) quitC := make(chan bool) - trie, err := loadManifest(ctx, a.fileStore, addr, quitC) + trie, err := loadManifest(ctx, a.fileStore, addr, quitC, NOOPDecrypt) if err != nil { apiModifyFail.Inc(1) return nil, err @@ -663,7 +722,7 @@ func (a *API) AddFile(ctx context.Context, mhash, path, fname string, content [] apiAddFileFail.Inc(1) return nil, "", err } - mkey, err := a.Resolve(ctx, uri) + mkey, err := a.ResolveURI(ctx, uri, EMPTY_CREDENTIALS) if err != nil { apiAddFileFail.Inc(1) return nil, "", err @@ -770,7 +829,7 @@ func (a *API) RemoveFile(ctx context.Context, mhash string, path string, fname s apiRmFileFail.Inc(1) return "", err } - mkey, err := a.Resolve(ctx, uri) + mkey, err := a.ResolveURI(ctx, uri, EMPTY_CREDENTIALS) if err != nil { apiRmFileFail.Inc(1) return "", err @@ -837,7 +896,7 @@ func (a *API) AppendFile(ctx context.Context, mhash, path, fname string, existin apiAppendFileFail.Inc(1) return nil, "", err } - mkey, err := a.Resolve(ctx, uri) + mkey, err := a.ResolveURI(ctx, uri, EMPTY_CREDENTIALS) if err != nil { apiAppendFileFail.Inc(1) return nil, "", err @@ -891,13 +950,13 @@ func (a *API) BuildDirectoryTree(ctx context.Context, mhash string, nameresolver if err != nil { return nil, nil, err } - addr, err = a.Resolve(ctx, uri) + addr, err = a.Resolve(ctx, uri.Addr) if err != nil { return nil, nil, err } quitC := make(chan bool) - rootTrie, err := loadManifest(ctx, a.fileStore, addr, quitC) + rootTrie, err := loadManifest(ctx, a.fileStore, addr, quitC, NOOPDecrypt) if err != nil { return nil, nil, fmt.Errorf("can't load manifest %v: %v", addr.String(), err) } @@ -955,7 +1014,7 @@ func (a *API) ResourceHashSize() int { // ResolveResourceManifest retrieves the Mutable Resource manifest for the given address, and returns the address of the metadata chunk. func (a *API) ResolveResourceManifest(ctx context.Context, addr storage.Address) (storage.Address, error) { - trie, err := loadManifest(ctx, a.fileStore, addr, nil) + trie, err := loadManifest(ctx, a.fileStore, addr, nil, NOOPDecrypt) if err != nil { return nil, fmt.Errorf("cannot load resource manifest: %v", err) } diff --git a/swarm/api/api_test.go b/swarm/api/api_test.go index 78fab9508..a65bf07e2 100644 --- a/swarm/api/api_test.go +++ b/swarm/api/api_test.go @@ -19,6 +19,7 @@ package api import ( "context" "errors" + "flag" "fmt" "io" "io/ioutil" @@ -28,10 +29,17 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/swarm/log" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/swarm/sctx" "github.com/ethereum/go-ethereum/swarm/storage" ) +func init() { + loglevel := flag.Int("loglevel", 2, "loglevel") + flag.Parse() + log.Root().SetHandler(log.CallerFileHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(os.Stderr, log.TerminalFormat(true))))) +} + func testAPI(t *testing.T, f func(*API, bool)) { datadir, err := ioutil.TempDir("", "bzz-test") if err != nil { @@ -42,7 +50,7 @@ func testAPI(t *testing.T, f func(*API, bool)) { if err != nil { return } - api := NewAPI(fileStore, nil, nil) + api := NewAPI(fileStore, nil, nil, nil) f(api, false) f(api, true) } @@ -85,7 +93,7 @@ func expResponse(content string, mimeType string, status int) *Response { func testGet(t *testing.T, api *API, bzzhash, path string) *testResponse { addr := storage.Address(common.Hex2Bytes(bzzhash)) - reader, mimeType, status, _, err := api.Get(context.TODO(), addr, path) + reader, mimeType, status, _, err := api.Get(context.TODO(), NOOPDecrypt, addr, path) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -229,7 +237,7 @@ func TestAPIResolve(t *testing.T) { if x.immutable { uri.Scheme = "bzz-immutable" } - res, err := api.Resolve(context.TODO(), uri) + res, err := api.ResolveURI(context.TODO(), uri, "") if err == nil { if x.expectErr != nil { t.Fatalf("expected error %q, got result %q", x.expectErr, res) @@ -373,3 +381,55 @@ func TestMultiResolver(t *testing.T) { }) } } + +func TestDecryptOriginForbidden(t *testing.T) { + ctx := context.TODO() + ctx = sctx.SetHost(ctx, "swarm-gateways.net") + + me := &ManifestEntry{ + Access: &AccessEntry{Type: AccessTypePass}, + } + + api := NewAPI(nil, nil, nil, nil) + + f := api.Decryptor(ctx, "") + err := f(me) + if err != ErrDecryptDomainForbidden { + t.Fatalf("should fail with ErrDecryptDomainForbidden, got %v", err) + } +} + +func TestDecryptOrigin(t *testing.T) { + for _, v := range []struct { + host string + expectError error + }{ + { + host: "localhost", + expectError: ErrDecrypt, + }, + { + host: "127.0.0.1", + expectError: ErrDecrypt, + }, + { + host: "swarm-gateways.net", + expectError: ErrDecryptDomainForbidden, + }, + } { + ctx := context.TODO() + ctx = sctx.SetHost(ctx, v.host) + + me := &ManifestEntry{ + Access: &AccessEntry{Type: AccessTypePass}, + } + + api := NewAPI(nil, nil, nil, nil) + + f := api.Decryptor(ctx, "") + err := f(me) + if err != v.expectError { + t.Fatalf("should fail with %v, got %v", v.expectError, err) + } + } +} diff --git a/swarm/api/client/client.go b/swarm/api/client/client.go index 8a9efe360..3d06e9e1c 100644 --- a/swarm/api/client/client.go +++ b/swarm/api/client/client.go @@ -43,6 +43,10 @@ var ( DefaultClient = NewClient(DefaultGateway) ) +var ( + ErrUnauthorized = errors.New("unauthorized") +) + func NewClient(gateway string) *Client { return &Client{ Gateway: gateway, @@ -188,7 +192,7 @@ func (c *Client) UploadDirectory(dir, defaultPath, manifest string, toEncrypt bo // DownloadDirectory downloads the files contained in a swarm manifest under // the given path into a local directory (existing files will be overwritten) -func (c *Client) DownloadDirectory(hash, path, destDir string) error { +func (c *Client) DownloadDirectory(hash, path, destDir, credentials string) error { stat, err := os.Stat(destDir) if err != nil { return err @@ -201,13 +205,20 @@ func (c *Client) DownloadDirectory(hash, path, destDir string) error { if err != nil { return err } + if credentials != "" { + req.SetBasicAuth("", credentials) + } req.Header.Set("Accept", "application/x-tar") res, err := http.DefaultClient.Do(req) if err != nil { return err } defer res.Body.Close() - if res.StatusCode != http.StatusOK { + switch res.StatusCode { + case http.StatusOK: + case http.StatusUnauthorized: + return ErrUnauthorized + default: return fmt.Errorf("unexpected HTTP status: %s", res.Status) } tr := tar.NewReader(res.Body) @@ -248,7 +259,7 @@ func (c *Client) DownloadDirectory(hash, path, destDir string) error { // DownloadFile downloads a single file into the destination directory // if the manifest entry does not specify a file name - it will fallback // to the hash of the file as a filename -func (c *Client) DownloadFile(hash, path, dest string) error { +func (c *Client) DownloadFile(hash, path, dest, credentials string) error { hasDestinationFilename := false if stat, err := os.Stat(dest); err == nil { hasDestinationFilename = !stat.IsDir() @@ -261,9 +272,9 @@ func (c *Client) DownloadFile(hash, path, dest string) error { } } - manifestList, err := c.List(hash, path) + manifestList, err := c.List(hash, path, credentials) if err != nil { - return fmt.Errorf("could not list manifest: %v", err) + return err } switch len(manifestList.Entries) { @@ -280,13 +291,19 @@ func (c *Client) DownloadFile(hash, path, dest string) error { if err != nil { return err } + if credentials != "" { + req.SetBasicAuth("", credentials) + } res, err := http.DefaultClient.Do(req) if err != nil { return err } defer res.Body.Close() - - if res.StatusCode != http.StatusOK { + switch res.StatusCode { + case http.StatusOK: + case http.StatusUnauthorized: + return ErrUnauthorized + default: return fmt.Errorf("unexpected HTTP status: expected 200 OK, got %d", res.StatusCode) } filename := "" @@ -367,13 +384,24 @@ func (c *Client) DownloadManifest(hash string) (*api.Manifest, bool, error) { // - a prefix of "dir1/" would return [dir1/dir2/, dir1/file3.txt] // // where entries ending with "/" are common prefixes. -func (c *Client) List(hash, prefix string) (*api.ManifestList, error) { - res, err := http.DefaultClient.Get(c.Gateway + "/bzz-list:/" + hash + "/" + prefix) +func (c *Client) List(hash, prefix, credentials string) (*api.ManifestList, error) { + req, err := http.NewRequest(http.MethodGet, c.Gateway+"/bzz-list:/"+hash+"/"+prefix, nil) + if err != nil { + return nil, err + } + if credentials != "" { + req.SetBasicAuth("", credentials) + } + res, err := http.DefaultClient.Do(req) if err != nil { return nil, err } defer res.Body.Close() - if res.StatusCode != http.StatusOK { + switch res.StatusCode { + case http.StatusOK: + case http.StatusUnauthorized: + return nil, ErrUnauthorized + default: return nil, fmt.Errorf("unexpected HTTP status: %s", res.Status) } var list api.ManifestList diff --git a/swarm/api/client/client_test.go b/swarm/api/client/client_test.go index ae82a91d7..2212f5c4c 100644 --- a/swarm/api/client/client_test.go +++ b/swarm/api/client/client_test.go @@ -228,7 +228,7 @@ func TestClientUploadDownloadDirectory(t *testing.T) { t.Fatal(err) } defer os.RemoveAll(tmp) - if err := client.DownloadDirectory(hash, "", tmp); err != nil { + if err := client.DownloadDirectory(hash, "", tmp, ""); err != nil { t.Fatal(err) } for _, file := range testDirFiles { @@ -265,7 +265,7 @@ func testClientFileList(toEncrypt bool, t *testing.T) { } ls := func(prefix string) []string { - list, err := client.List(hash, prefix) + list, err := client.List(hash, prefix, "") if err != nil { t.Fatal(err) } diff --git a/swarm/api/encrypt.go b/swarm/api/encrypt.go new file mode 100644 index 000000000..9a2e36914 --- /dev/null +++ b/swarm/api/encrypt.go @@ -0,0 +1,76 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. + +package api + +import ( + "encoding/binary" + "errors" + + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/swarm/storage/encryption" +) + +type RefEncryption struct { + spanEncryption encryption.Encryption + dataEncryption encryption.Encryption + span []byte +} + +func NewRefEncryption(refSize int) *RefEncryption { + span := make([]byte, 8) + binary.LittleEndian.PutUint64(span, uint64(refSize)) + return &RefEncryption{ + spanEncryption: encryption.New(0, uint32(refSize/32), sha3.NewKeccak256), + dataEncryption: encryption.New(refSize, 0, sha3.NewKeccak256), + span: span, + } +} + +func (re *RefEncryption) Encrypt(ref []byte, key []byte) ([]byte, error) { + encryptedSpan, err := re.spanEncryption.Encrypt(re.span, key) + if err != nil { + return nil, err + } + encryptedData, err := re.dataEncryption.Encrypt(ref, key) + if err != nil { + return nil, err + } + encryptedRef := make([]byte, len(ref)+8) + copy(encryptedRef[:8], encryptedSpan) + copy(encryptedRef[8:], encryptedData) + + return encryptedRef, nil +} + +func (re *RefEncryption) Decrypt(ref []byte, key []byte) ([]byte, error) { + decryptedSpan, err := re.spanEncryption.Decrypt(ref[:8], key) + if err != nil { + return nil, err + } + + size := binary.LittleEndian.Uint64(decryptedSpan) + if size != uint64(len(ref)-8) { + return nil, errors.New("invalid span in encrypted reference") + } + + decryptedRef, err := re.dataEncryption.Decrypt(ref[8:], key) + if err != nil { + return nil, err + } + + return decryptedRef, nil +} diff --git a/swarm/api/filesystem.go b/swarm/api/filesystem.go index aacd26699..8251ebc4d 100644 --- a/swarm/api/filesystem.go +++ b/swarm/api/filesystem.go @@ -191,7 +191,7 @@ func (fs *FileSystem) Download(bzzpath, localpath string) error { if err != nil { return err } - addr, err := fs.api.Resolve(context.TODO(), uri) + addr, err := fs.api.Resolve(context.TODO(), uri.Addr) if err != nil { return err } @@ -202,7 +202,7 @@ func (fs *FileSystem) Download(bzzpath, localpath string) error { } quitC := make(chan bool) - trie, err := loadManifest(context.TODO(), fs.api.fileStore, addr, quitC) + trie, err := loadManifest(context.TODO(), fs.api.fileStore, addr, quitC, NOOPDecrypt) if err != nil { log.Warn(fmt.Sprintf("fs.Download: loadManifestTrie error: %v", err)) return err diff --git a/swarm/api/filesystem_test.go b/swarm/api/filesystem_test.go index 84a2989d6..fe7527b1f 100644 --- a/swarm/api/filesystem_test.go +++ b/swarm/api/filesystem_test.go @@ -64,7 +64,7 @@ func TestApiDirUpload0(t *testing.T) { checkResponse(t, resp, exp) addr := storage.Address(common.Hex2Bytes(bzzhash)) - _, _, _, _, err = api.Get(context.TODO(), addr, "") + _, _, _, _, err = api.Get(context.TODO(), NOOPDecrypt, addr, "") if err == nil { t.Fatalf("expected error: %v", err) } @@ -143,7 +143,7 @@ func TestApiDirUploadModify(t *testing.T) { exp = expResponse(content, "text/css", 0) checkResponse(t, resp, exp) - _, _, _, _, err = api.Get(context.TODO(), addr, "") + _, _, _, _, err = api.Get(context.TODO(), nil, addr, "") if err == nil { t.Errorf("expected error: %v", err) } diff --git a/swarm/api/http/middleware.go b/swarm/api/http/middleware.go index c0d8d1a40..3b2dcc7d5 100644 --- a/swarm/api/http/middleware.go +++ b/swarm/api/http/middleware.go @@ -9,6 +9,7 @@ import ( "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/swarm/api" "github.com/ethereum/go-ethereum/swarm/log" + "github.com/ethereum/go-ethereum/swarm/sctx" "github.com/ethereum/go-ethereum/swarm/spancontext" "github.com/pborman/uuid" ) @@ -35,6 +36,15 @@ func SetRequestID(h http.Handler) http.Handler { }) } +func SetRequestHost(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r = r.WithContext(sctx.SetHost(r.Context(), r.Host)) + log.Info("setting request host", "ruid", GetRUID(r.Context()), "host", sctx.GetHost(r.Context())) + + h.ServeHTTP(w, r) + }) +} + func ParseURI(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { uri, err := api.Parse(strings.TrimLeft(r.URL.Path, "/")) @@ -87,7 +97,7 @@ func RecoverPanic(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if err := recover(); err != nil { - log.Error("panic recovery!", "stack trace", debug.Stack(), "url", r.URL.String(), "headers", r.Header) + log.Error("panic recovery!", "stack trace", string(debug.Stack()), "url", r.URL.String(), "headers", r.Header) } }() h.ServeHTTP(w, r) diff --git a/swarm/api/http/response.go b/swarm/api/http/response.go index f050e706a..c9fb9d285 100644 --- a/swarm/api/http/response.go +++ b/swarm/api/http/response.go @@ -79,7 +79,7 @@ func RespondTemplate(w http.ResponseWriter, r *http.Request, templateName, msg s } func RespondError(w http.ResponseWriter, r *http.Request, msg string, code int) { - log.Debug("RespondError", "ruid", GetRUID(r.Context()), "uri", GetURI(r.Context())) + log.Debug("RespondError", "ruid", GetRUID(r.Context()), "uri", GetURI(r.Context()), "code", code) RespondTemplate(w, r, "error", msg, code) } diff --git a/swarm/api/http/server.go b/swarm/api/http/server.go index 5a5c42adc..b5ea0c23d 100644 --- a/swarm/api/http/server.go +++ b/swarm/api/http/server.go @@ -23,7 +23,6 @@ import ( "bufio" "bytes" "encoding/json" - "errors" "fmt" "io" "io/ioutil" @@ -97,6 +96,7 @@ func NewServer(api *api.API, corsString string) *Server { defaultMiddlewares := []Adapter{ RecoverPanic, SetRequestID, + SetRequestHost, InitLoggingResponseWriter, ParseURI, InstrumentOpenTracing, @@ -169,6 +169,7 @@ func NewServer(api *api.API, corsString string) *Server { } func (s *Server) ListenAndServe(addr string) error { + s.listenAddr = addr return http.ListenAndServe(addr, s) } @@ -178,16 +179,24 @@ func (s *Server) ListenAndServe(addr string) error { // https://github.com/atom/electron/blob/master/docs/api/protocol.md type Server struct { http.Handler - api *api.API + api *api.API + listenAddr string } func (s *Server) HandleBzzGet(w http.ResponseWriter, r *http.Request) { - log.Debug("handleBzzGet", "ruid", GetRUID(r.Context())) + log.Debug("handleBzzGet", "ruid", GetRUID(r.Context()), "uri", r.RequestURI) if r.Header.Get("Accept") == "application/x-tar" { uri := GetURI(r.Context()) - reader, err := s.api.GetDirectoryTar(r.Context(), uri) + _, credentials, _ := r.BasicAuth() + reader, err := s.api.GetDirectoryTar(r.Context(), s.api.Decryptor(r.Context(), credentials), uri) if err != nil { + if isDecryptError(err) { + w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", uri.Address().String())) + RespondError(w, r, err.Error(), http.StatusUnauthorized) + return + } RespondError(w, r, fmt.Sprintf("Had an error building the tarball: %v", err), http.StatusInternalServerError) + return } defer reader.Close() @@ -287,7 +296,7 @@ func (s *Server) HandlePostFiles(w http.ResponseWriter, r *http.Request) { var addr storage.Address if uri.Addr != "" && uri.Addr != "encrypt" { - addr, err = s.api.Resolve(r.Context(), uri) + addr, err = s.api.Resolve(r.Context(), uri.Addr) if err != nil { postFilesFail.Inc(1) RespondError(w, r, fmt.Sprintf("cannot resolve %s: %s", uri.Addr, err), http.StatusInternalServerError) @@ -563,7 +572,7 @@ func (s *Server) HandleGetResource(w http.ResponseWriter, r *http.Request) { // resolve the content key. manifestAddr := uri.Address() if manifestAddr == nil { - manifestAddr, err = s.api.Resolve(r.Context(), uri) + manifestAddr, err = s.api.Resolve(r.Context(), uri.Addr) if err != nil { getFail.Inc(1) RespondError(w, r, fmt.Sprintf("cannot resolve %s: %s", uri.Addr, err), http.StatusNotFound) @@ -682,62 +691,21 @@ func (s *Server) HandleGet(w http.ResponseWriter, r *http.Request) { uri := GetURI(r.Context()) log.Debug("handle.get", "ruid", ruid, "uri", uri) getCount.Inc(1) + _, pass, _ := r.BasicAuth() - var err error - addr := uri.Address() - if addr == nil { - addr, err = s.api.Resolve(r.Context(), uri) - if err != nil { - getFail.Inc(1) - RespondError(w, r, fmt.Sprintf("cannot resolve %s: %s", uri.Addr, err), http.StatusNotFound) - return - } - } else { - w.Header().Set("Cache-Control", "max-age=2147483648, immutable") // url was of type bzz://<hex key>/path, so we are sure it is immutable. + addr, err := s.api.ResolveURI(r.Context(), uri, pass) + if err != nil { + getFail.Inc(1) + RespondError(w, r, fmt.Sprintf("cannot resolve %s: %s", uri.Addr, err), http.StatusNotFound) + return } + w.Header().Set("Cache-Control", "max-age=2147483648, immutable") // url was of type bzz://<hex key>/path, so we are sure it is immutable. log.Debug("handle.get: resolved", "ruid", ruid, "key", addr) // if path is set, interpret <key> as a manifest and return the // raw entry at the given path - if uri.Path != "" { - walker, err := s.api.NewManifestWalker(r.Context(), addr, nil) - if err != nil { - getFail.Inc(1) - RespondError(w, r, fmt.Sprintf("%s is not a manifest", addr), http.StatusBadRequest) - return - } - var entry *api.ManifestEntry - walker.Walk(func(e *api.ManifestEntry) error { - // if the entry matches the path, set entry and stop - // the walk - if e.Path == uri.Path { - entry = e - // return an error to cancel the walk - return errors.New("found") - } - - // ignore non-manifest files - if e.ContentType != api.ManifestType { - return nil - } - - // if the manifest's path is a prefix of the - // requested path, recurse into it by returning - // nil and continuing the walk - if strings.HasPrefix(uri.Path, e.Path) { - return nil - } - return api.ErrSkipManifest - }) - if entry == nil { - getFail.Inc(1) - RespondError(w, r, fmt.Sprintf("manifest entry could not be loaded"), http.StatusNotFound) - return - } - addr = storage.Address(common.Hex2Bytes(entry.Hash)) - } etag := common.Bytes2Hex(addr) noneMatchEtag := r.Header.Get("If-None-Match") w.Header().Set("ETag", fmt.Sprintf("%q", etag)) // set etag to manifest key or raw entry key. @@ -781,6 +749,7 @@ func (s *Server) HandleGet(w http.ResponseWriter, r *http.Request) { func (s *Server) HandleGetList(w http.ResponseWriter, r *http.Request) { ruid := GetRUID(r.Context()) uri := GetURI(r.Context()) + _, credentials, _ := r.BasicAuth() log.Debug("handle.get.list", "ruid", ruid, "uri", uri) getListCount.Inc(1) @@ -790,7 +759,7 @@ func (s *Server) HandleGetList(w http.ResponseWriter, r *http.Request) { return } - addr, err := s.api.Resolve(r.Context(), uri) + addr, err := s.api.Resolve(r.Context(), uri.Addr) if err != nil { getListFail.Inc(1) RespondError(w, r, fmt.Sprintf("cannot resolve %s: %s", uri.Addr, err), http.StatusNotFound) @@ -798,9 +767,14 @@ func (s *Server) HandleGetList(w http.ResponseWriter, r *http.Request) { } log.Debug("handle.get.list: resolved", "ruid", ruid, "key", addr) - list, err := s.api.GetManifestList(r.Context(), addr, uri.Path) + list, err := s.api.GetManifestList(r.Context(), s.api.Decryptor(r.Context(), credentials), addr, uri.Path) if err != nil { getListFail.Inc(1) + if isDecryptError(err) { + w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", addr.String())) + RespondError(w, r, err.Error(), http.StatusUnauthorized) + return + } RespondError(w, r, err.Error(), http.StatusInternalServerError) return } @@ -833,7 +807,8 @@ func (s *Server) HandleGetList(w http.ResponseWriter, r *http.Request) { func (s *Server) HandleGetFile(w http.ResponseWriter, r *http.Request) { ruid := GetRUID(r.Context()) uri := GetURI(r.Context()) - log.Debug("handle.get.file", "ruid", ruid) + _, credentials, _ := r.BasicAuth() + log.Debug("handle.get.file", "ruid", ruid, "uri", r.RequestURI) getFileCount.Inc(1) // ensure the root path has a trailing slash so that relative URLs work @@ -845,7 +820,7 @@ func (s *Server) HandleGetFile(w http.ResponseWriter, r *http.Request) { manifestAddr := uri.Address() if manifestAddr == nil { - manifestAddr, err = s.api.Resolve(r.Context(), uri) + manifestAddr, err = s.api.ResolveURI(r.Context(), uri, credentials) if err != nil { getFileFail.Inc(1) RespondError(w, r, fmt.Sprintf("cannot resolve %s: %s", uri.Addr, err), http.StatusNotFound) @@ -856,7 +831,8 @@ func (s *Server) HandleGetFile(w http.ResponseWriter, r *http.Request) { } log.Debug("handle.get.file: resolved", "ruid", ruid, "key", manifestAddr) - reader, contentType, status, contentKey, err := s.api.Get(r.Context(), manifestAddr, uri.Path) + + reader, contentType, status, contentKey, err := s.api.Get(r.Context(), s.api.Decryptor(r.Context(), credentials), manifestAddr, uri.Path) etag := common.Bytes2Hex(contentKey) noneMatchEtag := r.Header.Get("If-None-Match") @@ -869,6 +845,12 @@ func (s *Server) HandleGetFile(w http.ResponseWriter, r *http.Request) { } if err != nil { + if isDecryptError(err) { + w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", manifestAddr)) + RespondError(w, r, err.Error(), http.StatusUnauthorized) + return + } + switch status { case http.StatusNotFound: getFileNotFound.Inc(1) @@ -883,9 +865,14 @@ func (s *Server) HandleGetFile(w http.ResponseWriter, r *http.Request) { //the request results in ambiguous files //e.g. /read with readme.md and readinglist.txt available in manifest if status == http.StatusMultipleChoices { - list, err := s.api.GetManifestList(r.Context(), manifestAddr, uri.Path) + list, err := s.api.GetManifestList(r.Context(), s.api.Decryptor(r.Context(), credentials), manifestAddr, uri.Path) if err != nil { getFileFail.Inc(1) + if isDecryptError(err) { + w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", manifestAddr)) + RespondError(w, r, err.Error(), http.StatusUnauthorized) + return + } RespondError(w, r, err.Error(), http.StatusInternalServerError) return } @@ -951,3 +938,7 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) } + +func isDecryptError(err error) bool { + return strings.Contains(err.Error(), api.ErrDecrypt.Error()) +} diff --git a/swarm/api/manifest.go b/swarm/api/manifest.go index 2a163dd39..a1329a800 100644 --- a/swarm/api/manifest.go +++ b/swarm/api/manifest.go @@ -46,13 +46,14 @@ type Manifest struct { // ManifestEntry represents an entry in a swarm manifest type ManifestEntry struct { - Hash string `json:"hash,omitempty"` - Path string `json:"path,omitempty"` - ContentType string `json:"contentType,omitempty"` - Mode int64 `json:"mode,omitempty"` - Size int64 `json:"size,omitempty"` - ModTime time.Time `json:"mod_time,omitempty"` - Status int `json:"status,omitempty"` + Hash string `json:"hash,omitempty"` + Path string `json:"path,omitempty"` + ContentType string `json:"contentType,omitempty"` + Mode int64 `json:"mode,omitempty"` + Size int64 `json:"size,omitempty"` + ModTime time.Time `json:"mod_time,omitempty"` + Status int `json:"status,omitempty"` + Access *AccessEntry `json:"access,omitempty"` } // ManifestList represents the result of listing files in a manifest @@ -98,7 +99,7 @@ type ManifestWriter struct { } func (a *API) NewManifestWriter(ctx context.Context, addr storage.Address, quitC chan bool) (*ManifestWriter, error) { - trie, err := loadManifest(ctx, a.fileStore, addr, quitC) + trie, err := loadManifest(ctx, a.fileStore, addr, quitC, NOOPDecrypt) if err != nil { return nil, fmt.Errorf("error loading manifest %s: %s", addr, err) } @@ -141,8 +142,8 @@ type ManifestWalker struct { quitC chan bool } -func (a *API) NewManifestWalker(ctx context.Context, addr storage.Address, quitC chan bool) (*ManifestWalker, error) { - trie, err := loadManifest(ctx, a.fileStore, addr, quitC) +func (a *API) NewManifestWalker(ctx context.Context, addr storage.Address, decrypt DecryptFunc, quitC chan bool) (*ManifestWalker, error) { + trie, err := loadManifest(ctx, a.fileStore, addr, quitC, decrypt) if err != nil { return nil, fmt.Errorf("error loading manifest %s: %s", addr, err) } @@ -194,6 +195,7 @@ type manifestTrie struct { entries [257]*manifestTrieEntry // indexed by first character of basePath, entries[256] is the empty basePath entry ref storage.Address // if ref != nil, it is stored encrypted bool + decrypt DecryptFunc } func newManifestTrieEntry(entry *ManifestEntry, subtrie *manifestTrie) *manifestTrieEntry { @@ -209,15 +211,15 @@ type manifestTrieEntry struct { subtrie *manifestTrie } -func loadManifest(ctx context.Context, fileStore *storage.FileStore, hash storage.Address, quitC chan bool) (trie *manifestTrie, err error) { // non-recursive, subtrees are downloaded on-demand +func loadManifest(ctx context.Context, fileStore *storage.FileStore, hash storage.Address, quitC chan bool, decrypt DecryptFunc) (trie *manifestTrie, err error) { // non-recursive, subtrees are downloaded on-demand log.Trace("manifest lookup", "key", hash) // retrieve manifest via FileStore manifestReader, isEncrypted := fileStore.Retrieve(ctx, hash) log.Trace("reader retrieved", "key", hash) - return readManifest(manifestReader, hash, fileStore, isEncrypted, quitC) + return readManifest(manifestReader, hash, fileStore, isEncrypted, quitC, decrypt) } -func readManifest(mr storage.LazySectionReader, hash storage.Address, fileStore *storage.FileStore, isEncrypted bool, quitC chan bool) (trie *manifestTrie, err error) { // non-recursive, subtrees are downloaded on-demand +func readManifest(mr storage.LazySectionReader, hash storage.Address, fileStore *storage.FileStore, isEncrypted bool, quitC chan bool, decrypt DecryptFunc) (trie *manifestTrie, err error) { // non-recursive, subtrees are downloaded on-demand // TODO check size for oversized manifests size, err := mr.Size(mr.Context(), quitC) @@ -258,26 +260,41 @@ func readManifest(mr storage.LazySectionReader, hash storage.Address, fileStore trie = &manifestTrie{ fileStore: fileStore, encrypted: isEncrypted, + decrypt: decrypt, } for _, entry := range man.Entries { - trie.addEntry(entry, quitC) + err = trie.addEntry(entry, quitC) + if err != nil { + return + } } return } -func (mt *manifestTrie) addEntry(entry *manifestTrieEntry, quitC chan bool) { +func (mt *manifestTrie) addEntry(entry *manifestTrieEntry, quitC chan bool) error { mt.ref = nil // trie modified, hash needs to be re-calculated on demand + if entry.ManifestEntry.Access != nil { + if mt.decrypt == nil { + return errors.New("dont have decryptor") + } + + err := mt.decrypt(&entry.ManifestEntry) + if err != nil { + return err + } + } + if len(entry.Path) == 0 { mt.entries[256] = entry - return + return nil } b := entry.Path[0] oldentry := mt.entries[b] if (oldentry == nil) || (oldentry.Path == entry.Path && oldentry.ContentType != ManifestType) { mt.entries[b] = entry - return + return nil } cpl := 0 @@ -287,12 +304,12 @@ func (mt *manifestTrie) addEntry(entry *manifestTrieEntry, quitC chan bool) { if (oldentry.ContentType == ManifestType) && (cpl == len(oldentry.Path)) { if mt.loadSubTrie(oldentry, quitC) != nil { - return + return nil } entry.Path = entry.Path[cpl:] oldentry.subtrie.addEntry(entry, quitC) oldentry.Hash = "" - return + return nil } commonPrefix := entry.Path[:cpl] @@ -310,6 +327,7 @@ func (mt *manifestTrie) addEntry(entry *manifestTrieEntry, quitC chan bool) { Path: commonPrefix, ContentType: ManifestType, }, subtrie) + return nil } func (mt *manifestTrie) getCountLast() (cnt int, entry *manifestTrieEntry) { @@ -398,9 +416,20 @@ func (mt *manifestTrie) recalcAndStore() error { } func (mt *manifestTrie) loadSubTrie(entry *manifestTrieEntry, quitC chan bool) (err error) { + if entry.ManifestEntry.Access != nil { + if mt.decrypt == nil { + return errors.New("dont have decryptor") + } + + err := mt.decrypt(&entry.ManifestEntry) + if err != nil { + return err + } + } + if entry.subtrie == nil { hash := common.Hex2Bytes(entry.Hash) - entry.subtrie, err = loadManifest(context.TODO(), mt.fileStore, hash, quitC) + entry.subtrie, err = loadManifest(context.TODO(), mt.fileStore, hash, quitC, mt.decrypt) entry.Hash = "" // might not match, should be recalculated } return diff --git a/swarm/api/manifest_test.go b/swarm/api/manifest_test.go index d65f023f8..1c8e53c43 100644 --- a/swarm/api/manifest_test.go +++ b/swarm/api/manifest_test.go @@ -44,7 +44,7 @@ func testGetEntry(t *testing.T, path, match string, multiple bool, paths ...stri quitC := make(chan bool) fileStore := storage.NewFileStore(nil, storage.NewFileStoreParams()) ref := make([]byte, fileStore.HashSize()) - trie, err := readManifest(manifest(paths...), ref, fileStore, false, quitC) + trie, err := readManifest(manifest(paths...), ref, fileStore, false, quitC, NOOPDecrypt) if err != nil { t.Errorf("unexpected error making manifest: %v", err) } @@ -101,7 +101,7 @@ func TestExactMatch(t *testing.T) { mf := manifest("shouldBeExactMatch.css", "shouldBeExactMatch.css.map") fileStore := storage.NewFileStore(nil, storage.NewFileStoreParams()) ref := make([]byte, fileStore.HashSize()) - trie, err := readManifest(mf, ref, fileStore, false, quitC) + trie, err := readManifest(mf, ref, fileStore, false, quitC, nil) if err != nil { t.Errorf("unexpected error making manifest: %v", err) } @@ -134,7 +134,7 @@ func TestAddFileWithManifestPath(t *testing.T) { } fileStore := storage.NewFileStore(nil, storage.NewFileStoreParams()) ref := make([]byte, fileStore.HashSize()) - trie, err := readManifest(reader, ref, fileStore, false, nil) + trie, err := readManifest(reader, ref, fileStore, false, nil, NOOPDecrypt) if err != nil { t.Fatal(err) } @@ -161,7 +161,7 @@ func TestReadManifestOverSizeLimit(t *testing.T) { reader := &storage.LazyTestSectionReader{ SectionReader: io.NewSectionReader(bytes.NewReader(manifest), 0, int64(len(manifest))), } - _, err := readManifest(reader, storage.Address{}, nil, false, nil) + _, err := readManifest(reader, storage.Address{}, nil, false, nil, NOOPDecrypt) if err == nil { t.Fatal("got no error from readManifest") } diff --git a/swarm/api/storage.go b/swarm/api/storage.go index 3b52301a0..8a48fe5bc 100644 --- a/swarm/api/storage.go +++ b/swarm/api/storage.go @@ -63,11 +63,11 @@ func (s *Storage) Get(ctx context.Context, bzzpath string) (*Response, error) { if err != nil { return nil, err } - addr, err := s.api.Resolve(ctx, uri) + addr, err := s.api.Resolve(ctx, uri.Addr) if err != nil { return nil, err } - reader, mimeType, status, _, err := s.api.Get(ctx, addr, uri.Path) + reader, mimeType, status, _, err := s.api.Get(ctx, nil, addr, uri.Path) if err != nil { return nil, err } @@ -93,7 +93,7 @@ func (s *Storage) Modify(ctx context.Context, rootHash, path, contentHash, conte if err != nil { return "", err } - addr, err := s.api.Resolve(ctx, uri) + addr, err := s.api.Resolve(ctx, uri.Addr) if err != nil { return "", err } diff --git a/swarm/api/uri.go b/swarm/api/uri.go index 14965e0d9..808517088 100644 --- a/swarm/api/uri.go +++ b/swarm/api/uri.go @@ -53,6 +53,19 @@ type URI struct { Path string } +func (u *URI) MarshalJSON() (out []byte, err error) { + return []byte(`"` + u.String() + `"`), nil +} + +func (u *URI) UnmarshalJSON(value []byte) error { + uri, err := Parse(string(value)) + if err != nil { + return err + } + *u = *uri + return nil +} + // Parse parses rawuri into a URI struct, where rawuri is expected to have one // of the following formats: // diff --git a/swarm/fuse/swarmfs_test.go b/swarm/fuse/swarmfs_test.go index d579d15a0..6efeb78d9 100644 --- a/swarm/fuse/swarmfs_test.go +++ b/swarm/fuse/swarmfs_test.go @@ -1650,7 +1650,7 @@ func TestFUSE(t *testing.T) { if err != nil { t.Fatal(err) } - ta := &testAPI{api: api.NewAPI(fileStore, nil, nil)} + ta := &testAPI{api: api.NewAPI(fileStore, nil, nil, nil)} //run a short suite of tests //approx time: 28s diff --git a/swarm/network_test.go b/swarm/network_test.go index d2a030933..176c635d8 100644 --- a/swarm/network_test.go +++ b/swarm/network_test.go @@ -445,7 +445,7 @@ func retrieve( log.Debug("api get: check file", "node", id.String(), "key", f.addr.String(), "total files found", atomic.LoadUint64(totalFoundCount)) - r, _, _, _, err := swarm.api.Get(context.TODO(), f.addr, "/") + r, _, _, _, err := swarm.api.Get(context.TODO(), api.NOOPDecrypt, f.addr, "/") if err != nil { errc <- fmt.Errorf("api get: node %s, key %s, kademlia %s: %v", id, f.addr, swarm.bzz.Hive, err) return diff --git a/swarm/sctx/sctx.go b/swarm/sctx/sctx.go index 8619f6e19..bed2b1145 100644 --- a/swarm/sctx/sctx.go +++ b/swarm/sctx/sctx.go @@ -1,7 +1,22 @@ package sctx +import "context" + type ContextKey int const ( HTTPRequestIDKey ContextKey = iota + requestHostKey ) + +func SetHost(ctx context.Context, domain string) context.Context { + return context.WithValue(ctx, requestHostKey, domain) +} + +func GetHost(ctx context.Context) string { + v, ok := ctx.Value(requestHostKey).(string) + if ok { + return v + } + return "" +} diff --git a/swarm/swarm.go b/swarm/swarm.go index f731ff33d..a895bdfa5 100644 --- a/swarm/swarm.go +++ b/swarm/swarm.go @@ -85,14 +85,12 @@ type Swarm struct { type SwarmAPI struct { Api *api.API Backend chequebook.Backend - PrvKey *ecdsa.PrivateKey } func (self *Swarm) API() *SwarmAPI { return &SwarmAPI{ Api: self.api, Backend: self.backend, - PrvKey: self.privateKey, } } @@ -217,7 +215,7 @@ func NewSwarm(config *api.Config, mockStore *mock.NodeStore) (self *Swarm, err e pss.SetHandshakeController(self.ps, pss.NewHandshakeParams()) } - self.api = api.NewAPI(self.fileStore, self.dns, resourceHandler) + self.api = api.NewAPI(self.fileStore, self.dns, resourceHandler, self.privateKey) // Manifests for Smart Hosting log.Debug(fmt.Sprintf("-> Web3 virtual server API")) diff --git a/swarm/testutil/http.go b/swarm/testutil/http.go index 238f78308..7fd60fcc3 100644 --- a/swarm/testutil/http.go +++ b/swarm/testutil/http.go @@ -77,7 +77,7 @@ func NewTestSwarmServer(t *testing.T, serverFunc func(*api.API) TestServer) *Tes t.Fatal(err) } - a := api.NewAPI(fileStore, nil, rh.Handler) + a := api.NewAPI(fileStore, nil, rh.Handler, nil) srv := httptest.NewServer(serverFunc(a)) return &TestSwarmServer{ Server: srv, |