Android KeyStore key import

Source code reference:
https://android.googlesource.com/platform/cts/ + /master/tests/tests/keystore/src/android/keystore/cts/ImportWrappedKeyTest.java

Auxiliary source code reference:
https://android.googlesource.com/platform/frameworks/base/ + /master/core/java/android/security/keymaster/KeymasterDefs.java
https://android.googlesource.com/platform/hardware/interfaces/ + /refs/heads/main/security/keymint/aidl/aidl_api/android.hardware.security.keymint/1/android/hardware/security/keymint/

The source code test is as follows:

import android.security.keystore.KeyGenParameterSpec;
import android.security.keystore.KeyProperties;
import android.security.keystore.SecureKeyImportUnavailableException;
import android.security.keystore.WrappedKeyEntry;
import android.util.Log;


import org.bouncycastle.asn1.ASN1Encoding;
import org.bouncycastle.asn1.DEREncodableVector;
import org.bouncycastle.asn1.DERInteger;
import org.bouncycastle.asn1.DERNull;
import org.bouncycastle.asn1.DEROctetString;
import org.bouncycastle.asn1.DERSequence;
import org.bouncycastle.asn1.DERSet;
import org.bouncycastle.asn1.DERTaggedObject;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.encoders.Hex;

import java.security.Key;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.Signature;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.MGF1ParameterSpec;

import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.OAEPParameterSpec;
import javax.crypto.spec.PSource;
import javax.crypto.spec.SecretKeySpec;

public class TestWrapKey {<!-- -->
    private static final String ALIAS = "my key";
    private static final String WRAPPING_KEY_ALIAS = "my_favorite_wrapping_key";
    public static final String TAG = "TestWrapKey";
    public static final int KM_TAG_PURPOSE = 536870913;
    public static final int KM_TAG_ALGORITHM = 268435458;
    public static final int KM_TAG_KEY_SIZE = 805306371;
    public static final int KM_TAG_DIGEST = 536870917;
    public static final int KM_TAG_NO_AUTH_REQUIRED = 1879048695;
    public static final int KM_PURPOSE_SIGN = 2;
    public static final int KM_PURPOSE_VERIFY = 3;
    public static final int KM_DIGEST_SHA_2_224 = 3;
    public static final int KM_DIGEST_SHA_2_256 = 4;
    public static final int KM_DIGEST_SHA_2_384 = 5;
    public static final int KM_DIGEST_SHA_2_512 = 6;
    public static final int KM_KEY_FORMAT_RAW = 3;

    public static final int KM_PURPOSE_ENCRYPT = 0;
    public static final int KM_PURPOSE_DECRYPT = 1;
    public static final int KM_MODE_ECB = 1;
    public static final int KM_MODE_CBC = 2;
    public static final int KM_TAG_BLOCK_MODE = 536870916;
    public static final int KM_PAD_PKCS7 = 64;
    public static final int KM_PAD_NONE = 1;
    public static final int KM_TAG_PADDING = 536870918;
    public static final int KM_ALGORITHM_AES = 32;
    private static final int GCM_TAG_SIZE = 128;
    private static final int WRAPPED_FORMAT_VERSION = 0;
    public static final int KM_KEY_FORMAT_PKCS8 = 1;
    public static final int KM_ALGORITHM_RSA = 1;
    public static final int KM_PAD_RSA_OAEP = 2;
    public static final int KM_PAD_RSA_PSS = 3;
    public static final int KM_PAD_RSA_PKCS1_1_5_ENCRYPT = 4;
    public static final int KM_PAD_RSA_PKCS1_1_5_SIGN = 5;
    public static final int KM_DIGEST_NONE = 0;
    public static final int KM_DIGEST_MD5 = 1;
    public static final int KM_DIGEST_SHA1 = 2;


    SecureRandom random = new SecureRandom();



    private void showMessage(String message) {<!-- -->
        Log.d(TAG, message);
    }

    // Build a secret key pair
    private KeyPair genKeyPair(String alias, boolean isStrongBoxBacked) throws Exception {<!-- -->
        KeyPairGenerator kpg =
                KeyPairGenerator.getInstance(KeyProperties.KEY_ALGORITHM_RSA, "AndroidKeyStore");
        kpg.initialize(
                new KeyGenParameterSpec.Builder(alias, KeyProperties.PURPOSE_WRAP_KEY)
                        .setDigests(KeyProperties.DIGEST_SHA256)
                        .setEncryptionPaddings(KeyProperties.ENCRYPTION_PADDING_RSA_OAEP)
                        .setBlockModes(KeyProperties.BLOCK_MODE_ECB)
                        .setIsStrongBoxBacked(isStrongBoxBacked)
                        .build());
        return kpg.generateKeyPair();
    }

    private int removeTagType(int tag) {<!-- -->
        int kmTagTypeMask = 0x0FFFFFFF;
        return tag & amp; kmTagTypeMask;
    }




    private DERSequence makeRsaAuthList(int size) {<!-- -->
        DEREncodableVector allPurposes = new DEREncodableVector();
        allPurposes.add(new DERInteger(KM_PURPOSE_ENCRYPT));
        allPurposes.add(new DERInteger(KM_PURPOSE_DECRYPT));
        allPurposes.add(new DERInteger(KM_PURPOSE_SIGN));
        allPurposes.add(new DERInteger(KM_PURPOSE_VERIFY));
        DERSet purposeSet = new DERSet(allPurposes);
        DERTaggedObject purpose =
                new DERTaggedObject(true, removeTagType(KM_TAG_PURPOSE), purposeSet);
        DERTaggedObject algorithm =
                new DERTaggedObject(true, removeTagType(KM_TAG_ALGORITHM),
                        new DERInteger(KM_ALGORITHM_RSA));
        DERTaggedObject keySize =
                new DERTaggedObject(true, removeTagType(KM_TAG_KEY_SIZE), new DERInteger(size));
        DEREncodableVector allBlockModes = new DEREncodableVector();
        allBlockModes.add(new DERInteger(KM_MODE_ECB));
        allBlockModes.add(new DERInteger(KM_MODE_CBC));
        DERSet blockModeSet = new DERSet(allBlockModes);
        DERTaggedObject blockMode =
                new DERTaggedObject(true, removeTagType(KM_TAG_BLOCK_MODE), blockModeSet);
        DEREncodableVector allDigests = new DEREncodableVector();
        allDigests.add(new DERInteger(KM_DIGEST_NONE));
        allDigests.add(new DERInteger(KM_DIGEST_MD5));
        allDigests.add(new DERInteger(KM_DIGEST_SHA1));
        allDigests.add(new DERInteger(KM_DIGEST_SHA_2_224));
        allDigests.add(new DERInteger(KM_DIGEST_SHA_2_256));
        allDigests.add(new DERInteger(KM_DIGEST_SHA_2_384));
        allDigests.add(new DERInteger(KM_DIGEST_SHA_2_512));
        DERSet digestSet = new DERSet(allDigests);
        DERTaggedObject digest =
                new DERTaggedObject(true, removeTagType(KM_TAG_DIGEST), digestSet);
        DEREncodableVector allPaddings = new DEREncodableVector();
        allPaddings.add(new DERInteger(KM_PAD_PKCS7));
        allPaddings.add(new DERInteger(KM_PAD_NONE));
        allPaddings.add(new DERInteger(KM_PAD_RSA_OAEP));
        allPaddings.add(new DERInteger(KM_PAD_RSA_PSS));
        allPaddings.add(new DERInteger(KM_PAD_RSA_PKCS1_1_5_ENCRYPT));
        allPaddings.add(new DERInteger(KM_PAD_RSA_PKCS1_1_5_SIGN));
        DERSet paddingSet = new DERSet(allPaddings);
        DERTaggedObject padding =
                new DERTaggedObject(true, removeTagType(KM_TAG_PADDING), paddingSet);
        DERTaggedObject noAuthRequired =
                new DERTaggedObject(true, removeTagType(KM_TAG_NO_AUTH_REQUIRED), DERNull.INSTANCE);
        //Build sequence
        DEREncodableVector allItems = new DEREncodableVector();
        allItems.add(purpose);
        allItems.add(algorithm);
        allItems.add(keySize);
        allItems.add(blockMode);
        allItems.add(digest);
        allItems.add(padding);
        allItems.add(noAuthRequired);
        return new DERSequence(allItems);
    }



    private DERSequence makeAesAuthList(int size) {<!-- -->
        return makeSymKeyAuthList(size, KM_ALGORITHM_AES);
    }

    private DERSequence makeSymKeyAuthList(int size, int algo) {<!-- -->
        DEREncodableVector allPurposes = new DEREncodableVector();
        allPurposes.add(new DERInteger(KM_PURPOSE_ENCRYPT));
        allPurposes.add(new DERInteger(KM_PURPOSE_DECRYPT));
        DERSet purposeSet = new DERSet(allPurposes);
        DERTaggedObject purpose =
                new DERTaggedObject(true, removeTagType(KM_TAG_PURPOSE), purposeSet);
        DERTaggedObject algorithm =
                new DERTaggedObject(true, removeTagType(KM_TAG_ALGORITHM), new DERInteger(algo));
        DERTaggedObject keySize =
                new DERTaggedObject(true, removeTagType(KM_TAG_KEY_SIZE), new DERInteger(size));
        DEREncodableVector allBlockModes = new DEREncodableVector();
        allBlockModes.add(new DERInteger(KM_MODE_ECB));
        allBlockModes.add(new DERInteger(KM_MODE_CBC));
        DERSet blockModeSet = new DERSet(allBlockModes);
        DERTaggedObject blockMode =
                new DERTaggedObject(true, removeTagType(KM_TAG_BLOCK_MODE), blockModeSet);
        DEREncodableVector allPaddings = new DEREncodableVector();
        allPaddings.add(new DERInteger(KM_PAD_PKCS7));
        allPaddings.add(new DERInteger(KM_PAD_NONE));
        DERSet paddingSet = new DERSet(allPaddings);
        DERTaggedObject padding =
                new DERTaggedObject(true, removeTagType(KM_TAG_PADDING), paddingSet);
        DERTaggedObject noAuthRequired =
                new DERTaggedObject(true, removeTagType(KM_TAG_NO_AUTH_REQUIRED), DERNull.INSTANCE);
        //Build sequence
        DEREncodableVector allItems = new DEREncodableVector();
        allItems.add(purpose);
        allItems.add(algorithm);
        allItems.add(keySize);
        allItems.add(blockMode);
        allItems.add(padding);
        allItems.add(noAuthRequired);
        return new DERSequence(allItems);
    }

    public byte[] wrapKey(PublicKey publicKey, byte[] keyMaterial, byte[] mask,
                          int keyFormat, DERSequence authorizationList) throws Exception {<!-- -->
        return wrapKey(publicKey, keyMaterial, mask, keyFormat, authorizationList, true);
    }

    public byte[] wrapKey(PublicKey publicKey, byte[] keyMaterial, byte[] mask,
                          int keyFormat, DERSequence authorizationList, boolean correctWrappingRequired)
            throws Exception {<!-- -->
        // Build description
        DEREncodableVector descriptionItems = new DEREncodableVector();
        descriptionItems.add(new DERInteger(keyFormat));
        descriptionItems.add(authorizationList);
        DERSequence wrappedKeyDescription = new DERSequence(descriptionItems);
        // Generate 12 byte initialization vector
        byte[] iv = new byte[12];
        random.nextBytes(iv);
        // Generate 256 bit AES key. This is the ephemeral key used to encrypt the secure key.
        byte[] aesKeyBytes = new byte[32];
        random.nextBytes(aesKeyBytes);
        //Encrypt ephemeral keys
        OAEPParameterSpec spec = new OAEPParameterSpec("SHA-256", "MGF1", MGF1ParameterSpec.SHA1, PSource.PSpecified.DEFAULT);
        Cipher pkCipher = Cipher.getInstance("RSA/ECB/OAEPPadding");
        if (correctWrappingRequired) {<!-- -->
            pkCipher.init(Cipher.ENCRYPT_MODE, publicKey, spec);
        } else {<!-- -->
            // Use incorrect OAEPParameters while initializing cipher. By default, main digest and
            // MGF1 digest are SHA-1 here.
            pkCipher.init(Cipher.ENCRYPT_MODE, publicKey);
        }
        byte[] encryptedEphemeralKeys = pkCipher.doFinal(aesKeyBytes);
        // Encrypt secure key
        Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
        SecretKeySpec secretKeySpec = new SecretKeySpec(aesKeyBytes, "AES");
        GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(GCM_TAG_SIZE, iv);
        cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, gcmParameterSpec);
        byte[] aad = wrappedKeyDescription.getEncoded();
        cipher.updateAAD(aad);
        byte[] encryptedSecureKey = cipher.doFinal(keyMaterial);
        // Get GCM tag. Java puts the tag at the end of the ciphertext data :(
        int len = encryptedSecureKey.length;
        int tagSize = (GCM_TAG_SIZE / 8);
        byte[] tag = Arrays.copyOfRange(encryptedSecureKey, len - tagSize, len);
        // Remove GCM tag from end of output
        encryptedSecureKey = Arrays.copyOfRange(encryptedSecureKey, 0, len - tagSize);
        // Build ASN.1 DER encoded sequence WrappedKeyWrapper
        DEREncodableVector items = new DEREncodableVector();
        items.add(new DERInteger(WRAPPED_FORMAT_VERSION));
        items.add(new DEROctetString(encryptedEphemeralKeys));
        items.add(new DEROctetString(iv));
        items.add(wrappedKeyDescription);
        items.add(new DEROctetString(encryptedSecureKey));
        items.add(new DEROctetString(tag));
        return new DERSequence(items).getEncoded(ASN1Encoding.DER);
    }

    public void importWrappedKey(byte[] wrappedKey, String wrappingKeyAlias) throws Exception {<!-- -->
        KeyStore keyStore = KeyStore.getInstance("AndroidKeyStore");
        keyStore.load(null, null);
        AlgorithmParameterSpec spec = new KeyGenParameterSpec.Builder(wrappingKeyAlias,
                KeyProperties.PURPOSE_WRAP_KEY)
                .setDigests(KeyProperties.DIGEST_SHA256)
                .build();
        KeyStore.Entry wrappedKeyEntry = new WrappedKeyEntry(wrappedKey, wrappingKeyAlias,
                "RSA/ECB/OAEPPadding", spec);
        keyStore.setEntry(ALIAS, wrappedKeyEntry, null);
    }
    public void importWrappedKey(byte[] wrappedKey) throws Exception {<!-- -->
        importWrappedKey(wrappedKey, WRAPPING_KEY_ALIAS);
    }

    public void testRSA() throws Exception{<!-- -->
        KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA");
        // Both TEE and Strongbox must support 2048-bit keys.
        int keySize = 2048;
        kpg.initialize(keySize);
        KeyPair kp = kpg.generateKeyPair();
        PublicKey publicKey = kp.getPublic();
        showMessage("publicKey:" + Hex.toHexString(publicKey.getEncoded()));
        PrivateKey privateKey = kp.getPrivate();
        showMessage("privateKey:" + Hex.toHexString(privateKey.getEncoded()));

        byte[] keyMaterial = privateKey.getEncoded();
        byte[] mask = new byte[32]; // Zero mask
        try {<!-- -->
            importWrappedKey(wrapKey(
                    genKeyPair(WRAPPING_KEY_ALIAS, false).getPublic(),
                    keyMaterial,
                    mask,
                    KM_KEY_FORMAT_PKCS8,
                    makeRsaAuthList(keySize)));
        } catch (SecureKeyImportUnavailableException e) {<!-- -->
            e.printStackTrace();
        }
        //Use Key
        KeyStore keyStore = KeyStore.getInstance("AndroidKeyStore");
        keyStore.load(null, null);
        if(!keyStore.containsAlias(ALIAS)){<!-- -->
            showMessage("fail to import RSA key");
            return;
        }
        String plaintext = "hello, world";
        Key importedKey = keyStore.getKey(ALIAS, null);

        // Encrypt with KS private key, then decrypt with local public key.
        Cipher c = Cipher.getInstance("RSA/ECB/PKCS1Padding");
        c.init(Cipher.ENCRYPT_MODE, importedKey);
        byte[] encrypted = c.doFinal(plaintext.getBytes());
        c.init(Cipher.DECRYPT_MODE, publicKey);
        showMessage("is same:" + (new String(c.doFinal(encrypted))).equals(plaintext));

        // Encrypt with local public key, then decrypt with KS private key.
        c.init(Cipher.ENCRYPT_MODE, publicKey);
        encrypted = c.doFinal(plaintext.getBytes());
        c.init(Cipher.DECRYPT_MODE, importedKey);
        showMessage("is same:" + (new String(c.doFinal(encrypted))).equals(plaintext));

        // Sign with KS private key, then verify with local public key.
        Signature s = Signature.getInstance("SHA256withRSA");
        s.initSign((PrivateKey) importedKey);
        s.update(plaintext.getBytes());
        byte[] signature = s.sign();
        s.initVerify(publicKey);
        s.update(plaintext.getBytes());
        showMessage("result:" + (s.verify(signature)));
    }



    public void testAES() throws Exception {<!-- -->
        // Construct the original AES key
        KeyGenerator kg = KeyGenerator.getInstance("AES");
        kg.init(256);
        Key swKey = kg.generateKey();
        byte[] keyMaterial = swKey.getEncoded();
        showMessage("keyMaterialHex:" + Hex.toHexString(keyMaterial));
        // Build the wrapped RSA key
        KeyPair keyPair = genKeyPair(WRAPPING_KEY_ALIAS, true);
        PublicKey aPublic = keyPair.getPublic();
        byte[] aPublicBytes = aPublic.getEncoded();
        showMessage("aPublic:" + Hex.toHexString(aPublicBytes));
        byte[] mask = new byte[32]; // Zero mask
        DERSequence asn1Encodables = makeAesAuthList(keyMaterial.length * 8);
        byte[] bytes = wrapKey(aPublic, keyMaterial, mask, KM_KEY_FORMAT_RAW, asn1Encodables);
        // import
        importWrappedKey(bytes);

        //Use Key
        KeyStore keyStore = KeyStore.getInstance("AndroidKeyStore");
        keyStore.load(null, null);
        if(!keyStore.containsAlias(ALIAS)){<!-- -->
           showMessage("fail to import AES key");
           return;
        }
        Key importedKey = keyStore.getKey(ALIAS, null);
        String plaintext = "hello, world";
        Cipher c = Cipher.getInstance("AES/ECB/PKCS7Padding");
        c.init(Cipher.ENCRYPT_MODE, importedKey);
        byte[] encrypted = c.doFinal(plaintext.getBytes());

        // Decrypt using key imported into keystore.
        c = Cipher.getInstance("AES/ECB/PKCS7Padding");
        c.init(Cipher.DECRYPT_MODE, importedKey);
        showMessage("is same:" + (new String(c.doFinal(encrypted)).equals(plaintext)));


        // Decrypt using local software copy of the key.
        c = Cipher.getInstance("AES/ECB/PKCS7Padding");
        c.init(Cipher.DECRYPT_MODE, swKey);
        showMessage("is same:" + (new String(c.doFinal(encrypted)).equals(plaintext)));
    }
}

Dependent packages:

implementation files('libs/bcpkix-jdk15to18-1.64.jar')
implementation files('libs/bcprov-jdk15to18-1.64.jar')