1 1.1 mrg `/* Implementation of the MATMUL intrinsic 2 1.1.1.4 mrg Copyright (C) 2002-2024 Free Software Foundation, Inc. 3 1.1 mrg Contributed by Paul Brook <paul (a] nowt.org> 4 1.1 mrg 5 1.1 mrg This file is part of the GNU Fortran runtime library (libgfortran). 6 1.1 mrg 7 1.1 mrg Libgfortran is free software; you can redistribute it and/or 8 1.1 mrg modify it under the terms of the GNU General Public 9 1.1 mrg License as published by the Free Software Foundation; either 10 1.1 mrg version 3 of the License, or (at your option) any later version. 11 1.1 mrg 12 1.1 mrg Libgfortran is distributed in the hope that it will be useful, 13 1.1 mrg but WITHOUT ANY WARRANTY; without even the implied warranty of 14 1.1 mrg MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 1.1 mrg GNU General Public License for more details. 16 1.1 mrg 17 1.1 mrg Under Section 7 of GPL version 3, you are granted additional 18 1.1 mrg permissions described in the GCC Runtime Library Exception, version 19 1.1 mrg 3.1, as published by the Free Software Foundation. 20 1.1 mrg 21 1.1 mrg You should have received a copy of the GNU General Public License and 22 1.1 mrg a copy of the GCC Runtime Library Exception along with this program; 23 1.1 mrg see the files COPYING3 and COPYING.RUNTIME respectively. If not, see 24 1.1 mrg <http://www.gnu.org/licenses/>. */ 25 1.1 mrg 26 1.1 mrg #include "libgfortran.h" 27 1.1 mrg #include <assert.h>' 28 1.1 mrg 29 1.1 mrg include(iparm.m4)dnl 30 1.1 mrg 31 1.1 mrg `#if defined (HAVE_'rtype_name`) 32 1.1 mrg 33 1.1 mrg /* Dimensions: retarray(x,y) a(x, count) b(count,y). 34 1.1 mrg Either a or b can be rank 1. In this case x or y is 1. */ 35 1.1 mrg 36 1.1 mrg extern void matmul_'rtype_code` ('rtype` * const restrict, 37 1.1 mrg gfc_array_l1 * const restrict, gfc_array_l1 * const restrict); 38 1.1 mrg export_proto(matmul_'rtype_code`); 39 1.1 mrg 40 1.1 mrg void 41 1.1 mrg matmul_'rtype_code` ('rtype` * const restrict retarray, 42 1.1 mrg gfc_array_l1 * const restrict a, gfc_array_l1 * const restrict b) 43 1.1 mrg { 44 1.1 mrg const GFC_LOGICAL_1 * restrict abase; 45 1.1 mrg const GFC_LOGICAL_1 * restrict bbase; 46 1.1 mrg 'rtype_name` * restrict dest; 47 1.1 mrg index_type rxstride; 48 1.1 mrg index_type rystride; 49 1.1 mrg index_type xcount; 50 1.1 mrg index_type ycount; 51 1.1 mrg index_type xstride; 52 1.1 mrg index_type ystride; 53 1.1 mrg index_type x; 54 1.1 mrg index_type y; 55 1.1 mrg int a_kind; 56 1.1 mrg int b_kind; 57 1.1 mrg 58 1.1 mrg const GFC_LOGICAL_1 * restrict pa; 59 1.1 mrg const GFC_LOGICAL_1 * restrict pb; 60 1.1 mrg index_type astride; 61 1.1 mrg index_type bstride; 62 1.1 mrg index_type count; 63 1.1 mrg index_type n; 64 1.1 mrg 65 1.1 mrg assert (GFC_DESCRIPTOR_RANK (a) == 2 66 1.1 mrg || GFC_DESCRIPTOR_RANK (b) == 2); 67 1.1 mrg 68 1.1 mrg if (retarray->base_addr == NULL) 69 1.1 mrg { 70 1.1 mrg if (GFC_DESCRIPTOR_RANK (a) == 1) 71 1.1 mrg { 72 1.1 mrg GFC_DIMENSION_SET(retarray->dim[0], 0, 73 1.1 mrg GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1); 74 1.1 mrg } 75 1.1 mrg else if (GFC_DESCRIPTOR_RANK (b) == 1) 76 1.1 mrg { 77 1.1 mrg GFC_DIMENSION_SET(retarray->dim[0], 0, 78 1.1 mrg GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1); 79 1.1 mrg } 80 1.1 mrg else 81 1.1 mrg { 82 1.1 mrg GFC_DIMENSION_SET(retarray->dim[0], 0, 83 1.1 mrg GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1); 84 1.1 mrg 85 1.1 mrg GFC_DIMENSION_SET(retarray->dim[1], 0, 86 1.1 mrg GFC_DESCRIPTOR_EXTENT(b,1) - 1, 87 1.1 mrg GFC_DESCRIPTOR_EXTENT(retarray,0)); 88 1.1 mrg } 89 1.1 mrg 90 1.1 mrg retarray->base_addr 91 1.1 mrg = xmallocarray (size0 ((array_t *) retarray), sizeof ('rtype_name`)); 92 1.1 mrg retarray->offset = 0; 93 1.1 mrg } 94 1.1 mrg else if (unlikely (compile_options.bounds_check)) 95 1.1 mrg { 96 1.1 mrg index_type ret_extent, arg_extent; 97 1.1 mrg 98 1.1 mrg if (GFC_DESCRIPTOR_RANK (a) == 1) 99 1.1 mrg { 100 1.1 mrg arg_extent = GFC_DESCRIPTOR_EXTENT(b,1); 101 1.1 mrg ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0); 102 1.1 mrg if (arg_extent != ret_extent) 103 1.1 mrg runtime_error ("Incorrect extent in return array in" 104 1.1 mrg " MATMUL intrinsic: is %ld, should be %ld", 105 1.1 mrg (long int) ret_extent, (long int) arg_extent); 106 1.1 mrg } 107 1.1 mrg else if (GFC_DESCRIPTOR_RANK (b) == 1) 108 1.1 mrg { 109 1.1 mrg arg_extent = GFC_DESCRIPTOR_EXTENT(a,0); 110 1.1 mrg ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0); 111 1.1 mrg if (arg_extent != ret_extent) 112 1.1 mrg runtime_error ("Incorrect extent in return array in" 113 1.1 mrg " MATMUL intrinsic: is %ld, should be %ld", 114 1.1 mrg (long int) ret_extent, (long int) arg_extent); 115 1.1 mrg } 116 1.1 mrg else 117 1.1 mrg { 118 1.1 mrg arg_extent = GFC_DESCRIPTOR_EXTENT(a,0); 119 1.1 mrg ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0); 120 1.1 mrg if (arg_extent != ret_extent) 121 1.1 mrg runtime_error ("Incorrect extent in return array in" 122 1.1 mrg " MATMUL intrinsic for dimension 1:" 123 1.1 mrg " is %ld, should be %ld", 124 1.1 mrg (long int) ret_extent, (long int) arg_extent); 125 1.1 mrg 126 1.1 mrg arg_extent = GFC_DESCRIPTOR_EXTENT(b,1); 127 1.1 mrg ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1); 128 1.1 mrg if (arg_extent != ret_extent) 129 1.1 mrg runtime_error ("Incorrect extent in return array in" 130 1.1 mrg " MATMUL intrinsic for dimension 2:" 131 1.1 mrg " is %ld, should be %ld", 132 1.1 mrg (long int) ret_extent, (long int) arg_extent); 133 1.1 mrg } 134 1.1 mrg } 135 1.1 mrg 136 1.1 mrg abase = a->base_addr; 137 1.1 mrg a_kind = GFC_DESCRIPTOR_SIZE (a); 138 1.1 mrg 139 1.1 mrg if (a_kind == 1 || a_kind == 2 || a_kind == 4 || a_kind == 8 140 1.1 mrg #ifdef HAVE_GFC_LOGICAL_16 141 1.1 mrg || a_kind == 16 142 1.1 mrg #endif 143 1.1 mrg ) 144 1.1 mrg abase = GFOR_POINTER_TO_L1 (abase, a_kind); 145 1.1 mrg else 146 1.1 mrg internal_error (NULL, "Funny sized logical array"); 147 1.1 mrg 148 1.1 mrg bbase = b->base_addr; 149 1.1 mrg b_kind = GFC_DESCRIPTOR_SIZE (b); 150 1.1 mrg 151 1.1 mrg if (b_kind == 1 || b_kind == 2 || b_kind == 4 || b_kind == 8 152 1.1 mrg #ifdef HAVE_GFC_LOGICAL_16 153 1.1 mrg || b_kind == 16 154 1.1 mrg #endif 155 1.1 mrg ) 156 1.1 mrg bbase = GFOR_POINTER_TO_L1 (bbase, b_kind); 157 1.1 mrg else 158 1.1 mrg internal_error (NULL, "Funny sized logical array"); 159 1.1 mrg 160 1.1 mrg dest = retarray->base_addr; 161 1.1 mrg ' 162 1.1 mrg sinclude(`matmul_asm_'rtype_code`.m4')dnl 163 1.1 mrg ` 164 1.1 mrg if (GFC_DESCRIPTOR_RANK (retarray) == 1) 165 1.1 mrg { 166 1.1 mrg rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0); 167 1.1 mrg rystride = rxstride; 168 1.1 mrg } 169 1.1 mrg else 170 1.1 mrg { 171 1.1 mrg rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0); 172 1.1 mrg rystride = GFC_DESCRIPTOR_STRIDE(retarray,1); 173 1.1 mrg } 174 1.1 mrg 175 1.1 mrg /* If we have rank 1 parameters, zero the absent stride, and set the size to 176 1.1 mrg one. */ 177 1.1 mrg if (GFC_DESCRIPTOR_RANK (a) == 1) 178 1.1 mrg { 179 1.1 mrg astride = GFC_DESCRIPTOR_STRIDE_BYTES(a,0); 180 1.1 mrg count = GFC_DESCRIPTOR_EXTENT(a,0); 181 1.1 mrg xstride = 0; 182 1.1 mrg rxstride = 0; 183 1.1 mrg xcount = 1; 184 1.1 mrg } 185 1.1 mrg else 186 1.1 mrg { 187 1.1 mrg astride = GFC_DESCRIPTOR_STRIDE_BYTES(a,1); 188 1.1 mrg count = GFC_DESCRIPTOR_EXTENT(a,1); 189 1.1 mrg xstride = GFC_DESCRIPTOR_STRIDE_BYTES(a,0); 190 1.1 mrg xcount = GFC_DESCRIPTOR_EXTENT(a,0); 191 1.1 mrg } 192 1.1 mrg if (GFC_DESCRIPTOR_RANK (b) == 1) 193 1.1 mrg { 194 1.1 mrg bstride = GFC_DESCRIPTOR_STRIDE_BYTES(b,0); 195 1.1 mrg assert(count == GFC_DESCRIPTOR_EXTENT(b,0)); 196 1.1 mrg ystride = 0; 197 1.1 mrg rystride = 0; 198 1.1 mrg ycount = 1; 199 1.1 mrg } 200 1.1 mrg else 201 1.1 mrg { 202 1.1 mrg bstride = GFC_DESCRIPTOR_STRIDE_BYTES(b,0); 203 1.1 mrg assert(count == GFC_DESCRIPTOR_EXTENT(b,0)); 204 1.1 mrg ystride = GFC_DESCRIPTOR_STRIDE_BYTES(b,1); 205 1.1 mrg ycount = GFC_DESCRIPTOR_EXTENT(b,1); 206 1.1 mrg } 207 1.1 mrg 208 1.1 mrg for (y = 0; y < ycount; y++) 209 1.1 mrg { 210 1.1 mrg for (x = 0; x < xcount; x++) 211 1.1 mrg { 212 1.1 mrg /* Do the summation for this element. For real and integer types 213 1.1 mrg this is the same as DOT_PRODUCT. For complex types we use do 214 1.1 mrg a*b, not conjg(a)*b. */ 215 1.1 mrg pa = abase; 216 1.1 mrg pb = bbase; 217 1.1 mrg *dest = 0; 218 1.1 mrg 219 1.1 mrg for (n = 0; n < count; n++) 220 1.1 mrg { 221 1.1 mrg if (*pa && *pb) 222 1.1 mrg { 223 1.1 mrg *dest = 1; 224 1.1 mrg break; 225 1.1 mrg } 226 1.1 mrg pa += astride; 227 1.1 mrg pb += bstride; 228 1.1 mrg } 229 1.1 mrg 230 1.1 mrg dest += rxstride; 231 1.1 mrg abase += xstride; 232 1.1 mrg } 233 1.1 mrg abase -= xstride * xcount; 234 1.1 mrg bbase += ystride; 235 1.1 mrg dest += rystride - (rxstride * xcount); 236 1.1 mrg } 237 1.1 mrg } 238 1.1 mrg 239 1.1 mrg #endif 240 1.1 mrg ' 241