Skip to content

Commit

Permalink
Merge pull request #128 from ivanut/AddSupportForKeyPairs
Browse files Browse the repository at this point in the history
Added support for private/public key pairs with key identifiers
  • Loading branch information
exoen authored May 31, 2023
2 parents 722805f + d6867c5 commit c61fe4c
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,28 @@ public void CreateVer2Configuration()
maskinportenClientConfiguration.Issuer.Should().Be(issuer);
maskinportenClientConfiguration.Certificate.Should().Be(certificate);
}


[Fact]
public void CreateVer2ConfigurationWithKeyPair()
{
const string issuer = "issuer";
const string keyIdentifier = "some-kid";
var privateKey = TestHelper.PrivateKey;
var publicKey = TestHelper.PublicKey;
var maskinportenClientConfiguration = MaskinportenClientConfigurationFactory.CreateVer2Configuration(
issuer,
privateKey: privateKey,
publicKey: publicKey,
keyIdentifier: keyIdentifier);
maskinportenClientConfiguration.TokenEndpoint.Should()
.Be(MaskinportenClientConfigurationFactory.VER2_TOKEN_ENDPOINT);
maskinportenClientConfiguration.Audience.Should().Be(MaskinportenClientConfigurationFactory.VER2_AUDIENCE);
maskinportenClientConfiguration.Issuer.Should().Be(issuer);
maskinportenClientConfiguration.PrivateKey.Should().Be(privateKey);
maskinportenClientConfiguration.PublicKey.Should().Be(publicKey);
maskinportenClientConfiguration.KeyIdentifier.Should().Be(keyIdentifier);
}

[Fact]
public void CreateProdConfiguration()
{
Expand All @@ -30,5 +51,26 @@ public void CreateProdConfiguration()
maskinportenClientConfiguration.Issuer.Should().Be(issuer);
maskinportenClientConfiguration.Certificate.Should().Be(certificate);
}

[Fact]
public void CreateProdConfigurationWithKeyPair()
{
const string issuer = "issuer";
const string keyIdentifier = "some-kid";
var privateKey = TestHelper.PrivateKey;
var publicKey = TestHelper.PublicKey;
var maskinportenClientConfiguration = MaskinportenClientConfigurationFactory.CreateProdConfiguration(
issuer,
privateKey: privateKey,
publicKey: publicKey,
keyIdentifier: keyIdentifier);
maskinportenClientConfiguration.TokenEndpoint.Should()
.Be(MaskinportenClientConfigurationFactory.PROD_TOKEN_ENDPOINT);
maskinportenClientConfiguration.Audience.Should().Be(MaskinportenClientConfigurationFactory.PROD_AUDIENCE);
maskinportenClientConfiguration.Issuer.Should().Be(issuer);
maskinportenClientConfiguration.PrivateKey.Should().Be(privateKey);
maskinportenClientConfiguration.PublicKey.Should().Be(publicKey);
maskinportenClientConfiguration.KeyIdentifier.Should().Be(keyIdentifier);
}
}
}
49 changes: 47 additions & 2 deletions KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
using System.Collections.Generic;
using System.Net;
using System.Net.Http;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using JWT.Algorithms;
using Moq;
using Moq.Protected;
using Newtonsoft.Json.Linq;
Expand All @@ -20,6 +23,9 @@ public class MaskinportenClientFixture
private string _audience = "testAudience";
private string _issuer = "testIssuer";
private string? _consumerOrg = null;
private RSA? _privateKey = null;
private RSA? _publicKey = null;
private string? _keyIdentifier = null;

public MaskinportenClientFixture()
{
Expand Down Expand Up @@ -53,6 +59,14 @@ public MaskinportenClientFixture WithNumberOfSecondsLeftBeforeExpire(int number)
return this;
}

public MaskinportenClientFixture WithKeyPair(RSA publicKey, RSA privateKey, string? keyIdentifier = null)
{
_publicKey = publicKey;
_privateKey = privateKey;
_keyIdentifier = keyIdentifier;
return this;
}

public MaskinportenClientFixture WithAudience(string audience)
{
_audience = audience;
Expand Down Expand Up @@ -98,15 +112,30 @@ private void SetDefaultValues()

private void SetDefaultProperties()
{
X509Certificate2? cert = null;

if (UseCertificate())
{
cert = _useIncorrectCertificate ? TestHelper.CertificateOtherThanUsedForDecode : TestHelper.Certificate;
}

Configuration = new MaskinportenClientConfiguration(
_audience,
_tokenEndpoint,
_issuer,
_numberOfSecondsLeftBeforeExpire,
_useIncorrectCertificate? TestHelper.CertificateOtherThanUsedForDecode : TestHelper.Certificate,
cert,
_publicKey,
_privateKey,
_keyIdentifier,
_consumerOrg);
}

private bool UseCertificate()
{
return _privateKey == null && _publicKey == null;
}

private void SetResponse()
{
var responseMessage = new HttpResponseMessage()
Expand All @@ -128,6 +157,7 @@ private void SetResponse()

private string GenerateJsonResponse()
{
const string KeyIdentifier = "some-key";
dynamic response = new JObject();
response.Add("expires_in", _expirationTime);

Expand All @@ -142,7 +172,22 @@ private string GenerateJsonResponse()
{"client_orgno", "987654321"},
{"jti", "3Yi-C4E7wAYmCB1Qxaa44VSlmyyGtmrzQQCRN7p4xCY="}
};
var encodedToken = TestHelper.EncodeJwt("mqT5A3LOSIHbpKrscb3EHGrr-WIFRfLdaqZ_5J9GR9s", tokenResponse);

string encodedToken;

if (UseCertificate())
{
encodedToken = TestHelper.EncodeJwt(KeyIdentifier, tokenResponse);
}
else
{
encodedToken = TestHelper.EncodeJwt(
KeyIdentifier,
tokenResponse,
new RSAlgorithmFactory(_publicKey, _privateKey),
new RS256Algorithm(_publicKey, _privateKey));
}

response.Add("access_token", encodedToken);
return response.ToString();
}
Expand Down
26 changes: 26 additions & 0 deletions KS.Fiks.Maskinporten.Client.Tests/MaskinportenClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using JWT.Algorithms;
using JWT.Exceptions;
using Moq;
using Moq.Protected;
Expand All @@ -30,6 +31,31 @@ public async Task ReturnsAccessToken()
accessToken.Should().BeOfType<MaskinportenToken>();
}

[Fact]
public async Task ReturnsAccessTokenUsingKeyPair()
{
const string ExpectedAudience = "someAudience";
var sut = _fixture
.WithAudience(ExpectedAudience)
.WithKeyPair(TestHelper.PublicKey, TestHelper.PrivateKey)
.CreateSut();

var accessToken = await sut.GetAccessToken(_fixture.DefaultScopes).ConfigureAwait(false);
accessToken.Should().BeOfType<MaskinportenToken>();

// This verifies that the audience field is encrypted by the sut av possible to decrypt using key pair.
_fixture.HttpMessageHandleMock.Protected().Verify(
"SendAsync",
Times.Exactly(1),
ItExpr.Is<HttpRequestMessage>(req =>
TestHelper.DeserializedFieldInJwt(
req,
"assertion",
"aud",
new RSAlgorithmFactory(TestHelper.PublicKey, TestHelper.PrivateKey)) == ExpectedAudience),
ItExpr.IsAny<CancellationToken>());
}

[Fact]
public async Task ReturnsAccesstokenWithNonemptyFields()
{
Expand Down
39 changes: 31 additions & 8 deletions KS.Fiks.Maskinporten.Client.Tests/TestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Web;
using JWT;
Expand All @@ -21,6 +22,8 @@ public static class TestHelper
private static readonly JwtValidator _validator = new JwtValidator(_serializer, new UtcDateTimeProvider());
private static readonly JwtBase64UrlEncoder _urlEncoder = new JwtBase64UrlEncoder();
private static readonly Logger log = LogManager.GetCurrentClassLogger();
private static readonly RSA _publicKey = RSA.Create();
private static readonly RSA _privateKey = RSA.Create();

public static Dictionary<string, string> RequestContentAsDictionary(HttpRequestMessage request)
{
Expand All @@ -39,14 +42,18 @@ public static Dictionary<string, string> RequestContentAsDictionary(HttpRequestM
"bob-virksomhetssertifikat.p12",
"PASSWORD");

public static RSA PublicKey => _publicKey;

public static RSA PrivateKey => _privateKey;

public static bool RequestContentIsJwt(HttpRequestMessage request, string jwtFieldName)
{
var content = RequestContentAsDictionary(request);
var serializedJwt = content[jwtFieldName];

try
{
var decodedJwt = GetDeserializedJwt(serializedJwt);
var decodedJwt = GetDeserializedJwt(serializedJwt, _factory);
return decodedJwt?.Length > 0;
}
catch (Exception ex)
Expand All @@ -58,15 +65,25 @@ public static bool RequestContentIsJwt(HttpRequestMessage request, string jwtFie

public static string DeserializedFieldInJwt(HttpRequestMessage request, string jwtFieldName, string field)
{
return DeserializedFieldInJwt<string>(request, jwtFieldName, field);
return DeserializedFieldInJwt(request, jwtFieldName, field, _factory);
}

public static string DeserializedFieldInJwt(HttpRequestMessage request, string jwtFieldName, string field, IAlgorithmFactory factory)
{
return DeserializedFieldInJwt<string>(request, jwtFieldName, field, factory);
}

public static string EncodeJwt(string keyId, Dictionary<string, object> claims)
{
return EncodeJwt(keyId, claims, _factory, new RS256Algorithm(Certificate));
}

public static string EncodeJwt(string keyId, Dictionary<string, object> claims, IAlgorithmFactory factory, IJwtAlgorithm algorithm)
{

var builder = new JwtBuilder()
.WithAlgorithmFactory(_factory)
.WithAlgorithm(new RS256Algorithm(Certificate))
.WithAlgorithmFactory(factory)
.WithAlgorithm(algorithm)
.WithJsonSerializer(_serializer)
.WithValidator(_validator)
.WithSecret("passord")
Expand All @@ -80,19 +97,25 @@ public static string EncodeJwt(string keyId, Dictionary<string, object> claims)
return builder.Encode();
}

public static T DeserializedFieldInJwt<T>(HttpRequestMessage request, string jwtFieldName, string field)
public static T DeserializedFieldInJwt<T>(HttpRequestMessage request, string jwtFieldName, string field, IAlgorithmFactory factory)
{
var content = RequestContentAsDictionary(request);
var serializedJwt = content[jwtFieldName];
var deserializedJwt = GetDeserializedJwt(serializedJwt);
var deserializedJwt = GetDeserializedJwt(serializedJwt, factory);

var jwtAsDictionary = JsonConvert.DeserializeObject<Dictionary<string, object>>(deserializedJwt);
return (T)jwtAsDictionary[field];
}

private static string GetDeserializedJwt(string serializedJwt)
public static T DeserializedFieldInJwt<T>(HttpRequestMessage request, string jwtFieldName, string field)
{
var decoder = new JwtDecoder(_serializer, _validator, _urlEncoder, _factory);
return DeserializedFieldInJwt<T>(request, jwtFieldName, field, _factory);
}

private static string GetDeserializedJwt(string serializedJwt, IAlgorithmFactory factory)
{
var decoder = new JwtDecoder(_serializer, _validator, _urlEncoder, factory);

return decoder.Decode(serializedJwt, "MustBeNonNullButValueDoesNotMatterForRS256", true);
}
}
Expand Down
37 changes: 30 additions & 7 deletions KS.Fiks.Maskinporten.Client/Jwt/JwtRequestTokenGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using JWT;
using JWT.Algorithms;
Expand All @@ -16,6 +17,9 @@ public class JwtRequestTokenGenerator : IJwtRequestTokenGenerator

private readonly JwtEncoder encoder;
private readonly X509Certificate2 certificate;
private readonly string keyIdentifier;
private readonly RSA privateKey;
private readonly RSA publicKey;

private static IDictionary<string, object> CreateJwtPayload(TokenRequest tokenRequest, MaskinportenClientConfiguration configuration)
{
Expand Down Expand Up @@ -45,8 +49,21 @@ private static IDictionary<string, object> CreateJwtPayload(TokenRequest tokenRe
public JwtRequestTokenGenerator(X509Certificate2 certificate)
{
this.certificate = certificate;
this.privateKey = certificate.GetRSAPrivateKey();
this.publicKey = certificate.GetRSAPublicKey();
this.encoder = new JwtEncoder(
new RS256Algorithm(this.certificate.GetRSAPublicKey(), this.certificate.GetRSAPrivateKey()),
new RS256Algorithm(publicKey, privateKey),
new JsonNetSerializer(),
new JwtBase64UrlEncoder());
}

public JwtRequestTokenGenerator(RSA publicKey, RSA privateKey, string keyIdentifier = null)
{
this.keyIdentifier = keyIdentifier;
this.privateKey = privateKey;
this.publicKey = publicKey;
this.encoder = new JwtEncoder(
new RS256Algorithm(publicKey, privateKey),
new JsonNetSerializer(),
new JwtBase64UrlEncoder());
}
Expand All @@ -62,13 +79,19 @@ public string CreateEncodedJwt(TokenRequest tokenRequest, MaskinportenClientConf

private IDictionary<string, object> CreateJwtHeader()
{
return new Dictionary<string, object>
Dictionary<string, object> jwtHeaderValues = new Dictionary<string, object>();

if (certificate != null)
{
jwtHeaderValues.Add("x5c", new List<string>() { Convert.ToBase64String(this.certificate.Export(X509ContentType.Cert)) });
}

if (!string.IsNullOrWhiteSpace(keyIdentifier))
{
{
"x5c",
new List<string>() { Convert.ToBase64String(this.certificate.Export(X509ContentType.Cert)) }
}
};
jwtHeaderValues.Add("kid", keyIdentifier);
}

return jwtHeaderValues;
}
}
}
12 changes: 11 additions & 1 deletion KS.Fiks.Maskinporten.Client/MaskinportenClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,17 @@ public MaskinportenClient(
_configuration = configuration;
_httpClient = httpClient ?? new HttpClient();
_tokenCache = new TokenCache();
_tokenGenerator = new JwtRequestTokenGenerator(_configuration.Certificate);
if (_configuration.Certificate != null)
{
_tokenGenerator = new JwtRequestTokenGenerator(_configuration.Certificate);
}
else
{
_tokenGenerator = new JwtRequestTokenGenerator(
_configuration.PublicKey,
_configuration.PrivateKey,
_configuration.KeyIdentifier);
}
}

public async Task<MaskinportenToken> GetAccessToken(IEnumerable<string> scopes)
Expand Down
Loading

0 comments on commit c61fe4c

Please sign in to comment.