diff --git a/go.mod b/go.mod index 7672bd8..e994506 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.13 require ( github.com/stretchr/testify v1.8.1 - golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c + golang.org/x/sys v0.8.0 ) diff --git a/go.sum b/go.sum index 384ed71..4419050 100644 --- a/go.sum +++ b/go.sum @@ -11,8 +11,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c h1:Lyn7+CqXIiC+LOR9aHD6jDK+hPcmAuCfuXztd1v4w1Q= -golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/sys.go b/sys.go index 033d3c4..dfff596 100644 --- a/sys.go +++ b/sys.go @@ -1,21 +1,23 @@ +//go:build windows // +build windows package wincred import ( "reflect" + "syscall" "unsafe" - syscall "golang.org/x/sys/windows" + "golang.org/x/sys/windows" ) var ( - modadvapi32 = syscall.NewLazyDLL("advapi32.dll") - procCredRead proc = modadvapi32.NewProc("CredReadW") + modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") + procCredRead = modadvapi32.NewProc("CredReadW") procCredWrite proc = modadvapi32.NewProc("CredWriteW") procCredDelete proc = modadvapi32.NewProc("CredDeleteW") procCredFree proc = modadvapi32.NewProc("CredFree") - procCredEnumerate proc = modadvapi32.NewProc("CredEnumerateW") + procCredEnumerate = modadvapi32.NewProc("CredEnumerateW") ) // Interface for syscall.Proc: helps testing @@ -29,7 +31,7 @@ type sysCREDENTIAL struct { Type uint32 TargetName *uint16 Comment *uint16 - LastWritten syscall.Filetime + LastWritten windows.Filetime CredentialBlobSize uint32 CredentialBlob uintptr Persist uint32 @@ -59,15 +61,16 @@ const ( sysCRED_TYPE_DOMAIN_EXTENDED sysCRED_TYPE = 0x6 // https://docs.microsoft.com/en-us/windows/desktop/Debug/system-error-codes - sysERROR_NOT_FOUND = syscall.Errno(1168) - sysERROR_INVALID_PARAMETER = syscall.Errno(87) + sysERROR_NOT_FOUND = windows.Errno(1168) + sysERROR_INVALID_PARAMETER = windows.Errno(87) ) // https://docs.microsoft.com/en-us/windows/desktop/api/wincred/nf-wincred-credreadw func sysCredRead(targetName string, typ sysCRED_TYPE) (*Credential, error) { var pcred *sysCREDENTIAL - targetNamePtr, _ := syscall.UTF16PtrFromString(targetName) - ret, _, err := procCredRead.Call( + targetNamePtr, _ := windows.UTF16PtrFromString(targetName) + ret, _, err := syscall.SyscallN( + procCredRead.Addr(), uintptr(unsafe.Pointer(targetNamePtr)), uintptr(typ), 0, @@ -98,7 +101,7 @@ func sysCredWrite(cred *Credential, typ sysCRED_TYPE) error { // https://docs.microsoft.com/en-us/windows/desktop/api/wincred/nf-wincred-creddeletew func sysCredDelete(cred *Credential, typ sysCRED_TYPE) error { - targetNamePtr, _ := syscall.UTF16PtrFromString(cred.TargetName) + targetNamePtr, _ := windows.UTF16PtrFromString(cred.TargetName) ret, _, err := procCredDelete.Call( uintptr(unsafe.Pointer(targetNamePtr)), uintptr(typ), @@ -117,9 +120,10 @@ func sysCredEnumerate(filter string, all bool) ([]*Credential, error) { var pcreds uintptr var filterPtr *uint16 if !all { - filterPtr, _ = syscall.UTF16PtrFromString(filter) + filterPtr, _ = windows.UTF16PtrFromString(filter) } - ret, _, err := procCredEnumerate.Call( + ret, _, err := syscall.SyscallN( + procCredEnumerate.Addr(), uintptr(unsafe.Pointer(filterPtr)), 0, uintptr(unsafe.Pointer(&count)), diff --git a/sys_test.go b/sys_test.go index d7b8e15..c803e2c 100644 --- a/sys_test.go +++ b/sys_test.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package wincred @@ -5,7 +6,6 @@ package wincred import ( "errors" "testing" - "unsafe" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -32,80 +32,6 @@ func (t *mockProc) Call(a ...uintptr) (r1, r2 uintptr, lastErr error) { return uintptr(args.Int(0)), uintptr(args.Int(1)), args.Error(2) } -func TestSysCredRead_MockFailure(t *testing.T) { - // The test error - testError := errors.New("test error") - // Mock `CreadRead`: returns failure state and the error - mockCredRead := new(mockProc) - mockCredRead.On("Call", mock.AnythingOfType("[]uintptr")).Return(0, 0, testError) - mockCredRead.Setup(&procCredRead) - defer mockCredRead.TearDown() - // Mock `CredFree`: Must not be called - mockCredFree := new(mockProc) - mockCredFree.On("Call", mock.AnythingOfType("[]uintptr")).Return(0, 0, nil) - mockCredFree.Setup(&procCredFree) - defer mockCredFree.TearDown() - - // Test it: - var res *Credential - var err error - assert.NotPanics(t, func() { res, err = sysCredRead("foo", sysCRED_TYPE_GENERIC) }) - assert.Nil(t, res) - assert.NotNil(t, err) - assert.Equal(t, "test error", err.Error()) - mockCredRead.AssertNumberOfCalls(t, "Call", 1) - mockCredFree.AssertNumberOfCalls(t, "Call", 0) -} - -func TestSysCredRead_Mock(t *testing.T) { - // prepare some test data - cred := new(Credential) - cred.TargetName = "Foo" - cred.Comment = "Bar" - cred.CredentialBlob = []byte{1, 2, 3} - credSys := sysFromCredential(cred) - t.Log(credSys) // Workaround to keep the object alive - - // Mock `CreadRead`: returns success and sets the pointer to the prepared sysCred struct - mockCredRead := new(mockProc) - mockCredRead. - On("Call", mock.AnythingOfType("[]uintptr")). - Return(1, 0, nil). - Run(func(args mock.Arguments) { - arg := args.Get(0).([]uintptr) - assert.Equal(t, 4, len(arg)) - *(**sysCREDENTIAL)(unsafe.Pointer(arg[3])) = credSys - }) - mockCredRead.Setup(&procCredRead) - defer mockCredRead.TearDown() - - // Mock `CredFree`: Must be called as well with the correct pointer - mockCredFree := new(mockProc) - mockCredFree. - On("Call", mock.AnythingOfType("[]uintptr")). - Return(0, 0, nil). - Run(func(args mock.Arguments) { - arg := args.Get(0).([]uintptr) - assert.Equal(t, 1, len(arg)) - assert.Equal(t, uintptr(unsafe.Pointer(credSys)), arg[0]) - }) - mockCredFree.Setup(&procCredFree) - defer mockCredFree.TearDown() - - // Test it: - var res *Credential - var err error - assert.NotPanics(t, func() { res, err = sysCredRead("Foo", sysCRED_TYPE_GENERIC) }) - mockCredRead.AssertNumberOfCalls(t, "Call", 1) - mockCredFree.AssertNumberOfCalls(t, "Call", 1) - assert.NotNil(t, res) - assert.Nil(t, err) - assert.Equal(t, "Foo", res.TargetName) - assert.Equal(t, "Bar", res.Comment) - assert.Equal(t, []byte{1, 2, 3}, res.CredentialBlob) - assert.NotEqual(t, &cred, &res) -} - func TestSysCredWrite_MockFailure(t *testing.T) { // Mock `CreadWrite`: returns failure state and the error mockCredWrite := new(mockProc) @@ -163,80 +89,3 @@ func TestSysCredDelete_Mock(t *testing.T) { assert.Nil(t, err) mockCredDelete.AssertNumberOfCalls(t, "Call", 1) } - -func TestSysCredEnumerate_MockFailure(t *testing.T) { - // The test error - testError := errors.New("test error") - // Mock `CreadEnumerate`: returns failure state and the error - mockCredEnumerate := new(mockProc) - mockCredEnumerate.On("Call", mock.AnythingOfType("[]uintptr")).Return(0, 0, testError) - mockCredEnumerate.Setup(&procCredEnumerate) - defer mockCredEnumerate.TearDown() - // Mock `CredFree`: Must not be called - mockCredFree := new(mockProc) - mockCredFree.On("Call", mock.AnythingOfType("[]uintptr")).Return(0, 0, nil) - mockCredFree.Setup(&procCredFree) - defer mockCredFree.TearDown() - - // Test it: - var res []*Credential - var err error - assert.NotPanics(t, func() { res, err = sysCredEnumerate("", true) }) - assert.Nil(t, res) - assert.NotNil(t, err) - assert.Equal(t, "test error", err.Error()) - mockCredEnumerate.AssertNumberOfCalls(t, "Call", 1) - mockCredFree.AssertNumberOfCalls(t, "Call", 0) -} - -func TestSysCredEnumerate_Mock(t *testing.T) { - // prepare some test data - creds := []*Credential{new(Credential), new(Credential)} - creds[0].TargetName = "Foo" - creds[1].TargetName = "Bar" - credsSys := [](*sysCREDENTIAL){ - sysFromCredential(creds[0]), - sysFromCredential(creds[1]), - } - t.Log(credsSys[0]) // Workaround to keep the object alive - t.Log(credsSys[1]) // Workaround to keep the object alive - - // Mock `CreadEnumerate`: returns success and sets the pointer to the prepared sysCreds array - mockCredEnumerate := new(mockProc) - mockCredEnumerate. - On("Call", mock.AnythingOfType("[]uintptr")). - Return(1, 0, nil). - Run(func(args mock.Arguments) { - arg := args.Get(0).([]uintptr) - assert.Equal(t, 4, len(arg)) - *(*int)(unsafe.Pointer(arg[2])) = len(credsSys) - *(*[]*sysCREDENTIAL)(unsafe.Pointer(arg[3])) = credsSys - }) - mockCredEnumerate.Setup(&procCredEnumerate) - defer mockCredEnumerate.TearDown() - - // Mock `CredFree`: Must be called as well with the correct pointer - mockCredFree := new(mockProc) - mockCredFree. - On("Call", mock.AnythingOfType("[]uintptr")). - Return(0, 0, nil). - Run(func(args mock.Arguments) { - arg := args.Get(0).([]uintptr) - assert.Equal(t, 1, len(arg)) - assert.Equal(t, uintptr(unsafe.Pointer(&credsSys[0])), arg[0]) - }) - mockCredFree.Setup(&procCredFree) - defer mockCredFree.TearDown() - - // Test it: - var res []*Credential - var err error - assert.NotPanics(t, func() { res, err = sysCredEnumerate("", true) }) - mockCredEnumerate.AssertNumberOfCalls(t, "Call", 1) - mockCredFree.AssertNumberOfCalls(t, "Call", 1) - assert.NotNil(t, res) - assert.Nil(t, err) - assert.Equal(t, 2, len(res)) - assert.Equal(t, "Foo", res[0].TargetName) - assert.Equal(t, "Bar", res[1].TargetName) -}