Home | History | Annotate | Line # | Download | only in unfdpass
unfdpass.c revision 1.9
      1 /*	$NetBSD: unfdpass.c,v 1.9 2008/02/29 16:28:12 ad Exp $	*/
      2 
      3 /*-
      4  * Copyright (c) 1998 The NetBSD Foundation, Inc.
      5  * All rights reserved.
      6  *
      7  * This code is derived from software contributed to The NetBSD Foundation
      8  * by Jason R. Thorpe of the Numerical Aerospace Simulation Facility,
      9  * NASA Ames Research Center.
     10  *
     11  * Redistribution and use in source and binary forms, with or without
     12  * modification, are permitted provided that the following conditions
     13  * are met:
     14  * 1. Redistributions of source code must retain the above copyright
     15  *    notice, this list of conditions and the following disclaimer.
     16  * 2. Redistributions in binary form must reproduce the above copyright
     17  *    notice, this list of conditions and the following disclaimer in the
     18  *    documentation and/or other materials provided with the distribution.
     19  * 3. All advertising materials mentioning features or use of this software
     20  *    must display the following acknowledgement:
     21  *	This product includes software developed by the NetBSD
     22  *	Foundation, Inc. and its contributors.
     23  * 4. Neither the name of The NetBSD Foundation nor the names of its
     24  *    contributors may be used to endorse or promote products derived
     25  *    from this software without specific prior written permission.
     26  *
     27  * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
     28  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
     29  * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
     30  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
     31  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
     32  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
     33  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
     34  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
     35  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
     36  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
     37  * POSSIBILITY OF SUCH DAMAGE.
     38  */
     39 
     40 /*
     41  * Test passing of file descriptors and credentials over Unix domain sockets.
     42  */
     43 
     44 #include <sys/param.h>
     45 #include <sys/socket.h>
     46 #include <sys/time.h>
     47 #include <sys/wait.h>
     48 #include <sys/un.h>
     49 #include <sys/uio.h>
     50 
     51 #include <err.h>
     52 #include <errno.h>
     53 #include <fcntl.h>
     54 #include <signal.h>
     55 #include <stdio.h>
     56 #include <string.h>
     57 #include <stdlib.h>
     58 #include <unistd.h>
     59 
     60 #define	SOCK_NAME	"test-sock"
     61 
     62 int	main(int, char *[]);
     63 void	child(void);
     64 void	catch_sigchld(int);
     65 void	usage(char *progname);
     66 
     67 #define	FILE_SIZE	128
     68 #define	MSG_SIZE	-1
     69 #define	NFILES		24
     70 
     71 #define	FDCM_DATASIZE	(sizeof(int) * NFILES)
     72 #define	CRCM_DATASIZE	(SOCKCREDSIZE(NGROUPS))
     73 
     74 #define	MESSAGE_SIZE	(CMSG_SPACE(FDCM_DATASIZE) +			\
     75 			 CMSG_SPACE(CRCM_DATASIZE))
     76 
     77 int chroot_rcvr = 0;
     78 int pass_dir = 0;
     79 int pass_root_dir = 0;
     80 int exit_early = 0;
     81 int exit_later = 0;
     82 int pass_sock = 0;
     83 int make_pretzel = 0;
     84 
     85 /* ARGSUSED */
     86 int
     87 main(argc, argv)
     88 	int argc;
     89 	char *argv[];
     90 {
     91 #if MSG_SIZE >= 0
     92 	struct iovec iov;
     93 #endif
     94 	char *progname=argv[0];
     95 	struct msghdr msg;
     96 	int listensock, sock, fd, i;
     97 	char fname[16], buf[FILE_SIZE];
     98 	struct cmsghdr *cmp;
     99 	void *message;
    100 	int *files = NULL;
    101 	struct sockcred *sc = NULL;
    102 	struct sockaddr_un sun, csun;
    103 	socklen_t csunlen;
    104 	pid_t pid;
    105 	int ch;
    106 
    107 	message = malloc(CMSG_SPACE(MESSAGE_SIZE));
    108 	if (message == NULL)
    109 		err(1, "unable to malloc message buffer");
    110 	memset(message, 0, CMSG_SPACE(MESSAGE_SIZE));
    111 
    112 	while ((ch = getopt(argc, argv, "DESdepr")) != -1) {
    113 		switch(ch) {
    114 
    115 		case 'e':
    116 			exit_early++; /* test early GC */
    117 			break;
    118 
    119 		case 'E':
    120 			exit_later++; /* test later GC */
    121 			break;
    122 
    123 		case 'd':
    124 			pass_dir++;
    125 			break;
    126 
    127 		case 'D':
    128 			pass_dir++;
    129 			pass_root_dir++;
    130 			break;
    131 
    132 		case 'S':
    133 			pass_sock++;
    134 			break;
    135 
    136 		case 'r':
    137 			chroot_rcvr++;
    138 			break;
    139 
    140 		case 'p':
    141 			make_pretzel++;
    142 			break;
    143 
    144 		case '?':
    145 		default:
    146 			usage(progname);
    147 		}
    148 	}
    149 
    150 
    151 	/*
    152 	 * Create the test files.
    153 	 */
    154 	for (i = 0; i < NFILES; i++) {
    155 		(void) sprintf(fname, "file%d", i + 1);
    156 		if ((fd = open(fname, O_WRONLY|O_CREAT|O_TRUNC, 0666)) == -1)
    157 			err(1, "open %s", fname);
    158 		(void) sprintf(buf, "This is file %d.\n", i + 1);
    159 		if (write(fd, buf, strlen(buf)) != strlen(buf))
    160 			err(1, "write %s", fname);
    161 		(void) close(fd);
    162 	}
    163 
    164 	/*
    165 	 * Create the listen socket.
    166 	 */
    167 	if ((listensock = socket(PF_LOCAL, SOCK_STREAM, 0)) == -1)
    168 		err(1, "socket");
    169 
    170 	(void) unlink(SOCK_NAME);
    171 	(void) memset(&sun, 0, sizeof(sun));
    172 	sun.sun_family = AF_LOCAL;
    173 	(void) strcpy(sun.sun_path, SOCK_NAME);
    174 	sun.sun_len = SUN_LEN(&sun);
    175 
    176 	i = 1;
    177 	if (setsockopt(listensock, 0, LOCAL_CREDS, &i, sizeof(i)) == -1)
    178 		err(1, "setsockopt");
    179 
    180 	if (bind(listensock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
    181 		err(1, "bind");
    182 
    183 	if (listen(listensock, 1) == -1)
    184 		err(1, "listen");
    185 
    186 	/*
    187 	 * Create the sender.
    188 	 */
    189 	(void) signal(SIGCHLD, catch_sigchld);
    190 	pid = fork();
    191 	switch (pid) {
    192 	case -1:
    193 		err(1, "fork");
    194 		/* NOTREACHED */
    195 
    196 	case 0:
    197 		child();
    198 		/* NOTREACHED */
    199 	}
    200 
    201 	if (exit_early)
    202 		exit(0);
    203 
    204 	if (chroot_rcvr &&
    205 	    ((chroot(".") < 0)))
    206 		err(1, "chroot");
    207 
    208 	/*
    209 	 * Wait for the sender to connect.
    210 	 */
    211 	csunlen = sizeof(csun);
    212 	if ((sock = accept(listensock, (struct sockaddr *)&csun,
    213 	    &csunlen)) == -1)
    214 		err(1, "accept");
    215 
    216 	/*
    217 	 * Give sender a chance to run.  We will get going again
    218 	 * once the SIGCHLD arrives.
    219 	 */
    220 	(void) sleep(10);
    221 
    222 	if (exit_later)
    223 		exit(0);
    224 
    225 	/*
    226 	 * Grab the descriptors and credentials passed to us.
    227 	 */
    228 
    229 	/* Expect 2 messages; descriptors and creds. */
    230 	do {
    231 		(void) memset(&msg, 0, sizeof(msg));
    232 		msg.msg_control = message;
    233 		msg.msg_controllen = MESSAGE_SIZE;
    234 #if MSG_SIZE >= 0
    235 		iov.iov_base = buf;
    236 		iov.iov_len = MSG_SIZE;
    237 		msg.msg_iov = &iov;
    238 		msg.msg_iovlen = 1;
    239 #endif
    240 
    241 		if (recvmsg(sock, &msg, 0) == -1)
    242 			err(1, "recvmsg");
    243 
    244 		(void) close(sock);
    245 		sock = -1;
    246 
    247 		if (msg.msg_controllen == 0)
    248 			errx(1, "no control messages received");
    249 
    250 		if (msg.msg_flags & MSG_CTRUNC)
    251 			errx(1, "lost control message data");
    252 
    253 		for (cmp = CMSG_FIRSTHDR(&msg); cmp != NULL;
    254 		     cmp = CMSG_NXTHDR(&msg, cmp)) {
    255 			if (cmp->cmsg_level != SOL_SOCKET)
    256 				errx(1, "bad control message level %d",
    257 				    cmp->cmsg_level);
    258 
    259 			switch (cmp->cmsg_type) {
    260 			case SCM_RIGHTS:
    261 				if (cmp->cmsg_len != CMSG_LEN(FDCM_DATASIZE))
    262 					errx(1, "bad fd control message "
    263 					    "length %d", cmp->cmsg_len);
    264 
    265 				files = (int *)CMSG_DATA(cmp);
    266 				break;
    267 
    268 			case SCM_CREDS:
    269 				if (cmp->cmsg_len < CMSG_LEN(SOCKCREDSIZE(1)))
    270 					errx(1, "bad cred control message "
    271 					    "length %d", cmp->cmsg_len);
    272 
    273 				sc = (struct sockcred *)CMSG_DATA(cmp);
    274 				break;
    275 
    276 			default:
    277 				errx(1, "unexpected control message");
    278 				/* NOTREACHED */
    279 			}
    280 		}
    281 
    282 		/*
    283 		 * Read the files and print their contents.
    284 		 */
    285 		if (files == NULL)
    286 			warnx("didn't get fd control message");
    287 		else {
    288 			for (i = 0; i < NFILES; i++) {
    289 				struct stat st;
    290 				(void) memset(buf, 0, sizeof(buf));
    291 				fstat(files[i], &st);
    292 				if (S_ISDIR(st.st_mode)) {
    293 					printf("file %d is a directory\n", i+1);
    294 				} else if (S_ISSOCK(st.st_mode)) {
    295 					printf("file %d is a socket\n", i+1);
    296 					sock = files[i];
    297 				} else {
    298 					int c;
    299 					c = read (files[i], buf, sizeof(buf));
    300 					if (c < 0)
    301 						err(1, "read file %d", i + 1);
    302 					else if (c == 0)
    303 						printf("[eof on %d]\n", i + 1);
    304 					else
    305 						printf("%s", buf);
    306 				}
    307 			}
    308 		}
    309 		/*
    310 		 * Double-check credentials.
    311 		 */
    312 		if (sc == NULL)
    313 			warnx("didn't get cred control message");
    314 		else {
    315 			if (sc->sc_uid == getuid() &&
    316 			    sc->sc_euid == geteuid() &&
    317 			    sc->sc_gid == getgid() &&
    318 			    sc->sc_egid == getegid())
    319 				printf("Credentials match.\n");
    320 			else
    321 				printf("Credentials do NOT match.\n");
    322 		}
    323 	} while (sock != -1);
    324 
    325 	/*
    326 	 * All done!
    327 	 */
    328 	exit(0);
    329 }
    330 
    331 void
    332 usage(progname)
    333 	char *progname;
    334 {
    335 	fprintf(stderr, "usage: %s [-derDES]\n", progname);
    336 	exit(1);
    337 }
    338 
    339 void
    340 catch_sigchld(sig)
    341 	int sig;
    342 {
    343 	int status;
    344 
    345 	(void) wait(&status);
    346 }
    347 
    348 void
    349 child()
    350 {
    351 #if MSG_SIZE >= 0
    352 	struct iovec iov;
    353 #endif
    354 	struct msghdr msg;
    355 	char fname[16];
    356 	struct cmsghdr *cmp;
    357 	void *fdcm;
    358 	int i, fd, sock, nfd, *files;
    359 	struct sockaddr_un sun;
    360 	int spair[2];
    361 
    362 	fdcm = malloc(CMSG_SPACE(FDCM_DATASIZE));
    363 	if (fdcm == NULL)
    364 		err(1, "unable to malloc fd control message");
    365 	memset(fdcm, 0, CMSG_SPACE(FDCM_DATASIZE));
    366 
    367 	cmp = fdcm;
    368 	files = (int *)CMSG_DATA(fdcm);
    369 
    370 	/*
    371 	 * Create socket and connect to the receiver.
    372 	 */
    373 	if ((sock = socket(PF_LOCAL, SOCK_STREAM, 0)) == -1)
    374 		errx(1, "child socket");
    375 
    376 	(void) memset(&sun, 0, sizeof(sun));
    377 	sun.sun_family = AF_LOCAL;
    378 	(void) strcpy(sun.sun_path, SOCK_NAME);
    379 	sun.sun_len = SUN_LEN(&sun);
    380 
    381 	if (connect(sock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
    382 		err(1, "child connect");
    383 
    384 	nfd = NFILES;
    385 	i = 0;
    386 
    387 	if (pass_sock) {
    388 		files[i++] = sock;
    389 	}
    390 
    391 	if (pass_dir)
    392 		nfd--;
    393 
    394 	/*
    395 	 * Open the files again, and pass them to the child
    396 	 * over the socket.
    397 	 */
    398 
    399 	for (; i < nfd; i++) {
    400 		(void) sprintf(fname, "file%d", i + 1);
    401 		if ((fd = open(fname, O_RDONLY, 0666)) == -1)
    402 			err(1, "child open %s", fname);
    403 		files[i] = fd;
    404 	}
    405 
    406 	if (pass_dir) {
    407 		char *dirname = pass_root_dir ? "/" : ".";
    408 
    409 
    410 		if ((fd = open(dirname, O_RDONLY, 0)) == -1) {
    411 			err(1, "child open directory %s", dirname);
    412 		}
    413 		files[i] = fd;
    414 	}
    415 
    416 	(void) memset(&msg, 0, sizeof(msg));
    417 	msg.msg_control = fdcm;
    418 	msg.msg_controllen = CMSG_LEN(FDCM_DATASIZE);
    419 #if MSG_SIZE >= 0
    420 	iov.iov_base = buf;
    421 	iov.iov_len = MSG_SIZE;
    422 	msg.msg_iov = &iov;
    423 	msg.msg_iovlen = 1;
    424 #endif
    425 
    426 	cmp = CMSG_FIRSTHDR(&msg);
    427 	cmp->cmsg_len = CMSG_LEN(FDCM_DATASIZE);
    428 	cmp->cmsg_level = SOL_SOCKET;
    429 	cmp->cmsg_type = SCM_RIGHTS;
    430 
    431 	while (make_pretzel > 0) {
    432 		if (socketpair(PF_LOCAL, SOCK_STREAM, 0, spair) < 0)
    433 			err(1, "socketpair");
    434 
    435 		printf("send pretzel\n");
    436 		if (sendmsg(spair[0], &msg, 0) < 0)
    437 			err(1, "child prezel sendmsg");
    438 
    439 		close(files[0]);
    440 		close(files[1]);
    441 		files[0] = spair[0];
    442 		files[1] = spair[1];
    443 		make_pretzel--;
    444 	}
    445 
    446 	if (sendmsg(sock, &msg, 0) == -1)
    447 		err(1, "child sendmsg");
    448 
    449 	/*
    450 	 * All done!
    451 	 */
    452 	exit(0);
    453 }
    454