From d6867c5a23de527c9cef0477d69a4c070a0d4f35 Mon Sep 17 00:00:00 2001 From: Ivan Andreas Uthus Date: Mon, 8 May 2023 23:43:19 +0200 Subject: [PATCH] Added support for private/public key pairs with key identifiers. This can be used instead of using a certificate (virksomhetssertifikat). Updated tests. --- ...inportenClientConfigurationFactoryTests.cs | 44 ++++++++++++++++- .../MaskinportenClientFixture.cs | 49 ++++++++++++++++++- .../MaskinportenClientTests.cs | 26 ++++++++++ .../TestHelper.cs | 39 ++++++++++++--- .../Jwt/JwtRequestTokenGenerator.cs | 37 +++++++++++--- .../MaskinportenClient.cs | 12 ++++- .../MaskinportenClientConfiguration.cs | 27 +++++++++- .../MaskinportenClientConfigurationFactory.cs | 33 ++++++++++--- 8 files changed, 241 insertions(+), 26 deletions(-) diff --git a/KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientConfigurationFactoryTests.cs b/KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientConfigurationFactoryTests.cs index 99a8805..f73cb8e 100644 --- a/KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientConfigurationFactoryTests.cs +++ b/KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientConfigurationFactoryTests.cs @@ -17,7 +17,28 @@ public void CreateVer2Configuration() maskinportenClientConfiguration.Issuer.Should().Be(issuer); maskinportenClientConfiguration.Certificate.Should().Be(certificate); } - + + [Fact] + public void CreateVer2ConfigurationWithKeyPair() + { + const string issuer = "issuer"; + const string keyIdentifier = "some-kid"; + var privateKey = TestHelper.PrivateKey; + var publicKey = TestHelper.PublicKey; + var maskinportenClientConfiguration = MaskinportenClientConfigurationFactory.CreateVer2Configuration( + issuer, + privateKey: privateKey, + publicKey: publicKey, + keyIdentifier: keyIdentifier); + maskinportenClientConfiguration.TokenEndpoint.Should() + .Be(MaskinportenClientConfigurationFactory.VER2_TOKEN_ENDPOINT); + maskinportenClientConfiguration.Audience.Should().Be(MaskinportenClientConfigurationFactory.VER2_AUDIENCE); + maskinportenClientConfiguration.Issuer.Should().Be(issuer); + maskinportenClientConfiguration.PrivateKey.Should().Be(privateKey); + maskinportenClientConfiguration.PublicKey.Should().Be(publicKey); + maskinportenClientConfiguration.KeyIdentifier.Should().Be(keyIdentifier); + } + [Fact] public void CreateProdConfiguration() { @@ -30,5 +51,26 @@ public void CreateProdConfiguration() maskinportenClientConfiguration.Issuer.Should().Be(issuer); maskinportenClientConfiguration.Certificate.Should().Be(certificate); } + + [Fact] + public void CreateProdConfigurationWithKeyPair() + { + const string issuer = "issuer"; + const string keyIdentifier = "some-kid"; + var privateKey = TestHelper.PrivateKey; + var publicKey = TestHelper.PublicKey; + var maskinportenClientConfiguration = MaskinportenClientConfigurationFactory.CreateProdConfiguration( + issuer, + privateKey: privateKey, + publicKey: publicKey, + keyIdentifier: keyIdentifier); + maskinportenClientConfiguration.TokenEndpoint.Should() + .Be(MaskinportenClientConfigurationFactory.PROD_TOKEN_ENDPOINT); + maskinportenClientConfiguration.Audience.Should().Be(MaskinportenClientConfigurationFactory.PROD_AUDIENCE); + maskinportenClientConfiguration.Issuer.Should().Be(issuer); + maskinportenClientConfiguration.PrivateKey.Should().Be(privateKey); + maskinportenClientConfiguration.PublicKey.Should().Be(publicKey); + maskinportenClientConfiguration.KeyIdentifier.Should().Be(keyIdentifier); + } } } \ No newline at end of file diff --git a/KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientFixture.cs b/KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientFixture.cs index b69c7db..945cb40 100644 --- a/KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientFixture.cs +++ b/KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientFixture.cs @@ -2,8 +2,11 @@ using System.Collections.Generic; using System.Net; using System.Net.Http; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; +using JWT.Algorithms; using Moq; using Moq.Protected; using Newtonsoft.Json.Linq; @@ -20,6 +23,9 @@ public class MaskinportenClientFixture private string _audience = "testAudience"; private string _issuer = "testIssuer"; private string? _consumerOrg = null; + private RSA? _privateKey = null; + private RSA? _publicKey = null; + private string? _keyIdentifier = null; public MaskinportenClientFixture() { @@ -53,6 +59,14 @@ public MaskinportenClientFixture WithNumberOfSecondsLeftBeforeExpire(int number) return this; } + public MaskinportenClientFixture WithKeyPair(RSA publicKey, RSA privateKey, string? keyIdentifier = null) + { + _publicKey = publicKey; + _privateKey = privateKey; + _keyIdentifier = keyIdentifier; + return this; + } + public MaskinportenClientFixture WithAudience(string audience) { _audience = audience; @@ -98,15 +112,30 @@ private void SetDefaultValues() private void SetDefaultProperties() { + X509Certificate2? cert = null; + + if (UseCertificate()) + { + cert = _useIncorrectCertificate ? TestHelper.CertificateOtherThanUsedForDecode : TestHelper.Certificate; + } + Configuration = new MaskinportenClientConfiguration( _audience, _tokenEndpoint, _issuer, _numberOfSecondsLeftBeforeExpire, - _useIncorrectCertificate? TestHelper.CertificateOtherThanUsedForDecode : TestHelper.Certificate, + cert, + _publicKey, + _privateKey, + _keyIdentifier, _consumerOrg); } + private bool UseCertificate() + { + return _privateKey == null && _publicKey == null; + } + private void SetResponse() { var responseMessage = new HttpResponseMessage() @@ -128,6 +157,7 @@ private void SetResponse() private string GenerateJsonResponse() { + const string KeyIdentifier = "some-key"; dynamic response = new JObject(); response.Add("expires_in", _expirationTime); @@ -142,7 +172,22 @@ private string GenerateJsonResponse() {"client_orgno", "987654321"}, {"jti", "3Yi-C4E7wAYmCB1Qxaa44VSlmyyGtmrzQQCRN7p4xCY="} }; - var encodedToken = TestHelper.EncodeJwt("mqT5A3LOSIHbpKrscb3EHGrr-WIFRfLdaqZ_5J9GR9s", tokenResponse); + + string encodedToken; + + if (UseCertificate()) + { + encodedToken = TestHelper.EncodeJwt(KeyIdentifier, tokenResponse); + } + else + { + encodedToken = TestHelper.EncodeJwt( + KeyIdentifier, + tokenResponse, + new RSAlgorithmFactory(_publicKey, _privateKey), + new RS256Algorithm(_publicKey, _privateKey)); + } + response.Add("access_token", encodedToken); return response.ToString(); } diff --git a/KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientTests.cs b/KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientTests.cs index 7822afb..e520432 100644 --- a/KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientTests.cs +++ b/KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientTests.cs @@ -5,6 +5,7 @@ using System.Threading; using System.Threading.Tasks; using FluentAssertions; +using JWT.Algorithms; using JWT.Exceptions; using Moq; using Moq.Protected; @@ -30,6 +31,31 @@ public async Task ReturnsAccessToken() accessToken.Should().BeOfType(); } + [Fact] + public async Task ReturnsAccessTokenUsingKeyPair() + { + const string ExpectedAudience = "someAudience"; + var sut = _fixture + .WithAudience(ExpectedAudience) + .WithKeyPair(TestHelper.PublicKey, TestHelper.PrivateKey) + .CreateSut(); + + var accessToken = await sut.GetAccessToken(_fixture.DefaultScopes).ConfigureAwait(false); + accessToken.Should().BeOfType(); + + // This verifies that the audience field is encrypted by the sut av possible to decrypt using key pair. + _fixture.HttpMessageHandleMock.Protected().Verify( + "SendAsync", + Times.Exactly(1), + ItExpr.Is(req => + TestHelper.DeserializedFieldInJwt( + req, + "assertion", + "aud", + new RSAlgorithmFactory(TestHelper.PublicKey, TestHelper.PrivateKey)) == ExpectedAudience), + ItExpr.IsAny()); + } + [Fact] public async Task ReturnsAccesstokenWithNonemptyFields() { diff --git a/KS.Fiks.Maskinporten.Client.Tests/TestHelper.cs b/KS.Fiks.Maskinporten.Client.Tests/TestHelper.cs index 64fac1c..33bb1f3 100644 --- a/KS.Fiks.Maskinporten.Client.Tests/TestHelper.cs +++ b/KS.Fiks.Maskinporten.Client.Tests/TestHelper.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Net.Http; +using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Web; using JWT; @@ -21,6 +22,8 @@ public static class TestHelper private static readonly JwtValidator _validator = new JwtValidator(_serializer, new UtcDateTimeProvider()); private static readonly JwtBase64UrlEncoder _urlEncoder = new JwtBase64UrlEncoder(); private static readonly Logger log = LogManager.GetCurrentClassLogger(); + private static readonly RSA _publicKey = RSA.Create(); + private static readonly RSA _privateKey = RSA.Create(); public static Dictionary RequestContentAsDictionary(HttpRequestMessage request) { @@ -39,6 +42,10 @@ public static Dictionary RequestContentAsDictionary(HttpRequestM "bob-virksomhetssertifikat.p12", "PASSWORD"); + public static RSA PublicKey => _publicKey; + + public static RSA PrivateKey => _privateKey; + public static bool RequestContentIsJwt(HttpRequestMessage request, string jwtFieldName) { var content = RequestContentAsDictionary(request); @@ -46,7 +53,7 @@ public static bool RequestContentIsJwt(HttpRequestMessage request, string jwtFie try { - var decodedJwt = GetDeserializedJwt(serializedJwt); + var decodedJwt = GetDeserializedJwt(serializedJwt, _factory); return decodedJwt?.Length > 0; } catch (Exception ex) @@ -58,15 +65,25 @@ public static bool RequestContentIsJwt(HttpRequestMessage request, string jwtFie public static string DeserializedFieldInJwt(HttpRequestMessage request, string jwtFieldName, string field) { - return DeserializedFieldInJwt(request, jwtFieldName, field); + return DeserializedFieldInJwt(request, jwtFieldName, field, _factory); + } + + public static string DeserializedFieldInJwt(HttpRequestMessage request, string jwtFieldName, string field, IAlgorithmFactory factory) + { + return DeserializedFieldInJwt(request, jwtFieldName, field, factory); } public static string EncodeJwt(string keyId, Dictionary claims) + { + return EncodeJwt(keyId, claims, _factory, new RS256Algorithm(Certificate)); + } + + public static string EncodeJwt(string keyId, Dictionary claims, IAlgorithmFactory factory, IJwtAlgorithm algorithm) { var builder = new JwtBuilder() - .WithAlgorithmFactory(_factory) - .WithAlgorithm(new RS256Algorithm(Certificate)) + .WithAlgorithmFactory(factory) + .WithAlgorithm(algorithm) .WithJsonSerializer(_serializer) .WithValidator(_validator) .WithSecret("passord") @@ -80,19 +97,25 @@ public static string EncodeJwt(string keyId, Dictionary claims) return builder.Encode(); } - public static T DeserializedFieldInJwt(HttpRequestMessage request, string jwtFieldName, string field) + public static T DeserializedFieldInJwt(HttpRequestMessage request, string jwtFieldName, string field, IAlgorithmFactory factory) { var content = RequestContentAsDictionary(request); var serializedJwt = content[jwtFieldName]; - var deserializedJwt = GetDeserializedJwt(serializedJwt); + var deserializedJwt = GetDeserializedJwt(serializedJwt, factory); var jwtAsDictionary = JsonConvert.DeserializeObject>(deserializedJwt); return (T)jwtAsDictionary[field]; } - private static string GetDeserializedJwt(string serializedJwt) + public static T DeserializedFieldInJwt(HttpRequestMessage request, string jwtFieldName, string field) { - var decoder = new JwtDecoder(_serializer, _validator, _urlEncoder, _factory); + return DeserializedFieldInJwt(request, jwtFieldName, field, _factory); + } + + private static string GetDeserializedJwt(string serializedJwt, IAlgorithmFactory factory) + { + var decoder = new JwtDecoder(_serializer, _validator, _urlEncoder, factory); + return decoder.Decode(serializedJwt, "MustBeNonNullButValueDoesNotMatterForRS256", true); } } diff --git a/KS.Fiks.Maskinporten.Client/Jwt/JwtRequestTokenGenerator.cs b/KS.Fiks.Maskinporten.Client/Jwt/JwtRequestTokenGenerator.cs index 0414053..f94fd01 100644 --- a/KS.Fiks.Maskinporten.Client/Jwt/JwtRequestTokenGenerator.cs +++ b/KS.Fiks.Maskinporten.Client/Jwt/JwtRequestTokenGenerator.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using JWT; using JWT.Algorithms; @@ -16,6 +17,9 @@ public class JwtRequestTokenGenerator : IJwtRequestTokenGenerator private readonly JwtEncoder encoder; private readonly X509Certificate2 certificate; + private readonly string keyIdentifier; + private readonly RSA privateKey; + private readonly RSA publicKey; private static IDictionary CreateJwtPayload(TokenRequest tokenRequest, MaskinportenClientConfiguration configuration) { @@ -45,8 +49,21 @@ private static IDictionary CreateJwtPayload(TokenRequest tokenRe public JwtRequestTokenGenerator(X509Certificate2 certificate) { this.certificate = certificate; + this.privateKey = certificate.GetRSAPrivateKey(); + this.publicKey = certificate.GetRSAPublicKey(); this.encoder = new JwtEncoder( - new RS256Algorithm(this.certificate.GetRSAPublicKey(), this.certificate.GetRSAPrivateKey()), + new RS256Algorithm(publicKey, privateKey), + new JsonNetSerializer(), + new JwtBase64UrlEncoder()); + } + + public JwtRequestTokenGenerator(RSA publicKey, RSA privateKey, string keyIdentifier = null) + { + this.keyIdentifier = keyIdentifier; + this.privateKey = privateKey; + this.publicKey = publicKey; + this.encoder = new JwtEncoder( + new RS256Algorithm(publicKey, privateKey), new JsonNetSerializer(), new JwtBase64UrlEncoder()); } @@ -62,13 +79,19 @@ public string CreateEncodedJwt(TokenRequest tokenRequest, MaskinportenClientConf private IDictionary CreateJwtHeader() { - return new Dictionary + Dictionary jwtHeaderValues = new Dictionary(); + + if (certificate != null) + { + jwtHeaderValues.Add("x5c", new List() { Convert.ToBase64String(this.certificate.Export(X509ContentType.Cert)) }); + } + + if (!string.IsNullOrWhiteSpace(keyIdentifier)) { - { - "x5c", - new List() { Convert.ToBase64String(this.certificate.Export(X509ContentType.Cert)) } - } - }; + jwtHeaderValues.Add("kid", keyIdentifier); + } + + return jwtHeaderValues; } } } \ No newline at end of file diff --git a/KS.Fiks.Maskinporten.Client/MaskinportenClient.cs b/KS.Fiks.Maskinporten.Client/MaskinportenClient.cs index 1df8809..36a16e4 100644 --- a/KS.Fiks.Maskinporten.Client/MaskinportenClient.cs +++ b/KS.Fiks.Maskinporten.Client/MaskinportenClient.cs @@ -29,7 +29,17 @@ public MaskinportenClient( _configuration = configuration; _httpClient = httpClient ?? new HttpClient(); _tokenCache = new TokenCache(); - _tokenGenerator = new JwtRequestTokenGenerator(_configuration.Certificate); + if (_configuration.Certificate != null) + { + _tokenGenerator = new JwtRequestTokenGenerator(_configuration.Certificate); + } + else + { + _tokenGenerator = new JwtRequestTokenGenerator( + _configuration.PublicKey, + _configuration.PrivateKey, + _configuration.KeyIdentifier); + } } public async Task GetAccessToken(IEnumerable scopes) diff --git a/KS.Fiks.Maskinporten.Client/MaskinportenClientConfiguration.cs b/KS.Fiks.Maskinporten.Client/MaskinportenClientConfiguration.cs index afaf37b..55a3ab5 100644 --- a/KS.Fiks.Maskinporten.Client/MaskinportenClientConfiguration.cs +++ b/KS.Fiks.Maskinporten.Client/MaskinportenClientConfiguration.cs @@ -1,3 +1,5 @@ +using System; +using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; namespace Ks.Fiks.Maskinporten.Client @@ -9,14 +11,25 @@ public MaskinportenClientConfiguration( string tokenEndpoint, string issuer, int numberOfSecondsLeftBeforeExpire, - X509Certificate2 certificate, + X509Certificate2 certificate = null, + RSA privateKey = null, + RSA publicKey = null, + string keyIdentifier = null, string consumerOrg = null) { + if (certificate == null && (privateKey == null || publicKey == null)) + { + throw new ArgumentException("Either certificate or private and public key must be set!"); + } + Audience = audience; TokenEndpoint = tokenEndpoint; Issuer = issuer; NumberOfSecondsLeftBeforeExpire = numberOfSecondsLeftBeforeExpire; Certificate = certificate; + PrivateKey = privateKey; + PublicKey = publicKey; + KeyIdentifier = keyIdentifier; ConsumerOrg = consumerOrg; } @@ -31,5 +44,17 @@ public MaskinportenClientConfiguration( public int NumberOfSecondsLeftBeforeExpire { get; } public X509Certificate2 Certificate { get; } + + public RSA PublicKey { get; } + + public RSA PrivateKey { get; } + + /// + /// Gets an optional identifier for the key given by the + /// and key pair. Can be used if several keys are set up for your integration. + /// + /// An optional identifier for the key given by the and + /// key pair. Can be used if several keys are set up for your integration. + public string KeyIdentifier { get; } } } \ No newline at end of file diff --git a/KS.Fiks.Maskinporten.Client/MaskinportenClientConfigurationFactory.cs b/KS.Fiks.Maskinporten.Client/MaskinportenClientConfigurationFactory.cs index 349aa4b..2656080 100644 --- a/KS.Fiks.Maskinporten.Client/MaskinportenClientConfigurationFactory.cs +++ b/KS.Fiks.Maskinporten.Client/MaskinportenClientConfigurationFactory.cs @@ -1,3 +1,4 @@ +using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; namespace Ks.Fiks.Maskinporten.Client @@ -12,23 +13,43 @@ public class MaskinportenClientConfigurationFactory public static MaskinportenClientConfiguration CreateVer2Configuration( string issuer, - X509Certificate2 certificate, + X509Certificate2 certificate = null, + RSA privateKey = null, + RSA publicKey = null, + string keyIdentifier = null, int numberOfSecondsLeftBeforeExpire = DEFAULT_NUMBER_SECONDS_LEFT, string consumerOrg = null) { - return new MaskinportenClientConfiguration(VER2_AUDIENCE, - VER2_TOKEN_ENDPOINT, issuer, numberOfSecondsLeftBeforeExpire, certificate, + return new MaskinportenClientConfiguration( + VER2_AUDIENCE, + VER2_TOKEN_ENDPOINT, + issuer, + numberOfSecondsLeftBeforeExpire, + certificate, + privateKey, + publicKey, + keyIdentifier, consumerOrg); } public static MaskinportenClientConfiguration CreateProdConfiguration( string issuer, - X509Certificate2 certificate, + X509Certificate2 certificate = null, + RSA privateKey = null, + RSA publicKey = null, + string keyIdentifier = null, int numberOfSecondsLeftBeforeExpire = DEFAULT_NUMBER_SECONDS_LEFT, string consumerOrg = null) { - return new MaskinportenClientConfiguration(PROD_AUDIENCE, - PROD_TOKEN_ENDPOINT, issuer, numberOfSecondsLeftBeforeExpire, certificate, + return new MaskinportenClientConfiguration( + PROD_AUDIENCE, + PROD_TOKEN_ENDPOINT, + issuer, + numberOfSecondsLeftBeforeExpire, + certificate, + privateKey, + publicKey, + keyIdentifier, consumerOrg); } }