Home | History | Annotate | Line # | Download | only in m4
      1      1.1  mrg `/* Implementation of the MATMUL intrinsic
      2  1.1.1.4  mrg    Copyright (C) 2002-2024 Free Software Foundation, Inc.
      3      1.1  mrg    Contributed by Paul Brook <paul (a] nowt.org>
      4      1.1  mrg 
      5      1.1  mrg This file is part of the GNU Fortran runtime library (libgfortran).
      6      1.1  mrg 
      7      1.1  mrg Libgfortran is free software; you can redistribute it and/or
      8      1.1  mrg modify it under the terms of the GNU General Public
      9      1.1  mrg License as published by the Free Software Foundation; either
     10      1.1  mrg version 3 of the License, or (at your option) any later version.
     11      1.1  mrg 
     12      1.1  mrg Libgfortran is distributed in the hope that it will be useful,
     13      1.1  mrg but WITHOUT ANY WARRANTY; without even the implied warranty of
     14      1.1  mrg MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     15      1.1  mrg GNU General Public License for more details.
     16      1.1  mrg 
     17      1.1  mrg Under Section 7 of GPL version 3, you are granted additional
     18      1.1  mrg permissions described in the GCC Runtime Library Exception, version
     19      1.1  mrg 3.1, as published by the Free Software Foundation.
     20      1.1  mrg 
     21      1.1  mrg You should have received a copy of the GNU General Public License and
     22      1.1  mrg a copy of the GCC Runtime Library Exception along with this program;
     23      1.1  mrg see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
     24      1.1  mrg <http://www.gnu.org/licenses/>.  */
     25      1.1  mrg 
     26      1.1  mrg #include "libgfortran.h"
     27      1.1  mrg #include <assert.h>'
     28      1.1  mrg 
     29      1.1  mrg include(iparm.m4)dnl
     30      1.1  mrg 
     31      1.1  mrg `#if defined (HAVE_'rtype_name`)
     32      1.1  mrg 
     33      1.1  mrg /* Dimensions: retarray(x,y) a(x, count) b(count,y).
     34      1.1  mrg    Either a or b can be rank 1.  In this case x or y is 1.  */
     35      1.1  mrg 
     36      1.1  mrg extern void matmul_'rtype_code` ('rtype` * const restrict, 
     37      1.1  mrg 	gfc_array_l1 * const restrict, gfc_array_l1 * const restrict);
     38      1.1  mrg export_proto(matmul_'rtype_code`);
     39      1.1  mrg 
     40      1.1  mrg void
     41      1.1  mrg matmul_'rtype_code` ('rtype` * const restrict retarray, 
     42      1.1  mrg 	gfc_array_l1 * const restrict a, gfc_array_l1 * const restrict b)
     43      1.1  mrg {
     44      1.1  mrg   const GFC_LOGICAL_1 * restrict abase;
     45      1.1  mrg   const GFC_LOGICAL_1 * restrict bbase;
     46      1.1  mrg   'rtype_name` * restrict dest;
     47      1.1  mrg   index_type rxstride;
     48      1.1  mrg   index_type rystride;
     49      1.1  mrg   index_type xcount;
     50      1.1  mrg   index_type ycount;
     51      1.1  mrg   index_type xstride;
     52      1.1  mrg   index_type ystride;
     53      1.1  mrg   index_type x;
     54      1.1  mrg   index_type y;
     55      1.1  mrg   int a_kind;
     56      1.1  mrg   int b_kind;
     57      1.1  mrg 
     58      1.1  mrg   const GFC_LOGICAL_1 * restrict pa;
     59      1.1  mrg   const GFC_LOGICAL_1 * restrict pb;
     60      1.1  mrg   index_type astride;
     61      1.1  mrg   index_type bstride;
     62      1.1  mrg   index_type count;
     63      1.1  mrg   index_type n;
     64      1.1  mrg 
     65      1.1  mrg   assert (GFC_DESCRIPTOR_RANK (a) == 2
     66      1.1  mrg           || GFC_DESCRIPTOR_RANK (b) == 2);
     67      1.1  mrg 
     68      1.1  mrg   if (retarray->base_addr == NULL)
     69      1.1  mrg     {
     70      1.1  mrg       if (GFC_DESCRIPTOR_RANK (a) == 1)
     71      1.1  mrg         {
     72      1.1  mrg 	  GFC_DIMENSION_SET(retarray->dim[0], 0,
     73      1.1  mrg 	                    GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
     74      1.1  mrg         }
     75      1.1  mrg       else if (GFC_DESCRIPTOR_RANK (b) == 1)
     76      1.1  mrg         {
     77      1.1  mrg 	  GFC_DIMENSION_SET(retarray->dim[0], 0,
     78      1.1  mrg 	                    GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
     79      1.1  mrg         }
     80      1.1  mrg       else
     81      1.1  mrg         {
     82      1.1  mrg 	  GFC_DIMENSION_SET(retarray->dim[0], 0,
     83      1.1  mrg 	                    GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
     84      1.1  mrg 
     85      1.1  mrg           GFC_DIMENSION_SET(retarray->dim[1], 0,
     86      1.1  mrg 	                    GFC_DESCRIPTOR_EXTENT(b,1) - 1,
     87      1.1  mrg 			    GFC_DESCRIPTOR_EXTENT(retarray,0));
     88      1.1  mrg         }
     89      1.1  mrg           
     90      1.1  mrg       retarray->base_addr
     91      1.1  mrg 	= xmallocarray (size0 ((array_t *) retarray), sizeof ('rtype_name`));
     92      1.1  mrg       retarray->offset = 0;
     93      1.1  mrg     }
     94      1.1  mrg     else if (unlikely (compile_options.bounds_check))
     95      1.1  mrg       {
     96      1.1  mrg 	index_type ret_extent, arg_extent;
     97      1.1  mrg 
     98      1.1  mrg 	if (GFC_DESCRIPTOR_RANK (a) == 1)
     99      1.1  mrg 	  {
    100      1.1  mrg 	    arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
    101      1.1  mrg 	    ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
    102      1.1  mrg 	    if (arg_extent != ret_extent)
    103      1.1  mrg 	      runtime_error ("Incorrect extent in return array in"
    104      1.1  mrg 			     " MATMUL intrinsic: is %ld, should be %ld",
    105      1.1  mrg 			     (long int) ret_extent, (long int) arg_extent);
    106      1.1  mrg 	  }
    107      1.1  mrg 	else if (GFC_DESCRIPTOR_RANK (b) == 1)
    108      1.1  mrg 	  {
    109      1.1  mrg 	    arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
    110      1.1  mrg 	    ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
    111      1.1  mrg 	    if (arg_extent != ret_extent)
    112      1.1  mrg 	      runtime_error ("Incorrect extent in return array in"
    113      1.1  mrg 			     " MATMUL intrinsic: is %ld, should be %ld",
    114      1.1  mrg 			     (long int) ret_extent, (long int) arg_extent);	    
    115      1.1  mrg 	  }
    116      1.1  mrg 	else
    117      1.1  mrg 	  {
    118      1.1  mrg 	    arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
    119      1.1  mrg 	    ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
    120      1.1  mrg 	    if (arg_extent != ret_extent)
    121      1.1  mrg 	      runtime_error ("Incorrect extent in return array in"
    122      1.1  mrg 			     " MATMUL intrinsic for dimension 1:"
    123      1.1  mrg 			     " is %ld, should be %ld",
    124      1.1  mrg 			     (long int) ret_extent, (long int) arg_extent);
    125      1.1  mrg 
    126      1.1  mrg 	    arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
    127      1.1  mrg 	    ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
    128      1.1  mrg 	    if (arg_extent != ret_extent)
    129      1.1  mrg 	      runtime_error ("Incorrect extent in return array in"
    130      1.1  mrg 			     " MATMUL intrinsic for dimension 2:"
    131      1.1  mrg 			     " is %ld, should be %ld",
    132      1.1  mrg 			     (long int) ret_extent, (long int) arg_extent);
    133      1.1  mrg 	  }
    134      1.1  mrg       }
    135      1.1  mrg 
    136      1.1  mrg   abase = a->base_addr;
    137      1.1  mrg   a_kind = GFC_DESCRIPTOR_SIZE (a);
    138      1.1  mrg 
    139      1.1  mrg   if (a_kind == 1 || a_kind == 2 || a_kind == 4 || a_kind == 8
    140      1.1  mrg #ifdef HAVE_GFC_LOGICAL_16
    141      1.1  mrg      || a_kind == 16
    142      1.1  mrg #endif
    143      1.1  mrg      )
    144      1.1  mrg     abase = GFOR_POINTER_TO_L1 (abase, a_kind);
    145      1.1  mrg   else
    146      1.1  mrg     internal_error (NULL, "Funny sized logical array");
    147      1.1  mrg 
    148      1.1  mrg   bbase = b->base_addr;
    149      1.1  mrg   b_kind = GFC_DESCRIPTOR_SIZE (b);
    150      1.1  mrg 
    151      1.1  mrg   if (b_kind == 1 || b_kind == 2 || b_kind == 4 || b_kind == 8
    152      1.1  mrg #ifdef HAVE_GFC_LOGICAL_16
    153      1.1  mrg      || b_kind == 16
    154      1.1  mrg #endif
    155      1.1  mrg      )
    156      1.1  mrg     bbase = GFOR_POINTER_TO_L1 (bbase, b_kind);
    157      1.1  mrg   else
    158      1.1  mrg     internal_error (NULL, "Funny sized logical array");
    159      1.1  mrg 
    160      1.1  mrg   dest = retarray->base_addr;
    161      1.1  mrg '
    162      1.1  mrg sinclude(`matmul_asm_'rtype_code`.m4')dnl
    163      1.1  mrg `
    164      1.1  mrg   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
    165      1.1  mrg     {
    166      1.1  mrg       rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
    167      1.1  mrg       rystride = rxstride;
    168      1.1  mrg     }
    169      1.1  mrg   else
    170      1.1  mrg     {
    171      1.1  mrg       rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
    172      1.1  mrg       rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
    173      1.1  mrg     }
    174      1.1  mrg 
    175      1.1  mrg   /* If we have rank 1 parameters, zero the absent stride, and set the size to
    176      1.1  mrg      one.  */
    177      1.1  mrg   if (GFC_DESCRIPTOR_RANK (a) == 1)
    178      1.1  mrg     {
    179      1.1  mrg       astride = GFC_DESCRIPTOR_STRIDE_BYTES(a,0);
    180      1.1  mrg       count = GFC_DESCRIPTOR_EXTENT(a,0);
    181      1.1  mrg       xstride = 0;
    182      1.1  mrg       rxstride = 0;
    183      1.1  mrg       xcount = 1;
    184      1.1  mrg     }
    185      1.1  mrg   else
    186      1.1  mrg     {
    187      1.1  mrg       astride = GFC_DESCRIPTOR_STRIDE_BYTES(a,1);
    188      1.1  mrg       count = GFC_DESCRIPTOR_EXTENT(a,1);
    189      1.1  mrg       xstride = GFC_DESCRIPTOR_STRIDE_BYTES(a,0);
    190      1.1  mrg       xcount = GFC_DESCRIPTOR_EXTENT(a,0);
    191      1.1  mrg     }
    192      1.1  mrg   if (GFC_DESCRIPTOR_RANK (b) == 1)
    193      1.1  mrg     {
    194      1.1  mrg       bstride = GFC_DESCRIPTOR_STRIDE_BYTES(b,0);
    195      1.1  mrg       assert(count == GFC_DESCRIPTOR_EXTENT(b,0));
    196      1.1  mrg       ystride = 0;
    197      1.1  mrg       rystride = 0;
    198      1.1  mrg       ycount = 1;
    199      1.1  mrg     }
    200      1.1  mrg   else
    201      1.1  mrg     {
    202      1.1  mrg       bstride = GFC_DESCRIPTOR_STRIDE_BYTES(b,0);
    203      1.1  mrg       assert(count == GFC_DESCRIPTOR_EXTENT(b,0));
    204      1.1  mrg       ystride = GFC_DESCRIPTOR_STRIDE_BYTES(b,1);
    205      1.1  mrg       ycount = GFC_DESCRIPTOR_EXTENT(b,1);
    206      1.1  mrg     }
    207      1.1  mrg 
    208      1.1  mrg   for (y = 0; y < ycount; y++)
    209      1.1  mrg     {
    210      1.1  mrg       for (x = 0; x < xcount; x++)
    211      1.1  mrg         {
    212      1.1  mrg           /* Do the summation for this element.  For real and integer types
    213      1.1  mrg              this is the same as DOT_PRODUCT.  For complex types we use do
    214      1.1  mrg              a*b, not conjg(a)*b.  */
    215      1.1  mrg           pa = abase;
    216      1.1  mrg           pb = bbase;
    217      1.1  mrg           *dest = 0;
    218      1.1  mrg 
    219      1.1  mrg           for (n = 0; n < count; n++)
    220      1.1  mrg             {
    221      1.1  mrg               if (*pa && *pb)
    222      1.1  mrg                 {
    223      1.1  mrg                   *dest = 1;
    224      1.1  mrg                   break;
    225      1.1  mrg                 }
    226      1.1  mrg               pa += astride;
    227      1.1  mrg               pb += bstride;
    228      1.1  mrg             }
    229      1.1  mrg 
    230      1.1  mrg           dest += rxstride;
    231      1.1  mrg           abase += xstride;
    232      1.1  mrg         }
    233      1.1  mrg       abase -= xstride * xcount;
    234      1.1  mrg       bbase += ystride;
    235      1.1  mrg       dest += rystride - (rxstride * xcount);
    236      1.1  mrg     }
    237      1.1  mrg }
    238      1.1  mrg 
    239      1.1  mrg #endif
    240      1.1  mrg '
    241