This is xnu-12377.1.9. See this file in:
/*
 * Copyright (c) 2024 Apple Inc. All rights reserved.
 *
 * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
 *
 * This file contains Original Code and/or Modifications of Original Code
 * as defined in and that are subject to the Apple Public Source License
 * Version 2.0 (the 'License'). You may not use this file except in
 * compliance with the License. The rights granted to you under the License
 * may not be used to create, or enable the creation or redistribution of,
 * unlawful or unlicensed copies of an Apple operating system, or to
 * circumvent, violate, or enable the circumvention or violation of, any
 * terms of an Apple operating system software license agreement.
 *
 * Please obtain a copy of the License at
 * http://www.opensource.apple.com/apsl/ and read it before using this file.
 *
 * The Original Code and all software distributed under the License are
 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
 * Please see the License for the specific language governing rights and
 * limitations under the License.
 *
 * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
 */

#include <vsock_helpers.h>

static int
_vsock_new_socket(uint16_t protocol)
{
	int sock = socket(AF_VSOCK, SOCK_STREAM, protocol);
	if (sock < 0 && errno == ENODEV) {
		T_SKIP("no vsock transport available");
	}
	return sock;
}

int
vsock_new_socket(void)
{
	int sock = _vsock_new_socket(VSOCK_PROTO_STANDARD);
	T_ASSERT_GT(sock, 0, "create new vsock socket");
	return sock;
}

int
vsock_private_new_socket(void)
{
	int sock = _vsock_new_socket(VSOCK_PROTO_PRIVATE);
	T_ASSERT_GT(sock, 0, "create new private vsock socket");
	return sock;
}

uint32_t
vsock_get_local_cid(int socket)
{
	uint32_t cid = 0;
	int result = ioctl(socket, IOCTL_VM_SOCKETS_GET_LOCAL_CID, &cid);
	T_ASSERT_POSIX_SUCCESS(result, "vsock ioctl cid successful");
	T_ASSERT_GT(cid, VMADDR_CID_HOST, "cid is set");
	T_ASSERT_NE(cid, VMADDR_CID_ANY, "cid is valid");

	return cid;
}

static int
_vsock_bind(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int *socket, uint16_t protocol)
{
	int sock = _vsock_new_socket(protocol);
	if (sock < 0) {
		return sock;
	}
	T_ASSERT_GT(sock, 0, "create new vsock socket");
	*socket = sock;

	bzero(addr, sizeof(*addr));
	addr->svm_port = port;
	addr->svm_cid = cid;

	return bind(*socket, (struct sockaddr *) addr, sizeof(*addr));
}

int
vsock_bind(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int *socket)
{
	return _vsock_bind(cid, port, addr, socket, VSOCK_PROTO_STANDARD);
}

int
vsock_private_bind(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int *socket)
{
	return _vsock_bind(cid, port, addr, socket, VSOCK_PROTO_PRIVATE);
}

int
vsock_listen(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int backlog, int *socket)
{
	int result = vsock_bind(cid, port, addr, socket);
	T_ASSERT_POSIX_SUCCESS(result, "vsock bind");
	return listen(*socket, backlog);
}

int
vsock_private_listen(uint32_t cid, uint32_t port, struct sockaddr_vm * addr, int backlog, int *socket)
{
	int result = vsock_private_bind(cid, port, addr, socket);
	T_ASSERT_POSIX_SUCCESS(result, "private vsock bind");
	return listen(*socket, backlog);
}

static int
_vsock_connect(uint32_t cid, uint32_t port, int *socket, uint16_t protocol)
{
	int sock = _vsock_new_socket(protocol);
	if (sock < 0) {
		return sock;
	}
	T_ASSERT_GT(sock, 0, "create new vsock socket");
	*socket = sock;

	struct sockaddr_vm addr = (struct sockaddr_vm) {
		.svm_cid = cid,
		.svm_port = port,
	};
	return connect(*socket, (struct sockaddr *)&addr, sizeof(addr));
}

int
vsock_connect(uint32_t cid, uint32_t port, int *socket)
{
	return _vsock_connect(cid, port, socket, VSOCK_PROTO_STANDARD);
}

int
vsock_private_connect(uint32_t cid, uint32_t port, int *socket)
{
	return _vsock_connect(cid, port, socket, VSOCK_PROTO_PRIVATE);
}

struct sockaddr_vm
vsock_getsockname(int socket)
{
	struct sockaddr_vm addr;
	socklen_t length = sizeof(addr);
	int result = getsockname(socket, (struct sockaddr *)&addr, &length);
	T_ASSERT_POSIX_SUCCESS(result, "vsock getsockname");
	T_ASSERT_EQ_INT((int) sizeof(addr), length, "correct address length");
	T_ASSERT_GT(addr.svm_port, 0, "bound to non-zero local port");
	return addr;
}

void
vsock_close(int socket)
{
	int result = close(socket);
	T_ASSERT_POSIX_SUCCESS(result, "vsock close");
}

void
vsock_connect_peers(uint32_t cid, uint32_t port, int backlog, int *socketA, int *socketB)
{
	// Listen.
	struct sockaddr_vm addr;
	int listen_socket;
	int result = vsock_listen(cid, port, &addr, backlog, &listen_socket);
	T_ASSERT_POSIX_SUCCESS(result, "vsock listen");

	const uint32_t connection_cid = vsock_get_local_cid(listen_socket);

	// Connect.
	int connect_socket;
	result = vsock_connect(connection_cid, addr.svm_port, &connect_socket);
	T_ASSERT_POSIX_SUCCESS(result, "vsock connect");

	// Accept.
	struct sockaddr_vm accepted_addr;
	socklen_t addrlen = sizeof(accepted_addr);
	int accepted_socket = accept(listen_socket, (struct sockaddr *)&accepted_addr, &addrlen);
	T_ASSERT_GT(accepted_socket, 0, "accepted socket");
	T_ASSERT_EQ_INT((int) sizeof(accepted_addr), addrlen, "correct address length");
	T_ASSERT_EQ_INT(connection_cid, accepted_addr.svm_cid, "same cid");
	T_ASSERT_NE_INT(VMADDR_CID_ANY, accepted_addr.svm_port, "some valid port");
	T_ASSERT_NE_INT(0, accepted_addr.svm_port, "some non-zero port");

	*socketA = connect_socket;
	*socketB = accepted_socket;
}

void
vsock_send(int socket, char *msg)
{
	T_ASSERT_NOTNULL(msg, "send message is not null");
	ssize_t sent_bytes = send(socket, msg, strlen(msg), 0);
	T_ASSERT_EQ_LONG(strlen(msg), (unsigned long)sent_bytes, "sent all bytes");
}

void
vsock_disable_sigpipe(int socket)
{
	int on = 1;
	int result = setsockopt(socket, SOL_SOCKET, SO_NOSIGPIPE, &on, sizeof(on));
	T_ASSERT_POSIX_SUCCESS(result, "vsock disable SIGPIPE");
}

bool
vsock_address_exists(struct xvsockpgen *buffer, struct sockaddr_vm addr)
{
	struct xvsockpgen *xvg = buffer;
	struct xvsockpgen *oxvg = buffer;

	bool found = false;
	for (xvg = (struct xvsockpgen *)((char *)xvg + xvg->xvg_len);
	    xvg->xvg_len > sizeof(struct xvsockpgen);
	    xvg = (struct xvsockpgen *)((char *)xvg + xvg->xvg_len)) {
		struct xvsockpcb *xpcb = (struct xvsockpcb *)xvg;

		/* Ignore PCBs which were freed during copyout. */
		if (xpcb->xvp_gencnt > oxvg->xvg_gen) {
			continue;
		}

		if (xpcb->xvp_local_cid == addr.svm_cid && xpcb->xvp_remote_cid == VMADDR_CID_ANY &&
		    xpcb->xvp_local_port == addr.svm_port && xpcb->xvp_remote_port == VMADDR_PORT_ANY) {
			found = true;
			break;
		}
	}

	T_ASSERT_NE(xvg, oxvg, "first and last xvsockpgen were returned");

	return found;
}

uint32_t
vsock_get_available_port(void)
{
	int socket;
	struct sockaddr_vm addr;
	int result = vsock_bind(VMADDR_CID_ANY, VMADDR_PORT_ANY, &addr, &socket);
	T_ASSERT_POSIX_SUCCESS(result, "vsock bind to any port");

	const struct sockaddr_vm bound_addr = vsock_getsockname(socket);
	const uint32_t port = bound_addr.svm_port;
	T_ASSERT_NE_INT(port, VMADDR_PORT_ANY, "port is specified");

	vsock_close(socket);

	return port;
}

int
vsock_bind_family(sa_family_t family)
{
	int socket = vsock_new_socket();
	const uint32_t port = vsock_get_available_port();

	struct sockaddr_vm addr = (struct sockaddr_vm) {
		.svm_family = family,
		.svm_cid = VMADDR_CID_ANY,
		.svm_port = port,
	};

	return bind(socket, (struct sockaddr *) &addr, sizeof(addr));
}