diff --git a/src/com/amazon/corretto/crypto/provider/EvpKeyFactory.java b/src/com/amazon/corretto/crypto/provider/EvpKeyFactory.java index 8505e7a2..ae2413e8 100644 --- a/src/com/amazon/corretto/crypto/provider/EvpKeyFactory.java +++ b/src/com/amazon/corretto/crypto/provider/EvpKeyFactory.java @@ -27,6 +27,8 @@ import java.security.spec.X509EncodedKeySpec; abstract class EvpKeyFactory extends KeyFactorySpi { + private static final String PKCS8_FORMAT = "PKCS#8"; + private static final String X509_FORMAT = "X.509"; private final EvpKeyType type; private final AmazonCorrettoCryptoProvider provider; @@ -91,11 +93,11 @@ protected PublicKey engineGeneratePublic(KeySpec keySpec) throws InvalidKeySpecE protected T engineGetKeySpec(Key key, Class keySpec) throws InvalidKeySpecException { if (keySpec.isAssignableFrom(PKCS8EncodedKeySpec.class) - && key.getFormat().equalsIgnoreCase("PKCS#8")) { + && PKCS8_FORMAT.equalsIgnoreCase(key.getFormat())) { return keySpec.cast(new PKCS8EncodedKeySpec(requireNonNullEncoding(key))); } if (keySpec.isAssignableFrom(X509EncodedKeySpec.class) - && key.getFormat().equalsIgnoreCase("X.509")) { + && X509_FORMAT.equalsIgnoreCase(key.getFormat())) { return keySpec.cast(new X509EncodedKeySpec(requireNonNullEncoding(key))); } @@ -110,10 +112,10 @@ protected Key engineTranslateKey(Key key) throws InvalidKeyException { try { final EvpKey result; - if (key.getFormat().equalsIgnoreCase("PKCS#8")) { + if (PKCS8_FORMAT.equalsIgnoreCase(key.getFormat())) { result = (EvpKey) engineGeneratePrivate(new PKCS8EncodedKeySpec(requireNonNullEncoding(key))); - } else if (key.getFormat().equalsIgnoreCase("X.509")) { + } else if (X509_FORMAT.equalsIgnoreCase(key.getFormat())) { result = (EvpKey) engineGeneratePublic(new X509EncodedKeySpec(requireNonNullEncoding(key))); } else { throw new InvalidKeyException("Cannot convert key of format " + key.getFormat()); diff --git a/tst/com/amazon/corretto/crypto/provider/test/EvpKeyFactoryTest.java b/tst/com/amazon/corretto/crypto/provider/test/EvpKeyFactoryTest.java index b9423915..4c6ccaeb 100644 --- a/tst/com/amazon/corretto/crypto/provider/test/EvpKeyFactoryTest.java +++ b/tst/com/amazon/corretto/crypto/provider/test/EvpKeyFactoryTest.java @@ -249,6 +249,45 @@ public void testX509Encoding(final KeyPair keyPair, final String testName) throw new X509EncodedKeySpec(Arrays.copyOf(validSpec, validSpec.length - 1)))); } + @ParameterizedTest(name = "{1}") + @MethodSource("allPairs") + public void nullAlgorithm(final KeyPair keyPair, final String testName) throws Exception { + final KeyFactory nativeFactory = + KeyFactory.getInstance(keyPair.getPublic().getAlgorithm(), NATIVE_PROVIDER); + final Key nullPublicKey = new NullDataKey(keyPair.getPublic(), true, false, false); + final Key nullPrivateKey = new NullDataKey(keyPair.getPrivate(), true, false, false); + + assertThrows(InvalidKeyException.class, () -> nativeFactory.translateKey(nullPublicKey)); + + assertThrows(InvalidKeyException.class, () -> nativeFactory.translateKey(nullPrivateKey)); + } + + @ParameterizedTest(name = "{1}") + @MethodSource("allPairs") + public void nullFormat(final KeyPair keyPair, final String testName) throws Exception { + final KeyFactory nativeFactory = + KeyFactory.getInstance(keyPair.getPublic().getAlgorithm(), NATIVE_PROVIDER); + final Key nullPublicKey = new NullDataKey(keyPair.getPublic(), false, true, false); + final Key nullPrivateKey = new NullDataKey(keyPair.getPrivate(), false, true, false); + + assertThrows(InvalidKeyException.class, () -> nativeFactory.translateKey(nullPublicKey)); + + assertThrows(InvalidKeyException.class, () -> nativeFactory.translateKey(nullPrivateKey)); + } + + @ParameterizedTest(name = "{1}") + @MethodSource("allPairs") + public void nullEncoding(final KeyPair keyPair, final String testName) throws Exception { + final KeyFactory nativeFactory = + KeyFactory.getInstance(keyPair.getPublic().getAlgorithm(), NATIVE_PROVIDER); + final Key nullPublicKey = new NullDataKey(keyPair.getPublic(), false, false, true); + final Key nullPrivateKey = new NullDataKey(keyPair.getPrivate(), false, false, true); + + assertThrows(InvalidKeyException.class, () -> nativeFactory.translateKey(nullPublicKey)); + + assertThrows(InvalidKeyException.class, () -> nativeFactory.translateKey(nullPrivateKey)); + } + @ParameterizedTest(name = "{1}") @MethodSource("allPairs") public void testPKCS8Encoding(final KeyPair keyPair, final String testName) throws Exception { @@ -670,4 +709,38 @@ private static class Samples { this.jceSample = jceSample; } } + + public static class NullDataKey implements Key { + private static final long serialVersionUID = 1; + private final Key delegate; + private final boolean nullAlgorithm; + private final boolean nullFormat; + private final boolean nullEncoded; + + public NullDataKey( + final Key delegate, + final boolean nullAlgorithm, + final boolean nullFormat, + final boolean nullEncoded) { + this.delegate = delegate; + this.nullAlgorithm = nullAlgorithm; + this.nullFormat = nullFormat; + this.nullEncoded = nullEncoded; + } + + @Override + public String getAlgorithm() { + return nullAlgorithm ? null : delegate.getAlgorithm(); + } + + @Override + public String getFormat() { + return nullFormat ? null : delegate.getFormat(); + } + + @Override + public byte[] getEncoded() { + return nullEncoded ? null : delegate.getEncoded(); + } + } }