aboutsummaryrefslogtreecommitdiffstats
path: root/crypto/key_store_plain.go
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/key_store_plain.go')
-rw-r--r--crypto/key_store_plain.go45
1 files changed, 42 insertions, 3 deletions
diff --git a/crypto/key_store_plain.go b/crypto/key_store_plain.go
index d785fdf68..c13c5e7a4 100644
--- a/crypto/key_store_plain.go
+++ b/crypto/key_store_plain.go
@@ -43,6 +43,7 @@ type KeyStore interface {
GetKeyAddresses() ([]common.Address, error) // get all addresses
StoreKey(*Key, string) error // store key optionally using auth string
DeleteKey(common.Address, string) error // delete key by addr and auth string
+ Cleanup(keyAddr common.Address) (err error)
}
type keyStorePlain struct {
@@ -86,6 +87,10 @@ func (ks keyStorePlain) GetKeyAddresses() (addresses []common.Address, err error
return getKeyAddresses(ks.keysDirPath)
}
+func (ks keyStorePlain) Cleanup(keyAddr common.Address) (err error) {
+ return cleanup(ks.keysDirPath, keyAddr)
+}
+
func (ks keyStorePlain) StoreKey(key *Key, auth string) (err error) {
keyJSON, err := json.Marshal(key)
if err != nil {
@@ -100,10 +105,14 @@ func (ks keyStorePlain) DeleteKey(keyAddr common.Address, auth string) (err erro
}
func deleteKey(keysDirPath string, keyAddr common.Address) (err error) {
- var keyFilePath string
- keyFilePath, err = getKeyFilePath(keysDirPath, keyAddr)
+ var path string
+ path, err = getKeyFilePath(keysDirPath, keyAddr)
if err == nil {
- err = os.Remove(keyFilePath)
+ addrHex := hex.EncodeToString(keyAddr[:])
+ if path == filepath.Join(keysDirPath, addrHex, addrHex) {
+ path = filepath.Join(keysDirPath, addrHex)
+ }
+ err = os.RemoveAll(path)
}
return
}
@@ -122,6 +131,36 @@ func getKeyFilePath(keysDirPath string, keyAddr common.Address) (keyFilePath str
return
}
+func cleanup(keysDirPath string, keyAddr common.Address) (err error) {
+ fileInfos, err := ioutil.ReadDir(keysDirPath)
+ if err != nil {
+ return
+ }
+ var paths []string
+ account := hex.EncodeToString(keyAddr[:])
+ for _, fileInfo := range fileInfos {
+ path := filepath.Join(keysDirPath, fileInfo.Name())
+ if len(path) >= 40 {
+ addr := path[len(path)-40 : len(path)]
+ if addr == account {
+ if path == filepath.Join(keysDirPath, addr, addr) {
+ path = filepath.Join(keysDirPath, addr)
+ }
+ paths = append(paths, path)
+ }
+ }
+ }
+ if len(paths) > 1 {
+ for i := 0; err == nil && i < len(paths)-1; i++ {
+ err = os.RemoveAll(paths[i])
+ if err != nil {
+ break
+ }
+ }
+ }
+ return
+}
+
func getKeyFile(keysDirPath string, keyAddr common.Address) (fileContent []byte, err error) {
var keyFilePath string
keyFilePath, err = getKeyFilePath(keysDirPath, keyAddr)