diff --git a/src/commands/auth.test.ts b/src/commands/auth.test.ts index 14f4fa4..3fe5469 100644 --- a/src/commands/auth.test.ts +++ b/src/commands/auth.test.ts @@ -1,14 +1,21 @@ +import { Api } from '@neondatabase/api-client'; import axios from 'axios'; -import { vi, beforeAll, describe, afterAll, expect } from 'vitest'; +import { + existsSync, + mkdtempSync, + readFileSync, + rmSync, + writeFileSync, +} from 'node:fs'; import { AddressInfo } from 'node:net'; -import { mkdtempSync, rmSync, readFileSync, writeFileSync } from 'node:fs'; -import { join } from 'path'; import { TokenSet } from 'openid-client'; -import { Api } from '@neondatabase/api-client'; +import { join } from 'path'; +import { afterAll, beforeAll, beforeEach, describe, expect, vi } from 'vitest'; -import { startOauthServer } from '../test_utils/oauth_server'; import { OAuth2Server } from 'oauth2-mock-server'; +import * as authModule from '../auth'; import { test } from '../test_utils/fixtures'; +import { startOauthServer } from '../test_utils/oauth_server'; import { authFlow, ensureAuth } from './auth'; vi.mock('open', () => ({ default: vi.fn((url: string) => axios.get(url)) })); @@ -52,59 +59,198 @@ describe('ensureAuth', () => { let configDir = ''; let oauthServer: OAuth2Server; let mockApiClient: Api; + let authSpy: any; + let refreshTokenSpy: any; beforeAll(async () => { configDir = mkdtempSync('test-config'); oauthServer = await startOauthServer(); mockApiClient = {} as Api; + authSpy = vi.spyOn(authModule, 'auth'); + refreshTokenSpy = vi.spyOn(authModule, 'refreshToken'); }); afterAll(async () => { rmSync(configDir, { recursive: true }); await oauthServer.stop(); + vi.restoreAllMocks(); + }); + + beforeEach(() => { + authSpy.mockClear(); + refreshTokenSpy.mockClear(); + }); + + const setupTestProps = (server: any) => ({ + _: ['some-command'], + configDir, + oauthHost: `http://localhost:${oauthServer.address().port}`, + clientId: 'test-client-id', + forceAuth: true, + apiKey: '', + apiHost: `http://localhost:${(server.address() as AddressInfo).port}`, + help: false, + apiClient: mockApiClient, }); test('should start new auth flow when refresh token fails', async ({ runMockServer, }) => { - // Mock refresh token to fail - vi.mock('../auth.ts', async (importOriginal) => { - const actual = await importOriginal(); - return { - ...actual, - refreshToken: vi.fn(() => - Promise.reject(new Error('AUTH_REFRESH_FAILED')), - ), - }; - }); + refreshTokenSpy.mockImplementationOnce(() => + Promise.reject(new Error('AUTH_REFRESH_FAILED')), + ); + + authSpy.mockImplementationOnce(() => + Promise.resolve( + new TokenSet({ + access_token: 'new-auth-token', + refresh_token: 'new-refresh-token', + expires_at: Math.floor(Date.now() / 1000) + 3600, + }), + ), + ); const server = await runMockServer('main'); - // Setup expired token const expiredTokenSet = new TokenSet({ access_token: 'expired-token', refresh_token: 'refresh-token', expires_at: Math.floor(Date.now() / 1000) - 3600, // 1 hour ago }); + writeFileSync( + join(configDir, 'credentials.json'), + JSON.stringify(expiredTokenSet), + { mode: 0o700 }, + ); + + const props = setupTestProps(server); + await ensureAuth(props); + + expect(refreshTokenSpy).toHaveBeenCalledTimes(1); + expect(authSpy).toHaveBeenCalledTimes(1); + expect(props.apiKey).toBe('new-auth-token'); + }); + + test('should trigger auth flow when credentials.json does not exist', async ({ + runMockServer, + }) => { + const server = await runMockServer('main'); + + // Ensure the credentials file does not exist const credentialsPath = join(configDir, 'credentials.json'); - writeFileSync(credentialsPath, JSON.stringify(expiredTokenSet), { - mode: 0o700, - }); + if (existsSync(credentialsPath)) { + rmSync(credentialsPath); + } - const props = { - _: ['some-command'], - configDir, - oauthHost: `http://localhost:${oauthServer.address().port}`, - clientId: 'test-client-id', - forceAuth: true, - apiKey: '', - apiHost: `http://localhost:${(server.address() as AddressInfo).port}`, - help: false, - apiClient: mockApiClient, - }; + const props = setupTestProps(server); + await ensureAuth(props); + expect(authSpy).toHaveBeenCalledTimes(1); + expect(refreshTokenSpy).not.toHaveBeenCalled(); + expect(props.apiKey).toEqual(expect.any(String)); + }); + + test('should trigger auth flow when credentials.json is invalid', async ({ + runMockServer, + }) => { + const server = await runMockServer('main'); + + // Write an empty credentials file + writeFileSync(join(configDir, 'credentials.json'), '', { mode: 0o700 }); + + const props = setupTestProps(server); await ensureAuth(props); - expect(props.apiKey).not.toBe('expired-token'); + + expect(authSpy).toHaveBeenCalledTimes(1); + expect(refreshTokenSpy).not.toHaveBeenCalled(); expect(props.apiKey).toEqual(expect.any(String)); }); + + test('should try refresh when token is missing access_token but has refresh_token', async ({ + runMockServer, + }) => { + const server = await runMockServer('main'); + const tokenWithoutAccess = new TokenSet({ + refresh_token: 'refresh-token', + }); + + writeFileSync( + join(configDir, 'credentials.json'), + JSON.stringify(tokenWithoutAccess), + { mode: 0o700 }, + ); + + refreshTokenSpy.mockImplementationOnce(() => + Promise.resolve( + new TokenSet({ + access_token: 'refreshed-token', + refresh_token: 'new-refresh-token', + expires_at: Math.floor(Date.now() / 1000) + 3600, + }), + ), + ); + + const props = setupTestProps(server); + await ensureAuth(props); + + expect(refreshTokenSpy).toHaveBeenCalledTimes(1); + expect(authSpy).not.toHaveBeenCalled(); + expect(props.apiKey).toBe('refreshed-token'); + }); + + test('should use existing valid token', async ({ runMockServer }) => { + const server = await runMockServer('main'); + const validTokenSet = new TokenSet({ + access_token: 'valid-token', + refresh_token: 'refresh-token', + expires_at: Math.floor(Date.now() / 1000) + 3600, // 1 hour from now + }); + + writeFileSync( + join(configDir, 'credentials.json'), + JSON.stringify(validTokenSet), + { mode: 0o700 }, + ); + + const props = setupTestProps(server); + await ensureAuth(props); + + expect(authSpy).not.toHaveBeenCalled(); + expect(refreshTokenSpy).not.toHaveBeenCalled(); + expect(props.apiKey).toBe('valid-token'); + }); + + test('should successfully refresh expired token', async ({ + runMockServer, + }) => { + refreshTokenSpy.mockImplementationOnce(() => + Promise.resolve( + new TokenSet({ + access_token: 'new-token', + refresh_token: 'new-refresh-token', + expires_at: Math.floor(Date.now() / 1000) + 3600, + }), + ), + ); + + const server = await runMockServer('main'); + const expiredTokenSet = new TokenSet({ + access_token: 'expired-token', + refresh_token: 'refresh-token', + expires_at: Math.floor(Date.now() / 1000) - 3600, // 1 hour ago + }); + + writeFileSync( + join(configDir, 'credentials.json'), + JSON.stringify(expiredTokenSet), + { mode: 0o700 }, + ); + + const props = setupTestProps(server); + await ensureAuth(props); + + expect(refreshTokenSpy).toHaveBeenCalledTimes(1); + expect(authSpy).not.toHaveBeenCalled(); + expect(props.apiKey).toBe('new-token'); + }); }); diff --git a/src/commands/auth.ts b/src/commands/auth.ts index ebbe297..19674e9 100644 --- a/src/commands/auth.ts +++ b/src/commands/auth.ts @@ -1,16 +1,16 @@ +import { existsSync, readFileSync, writeFileSync } from 'node:fs'; import { join } from 'node:path'; import { createHash } from 'node:crypto'; -import { writeFileSync, existsSync, readFileSync } from 'node:fs'; import { TokenSet } from 'openid-client'; import yargs from 'yargs'; import { Api } from '@neondatabase/api-client'; -import { auth, refreshToken } from '../auth.js'; -import { log } from '../log.js'; import { getApiClient } from '../api.js'; -import { isCi } from '../env.js'; +import { auth, refreshToken } from '../auth.js'; import { CREDENTIALS_FILE } from '../config.js'; +import { isCi } from '../env.js'; +import { log } from '../log.js'; type AuthProps = { _: (string | number)[]; @@ -80,6 +80,78 @@ const preserveCredentials = async ( log.debug('Credentials MD5 hash: %s', md5hash(contents)); }; +type TokenSetContents = { + user_id: string; +} & TokenSet; + +const isCompleteTokenSet = ( + tokenSet: TokenSet, +): tokenSet is Required => { + return !!( + tokenSet.access_token && + tokenSet.refresh_token && + tokenSet.expires_at + ); +}; + +const handleExistingToken = async ( + tokenSet: TokenSet, + props: AuthProps, + credentialsPath: string, +): Promise<{ apiKey: string; apiClient: Api } | null> => { + // Use existing access_token, if present and valid + if (!!tokenSet.access_token && !tokenSet.expired()) { + const apiClient = getApiClient({ + apiKey: tokenSet.access_token, + apiHost: props.apiHost, + }); + + return { apiKey: tokenSet.access_token, apiClient }; + } + + // Either access_token is missing or its expired. Refresh the token + log.debug( + tokenSet.expired() + ? 'Token is expired, attempting refresh' + : 'Token is missing access_token, attempting refresh', + ); + + if (!tokenSet.refresh_token) { + log.debug('TokenSet is missing refresh_token, starting authentication'); + return null; + } + + try { + const refreshedTokenSet = await refreshToken( + { + oauthHost: props.oauthHost, + clientId: props.clientId, + }, + tokenSet, + ); + + if (!isCompleteTokenSet(refreshedTokenSet)) { + log.debug('Refreshed token is invalid or missing access_token'); + return null; + } + + const apiKey = refreshedTokenSet.access_token; + const apiClient = getApiClient({ + apiKey, + apiHost: props.apiHost, + }); + + await preserveCredentials(credentialsPath, refreshedTokenSet, apiClient); + log.debug('Token refresh successful'); + + return { apiKey, apiClient }; + } catch (err: unknown) { + const typedErr = err instanceof Error ? err : new Error('Unknown error'); + log.debug('Failed to refresh token: %s', typedErr.message); + throw new Error('AUTH_REFRESH_FAILED'); + } +}; + export const ensureAuth = async ( props: AuthProps & { apiKey: string; @@ -87,12 +159,15 @@ export const ensureAuth = async ( help: boolean; }, ) => { + // Skip auth for help command or no command if (props._.length === 0 || props.help) { return; } + + // Use existing API key or handle auth command if (props.apiKey || props._[0] === 'auth') { if (props.apiKey) { - log.debug('using an API key to authorize requests'); + log.debug('Using an API key to authorize requests'); } props.apiClient = getApiClient({ apiKey: props.apiKey, @@ -100,71 +175,54 @@ export const ensureAuth = async ( }); return; } + const credentialsPath = join(props.configDir, CREDENTIALS_FILE); + + // Handle case when credentials file exists if (existsSync(credentialsPath)) { log.debug('Trying to read credentials from %s', credentialsPath); try { const contents = readFileSync(credentialsPath, 'utf8'); log.debug('Credentials MD5 hash: %s', md5hash(contents)); - const tokenSet = new TokenSet(JSON.parse(contents)); - if (tokenSet.expired()) { - log.debug('Using refresh token to update access token'); - let refreshedTokenSet; - try { - refreshedTokenSet = await refreshToken( - { - oauthHost: props.oauthHost, - clientId: props.clientId, - }, - tokenSet, - ); - } catch (err: unknown) { - const typedErr = err && err instanceof Error ? err : undefined; - log.error('Failed to refresh token\n%s', typedErr?.message); - log.info('Starting auth flow'); - throw new Error('AUTH_REFRESH_FAILED'); - } - - props.apiKey = refreshedTokenSet.access_token || 'UNKNOWN'; - props.apiClient = getApiClient({ - apiKey: props.apiKey, - apiHost: props.apiHost, - }); - await preserveCredentials( - credentialsPath, - refreshedTokenSet, - props.apiClient, - ); + const tokenSetContents: TokenSetContents = JSON.parse(contents); + const tokenSet = new TokenSet(tokenSetContents); + + // Try to use existing token or refresh it + const result = await handleExistingToken( + tokenSet, + props, + credentialsPath, + ); + if (result) { + props.apiKey = result.apiKey; + props.apiClient = result.apiClient; return; } - const token = tokenSet.access_token || 'UNKNOWN'; - - props.apiKey = token; - props.apiClient = getApiClient({ - apiKey: props.apiKey, - apiHost: props.apiHost, - }); - return; - } catch (e) { + } catch (err) { if ( - (e instanceof Error && e.message.includes('AUTH_REFRESH_FAILED')) || - (e as { code: string }).code === 'ENOENT' + !(err instanceof Error && err.message === 'AUTH_REFRESH_FAILED') && + (err as { code: string }).code !== 'ENOENT' && + !(err instanceof SyntaxError) ) { - props.apiKey = await authFlow(props); - } else { - // throw for any other errors - throw e; + // Throw for any errors except auth refresh failure, missing file, or invalid credentials file + throw err; } + + // Fall through to new auth flow for auth failures + log.debug('Ensure auth failed, starting authentication', err); } } else { log.debug( 'Credentials file %s does not exist, starting authentication', credentialsPath, ); - props.apiKey = await authFlow(props); } + + // Start new auth flow if no valid token exists or refresh failed + const apiKey = await authFlow(props); + props.apiKey = apiKey; props.apiClient = getApiClient({ - apiKey: props.apiKey, + apiKey, apiHost: props.apiHost, }); };