Skip to content

Commit 7eb0b0a

Browse files
authored
fix(wallet): set Argon2 derived bytes for AES IV (#1703)
1 parent 6479897 commit 7eb0b0a

12 files changed

+265
-58
lines changed

wallet/encrypter/encrypter.go

+50-21
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ type argon2dParameters struct {
2121
iterations uint32
2222
memory uint32
2323
parallelism uint8
24+
keyLen uint32
2425
}
2526

2627
type Option func(p *argon2dParameters)
@@ -47,11 +48,20 @@ const (
4748
nameParamIterations = "iterations"
4849
nameParamMemory = "memory"
4950
nameParamParallelism = "parallelism"
51+
nameParamKeyLen = "keylen"
5052

5153
nameFuncNope = ""
5254
nameFuncArgon2ID = "ARGON2ID"
5355
nameFuncAES256CTR = "AES_256_CTR"
56+
nameFuncAES256CBC = "AES_256_CBC"
5457
nameFuncMACv1 = "MACV1"
58+
59+
// Parameter Choice
60+
// https://www.rfc-editor.org/rfc/rfc9106.html#section-4
61+
defaultIterations = 3
62+
defaultMemory = 65536 // 2 ^ 16
63+
defaultParallelism = 4
64+
defaultKeyLen = 48
5565
)
5666

5767
// ErrNotSupported describes an error in which the encrypted method is no
@@ -75,17 +85,17 @@ func NopeEncrypter() Encrypter {
7585
}
7686
}
7787

78-
// DefaultEncrypter creates a default encrypter instance.
88+
// DefaultEncrypter creates a new encrypter instance.
89+
// If no option sets it uses the default parameters.
7990
//
8091
// The default encrypter uses Argon2ID as password hasher and AES_256_CTR as
8192
// encryption algorithm.
8293
func DefaultEncrypter(opts ...Option) Encrypter {
83-
// Parameter Choice
84-
// https://www.rfc-editor.org/rfc/rfc9106.html#section-4
8594
argon2dParameters := &argon2dParameters{
86-
iterations: uint32(3),
87-
memory: uint32(65536), // 2 ^ 16
88-
parallelism: uint8(4),
95+
iterations: defaultIterations,
96+
memory: defaultMemory,
97+
parallelism: defaultParallelism,
98+
keyLen: defaultKeyLen,
8999
}
90100
for _, opt := range opts {
91101
opt(argon2dParameters)
@@ -98,6 +108,7 @@ func DefaultEncrypter(opts ...Option) Encrypter {
98108
encParams.SetUint32(nameParamIterations, argon2dParameters.iterations)
99109
encParams.SetUint32(nameParamMemory, argon2dParameters.memory)
100110
encParams.SetUint8(nameParamParallelism, argon2dParameters.parallelism)
111+
encParams.SetUint32(nameParamKeyLen, argon2dParameters.keyLen)
101112

102113
return Encrypter{
103114
Method: method,
@@ -140,22 +151,23 @@ func (e *Encrypter) Encrypt(message, password string) (string, error) {
140151
return "", err
141152
}
142153

143-
iterations := e.Params.GetUint32(nameParamIterations)
144-
memory := e.Params.GetUint32(nameParamMemory)
145-
parallelism := e.Params.GetUint8(nameParamParallelism)
154+
iterations := e.Params.GetUint32(nameParamIterations, defaultIterations)
155+
memory := e.Params.GetUint32(nameParamMemory, defaultMemory)
156+
parallelism := e.Params.GetUint8(nameParamParallelism, defaultParallelism)
157+
keyLen := e.Params.GetUint32(nameParamKeyLen, defaultKeyLen)
146158

147159
// Argon2 currently has three modes:
148160
// - data-dependent Argon2d,
149161
// - data-independent Argon2i,
150162
// - a mix of the two, Argon2id.
151-
cipherKey := argon2.IDKey([]byte(password), salt, iterations, memory, parallelism, 32)
163+
derivedBytes := argon2.IDKey([]byte(password), salt, iterations, memory, parallelism, keyLen)
152164

153165
// Encrypter method
154166
switch funcs[1] {
155167
case nameFuncAES256CTR:
156-
// Using salt for Initialization Vector (IV)
157-
iv := salt
158-
cipher := aesCrypt([]byte(message), iv, cipherKey)
168+
cipherKey := derivedBytes[:32]
169+
iv := derivedBytes[32:]
170+
cipher := aesCTRCrypt([]byte(message), iv, cipherKey)
159171

160172
// MAC method
161173
switch funcs[2] {
@@ -215,18 +227,35 @@ func (e *Encrypter) Decrypt(cipherText, password string) (string, error) {
215227
case nameFuncArgon2ID:
216228
salt := data[0:16]
217229

218-
iterations := e.Params.GetUint32(nameParamIterations)
219-
memory := e.Params.GetUint32(nameParamMemory)
220-
parallelism := e.Params.GetUint8(nameParamParallelism)
230+
iterations := e.Params.GetUint32(nameParamIterations, defaultIterations)
231+
memory := e.Params.GetUint32(nameParamMemory, defaultMemory)
232+
parallelism := e.Params.GetUint8(nameParamParallelism, defaultParallelism)
233+
keyLen := e.Params.GetUint32(nameParamKeyLen, defaultKeyLen)
221234

222-
cipherKey := argon2.IDKey([]byte(password), salt, iterations, memory, parallelism, 32)
235+
derivedByte := argon2.IDKey([]byte(password), salt, iterations, memory, parallelism, keyLen)
223236

224237
// Encrypter method
225238
switch funcs[1] {
226239
case nameFuncAES256CTR:
227-
iv := salt
240+
var initVec, cipherKey []byte
241+
242+
switch keyLen {
243+
case 0:
244+
// This case supports legacy encryption methods where the same salt is reused as the IV.
245+
cipherKey = derivedByte
246+
initVec = salt
247+
248+
case 48:
249+
// The first 32 bytes are used as the encryption key, and the last 16 bytes are used as the IV.
250+
cipherKey = derivedByte[:32]
251+
initVec = derivedByte[32:]
252+
253+
default:
254+
return "", ErrInvalidParam
255+
}
256+
228257
enc := data[16 : len(data)-4]
229-
text = string(aesCrypt(enc, iv, cipherKey))
258+
text = string(aesCTRCrypt(enc, initVec, cipherKey))
230259

231260
// MAC method
232261
switch funcs[2] {
@@ -249,9 +278,9 @@ func (e *Encrypter) Decrypt(cipherText, password string) (string, error) {
249278
return text, nil
250279
}
251280

252-
// aesCrypt encrypts/decrypts a message using AES-256-CTR and
281+
// aesCTRCrypt encrypts/decrypts a message using AES-256-CTR and
253282
// returns the encoded/decoded bytes.
254-
func aesCrypt(message, initVec, cipherKey []byte) []byte {
283+
func aesCTRCrypt(message, initVec, cipherKey []byte) []byte {
255284
// Generate the cipher message
256285
cipherMsg := make([]byte, len(message))
257286
aesCipher, err := aes.NewCipher(cipherKey)

wallet/encrypter/encrypter_test.go

+30-1
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,11 @@ func TestDefaultEncrypter(t *testing.T) {
3737
assert.Equal(t, "3", enc.Params["iterations"])
3838
assert.Equal(t, "4", enc.Params["memory"])
3939
assert.Equal(t, "5", enc.Params["parallelism"])
40+
assert.Equal(t, "48", enc.Params["keylen"])
4041
assert.True(t, enc.IsEncrypted())
4142
}
4243

43-
func TestEncrypter(t *testing.T) {
44+
func TestEncrypterV2(t *testing.T) {
4445
enc := &Encrypter{
4546
Method: "ARGON2ID-AES_256_CTR-MACV1",
4647
Params: params{
@@ -66,3 +67,31 @@ func TestEncrypter(t *testing.T) {
6667
_, err = enc.Decrypt(cipher, "invalid-password")
6768
assert.ErrorIs(t, err, ErrInvalidPassword)
6869
}
70+
71+
func TestEncrypterV3(t *testing.T) {
72+
enc := &Encrypter{
73+
Method: "ARGON2ID-AES_256_CTR-MACV1",
74+
Params: params{
75+
nameParamIterations: "1",
76+
nameParamMemory: "1",
77+
nameParamParallelism: "1",
78+
nameParamKeyLen: "48",
79+
},
80+
}
81+
82+
msg := "foo"
83+
84+
_, err := enc.Encrypt(msg, "")
85+
assert.ErrorIs(t, err, ErrInvalidPassword)
86+
87+
password := "cowboy"
88+
cipher, err := enc.Encrypt(msg, password)
89+
assert.NoError(t, err)
90+
91+
dec, err := enc.Decrypt(cipher, password)
92+
assert.NoError(t, err)
93+
assert.Equal(t, msg, dec)
94+
95+
_, err = enc.Decrypt(cipher, "invalid-password")
96+
assert.ErrorIs(t, err, ErrInvalidPassword)
97+
}

wallet/encrypter/error.go

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ import (
77
// ErrInvalidPassword describes an error in which the password is invalid.
88
var ErrInvalidPassword = errors.New("invalid password")
99

10+
// ErrInvalidParam describes an error in which the encryption parameter is invalid.
11+
var ErrInvalidParam = errors.New("invalid param")
12+
1013
// ErrInvalidCipher describes an error in which the cipher message is invalid.
1114
var ErrInvalidCipher = errors.New("invalid cipher message")
1215

wallet/encrypter/params.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,19 @@ func (p params) SetString(key, val string) {
3131
p[key] = val
3232
}
3333

34-
func (p params) GetUint8(key string) uint8 {
35-
return uint8(p.GetUint64(key))
34+
func (p params) GetUint8(key string, defaultValue uint64) uint8 {
35+
return uint8(p.GetUint64(key, defaultValue))
3636
}
3737

38-
func (p params) GetUint32(key string) uint32 {
39-
return uint32(p.GetUint64(key))
38+
func (p params) GetUint32(key string, defaultValue uint64) uint32 {
39+
return uint32(p.GetUint64(key, defaultValue))
4040
}
4141

42-
func (p params) GetUint64(key string) uint64 {
42+
func (p params) GetUint64(key string, defaultValue uint64) uint64 {
4343
val, err := strconv.ParseUint(p[key], 10, 64)
44-
exitOnErr(err)
44+
if err != nil {
45+
return defaultValue
46+
}
4547

4648
return val
4749
}

wallet/encrypter/params_test.go

+18-9
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func TestParamsUint8(t *testing.T) {
1818
p := params{}
1919
for _, tt := range tests {
2020
p.SetUint8(tt.key, tt.val)
21-
assert.Equal(t, tt.val, p.GetUint8(tt.key))
21+
assert.Equal(t, tt.val, p.GetUint8(tt.key, 0))
2222
}
2323
}
2424

@@ -34,7 +34,7 @@ func TestParamsUint32(t *testing.T) {
3434
p := params{}
3535
for _, tt := range tests {
3636
p.SetUint32(tt.key, tt.val)
37-
assert.Equal(t, tt.val, p.GetUint32(tt.key))
37+
assert.Equal(t, tt.val, p.GetUint32(tt.key, 0))
3838
}
3939
}
4040

@@ -50,24 +50,33 @@ func TestParamsUint64(t *testing.T) {
5050
p := params{}
5151
for _, tt := range tests {
5252
p.SetUint64(tt.key, tt.val)
53-
assert.Equal(t, tt.val, p.GetUint64(tt.key))
53+
assert.Equal(t, tt.val, p.GetUint64(tt.key, 0))
5454
}
5555
}
5656

57+
func TestParamsDefaultValue(t *testing.T) {
58+
p := params{}
59+
assert.Equal(t, uint64(24), p.GetUint64("not-exist", 24))
60+
assert.Equal(t, uint32(24), p.GetUint32("not-exist", 24))
61+
assert.Equal(t, uint8(24), p.GetUint8("not-exist", 24))
62+
}
63+
5764
func TestParamsBytes(t *testing.T) {
5865
tests := []struct {
59-
key string
60-
val []byte
66+
key string
67+
val []byte
68+
base64 string
6169
}{
62-
{"k1", []byte{0, 0}},
63-
{"k2", []byte{0xff, 0xff}},
64-
{"k2", []byte{}},
70+
{"k1", []byte{0, 0}, "AAA="},
71+
{"k2", []byte{0xff, 0xff}, "//8="},
72+
{"k2", []byte{}, ""},
6573
}
6674

6775
p := params{}
6876
for _, tt := range tests {
6977
p.SetBytes(tt.key, tt.val)
7078
assert.Equal(t, tt.val, p.GetBytes(tt.key))
79+
assert.Equal(t, tt.base64, p.GetString(tt.key))
7180
}
7281
}
7382

@@ -78,7 +87,7 @@ func TestParamsString(t *testing.T) {
7887
}{
7988
{"k1", "foo"},
8089
{"k2", "bar"},
81-
{"k3", "bar"},
90+
{"k3", ""},
8291
}
8392

8493
p := params{}

wallet/store.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ import (
1818
)
1919

2020
const (
21-
Version1 = 1 // initial version
22-
Version2 = 2 // supporting Ed25519
21+
Version1 = 1 // Initial version
22+
Version2 = 2 // Supporting Ed25519
23+
Version3 = 3 // USe AEC-256-CBC for default encryption
2324

24-
VersionLatest = Version2
25+
VersionLatest = Version3
2526
)
2627

2728
type Store struct {
@@ -84,8 +85,7 @@ func (s *Store) UpgradeWallet(walletPath string) error {
8485
return err
8586
}
8687

87-
case Version2:
88-
// Current version
88+
case Version2, Version3:
8989
return nil
9090

9191
default:

wallet/store_test.go

+24-13
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,40 @@ import (
1010

1111
func TestUpgradeWallet(t *testing.T) {
1212
// password is: "password"
13-
data, err := util.ReadFile("./testdata/wallet_version_1")
14-
require.NoError(t, err)
13+
tests := []struct {
14+
walletPath string
15+
upgradedVersion int
16+
}{
17+
{"./testdata/wallet_version_1", 2},
18+
{"./testdata/wallet_version_2", 2},
19+
{"./testdata/wallet_version_3", 3},
20+
}
21+
22+
for _, tt := range tests {
23+
data, err := util.ReadFile(tt.walletPath)
24+
require.NoError(t, err)
1525

16-
tempPath := util.TempFilePath()
17-
err = util.WriteFile(tempPath, data)
18-
require.NoError(t, err)
26+
tempPath := util.TempFilePath()
27+
err = util.WriteFile(tempPath, data)
28+
require.NoError(t, err)
1929

20-
wlt, err := Open(tempPath, true)
21-
require.NoError(t, err)
30+
wlt, err := Open(tempPath, true)
31+
require.NoError(t, err)
2232

23-
assert.Equal(t, 4, wlt.AddressCount())
24-
assert.Equal(t, VersionLatest, wlt.store.Version)
33+
assert.Equal(t, 4, wlt.AddressCount())
34+
assert.Equal(t, tt.upgradedVersion, wlt.store.Version)
2535

26-
infos := wlt.AddressInfos()
27-
for _, info := range infos {
28-
assert.NotEmpty(t, info.PublicKey)
36+
infos := wlt.AddressInfos()
37+
for _, info := range infos {
38+
assert.NotEmpty(t, info.PublicKey)
39+
}
2940
}
3041
}
3142

3243
func TestUnsupportedWallet(t *testing.T) {
3344
_, err := Open("./testdata/unsupported_wallet", true)
3445
require.ErrorIs(t, err, UnsupportedVersionError{
35-
WalletVersion: 3,
46+
WalletVersion: 4,
3647
SupportedVersion: VersionLatest,
3748
})
3849
}

wallet/testdata/unsupported_wallet

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"version": 3
2+
"version": 4
33
}

0 commit comments

Comments
 (0)