This is xnu-12377.1.9. See this file in:
/*
 * Copyright (c) 2022 Apple Computer, Inc. All rights reserved.
 *
 * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
 *
 * This file contains Original Code and/or Modifications of Original Code
 * as defined in and that are subject to the Apple Public Source License
 * Version 2.0 (the 'License'). You may not use this file except in
 * compliance with the License. The rights granted to you under the License
 * may not be used to create, or enable the creation or redistribution of,
 * unlawful or unlicensed copies of an Apple operating system, or to
 * circumvent, violate, or enable the circumvention or violation of, any
 * terms of an Apple operating system software license agreement.
 *
 * Please obtain a copy of the License at
 * http://www.opensource.apple.com/apsl/ and read it before using this file.
 *
 * The Original Code and all software distributed under the License are
 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
 * Please see the License for the specific language governing rights and
 * limitations under the License.
 *
 * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
 */

#include <mach/thread_act.h>
#include <stdint.h>
#include <stdlib.h>
#include <sys/sysctl.h>

#include "arm_matrix.h"

const static unsigned int SME_Z_VECTORS = 32;
const static unsigned int SME_P_VECTORS = 16;

static unsigned int
sme_version(void)
{
	static unsigned int ret = 0;
	static bool already_read = false;

	if (!already_read) {
		size_t size = sizeof(unsigned int);
		unsigned int feat_sme, feat_sme2;
		sysctlbyname("hw.optional.arm.FEAT_SME", &feat_sme, &size, NULL, 0);
		sysctlbyname("hw.optional.arm.FEAT_SME2", &feat_sme2, &size, NULL, 0);

		if (feat_sme2) {
			ret = 2;
		} else if (feat_sme) {
			ret = 1;
		} else {
			ret = 0;
		}

		already_read = true;
	}

	return ret;
}

static uint16_t
arm_sme_svl_b(void)
{
	uint64_t ret = 0;
	asm volatile (
                "rdsvl	%[ret], #1"
                : [ret] "=r"(ret)
        );
	return (uint16_t)ret;
}

static size_t
sme_za_size(void)
{
	return arm_sme_svl_b() * arm_sme_svl_b();
}

static size_t
sme_z_size(void)
{
	return arm_sme_svl_b() * SME_Z_VECTORS;
}

static size_t
sme_p_size(void)
{
	return arm_sme_svl_b() * SME_P_VECTORS / 8;
}

static size_t
sme_zt0_size(void)
{
	if (sme_version() >= 2) {
		return 64;
	} else {
		return 0;
	}
}

static size_t
sme_tpidr2_size(void)
{
	return sizeof(uint64_t);
}

static inline uint8_t *
sme_za(void *addr)
{
	return addr;
}

static inline const uint8_t *
const_sme_za(const void *addr)
{
	return addr;
}

static inline uint8_t *
sme_z(void *addr)
{
	return sme_za(addr) + sme_za_size();
}

static inline const uint8_t *
const_sme_z(const void *addr)
{
	return const_sme_za(addr) + sme_za_size();
}

static inline uint8_t *
sme_p(void *addr)
{
	return sme_z(addr) + sme_z_size();
}

static inline const uint8_t *
const_sme_p(const void *addr)
{
	return const_sme_z(addr) + sme_z_size();
}

static inline uint8_t *
sme_zt0(void *addr)
{
	return sme_p(addr) + sme_p_size();
}

static inline const uint8_t *
const_sme_zt0(const void *addr)
{
	return const_sme_p(addr) + sme_p_size();
}

static size_t
sme_data_size(void)
{
	return sme_za_size() + sme_z_size() + sme_p_size() + sme_zt0_size() + sme_tpidr2_size();
}

static inline void
set_sme_tpidr2_el0(void *addr, uint64_t val)
{
	uint64_t *ptr = (uint64_t *)(sme_zt0(addr) + sme_zt0_size());
	*ptr = val;
}

static inline uint64_t
get_sme_tpidr2_el0(const void *addr)
{
	const uint64_t *ptr = (const uint64_t *)(const_sme_zt0(addr) + sme_zt0_size());
	return *ptr;
}

static void *
sme_alloc_data(void)
{
	return malloc(sme_data_size());
}

static bool
sme_is_available(void)
{
	return sme_version() > 0;
}

static void
sme_start(void)
{
	asm volatile ("smstart");
}

static void
sme_stop(void)
{
	asm volatile ("smstop");
}

static void
sme_load_one_vector(const void *addr)
{
	asm volatile (
                "mov    w12, #0"                "\n"
                "ldr    za[w12, #0], [%[addr]]" "\n"
                :
                : [addr] "r"(addr)
                : "w12"
        );
}

static void
sme_load_data(const void *addr)
{
	const uint8_t *za = const_sme_za(addr);
	const uint8_t *z = const_sme_z(addr);
	const uint8_t *p = const_sme_p(addr);
	uint16_t svl_b = arm_sme_svl_b();

	for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
		asm volatile (
                        "ldr    za[%w[i],  #0], [%[addr],  #0, mul vl]"   "\n"
                        "ldr    za[%w[i],  #1], [%[addr],  #1, mul vl]"   "\n"
                        "ldr    za[%w[i],  #2], [%[addr],  #2, mul vl]"   "\n"
                        "ldr    za[%w[i],  #3], [%[addr],  #3, mul vl]"   "\n"
                        "ldr    za[%w[i],  #4], [%[addr],  #4, mul vl]"   "\n"
                        "ldr    za[%w[i],  #5], [%[addr],  #5, mul vl]"   "\n"
                        "ldr    za[%w[i],  #6], [%[addr],  #6, mul vl]"   "\n"
                        "ldr    za[%w[i],  #7], [%[addr],  #7, mul vl]"   "\n"
                        "ldr    za[%w[i],  #8], [%[addr],  #8, mul vl]"   "\n"
                        "ldr    za[%w[i],  #9], [%[addr],  #9, mul vl]"   "\n"
                        "ldr    za[%w[i], #10], [%[addr], #10, mul vl]"   "\n"
                        "ldr    za[%w[i], #11], [%[addr], #11, mul vl]"   "\n"
                        "ldr    za[%w[i], #12], [%[addr], #12, mul vl]"   "\n"
                        "ldr    za[%w[i], #13], [%[addr], #13, mul vl]"   "\n"
                        "ldr    za[%w[i], #14], [%[addr], #14, mul vl]"   "\n"
                        "ldr    za[%w[i], #15], [%[addr], #15, mul vl]"   "\n"
                        :
                        : [i] "r"(i),
                          [addr] "r"(za + (i * svl_b))
                );
	}

	asm volatile (
                "ldr    z0, [%[z],   #0, mul vl]"        "\n"
                "ldr    z1, [%[z],   #1, mul vl]"        "\n"
                "ldr    z2, [%[z],   #2, mul vl]"        "\n"
                "ldr    z3, [%[z],   #3, mul vl]"        "\n"
                "ldr    z4, [%[z],   #4, mul vl]"        "\n"
                "ldr    z5, [%[z],   #5, mul vl]"        "\n"
                "ldr    z6, [%[z],   #6, mul vl]"        "\n"
                "ldr    z7, [%[z],   #7, mul vl]"        "\n"
                "ldr    z8, [%[z],   #8, mul vl]"        "\n"
                "ldr    z9, [%[z],   #9, mul vl]"        "\n"
                "ldr   z10, [%[z],  #10, mul vl]"        "\n"
                "ldr   z11, [%[z],  #11, mul vl]"        "\n"
                "ldr   z12, [%[z],  #12, mul vl]"        "\n"
                "ldr   z13, [%[z],  #13, mul vl]"        "\n"
                "ldr   z14, [%[z],  #14, mul vl]"        "\n"
                "ldr   z15, [%[z],  #15, mul vl]"        "\n"
                "ldr   z16, [%[z],  #16, mul vl]"        "\n"
                "ldr   z17, [%[z],  #17, mul vl]"        "\n"
                "ldr   z18, [%[z],  #18, mul vl]"        "\n"
                "ldr   z19, [%[z],  #19, mul vl]"        "\n"
                "ldr   z20, [%[z],  #20, mul vl]"        "\n"
                "ldr   z21, [%[z],  #21, mul vl]"        "\n"
                "ldr   z22, [%[z],  #22, mul vl]"        "\n"
                "ldr   z23, [%[z],  #23, mul vl]"        "\n"
                "ldr   z24, [%[z],  #24, mul vl]"        "\n"
                "ldr   z25, [%[z],  #25, mul vl]"        "\n"
                "ldr   z26, [%[z],  #26, mul vl]"        "\n"
                "ldr   z27, [%[z],  #27, mul vl]"        "\n"
                "ldr   z28, [%[z],  #28, mul vl]"        "\n"
                "ldr   z29, [%[z],  #29, mul vl]"        "\n"
                "ldr   z30, [%[z],  #30, mul vl]"        "\n"
                "ldr   z31, [%[z],  #31, mul vl]"        "\n"
                :
                : [z] "r"(z)
        );

	asm volatile (
                "ldr     p0, [%[p],  #0, mul vl]"        "\n"
                "ldr     p1, [%[p],  #1, mul vl]"        "\n"
                "ldr     p2, [%[p],  #2, mul vl]"        "\n"
                "ldr     p3, [%[p],  #3, mul vl]"        "\n"
                "ldr     p4, [%[p],  #4, mul vl]"        "\n"
                "ldr     p5, [%[p],  #5, mul vl]"        "\n"
                "ldr     p6, [%[p],  #6, mul vl]"        "\n"
                "ldr     p7, [%[p],  #7, mul vl]"        "\n"
                "ldr     p8, [%[p],  #8, mul vl]"        "\n"
                "ldr     p9, [%[p],  #9, mul vl]"        "\n"
                "ldr    p10, [%[p], #10, mul vl]"        "\n"
                "ldr    p11, [%[p], #11, mul vl]"        "\n"
                "ldr    p12, [%[p], #12, mul vl]"        "\n"
                "ldr    p13, [%[p], #13, mul vl]"        "\n"
                "ldr    p14, [%[p], #14, mul vl]"        "\n"
                "ldr    p15, [%[p], #15, mul vl]"        "\n"
                :
                : [p] "r"(p)
        );

	if (sme_zt0_size()) {
		const uint8_t *zt0 = const_sme_zt0(addr);
		asm volatile (
                        "ldr	zt0, [%[zt0]]"
                        :
                        : [zt0] "r"(zt0)
                );
	}

	__builtin_arm_wsr64("TPIDR2_EL0", get_sme_tpidr2_el0(addr));
}

static void
sme_store_data(void *addr)
{
	uint8_t *za = sme_za(addr);
	uint8_t *z = sme_z(addr);
	uint8_t *p = sme_p(addr);
	uint16_t svl_b = arm_sme_svl_b();

	for (register uint16_t i asm("w12") = 0; i < svl_b; i += 16) {
		asm volatile (
                        "str    za[%w[i],  #0], [%[addr],  #0, mul vl]"   "\n"
                        "str    za[%w[i],  #1], [%[addr],  #1, mul vl]"   "\n"
                        "str    za[%w[i],  #2], [%[addr],  #2, mul vl]"   "\n"
                        "str    za[%w[i],  #3], [%[addr],  #3, mul vl]"   "\n"
                        "str    za[%w[i],  #4], [%[addr],  #4, mul vl]"   "\n"
                        "str    za[%w[i],  #5], [%[addr],  #5, mul vl]"   "\n"
                        "str    za[%w[i],  #6], [%[addr],  #6, mul vl]"   "\n"
                        "str    za[%w[i],  #7], [%[addr],  #7, mul vl]"   "\n"
                        "str    za[%w[i],  #8], [%[addr],  #8, mul vl]"   "\n"
                        "str    za[%w[i],  #9], [%[addr],  #9, mul vl]"   "\n"
                        "str    za[%w[i], #10], [%[addr], #10, mul vl]"   "\n"
                        "str    za[%w[i], #11], [%[addr], #11, mul vl]"   "\n"
                        "str    za[%w[i], #12], [%[addr], #12, mul vl]"   "\n"
                        "str    za[%w[i], #13], [%[addr], #13, mul vl]"   "\n"
                        "str    za[%w[i], #14], [%[addr], #14, mul vl]"   "\n"
                        "str    za[%w[i], #15], [%[addr], #15, mul vl]"   "\n"
                        :
                        : [i] "r"(i),
                          [addr] "r"(za + (i * svl_b))
                );
	}

	asm volatile (
                "str    z0, [%[z],   #0, mul vl]"        "\n"
                "str    z1, [%[z],   #1, mul vl]"        "\n"
                "str    z2, [%[z],   #2, mul vl]"        "\n"
                "str    z3, [%[z],   #3, mul vl]"        "\n"
                "str    z4, [%[z],   #4, mul vl]"        "\n"
                "str    z5, [%[z],   #5, mul vl]"        "\n"
                "str    z6, [%[z],   #6, mul vl]"        "\n"
                "str    z7, [%[z],   #7, mul vl]"        "\n"
                "str    z8, [%[z],   #8, mul vl]"        "\n"
                "str    z9, [%[z],   #9, mul vl]"        "\n"
                "str   z10, [%[z],  #10, mul vl]"        "\n"
                "str   z11, [%[z],  #11, mul vl]"        "\n"
                "str   z12, [%[z],  #12, mul vl]"        "\n"
                "str   z13, [%[z],  #13, mul vl]"        "\n"
                "str   z14, [%[z],  #14, mul vl]"        "\n"
                "str   z15, [%[z],  #15, mul vl]"        "\n"
                "str   z16, [%[z],  #16, mul vl]"        "\n"
                "str   z17, [%[z],  #17, mul vl]"        "\n"
                "str   z18, [%[z],  #18, mul vl]"        "\n"
                "str   z19, [%[z],  #19, mul vl]"        "\n"
                "str   z20, [%[z],  #20, mul vl]"        "\n"
                "str   z21, [%[z],  #21, mul vl]"        "\n"
                "str   z22, [%[z],  #22, mul vl]"        "\n"
                "str   z23, [%[z],  #23, mul vl]"        "\n"
                "str   z24, [%[z],  #24, mul vl]"        "\n"
                "str   z25, [%[z],  #25, mul vl]"        "\n"
                "str   z26, [%[z],  #26, mul vl]"        "\n"
                "str   z27, [%[z],  #27, mul vl]"        "\n"
                "str   z28, [%[z],  #28, mul vl]"        "\n"
                "str   z29, [%[z],  #29, mul vl]"        "\n"
                "str   z30, [%[z],  #30, mul vl]"        "\n"
                "str   z31, [%[z],  #31, mul vl]"        "\n"
                :
                : [z] "r"(z)
        );

	asm volatile (
                "str     p0, [%[p],  #0, mul vl]"        "\n"
                "str     p1, [%[p],  #1, mul vl]"        "\n"
                "str     p2, [%[p],  #2, mul vl]"        "\n"
                "str     p3, [%[p],  #3, mul vl]"        "\n"
                "str     p4, [%[p],  #4, mul vl]"        "\n"
                "str     p5, [%[p],  #5, mul vl]"        "\n"
                "str     p6, [%[p],  #6, mul vl]"        "\n"
                "str     p7, [%[p],  #7, mul vl]"        "\n"
                "str     p8, [%[p],  #8, mul vl]"        "\n"
                "str     p9, [%[p],  #9, mul vl]"        "\n"
                "str    p10, [%[p], #10, mul vl]"        "\n"
                "str    p11, [%[p], #11, mul vl]"        "\n"
                "str    p12, [%[p], #12, mul vl]"        "\n"
                "str    p13, [%[p], #13, mul vl]"        "\n"
                "str    p14, [%[p], #14, mul vl]"        "\n"
                "str    p15, [%[p], #15, mul vl]"        "\n"
                :
                : [p] "r"(p)
        );

	if (sme_zt0_size()) {
		uint8_t *zt0 = sme_zt0(addr);
		asm volatile (
                        "str	zt0, [%[zt0]]"
                        :
                        : [zt0] "r"(zt0)
                );
	}

	set_sme_tpidr2_el0(addr, __builtin_arm_rsr64("TPIDR2_EL0"));
}

static kern_return_t
sme_thread_get_state(thread_act_t thread, void *addr)
{
	uint8_t *za = sme_za(addr);
	uint8_t *z = sme_z(addr);
	uint8_t *p = sme_p(addr);
	uint16_t svl_b = arm_sme_svl_b();

	arm_sme_state_t sme_state;
	mach_msg_type_number_t sme_count = ARM_SME_STATE_COUNT;
	kern_return_t err = thread_get_state(thread, ARM_SME_STATE, (thread_state_t)&sme_state, &sme_count);
	if (err) {
		return err;
	}
	set_sme_tpidr2_el0(addr, sme_state.__tpidr2_el0);

	arm_sme_za_state_t za_state;
	mach_msg_type_number_t za_count = ARM_SME_ZA_STATE_COUNT;
	err = thread_get_state(thread, ARM_SME_ZA_STATE1, (thread_state_t)&za_state, &za_count);
	if (err) {
		return err;
	}

	arm_sve_z_state_t z_state1, z_state2;
	mach_msg_type_number_t z_streaming_count = ARM_SVE_Z_STATE_COUNT;
	err = thread_get_state(thread, ARM_SVE_Z_STATE1, (thread_state_t)&z_state1, &z_streaming_count);
	if (err) {
		return err;
	}
	err = thread_get_state(thread, ARM_SVE_Z_STATE2, (thread_state_t)&z_state2, &z_streaming_count);
	if (err) {
		return err;
	}

	arm_sve_p_state_t p_state;
	mach_msg_type_number_t p_streaming_count = ARM_SVE_P_STATE_COUNT;
	err = thread_get_state(thread, ARM_SVE_P_STATE, (thread_state_t)&p_state, &p_streaming_count);
	if (err) {
		return err;
	}

	memcpy(za, za_state.__za, svl_b * svl_b);

	size_t z_elem_size = svl_b;
	for (int i = 0; i < 16; i++) {
		memcpy(z, z_state1.__z[i], z_elem_size);
		z += z_elem_size;
	}
	for (int i = 0; i < 16; i++) {
		memcpy(z, z_state2.__z[i], z_elem_size);
		z += z_elem_size;
	}

	size_t p_elem_size = svl_b / 8;
	for (int i = 0; i < 16; i++) {
		memcpy(p, p_state.__p[i], p_elem_size);
		p += p_elem_size;
	}

	if (sme_zt0_size()) {
		uint8_t *zt0 = sme_zt0(addr);

		arm_sme2_state_t sme2_state;
		mach_msg_type_number_t sme2_count = ARM_SME2_STATE_COUNT;
		err = thread_get_state(thread, ARM_SME2_STATE, (thread_state_t)&sme2_state, &sme2_count);
		if (err) {
			return err;
		}

		memcpy(zt0, sme2_state.__zt0, sizeof(sme2_state.__zt0));
	}

	return KERN_SUCCESS;
}

static kern_return_t
sme_thread_set_state(thread_act_t thread, const void *addr)
{
	const uint8_t *za = const_sme_za(addr);
	const uint8_t *z = const_sme_z(addr);
	const uint8_t *p = const_sme_p(addr);
	uint16_t svl_b = arm_sme_svl_b();

	arm_sme_state_t sme_state;
	sme_state.__svcr = 0x3;
	sme_state.__svl_b = svl_b;
	sme_state.__tpidr2_el0 = get_sme_tpidr2_el0(addr);

	arm_sme_za_state_t za_state;
	memcpy(za_state.__za, za, svl_b * svl_b);

	arm_sve_z_state_t z_state1, z_state2;
	size_t z_elem_size = svl_b;
	for (int i = 0; i < 16; i++) {
		memcpy(z_state1.__z[i], z, z_elem_size);
		z += z_elem_size;
	}
	for (int i = 0; i < 16; i++) {
		memcpy(z_state2.__z[i], z, z_elem_size);
		z += z_elem_size;
	}

	arm_sve_p_state_t p_state;
	size_t p_elem_size = svl_b / 8;
	for (int i = 0; i < 16; i++) {
		memcpy(p_state.__p[i], p, p_elem_size);
		p += p_elem_size;
	}

	kern_return_t err = thread_set_state(thread, ARM_SME_STATE, (thread_state_t)&sme_state, ARM_SME_STATE_COUNT);
	if (err) {
		return err;
	}

	err = thread_set_state(thread, ARM_SVE_Z_STATE1, (thread_state_t)&z_state1, ARM_SVE_Z_STATE_COUNT);
	if (err) {
		return err;
	}

	err = thread_set_state(thread, ARM_SVE_Z_STATE2, (thread_state_t)&z_state2, ARM_SVE_Z_STATE_COUNT);
	if (err) {
		return err;
	}

	err = thread_set_state(thread, ARM_SVE_P_STATE, (thread_state_t)&p_state, ARM_SVE_P_STATE_COUNT);
	if (err) {
		return err;
	}

	err = thread_set_state(thread, ARM_SME_ZA_STATE1, (thread_state_t)&za_state, ARM_SME_ZA_STATE_COUNT);
	if (err) {
		return err;
	}

	if (sme_zt0_size()) {
		const uint8_t *zt0 = const_sme_zt0(addr);

		arm_sme2_state_t sme2_state;
		memcpy(sme2_state.__zt0, zt0, sizeof(sme2_state.__zt0));

		err = thread_set_state(thread, ARM_SME2_STATE, (thread_state_t)&sme2_state, ARM_SME2_STATE_COUNT);
		if (err) {
			return err;
		}
	}

	return KERN_SUCCESS;
}

const struct arm_matrix_operations sme_operations = {
	.name = "SME",

	.data_size = sme_data_size,
	.alloc_data = sme_alloc_data,

	.is_available = sme_is_available,
	.start = sme_start,
	.stop = sme_stop,

	.load_one_vector = sme_load_one_vector,
	.load_data = sme_load_data,
	.store_data = sme_store_data,

	.thread_get_state = sme_thread_get_state,
	.thread_set_state = sme_thread_set_state,
};