Home | History | Annotate | Line # | Download | only in unfdpass
      1 /*	$NetBSD: unfdpass.c,v 1.12 2021/08/08 20:54:48 nia Exp $	*/
      2 
      3 /*-
      4  * Copyright (c) 1998 The NetBSD Foundation, Inc.
      5  * All rights reserved.
      6  *
      7  * This code is derived from software contributed to The NetBSD Foundation
      8  * by Jason R. Thorpe of the Numerical Aerospace Simulation Facility,
      9  * NASA Ames Research Center.
     10  *
     11  * Redistribution and use in source and binary forms, with or without
     12  * modification, are permitted provided that the following conditions
     13  * are met:
     14  * 1. Redistributions of source code must retain the above copyright
     15  *    notice, this list of conditions and the following disclaimer.
     16  * 2. Redistributions in binary form must reproduce the above copyright
     17  *    notice, this list of conditions and the following disclaimer in the
     18  *    documentation and/or other materials provided with the distribution.
     19  *
     20  * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
     21  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
     22  * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
     23  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
     24  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
     25  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
     26  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
     27  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
     28  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
     29  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
     30  * POSSIBILITY OF SUCH DAMAGE.
     31  */
     32 
     33 /*
     34  * Test passing of file descriptors and credentials over Unix domain sockets.
     35  */
     36 
     37 #include <sys/param.h>
     38 #include <sys/socket.h>
     39 #include <sys/time.h>
     40 #include <sys/wait.h>
     41 #include <sys/un.h>
     42 #include <sys/uio.h>
     43 #include <sys/stat.h>
     44 
     45 #include <err.h>
     46 #include <errno.h>
     47 #include <fcntl.h>
     48 #include <signal.h>
     49 #include <stdio.h>
     50 #include <string.h>
     51 #include <stdlib.h>
     52 #include <unistd.h>
     53 
     54 #define	SOCK_NAME	"test-sock"
     55 
     56 int	main(int, char *[]);
     57 void	child(void);
     58 void	catch_sigchld(int);
     59 void	usage(char *progname);
     60 
     61 #define	FILE_SIZE	128
     62 #define	MSG_SIZE	-1
     63 #define	NFILES		24
     64 
     65 #define	FDCM_DATASIZE	(sizeof(int) * NFILES)
     66 #define	CRCM_DATASIZE	(SOCKCREDSIZE(NGROUPS))
     67 
     68 #define	MESSAGE_SIZE	(CMSG_SPACE(FDCM_DATASIZE) +			\
     69 			 CMSG_SPACE(CRCM_DATASIZE))
     70 
     71 int chroot_rcvr = 0;
     72 int pass_dir = 0;
     73 int pass_root_dir = 0;
     74 int exit_early = 0;
     75 int exit_later = 0;
     76 int pass_sock = 0;
     77 int make_pretzel = 0;
     78 
     79 /* ARGSUSED */
     80 int
     81 main(argc, argv)
     82 	int argc;
     83 	char *argv[];
     84 {
     85 #if MSG_SIZE >= 0
     86 	struct iovec iov;
     87 #endif
     88 	char *progname=argv[0];
     89 	struct msghdr msg;
     90 	int listensock, sock, fd, i;
     91 	char fname[16], buf[FILE_SIZE];
     92 	struct cmsghdr *cmp;
     93 	void *message;
     94 	int *files = NULL;
     95 	struct sockcred *sc = NULL;
     96 	struct sockaddr_un sun, csun;
     97 	socklen_t csunlen;
     98 	pid_t pid;
     99 	int ch;
    100 
    101 	message = malloc(CMSG_SPACE(MESSAGE_SIZE));
    102 	if (message == NULL)
    103 		err(1, "unable to malloc message buffer");
    104 	memset(message, 0, CMSG_SPACE(MESSAGE_SIZE));
    105 
    106 	while ((ch = getopt(argc, argv, "DESdepr")) != -1) {
    107 		switch(ch) {
    108 
    109 		case 'e':
    110 			exit_early++; /* test early GC */
    111 			break;
    112 
    113 		case 'E':
    114 			exit_later++; /* test later GC */
    115 			break;
    116 
    117 		case 'd':
    118 			pass_dir++;
    119 			break;
    120 
    121 		case 'D':
    122 			pass_dir++;
    123 			pass_root_dir++;
    124 			break;
    125 
    126 		case 'S':
    127 			pass_sock++;
    128 			break;
    129 
    130 		case 'r':
    131 			chroot_rcvr++;
    132 			break;
    133 
    134 		case 'p':
    135 			make_pretzel++;
    136 			break;
    137 
    138 		case '?':
    139 		default:
    140 			usage(progname);
    141 		}
    142 	}
    143 
    144 
    145 	/*
    146 	 * Create the test files.
    147 	 */
    148 	for (i = 0; i < NFILES; i++) {
    149 		(void) sprintf(fname, "file%d", i + 1);
    150 		if ((fd = open(fname, O_WRONLY|O_CREAT|O_TRUNC, 0666)) == -1)
    151 			err(1, "open %s", fname);
    152 		(void) sprintf(buf, "This is file %d.\n", i + 1);
    153 		if (write(fd, buf, strlen(buf)) != strlen(buf))
    154 			err(1, "write %s", fname);
    155 		(void) close(fd);
    156 	}
    157 
    158 	/*
    159 	 * Create the listen socket.
    160 	 */
    161 	if ((listensock = socket(PF_LOCAL, SOCK_STREAM, 0)) == -1)
    162 		err(1, "socket");
    163 
    164 	(void) unlink(SOCK_NAME);
    165 	(void) memset(&sun, 0, sizeof(sun));
    166 	sun.sun_family = AF_LOCAL;
    167 	(void) strcpy(sun.sun_path, SOCK_NAME);
    168 	sun.sun_len = SUN_LEN(&sun);
    169 
    170 	i = 1;
    171 	if (setsockopt(listensock, SOL_LOCAL, LOCAL_CREDS, &i, sizeof(i)) == -1)
    172 		err(1, "setsockopt");
    173 
    174 	if (bind(listensock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
    175 		err(1, "bind");
    176 
    177 	if (listen(listensock, 1) == -1)
    178 		err(1, "listen");
    179 
    180 	/*
    181 	 * Create the sender.
    182 	 */
    183 	(void) signal(SIGCHLD, catch_sigchld);
    184 	pid = fork();
    185 	switch (pid) {
    186 	case -1:
    187 		err(1, "fork");
    188 		/* NOTREACHED */
    189 
    190 	case 0:
    191 		child();
    192 		/* NOTREACHED */
    193 	}
    194 
    195 	if (exit_early)
    196 		exit(0);
    197 
    198 	if (chroot_rcvr &&
    199 	    ((chroot(".") < 0)))
    200 		err(1, "chroot");
    201 
    202 	/*
    203 	 * Wait for the sender to connect.
    204 	 */
    205 	csunlen = sizeof(csun);
    206 	if ((sock = accept(listensock, (struct sockaddr *)&csun,
    207 	    &csunlen)) == -1)
    208 		err(1, "accept");
    209 
    210 	/*
    211 	 * Give sender a chance to run.  We will get going again
    212 	 * once the SIGCHLD arrives.
    213 	 */
    214 	(void) sleep(10);
    215 
    216 	if (exit_later)
    217 		exit(0);
    218 
    219 	/*
    220 	 * Grab the descriptors and credentials passed to us.
    221 	 */
    222 
    223 	/* Expect 2 messages; descriptors and creds. */
    224 	do {
    225 		(void) memset(&msg, 0, sizeof(msg));
    226 		msg.msg_control = message;
    227 		msg.msg_controllen = MESSAGE_SIZE;
    228 #if MSG_SIZE >= 0
    229 		iov.iov_base = buf;
    230 		iov.iov_len = MSG_SIZE;
    231 		msg.msg_iov = &iov;
    232 		msg.msg_iovlen = 1;
    233 #endif
    234 
    235 		if (recvmsg(sock, &msg, 0) == -1)
    236 			err(1, "recvmsg");
    237 
    238 		(void) close(sock);
    239 		sock = -1;
    240 
    241 		if (msg.msg_controllen == 0)
    242 			errx(1, "no control messages received");
    243 
    244 		if (msg.msg_flags & MSG_CTRUNC)
    245 			errx(1, "lost control message data");
    246 
    247 		for (cmp = CMSG_FIRSTHDR(&msg); cmp != NULL;
    248 		     cmp = CMSG_NXTHDR(&msg, cmp)) {
    249 			if (cmp->cmsg_level != SOL_SOCKET)
    250 				errx(1, "bad control message level %d",
    251 				    cmp->cmsg_level);
    252 
    253 			switch (cmp->cmsg_type) {
    254 			case SCM_RIGHTS:
    255 				if (cmp->cmsg_len != CMSG_LEN(FDCM_DATASIZE))
    256 					errx(1, "bad fd control message "
    257 					    "length %d", cmp->cmsg_len);
    258 
    259 				files = (int *)CMSG_DATA(cmp);
    260 				break;
    261 
    262 			case SCM_CREDS:
    263 				if (cmp->cmsg_len < CMSG_LEN(SOCKCREDSIZE(1)))
    264 					errx(1, "bad cred control message "
    265 					    "length %d", cmp->cmsg_len);
    266 
    267 				sc = (struct sockcred *)CMSG_DATA(cmp);
    268 				break;
    269 
    270 			default:
    271 				errx(1, "unexpected control message");
    272 				/* NOTREACHED */
    273 			}
    274 		}
    275 
    276 		/*
    277 		 * Read the files and print their contents.
    278 		 */
    279 		if (files == NULL)
    280 			warnx("didn't get fd control message");
    281 		else {
    282 			for (i = 0; i < NFILES; i++) {
    283 				struct stat st;
    284 				(void) memset(buf, 0, sizeof(buf));
    285 				fstat(files[i], &st);
    286 				if (S_ISDIR(st.st_mode)) {
    287 					printf("file %d is a directory\n", i+1);
    288 				} else if (S_ISSOCK(st.st_mode)) {
    289 					printf("file %d is a socket\n", i+1);
    290 					sock = files[i];
    291 				} else {
    292 					int c;
    293 					c = read (files[i], buf, sizeof(buf));
    294 					if (c < 0)
    295 						err(1, "read file %d", i + 1);
    296 					else if (c == 0)
    297 						printf("[eof on %d]\n", i + 1);
    298 					else
    299 						printf("%s", buf);
    300 				}
    301 			}
    302 		}
    303 		/*
    304 		 * Double-check credentials.
    305 		 */
    306 		if (sc == NULL)
    307 			warnx("didn't get cred control message");
    308 		else {
    309 			if (sc->sc_uid == getuid() &&
    310 			    sc->sc_euid == geteuid() &&
    311 			    sc->sc_gid == getgid() &&
    312 			    sc->sc_egid == getegid())
    313 				printf("Credentials match.\n");
    314 			else
    315 				printf("Credentials do NOT match.\n");
    316 		}
    317 	} while (sock != -1);
    318 
    319 	/*
    320 	 * All done!
    321 	 */
    322 	exit(0);
    323 }
    324 
    325 void
    326 usage(progname)
    327 	char *progname;
    328 {
    329 	fprintf(stderr, "usage: %s [-derDES]\n", progname);
    330 	exit(1);
    331 }
    332 
    333 void
    334 catch_sigchld(sig)
    335 	int sig;
    336 {
    337 	int status;
    338 
    339 	(void) wait(&status);
    340 }
    341 
    342 void
    343 child()
    344 {
    345 #if MSG_SIZE >= 0
    346 	struct iovec iov;
    347 #endif
    348 	struct msghdr msg;
    349 	char fname[16];
    350 	struct cmsghdr *cmp;
    351 	void *fdcm;
    352 	int i, fd, sock, nfd, *files;
    353 	struct sockaddr_un sun;
    354 	int spair[2];
    355 
    356 	fdcm = malloc(CMSG_SPACE(FDCM_DATASIZE));
    357 	if (fdcm == NULL)
    358 		err(1, "unable to malloc fd control message");
    359 	memset(fdcm, 0, CMSG_SPACE(FDCM_DATASIZE));
    360 
    361 	cmp = fdcm;
    362 	files = (int *)CMSG_DATA(fdcm);
    363 
    364 	/*
    365 	 * Create socket and connect to the receiver.
    366 	 */
    367 	if ((sock = socket(PF_LOCAL, SOCK_STREAM, 0)) == -1)
    368 		errx(1, "child socket");
    369 
    370 	(void) memset(&sun, 0, sizeof(sun));
    371 	sun.sun_family = AF_LOCAL;
    372 	(void) strcpy(sun.sun_path, SOCK_NAME);
    373 	sun.sun_len = SUN_LEN(&sun);
    374 
    375 	if (connect(sock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
    376 		err(1, "child connect");
    377 
    378 	nfd = NFILES;
    379 	i = 0;
    380 
    381 	if (pass_sock) {
    382 		files[i++] = sock;
    383 	}
    384 
    385 	if (pass_dir)
    386 		nfd--;
    387 
    388 	/*
    389 	 * Open the files again, and pass them to the child
    390 	 * over the socket.
    391 	 */
    392 
    393 	for (; i < nfd; i++) {
    394 		(void) sprintf(fname, "file%d", i + 1);
    395 		if ((fd = open(fname, O_RDONLY, 0666)) == -1)
    396 			err(1, "child open %s", fname);
    397 		files[i] = fd;
    398 	}
    399 
    400 	if (pass_dir) {
    401 		char *dirname = pass_root_dir ? "/" : ".";
    402 
    403 
    404 		if ((fd = open(dirname, O_RDONLY, 0)) == -1) {
    405 			err(1, "child open directory %s", dirname);
    406 		}
    407 		files[i] = fd;
    408 	}
    409 
    410 	(void) memset(&msg, 0, sizeof(msg));
    411 	msg.msg_control = fdcm;
    412 	msg.msg_controllen = CMSG_LEN(FDCM_DATASIZE);
    413 #if MSG_SIZE >= 0
    414 	iov.iov_base = buf;
    415 	iov.iov_len = MSG_SIZE;
    416 	msg.msg_iov = &iov;
    417 	msg.msg_iovlen = 1;
    418 #endif
    419 
    420 	cmp = CMSG_FIRSTHDR(&msg);
    421 	cmp->cmsg_len = CMSG_LEN(FDCM_DATASIZE);
    422 	cmp->cmsg_level = SOL_SOCKET;
    423 	cmp->cmsg_type = SCM_RIGHTS;
    424 
    425 	while (make_pretzel > 0) {
    426 		if (socketpair(PF_LOCAL, SOCK_STREAM, 0, spair) < 0)
    427 			err(1, "socketpair");
    428 
    429 		printf("send pretzel\n");
    430 		if (sendmsg(spair[0], &msg, 0) < 0)
    431 			err(1, "child prezel sendmsg");
    432 
    433 		close(files[0]);
    434 		close(files[1]);
    435 		files[0] = spair[0];
    436 		files[1] = spair[1];
    437 		make_pretzel--;
    438 	}
    439 
    440 	if (sendmsg(sock, &msg, 0) == -1)
    441 		err(1, "child sendmsg");
    442 
    443 	/*
    444 	 * All done!
    445 	 */
    446 	exit(0);
    447 }
    448