matmull.m4 revision 1.1.1.4 1 1.1 mrg `/* Implementation of the MATMUL intrinsic
2 1.1.1.4 mrg Copyright (C) 2002-2024 Free Software Foundation, Inc.
3 1.1 mrg Contributed by Paul Brook <paul (a] nowt.org>
4 1.1 mrg
5 1.1 mrg This file is part of the GNU Fortran runtime library (libgfortran).
6 1.1 mrg
7 1.1 mrg Libgfortran is free software; you can redistribute it and/or
8 1.1 mrg modify it under the terms of the GNU General Public
9 1.1 mrg License as published by the Free Software Foundation; either
10 1.1 mrg version 3 of the License, or (at your option) any later version.
11 1.1 mrg
12 1.1 mrg Libgfortran is distributed in the hope that it will be useful,
13 1.1 mrg but WITHOUT ANY WARRANTY; without even the implied warranty of
14 1.1 mrg MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 1.1 mrg GNU General Public License for more details.
16 1.1 mrg
17 1.1 mrg Under Section 7 of GPL version 3, you are granted additional
18 1.1 mrg permissions described in the GCC Runtime Library Exception, version
19 1.1 mrg 3.1, as published by the Free Software Foundation.
20 1.1 mrg
21 1.1 mrg You should have received a copy of the GNU General Public License and
22 1.1 mrg a copy of the GCC Runtime Library Exception along with this program;
23 1.1 mrg see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 1.1 mrg <http://www.gnu.org/licenses/>. */
25 1.1 mrg
26 1.1 mrg #include "libgfortran.h"
27 1.1 mrg #include <assert.h>'
28 1.1 mrg
29 1.1 mrg include(iparm.m4)dnl
30 1.1 mrg
31 1.1 mrg `#if defined (HAVE_'rtype_name`)
32 1.1 mrg
33 1.1 mrg /* Dimensions: retarray(x,y) a(x, count) b(count,y).
34 1.1 mrg Either a or b can be rank 1. In this case x or y is 1. */
35 1.1 mrg
36 1.1 mrg extern void matmul_'rtype_code` ('rtype` * const restrict,
37 1.1 mrg gfc_array_l1 * const restrict, gfc_array_l1 * const restrict);
38 1.1 mrg export_proto(matmul_'rtype_code`);
39 1.1 mrg
40 1.1 mrg void
41 1.1 mrg matmul_'rtype_code` ('rtype` * const restrict retarray,
42 1.1 mrg gfc_array_l1 * const restrict a, gfc_array_l1 * const restrict b)
43 1.1 mrg {
44 1.1 mrg const GFC_LOGICAL_1 * restrict abase;
45 1.1 mrg const GFC_LOGICAL_1 * restrict bbase;
46 1.1 mrg 'rtype_name` * restrict dest;
47 1.1 mrg index_type rxstride;
48 1.1 mrg index_type rystride;
49 1.1 mrg index_type xcount;
50 1.1 mrg index_type ycount;
51 1.1 mrg index_type xstride;
52 1.1 mrg index_type ystride;
53 1.1 mrg index_type x;
54 1.1 mrg index_type y;
55 1.1 mrg int a_kind;
56 1.1 mrg int b_kind;
57 1.1 mrg
58 1.1 mrg const GFC_LOGICAL_1 * restrict pa;
59 1.1 mrg const GFC_LOGICAL_1 * restrict pb;
60 1.1 mrg index_type astride;
61 1.1 mrg index_type bstride;
62 1.1 mrg index_type count;
63 1.1 mrg index_type n;
64 1.1 mrg
65 1.1 mrg assert (GFC_DESCRIPTOR_RANK (a) == 2
66 1.1 mrg || GFC_DESCRIPTOR_RANK (b) == 2);
67 1.1 mrg
68 1.1 mrg if (retarray->base_addr == NULL)
69 1.1 mrg {
70 1.1 mrg if (GFC_DESCRIPTOR_RANK (a) == 1)
71 1.1 mrg {
72 1.1 mrg GFC_DIMENSION_SET(retarray->dim[0], 0,
73 1.1 mrg GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
74 1.1 mrg }
75 1.1 mrg else if (GFC_DESCRIPTOR_RANK (b) == 1)
76 1.1 mrg {
77 1.1 mrg GFC_DIMENSION_SET(retarray->dim[0], 0,
78 1.1 mrg GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
79 1.1 mrg }
80 1.1 mrg else
81 1.1 mrg {
82 1.1 mrg GFC_DIMENSION_SET(retarray->dim[0], 0,
83 1.1 mrg GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
84 1.1 mrg
85 1.1 mrg GFC_DIMENSION_SET(retarray->dim[1], 0,
86 1.1 mrg GFC_DESCRIPTOR_EXTENT(b,1) - 1,
87 1.1 mrg GFC_DESCRIPTOR_EXTENT(retarray,0));
88 1.1 mrg }
89 1.1 mrg
90 1.1 mrg retarray->base_addr
91 1.1 mrg = xmallocarray (size0 ((array_t *) retarray), sizeof ('rtype_name`));
92 1.1 mrg retarray->offset = 0;
93 1.1 mrg }
94 1.1 mrg else if (unlikely (compile_options.bounds_check))
95 1.1 mrg {
96 1.1 mrg index_type ret_extent, arg_extent;
97 1.1 mrg
98 1.1 mrg if (GFC_DESCRIPTOR_RANK (a) == 1)
99 1.1 mrg {
100 1.1 mrg arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
101 1.1 mrg ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
102 1.1 mrg if (arg_extent != ret_extent)
103 1.1 mrg runtime_error ("Incorrect extent in return array in"
104 1.1 mrg " MATMUL intrinsic: is %ld, should be %ld",
105 1.1 mrg (long int) ret_extent, (long int) arg_extent);
106 1.1 mrg }
107 1.1 mrg else if (GFC_DESCRIPTOR_RANK (b) == 1)
108 1.1 mrg {
109 1.1 mrg arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
110 1.1 mrg ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
111 1.1 mrg if (arg_extent != ret_extent)
112 1.1 mrg runtime_error ("Incorrect extent in return array in"
113 1.1 mrg " MATMUL intrinsic: is %ld, should be %ld",
114 1.1 mrg (long int) ret_extent, (long int) arg_extent);
115 1.1 mrg }
116 1.1 mrg else
117 1.1 mrg {
118 1.1 mrg arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
119 1.1 mrg ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
120 1.1 mrg if (arg_extent != ret_extent)
121 1.1 mrg runtime_error ("Incorrect extent in return array in"
122 1.1 mrg " MATMUL intrinsic for dimension 1:"
123 1.1 mrg " is %ld, should be %ld",
124 1.1 mrg (long int) ret_extent, (long int) arg_extent);
125 1.1 mrg
126 1.1 mrg arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
127 1.1 mrg ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
128 1.1 mrg if (arg_extent != ret_extent)
129 1.1 mrg runtime_error ("Incorrect extent in return array in"
130 1.1 mrg " MATMUL intrinsic for dimension 2:"
131 1.1 mrg " is %ld, should be %ld",
132 1.1 mrg (long int) ret_extent, (long int) arg_extent);
133 1.1 mrg }
134 1.1 mrg }
135 1.1 mrg
136 1.1 mrg abase = a->base_addr;
137 1.1 mrg a_kind = GFC_DESCRIPTOR_SIZE (a);
138 1.1 mrg
139 1.1 mrg if (a_kind == 1 || a_kind == 2 || a_kind == 4 || a_kind == 8
140 1.1 mrg #ifdef HAVE_GFC_LOGICAL_16
141 1.1 mrg || a_kind == 16
142 1.1 mrg #endif
143 1.1 mrg )
144 1.1 mrg abase = GFOR_POINTER_TO_L1 (abase, a_kind);
145 1.1 mrg else
146 1.1 mrg internal_error (NULL, "Funny sized logical array");
147 1.1 mrg
148 1.1 mrg bbase = b->base_addr;
149 1.1 mrg b_kind = GFC_DESCRIPTOR_SIZE (b);
150 1.1 mrg
151 1.1 mrg if (b_kind == 1 || b_kind == 2 || b_kind == 4 || b_kind == 8
152 1.1 mrg #ifdef HAVE_GFC_LOGICAL_16
153 1.1 mrg || b_kind == 16
154 1.1 mrg #endif
155 1.1 mrg )
156 1.1 mrg bbase = GFOR_POINTER_TO_L1 (bbase, b_kind);
157 1.1 mrg else
158 1.1 mrg internal_error (NULL, "Funny sized logical array");
159 1.1 mrg
160 1.1 mrg dest = retarray->base_addr;
161 1.1 mrg '
162 1.1 mrg sinclude(`matmul_asm_'rtype_code`.m4')dnl
163 1.1 mrg `
164 1.1 mrg if (GFC_DESCRIPTOR_RANK (retarray) == 1)
165 1.1 mrg {
166 1.1 mrg rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
167 1.1 mrg rystride = rxstride;
168 1.1 mrg }
169 1.1 mrg else
170 1.1 mrg {
171 1.1 mrg rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
172 1.1 mrg rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
173 1.1 mrg }
174 1.1 mrg
175 1.1 mrg /* If we have rank 1 parameters, zero the absent stride, and set the size to
176 1.1 mrg one. */
177 1.1 mrg if (GFC_DESCRIPTOR_RANK (a) == 1)
178 1.1 mrg {
179 1.1 mrg astride = GFC_DESCRIPTOR_STRIDE_BYTES(a,0);
180 1.1 mrg count = GFC_DESCRIPTOR_EXTENT(a,0);
181 1.1 mrg xstride = 0;
182 1.1 mrg rxstride = 0;
183 1.1 mrg xcount = 1;
184 1.1 mrg }
185 1.1 mrg else
186 1.1 mrg {
187 1.1 mrg astride = GFC_DESCRIPTOR_STRIDE_BYTES(a,1);
188 1.1 mrg count = GFC_DESCRIPTOR_EXTENT(a,1);
189 1.1 mrg xstride = GFC_DESCRIPTOR_STRIDE_BYTES(a,0);
190 1.1 mrg xcount = GFC_DESCRIPTOR_EXTENT(a,0);
191 1.1 mrg }
192 1.1 mrg if (GFC_DESCRIPTOR_RANK (b) == 1)
193 1.1 mrg {
194 1.1 mrg bstride = GFC_DESCRIPTOR_STRIDE_BYTES(b,0);
195 1.1 mrg assert(count == GFC_DESCRIPTOR_EXTENT(b,0));
196 1.1 mrg ystride = 0;
197 1.1 mrg rystride = 0;
198 1.1 mrg ycount = 1;
199 1.1 mrg }
200 1.1 mrg else
201 1.1 mrg {
202 1.1 mrg bstride = GFC_DESCRIPTOR_STRIDE_BYTES(b,0);
203 1.1 mrg assert(count == GFC_DESCRIPTOR_EXTENT(b,0));
204 1.1 mrg ystride = GFC_DESCRIPTOR_STRIDE_BYTES(b,1);
205 1.1 mrg ycount = GFC_DESCRIPTOR_EXTENT(b,1);
206 1.1 mrg }
207 1.1 mrg
208 1.1 mrg for (y = 0; y < ycount; y++)
209 1.1 mrg {
210 1.1 mrg for (x = 0; x < xcount; x++)
211 1.1 mrg {
212 1.1 mrg /* Do the summation for this element. For real and integer types
213 1.1 mrg this is the same as DOT_PRODUCT. For complex types we use do
214 1.1 mrg a*b, not conjg(a)*b. */
215 1.1 mrg pa = abase;
216 1.1 mrg pb = bbase;
217 1.1 mrg *dest = 0;
218 1.1 mrg
219 1.1 mrg for (n = 0; n < count; n++)
220 1.1 mrg {
221 1.1 mrg if (*pa && *pb)
222 1.1 mrg {
223 1.1 mrg *dest = 1;
224 1.1 mrg break;
225 1.1 mrg }
226 1.1 mrg pa += astride;
227 1.1 mrg pb += bstride;
228 1.1 mrg }
229 1.1 mrg
230 1.1 mrg dest += rxstride;
231 1.1 mrg abase += xstride;
232 1.1 mrg }
233 1.1 mrg abase -= xstride * xcount;
234 1.1 mrg bbase += ystride;
235 1.1 mrg dest += rystride - (rxstride * xcount);
236 1.1 mrg }
237 1.1 mrg }
238 1.1 mrg
239 1.1 mrg #endif
240 1.1 mrg '
241