From 6f004c46d53958e2fe0e8aad5584a46f9f4bce6c Mon Sep 17 00:00:00 2001
From: Martin Holst Swende <martin@swende.se>
Date: Wed, 19 Sep 2018 18:08:38 +0200
Subject: accounts/keystore: double-check keystore file after creation (#17348)

---
 accounts/keystore/key.go                 | 18 +++++++++++++-----
 accounts/keystore/keystore.go            |  2 +-
 accounts/keystore/keystore_passphrase.go | 27 +++++++++++++++++++++++++--
 accounts/keystore/keystore_plain_test.go |  4 ++--
 4 files changed, 41 insertions(+), 10 deletions(-)

diff --git a/accounts/keystore/key.go b/accounts/keystore/key.go
index 211fa863d..9e3e4856c 100644
--- a/accounts/keystore/key.go
+++ b/accounts/keystore/key.go
@@ -179,26 +179,34 @@ func storeNewKey(ks keyStore, rand io.Reader, auth string) (*Key, accounts.Accou
 	return key, a, err
 }
 
-func writeKeyFile(file string, content []byte) error {
+func writeTemporaryKeyFile(file string, content []byte) (string, error) {
 	// Create the keystore directory with appropriate permissions
 	// in case it is not present yet.
 	const dirPerm = 0700
 	if err := os.MkdirAll(filepath.Dir(file), dirPerm); err != nil {
-		return err
+		return "", err
 	}
 	// Atomic write: create a temporary hidden file first
 	// then move it into place. TempFile assigns mode 0600.
 	f, err := ioutil.TempFile(filepath.Dir(file), "."+filepath.Base(file)+".tmp")
 	if err != nil {
-		return err
+		return "", err
 	}
 	if _, err := f.Write(content); err != nil {
 		f.Close()
 		os.Remove(f.Name())
-		return err
+		return "", err
 	}
 	f.Close()
-	return os.Rename(f.Name(), file)
+	return f.Name(), nil
+}
+
+func writeKeyFile(file string, content []byte) error {
+	name, err := writeTemporaryKeyFile(file, content)
+	if err != nil {
+		return err
+	}
+	return os.Rename(name, file)
 }
 
 // keyFileName implements the naming convention for keyfiles:
diff --git a/accounts/keystore/keystore.go b/accounts/keystore/keystore.go
index 6b04acd05..2918047cc 100644
--- a/accounts/keystore/keystore.go
+++ b/accounts/keystore/keystore.go
@@ -78,7 +78,7 @@ type unlocked struct {
 // NewKeyStore creates a keystore for the given directory.
 func NewKeyStore(keydir string, scryptN, scryptP int) *KeyStore {
 	keydir, _ = filepath.Abs(keydir)
-	ks := &KeyStore{storage: &keyStorePassphrase{keydir, scryptN, scryptP}}
+	ks := &KeyStore{storage: &keyStorePassphrase{keydir, scryptN, scryptP, false}}
 	ks.init(keydir)
 	return ks
 }
diff --git a/accounts/keystore/keystore_passphrase.go b/accounts/keystore/keystore_passphrase.go
index 59738abe1..5aa3a6bbd 100644
--- a/accounts/keystore/keystore_passphrase.go
+++ b/accounts/keystore/keystore_passphrase.go
@@ -35,6 +35,7 @@ import (
 	"fmt"
 	"io"
 	"io/ioutil"
+	"os"
 	"path/filepath"
 
 	"github.com/ethereum/go-ethereum/common"
@@ -72,6 +73,10 @@ type keyStorePassphrase struct {
 	keysDirPath string
 	scryptN     int
 	scryptP     int
+	// skipKeyFileVerification disables the security-feature which does
+	// reads and decrypts any newly created keyfiles. This should be 'false' in all
+	// cases except tests -- setting this to 'true' is not recommended.
+	skipKeyFileVerification bool
 }
 
 func (ks keyStorePassphrase) GetKey(addr common.Address, filename, auth string) (*Key, error) {
@@ -93,7 +98,7 @@ func (ks keyStorePassphrase) GetKey(addr common.Address, filename, auth string)
 
 // StoreKey generates a key, encrypts with 'auth' and stores in the given directory
 func StoreKey(dir, auth string, scryptN, scryptP int) (common.Address, error) {
-	_, a, err := storeNewKey(&keyStorePassphrase{dir, scryptN, scryptP}, rand.Reader, auth)
+	_, a, err := storeNewKey(&keyStorePassphrase{dir, scryptN, scryptP, false}, rand.Reader, auth)
 	return a.Address, err
 }
 
@@ -102,7 +107,25 @@ func (ks keyStorePassphrase) StoreKey(filename string, key *Key, auth string) er
 	if err != nil {
 		return err
 	}
-	return writeKeyFile(filename, keyjson)
+	// Write into temporary file
+	tmpName, err := writeTemporaryKeyFile(filename, keyjson)
+	if err != nil {
+		return err
+	}
+	if !ks.skipKeyFileVerification {
+		// Verify that we can decrypt the file with the given password.
+		_, err = ks.GetKey(key.Address, tmpName, auth)
+		if err != nil {
+			msg := "An error was encountered when saving and verifying the keystore file. \n" +
+				"This indicates that the keystore is corrupted. \n" +
+				"The corrupted file is stored at \n%v\n" +
+				"Please file a ticket at:\n\n" +
+				"https://github.com/ethereum/go-ethereum/issues." +
+				"The error was : %s"
+			return fmt.Errorf(msg, tmpName, err)
+		}
+	}
+	return os.Rename(tmpName, filename)
 }
 
 func (ks keyStorePassphrase) JoinPath(filename string) string {
diff --git a/accounts/keystore/keystore_plain_test.go b/accounts/keystore/keystore_plain_test.go
index a1c3bc4b6..32852a0ad 100644
--- a/accounts/keystore/keystore_plain_test.go
+++ b/accounts/keystore/keystore_plain_test.go
@@ -37,7 +37,7 @@ func tmpKeyStoreIface(t *testing.T, encrypted bool) (dir string, ks keyStore) {
 		t.Fatal(err)
 	}
 	if encrypted {
-		ks = &keyStorePassphrase{d, veryLightScryptN, veryLightScryptP}
+		ks = &keyStorePassphrase{d, veryLightScryptN, veryLightScryptP, true}
 	} else {
 		ks = &keyStorePlain{d}
 	}
@@ -191,7 +191,7 @@ func TestV1_1(t *testing.T) {
 
 func TestV1_2(t *testing.T) {
 	t.Parallel()
-	ks := &keyStorePassphrase{"testdata/v1", LightScryptN, LightScryptP}
+	ks := &keyStorePassphrase{"testdata/v1", LightScryptN, LightScryptP, true}
 	addr := common.HexToAddress("cb61d5a9c4896fb9658090b597ef0e7be6f7b67e")
 	file := "testdata/v1/cb61d5a9c4896fb9658090b597ef0e7be6f7b67e/cb61d5a9c4896fb9658090b597ef0e7be6f7b67e"
 	k, err := ks.GetKey(addr, file, "g")
-- 
cgit v1.2.3