1//
2// Copyright 2012-2016 Francisco Jerez
3// Copyright 2012-2016 Advanced Micro Devices, Inc.
4// Copyright 2014-2016 Jan Vesely
5// Copyright 2014-2015 Serge Martin
6// Copyright 2015 Zoltan Gilian
7//
8// Permission is hereby granted, free of charge, to any person obtaining a
9// copy of this software and associated documentation files (the "Software"),
10// to deal in the Software without restriction, including without limitation
11// the rights to use, copy, modify, merge, publish, distribute, sublicense,
12// and/or sell copies of the Software, and to permit persons to whom the
13// Software is furnished to do so, subject to the following conditions:
14//
15// The above copyright notice and this permission notice shall be included in
16// all copies or substantial portions of the Software.
17//
18// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
21// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
22// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
23// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
24// OTHER DEALINGS IN THE SOFTWARE.
25
26#include <sstream>
27
28#include <llvm/ADT/ArrayRef.h>
29#include <llvm/IR/DiagnosticPrinter.h>
30#include <llvm/IR/DiagnosticInfo.h>
31#include <llvm/IR/LLVMContext.h>
32#include <llvm/IR/Type.h>
33#include <llvm/Support/raw_ostream.h>
34#include <llvm/Bitcode/BitcodeWriter.h>
35#include <llvm/Bitcode/BitcodeReader.h>
36#include <llvm-c/Core.h>
37#include <llvm-c/Target.h>
38#include <LLVMSPIRVLib/LLVMSPIRVLib.h>
39
40#include <clang/CodeGen/CodeGenAction.h>
41#include <clang/Lex/PreprocessorOptions.h>
42#include <clang/Frontend/CompilerInstance.h>
43#include <clang/Frontend/TextDiagnosticBuffer.h>
44#include <clang/Frontend/TextDiagnosticPrinter.h>
45#include <clang/Basic/TargetInfo.h>
46
47#include <spirv-tools/libspirv.hpp>
48#include <spirv-tools/linker.hpp>
49#include <spirv-tools/optimizer.hpp>
50
51#include "util/macros.h"
52#include "glsl_types.h"
53
54#include "spirv.h"
55
56#ifdef USE_STATIC_OPENCL_C_H
57#include "opencl-c.h.h"
58#include "opencl-c-base.h.h"
59#endif
60
61#include "clc_helpers.h"
62
63/* Use the highest version of SPIRV supported by SPIRV-Tools. */
64constexpr spv_target_env spirv_target = SPV_ENV_UNIVERSAL_1_5;
65
66constexpr SPIRV::VersionNumber invalid_spirv_trans_version = static_cast<SPIRV::VersionNumber>(0);
67
68using ::llvm::Function;
69using ::llvm::LLVMContext;
70using ::llvm::Module;
71using ::llvm::raw_string_ostream;
72
73static void
74llvm_log_handler(const ::llvm::DiagnosticInfo &di, void *data) {
75   raw_string_ostream os { *reinterpret_cast<std::string *>(data) };
76   ::llvm::DiagnosticPrinterRawOStream printer { os };
77   di.print(printer);
78}
79
80class SPIRVKernelArg {
81public:
82   SPIRVKernelArg(uint32_t id, uint32_t typeId) : id(id), typeId(typeId),
83                                                  addrQualifier(CLC_KERNEL_ARG_ADDRESS_PRIVATE),
84                                                  accessQualifier(0),
85                                                  typeQualifier(0) { }
86   ~SPIRVKernelArg() { }
87
88   uint32_t id;
89   uint32_t typeId;
90   std::string name;
91   std::string typeName;
92   enum clc_kernel_arg_address_qualifier addrQualifier;
93   unsigned accessQualifier;
94   unsigned typeQualifier;
95};
96
97class SPIRVKernelInfo {
98public:
99   SPIRVKernelInfo(uint32_t fid, const char *nm) : funcId(fid), name(nm), vecHint(0) { }
100   ~SPIRVKernelInfo() { }
101
102   uint32_t funcId;
103   std::string name;
104   std::vector<SPIRVKernelArg> args;
105   unsigned vecHint;
106};
107
108class SPIRVKernelParser {
109public:
110   SPIRVKernelParser() : curKernel(NULL)
111   {
112      ctx = spvContextCreate(spirv_target);
113   }
114
115   ~SPIRVKernelParser()
116   {
117     spvContextDestroy(ctx);
118   }
119
120   void parseEntryPoint(const spv_parsed_instruction_t *ins)
121   {
122      assert(ins->num_operands >= 3);
123
124      const spv_parsed_operand_t *op = &ins->operands[1];
125
126      assert(op->type == SPV_OPERAND_TYPE_ID);
127
128      uint32_t funcId = ins->words[op->offset];
129
130      for (auto &iter : kernels) {
131         if (funcId == iter.funcId)
132            return;
133      }
134
135      op = &ins->operands[2];
136      assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
137      const char *name = reinterpret_cast<const char *>(ins->words + op->offset);
138
139      kernels.push_back(SPIRVKernelInfo(funcId, name));
140   }
141
142   void parseFunction(const spv_parsed_instruction_t *ins)
143   {
144      assert(ins->num_operands == 4);
145
146      const spv_parsed_operand_t *op = &ins->operands[1];
147
148      assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
149
150      uint32_t funcId = ins->words[op->offset];
151
152      for (auto &kernel : kernels) {
153         if (funcId == kernel.funcId && !kernel.args.size()) {
154            curKernel = &kernel;
155	    return;
156         }
157      }
158   }
159
160   void parseFunctionParam(const spv_parsed_instruction_t *ins)
161   {
162      const spv_parsed_operand_t *op;
163      uint32_t id, typeId;
164
165      if (!curKernel)
166         return;
167
168      assert(ins->num_operands == 2);
169      op = &ins->operands[0];
170      assert(op->type == SPV_OPERAND_TYPE_TYPE_ID);
171      typeId = ins->words[op->offset];
172      op = &ins->operands[1];
173      assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
174      id = ins->words[op->offset];
175      curKernel->args.push_back(SPIRVKernelArg(id, typeId));
176   }
177
178   void parseName(const spv_parsed_instruction_t *ins)
179   {
180      const spv_parsed_operand_t *op;
181      const char *name;
182      uint32_t id;
183
184      assert(ins->num_operands == 2);
185
186      op = &ins->operands[0];
187      assert(op->type == SPV_OPERAND_TYPE_ID);
188      id = ins->words[op->offset];
189      op = &ins->operands[1];
190      assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
191      name = reinterpret_cast<const char *>(ins->words + op->offset);
192
193      for (auto &kernel : kernels) {
194         for (auto &arg : kernel.args) {
195            if (arg.id == id && arg.name.empty()) {
196              arg.name = name;
197              break;
198	    }
199         }
200      }
201   }
202
203   void parseTypePointer(const spv_parsed_instruction_t *ins)
204   {
205      enum clc_kernel_arg_address_qualifier addrQualifier;
206      uint32_t typeId, storageClass;
207      const spv_parsed_operand_t *op;
208
209      assert(ins->num_operands == 3);
210
211      op = &ins->operands[0];
212      assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
213      typeId = ins->words[op->offset];
214
215      op = &ins->operands[1];
216      assert(op->type == SPV_OPERAND_TYPE_STORAGE_CLASS);
217      storageClass = ins->words[op->offset];
218      switch (storageClass) {
219      case SpvStorageClassCrossWorkgroup:
220         addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
221         break;
222      case SpvStorageClassWorkgroup:
223         addrQualifier = CLC_KERNEL_ARG_ADDRESS_LOCAL;
224         break;
225      case SpvStorageClassUniformConstant:
226         addrQualifier = CLC_KERNEL_ARG_ADDRESS_CONSTANT;
227         break;
228      default:
229         addrQualifier = CLC_KERNEL_ARG_ADDRESS_PRIVATE;
230         break;
231      }
232
233      for (auto &kernel : kernels) {
234	 for (auto &arg : kernel.args) {
235            if (arg.typeId == typeId)
236               arg.addrQualifier = addrQualifier;
237         }
238      }
239   }
240
241   void parseOpString(const spv_parsed_instruction_t *ins)
242   {
243      const spv_parsed_operand_t *op;
244      std::string str;
245
246      assert(ins->num_operands == 2);
247
248      op = &ins->operands[1];
249      assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
250      str = reinterpret_cast<const char *>(ins->words + op->offset);
251
252      if (str.find("kernel_arg_type.") != 0)
253         return;
254
255      size_t start = sizeof("kernel_arg_type.") - 1;
256
257      for (auto &kernel : kernels) {
258         size_t pos;
259
260	 pos = str.find(kernel.name, start);
261         if (pos == std::string::npos ||
262             pos != start || str[start + kernel.name.size()] != '.')
263            continue;
264
265	 pos = start + kernel.name.size();
266         if (str[pos++] != '.')
267            continue;
268
269         for (auto &arg : kernel.args) {
270            if (arg.name.empty())
271               break;
272
273            size_t typeEnd = str.find(',', pos);
274	    if (typeEnd == std::string::npos)
275               break;
276
277            arg.typeName = str.substr(pos, typeEnd - pos);
278            pos = typeEnd + 1;
279         }
280      }
281   }
282
283   void applyDecoration(uint32_t id, const spv_parsed_instruction_t *ins)
284   {
285      auto iter = decorationGroups.find(id);
286      if (iter != decorationGroups.end()) {
287         for (uint32_t entry : iter->second)
288            applyDecoration(entry, ins);
289         return;
290      }
291
292      const spv_parsed_operand_t *op;
293      uint32_t decoration;
294
295      assert(ins->num_operands >= 2);
296
297      op = &ins->operands[1];
298      assert(op->type == SPV_OPERAND_TYPE_DECORATION);
299      decoration = ins->words[op->offset];
300
301      if (decoration == SpvDecorationSpecId) {
302         uint32_t spec_id = ins->words[ins->operands[2].offset];
303         for (auto &c : specConstants) {
304            if (c.second.id == spec_id) {
305               assert(c.first == id);
306               return;
307            }
308         }
309         specConstants.emplace_back(id, clc_parsed_spec_constant{ spec_id });
310         return;
311      }
312
313      for (auto &kernel : kernels) {
314         for (auto &arg : kernel.args) {
315            if (arg.id == id) {
316               switch (decoration) {
317               case SpvDecorationVolatile:
318                  arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_VOLATILE;
319                  break;
320               case SpvDecorationConstant:
321                  arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
322                  break;
323               case SpvDecorationRestrict:
324                  arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
325                  break;
326               case SpvDecorationFuncParamAttr:
327                  op = &ins->operands[2];
328                  assert(op->type == SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE);
329                  switch (ins->words[op->offset]) {
330                  case SpvFunctionParameterAttributeNoAlias:
331                     arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
332                     break;
333                  case SpvFunctionParameterAttributeNoWrite:
334                     arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
335                     break;
336                  }
337                  break;
338               }
339            }
340
341         }
342      }
343   }
344
345   void parseOpDecorate(const spv_parsed_instruction_t *ins)
346   {
347      const spv_parsed_operand_t *op;
348      uint32_t id;
349
350      assert(ins->num_operands >= 2);
351
352      op = &ins->operands[0];
353      assert(op->type == SPV_OPERAND_TYPE_ID);
354      id = ins->words[op->offset];
355
356      applyDecoration(id, ins);
357   }
358
359   void parseOpGroupDecorate(const spv_parsed_instruction_t *ins)
360   {
361      assert(ins->num_operands >= 2);
362
363      const spv_parsed_operand_t *op = &ins->operands[0];
364      assert(op->type == SPV_OPERAND_TYPE_ID);
365      uint32_t groupId = ins->words[op->offset];
366
367      auto lowerBound = decorationGroups.lower_bound(groupId);
368      if (lowerBound != decorationGroups.end() &&
369          lowerBound->first == groupId)
370         // Group already filled out
371         return;
372
373      auto iter = decorationGroups.emplace_hint(lowerBound, groupId, std::vector<uint32_t>{});
374      auto& vec = iter->second;
375      vec.reserve(ins->num_operands - 1);
376      for (uint32_t i = 1; i < ins->num_operands; ++i) {
377         op = &ins->operands[i];
378         assert(op->type == SPV_OPERAND_TYPE_ID);
379         vec.push_back(ins->words[op->offset]);
380      }
381   }
382
383   void parseOpTypeImage(const spv_parsed_instruction_t *ins)
384   {
385      const spv_parsed_operand_t *op;
386      uint32_t typeId;
387      unsigned accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
388
389      op = &ins->operands[0];
390      assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
391      typeId = ins->words[op->offset];
392
393      if (ins->num_operands >= 9) {
394         op = &ins->operands[8];
395         assert(op->type == SPV_OPERAND_TYPE_ACCESS_QUALIFIER);
396         switch (ins->words[op->offset]) {
397         case SpvAccessQualifierReadOnly:
398            accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
399            break;
400         case SpvAccessQualifierWriteOnly:
401            accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE;
402            break;
403         case SpvAccessQualifierReadWrite:
404            accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE |
405               CLC_KERNEL_ARG_ACCESS_READ;
406            break;
407         }
408      }
409
410      for (auto &kernel : kernels) {
411	 for (auto &arg : kernel.args) {
412            if (arg.typeId == typeId) {
413               arg.accessQualifier = accessQualifier;
414               arg.addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
415            }
416         }
417      }
418   }
419
420   void parseExecutionMode(const spv_parsed_instruction_t *ins)
421   {
422      uint32_t executionMode = ins->words[ins->operands[1].offset];
423      if (executionMode != SpvExecutionModeVecTypeHint)
424         return;
425
426      uint32_t funcId = ins->words[ins->operands[0].offset];
427      uint32_t vecHint = ins->words[ins->operands[2].offset];
428      for (auto& kernel : kernels) {
429         if (kernel.funcId == funcId)
430            kernel.vecHint = vecHint;
431      }
432   }
433
434   void parseLiteralType(const spv_parsed_instruction_t *ins)
435   {
436      uint32_t typeId = ins->words[ins->operands[0].offset];
437      auto& literalType = literalTypes[typeId];
438      switch (ins->opcode) {
439      case SpvOpTypeBool:
440         literalType = CLC_SPEC_CONSTANT_BOOL;
441         break;
442      case SpvOpTypeFloat: {
443         uint32_t sizeInBits = ins->words[ins->operands[1].offset];
444         switch (sizeInBits) {
445         case 32:
446            literalType = CLC_SPEC_CONSTANT_FLOAT;
447            break;
448         case 64:
449            literalType = CLC_SPEC_CONSTANT_DOUBLE;
450            break;
451         case 16:
452            /* Can't be used for a spec constant */
453            break;
454         default:
455            unreachable("Unexpected float bit size");
456         }
457         break;
458      }
459      case SpvOpTypeInt: {
460         uint32_t sizeInBits = ins->words[ins->operands[1].offset];
461         bool isSigned = ins->words[ins->operands[2].offset];
462         if (isSigned) {
463            switch (sizeInBits) {
464            case 8:
465               literalType = CLC_SPEC_CONSTANT_INT8;
466               break;
467            case 16:
468               literalType = CLC_SPEC_CONSTANT_INT16;
469               break;
470            case 32:
471               literalType = CLC_SPEC_CONSTANT_INT32;
472               break;
473            case 64:
474               literalType = CLC_SPEC_CONSTANT_INT64;
475               break;
476            default:
477               unreachable("Unexpected int bit size");
478            }
479         } else {
480            switch (sizeInBits) {
481            case 8:
482               literalType = CLC_SPEC_CONSTANT_UINT8;
483               break;
484            case 16:
485               literalType = CLC_SPEC_CONSTANT_UINT16;
486               break;
487            case 32:
488               literalType = CLC_SPEC_CONSTANT_UINT32;
489               break;
490            case 64:
491               literalType = CLC_SPEC_CONSTANT_UINT64;
492               break;
493            default:
494               unreachable("Unexpected uint bit size");
495            }
496         }
497         break;
498      }
499      default:
500         unreachable("Unexpected type opcode");
501      }
502   }
503
504   void parseSpecConstant(const spv_parsed_instruction_t *ins)
505   {
506      uint32_t id = ins->result_id;
507      for (auto& c : specConstants) {
508         if (c.first == id) {
509            auto& data = c.second;
510            switch (ins->opcode) {
511            case SpvOpSpecConstant: {
512               uint32_t typeId = ins->words[ins->operands[0].offset];
513
514               // This better be an integer or float type
515               auto typeIter = literalTypes.find(typeId);
516               assert(typeIter != literalTypes.end());
517
518               data.type = typeIter->second;
519               break;
520            }
521            case SpvOpSpecConstantFalse:
522            case SpvOpSpecConstantTrue:
523               data.type = CLC_SPEC_CONSTANT_BOOL;
524               break;
525            default:
526               unreachable("Composites and Ops are not directly specializable.");
527            }
528         }
529      }
530   }
531
532   static spv_result_t
533   parseInstruction(void *data, const spv_parsed_instruction_t *ins)
534   {
535      SPIRVKernelParser *parser = reinterpret_cast<SPIRVKernelParser *>(data);
536
537      switch (ins->opcode) {
538      case SpvOpName:
539         parser->parseName(ins);
540         break;
541      case SpvOpEntryPoint:
542         parser->parseEntryPoint(ins);
543         break;
544      case SpvOpFunction:
545         parser->parseFunction(ins);
546         break;
547      case SpvOpFunctionParameter:
548         parser->parseFunctionParam(ins);
549         break;
550      case SpvOpFunctionEnd:
551      case SpvOpLabel:
552         parser->curKernel = NULL;
553         break;
554      case SpvOpTypePointer:
555         parser->parseTypePointer(ins);
556         break;
557      case SpvOpTypeImage:
558         parser->parseOpTypeImage(ins);
559         break;
560      case SpvOpString:
561         parser->parseOpString(ins);
562         break;
563      case SpvOpDecorate:
564         parser->parseOpDecorate(ins);
565         break;
566      case SpvOpGroupDecorate:
567         parser->parseOpGroupDecorate(ins);
568         break;
569      case SpvOpExecutionMode:
570         parser->parseExecutionMode(ins);
571         break;
572      case SpvOpTypeBool:
573      case SpvOpTypeInt:
574      case SpvOpTypeFloat:
575         parser->parseLiteralType(ins);
576         break;
577      case SpvOpSpecConstant:
578      case SpvOpSpecConstantFalse:
579      case SpvOpSpecConstantTrue:
580         parser->parseSpecConstant(ins);
581         break;
582      default:
583         break;
584      }
585
586      return SPV_SUCCESS;
587   }
588
589   bool parsingComplete()
590   {
591      for (auto &kernel : kernels) {
592         if (kernel.name.empty())
593            return false;
594
595         for (auto &arg : kernel.args) {
596            if (arg.name.empty() || arg.typeName.empty())
597               return false;
598         }
599      }
600
601      return true;
602   }
603
604   bool parseBinary(const struct clc_binary &spvbin, const struct clc_logger *logger)
605   {
606      /* 3 passes should be enough to retrieve all kernel information:
607       * 1st pass: all entry point name and number of args
608       * 2nd pass: argument names and type names
609       * 3rd pass: pointer type names
610       */
611      for (unsigned pass = 0; pass < 3; pass++) {
612         spv_diagnostic diagnostic = NULL;
613         auto result = spvBinaryParse(ctx, reinterpret_cast<void *>(this),
614                                      static_cast<uint32_t*>(spvbin.data), spvbin.size / 4,
615                                      NULL, parseInstruction, &diagnostic);
616
617         if (result != SPV_SUCCESS) {
618            if (diagnostic && logger)
619               logger->error(logger->priv, diagnostic->error);
620            return false;
621         }
622
623         if (parsingComplete())
624            return true;
625      }
626
627      assert(0);
628      return false;
629   }
630
631   std::vector<SPIRVKernelInfo> kernels;
632   std::vector<std::pair<uint32_t, clc_parsed_spec_constant>> specConstants;
633   std::map<uint32_t, enum clc_spec_constant_type> literalTypes;
634   std::map<uint32_t, std::vector<uint32_t>> decorationGroups;
635   SPIRVKernelInfo *curKernel;
636   spv_context ctx;
637};
638
639bool
640clc_spirv_get_kernels_info(const struct clc_binary *spvbin,
641                           const struct clc_kernel_info **out_kernels,
642                           unsigned *num_kernels,
643                           const struct clc_parsed_spec_constant **out_spec_constants,
644                           unsigned *num_spec_constants,
645                           const struct clc_logger *logger)
646{
647   struct clc_kernel_info *kernels;
648   struct clc_parsed_spec_constant *spec_constants = NULL;
649
650   SPIRVKernelParser parser;
651
652   if (!parser.parseBinary(*spvbin, logger))
653      return false;
654
655   *num_kernels = parser.kernels.size();
656   *num_spec_constants = parser.specConstants.size();
657   if (!*num_kernels)
658      return false;
659
660   kernels = reinterpret_cast<struct clc_kernel_info *>(calloc(*num_kernels,
661                                                               sizeof(*kernels)));
662   assert(kernels);
663   for (unsigned i = 0; i < parser.kernels.size(); i++) {
664      kernels[i].name = strdup(parser.kernels[i].name.c_str());
665      kernels[i].num_args = parser.kernels[i].args.size();
666      kernels[i].vec_hint_size = parser.kernels[i].vecHint >> 16;
667      kernels[i].vec_hint_type = (enum clc_vec_hint_type)(parser.kernels[i].vecHint & 0xFFFF);
668      if (!kernels[i].num_args)
669         continue;
670
671      struct clc_kernel_arg *args;
672
673      args = reinterpret_cast<struct clc_kernel_arg *>(calloc(kernels[i].num_args,
674                                                       sizeof(*kernels->args)));
675      kernels[i].args = args;
676      assert(args);
677      for (unsigned j = 0; j < kernels[i].num_args; j++) {
678         if (!parser.kernels[i].args[j].name.empty())
679            args[j].name = strdup(parser.kernels[i].args[j].name.c_str());
680         args[j].type_name = strdup(parser.kernels[i].args[j].typeName.c_str());
681         args[j].address_qualifier = parser.kernels[i].args[j].addrQualifier;
682         args[j].type_qualifier = parser.kernels[i].args[j].typeQualifier;
683         args[j].access_qualifier = parser.kernels[i].args[j].accessQualifier;
684      }
685   }
686
687   if (*num_spec_constants) {
688      spec_constants = reinterpret_cast<struct clc_parsed_spec_constant *>(calloc(*num_spec_constants,
689                                                                                  sizeof(*spec_constants)));
690      assert(spec_constants);
691
692      for (unsigned i = 0; i < parser.specConstants.size(); ++i) {
693         spec_constants[i] = parser.specConstants[i].second;
694      }
695   }
696
697   *out_kernels = kernels;
698   *out_spec_constants = spec_constants;
699
700   return true;
701}
702
703void
704clc_free_kernels_info(const struct clc_kernel_info *kernels,
705                      unsigned num_kernels)
706{
707   if (!kernels)
708      return;
709
710   for (unsigned i = 0; i < num_kernels; i++) {
711      if (kernels[i].args) {
712         for (unsigned j = 0; j < kernels[i].num_args; j++) {
713            free((void *)kernels[i].args[j].name);
714            free((void *)kernels[i].args[j].type_name);
715         }
716      }
717      free((void *)kernels[i].name);
718   }
719
720   free((void *)kernels);
721}
722
723static std::pair<std::unique_ptr<::llvm::Module>, std::unique_ptr<LLVMContext>>
724clc_compile_to_llvm_module(const struct clc_compile_args *args,
725                           const struct clc_logger *logger)
726{
727   LLVMInitializeAllTargets();
728   LLVMInitializeAllTargetInfos();
729   LLVMInitializeAllTargetMCs();
730   LLVMInitializeAllAsmPrinters();
731
732   std::string log;
733   std::unique_ptr<LLVMContext> llvm_ctx { new LLVMContext };
734   llvm_ctx->setDiagnosticHandlerCallBack(llvm_log_handler, &log);
735
736   std::unique_ptr<clang::CompilerInstance> c { new clang::CompilerInstance };
737   clang::DiagnosticsEngine diag { new clang::DiagnosticIDs,
738         new clang::DiagnosticOptions,
739         new clang::TextDiagnosticPrinter(*new raw_string_ostream(log),
740                                          &c->getDiagnosticOpts(), true)};
741
742   std::vector<const char *> clang_opts = {
743      args->source.name,
744      "-triple", "spir64-unknown-unknown",
745      // By default, clang prefers to use modules to pull in the default headers,
746      // which doesn't work with our technique of embedding the headers in our binary
747      "-finclude-default-header",
748      // Add a default CL compiler version. Clang will pick the last one specified
749      // on the command line, so the app can override this one.
750      "-cl-std=cl1.2",
751      // The LLVM-SPIRV-Translator doesn't support memset with variable size
752      "-fno-builtin-memset",
753      // LLVM's optimizations can produce code that the translator can't translate
754      "-O0",
755      // Ensure inline functions are actually emitted
756      "-fgnu89-inline"
757   };
758   // We assume there's appropriate defines for __OPENCL_VERSION__ and __IMAGE_SUPPORT__
759   // being provided by the caller here.
760   clang_opts.insert(clang_opts.end(), args->args, args->args + args->num_args);
761
762   if (!clang::CompilerInvocation::CreateFromArgs(c->getInvocation(),
763#if LLVM_VERSION_MAJOR >= 10
764                                                  clang_opts,
765#else
766                                                  clang_opts.data(),
767                                                  clang_opts.data() + clang_opts.size(),
768#endif
769                                                  diag)) {
770      clc_error(logger, "%sCouldn't create Clang invocation.\n", log.c_str());
771      return {};
772   }
773
774   if (diag.hasErrorOccurred()) {
775      clc_error(logger, "%sErrors occurred during Clang invocation.\n",
776                log.c_str());
777      return {};
778   }
779
780   // This is a workaround for a Clang bug which causes the number
781   // of warnings and errors to be printed to stderr.
782   // http://www.llvm.org/bugs/show_bug.cgi?id=19735
783   c->getDiagnosticOpts().ShowCarets = false;
784
785   c->createDiagnostics(new clang::TextDiagnosticPrinter(
786                           *new raw_string_ostream(log),
787                           &c->getDiagnosticOpts(), true));
788
789   c->setTarget(clang::TargetInfo::CreateTargetInfo(
790                   c->getDiagnostics(), c->getInvocation().TargetOpts));
791
792   c->getFrontendOpts().ProgramAction = clang::frontend::EmitLLVMOnly;
793
794#ifdef USE_STATIC_OPENCL_C_H
795   c->getHeaderSearchOpts().UseBuiltinIncludes = false;
796   c->getHeaderSearchOpts().UseStandardSystemIncludes = false;
797
798   // Add opencl-c generic search path
799   {
800      ::llvm::SmallString<128> system_header_path;
801      ::llvm::sys::path::system_temp_directory(true, system_header_path);
802      ::llvm::sys::path::append(system_header_path, "openclon12");
803      c->getHeaderSearchOpts().AddPath(system_header_path.str(),
804                                       clang::frontend::Angled,
805                                       false, false);
806
807      ::llvm::sys::path::append(system_header_path, "opencl-c.h");
808      c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
809         ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_source, ARRAY_SIZE(opencl_c_source) - 1)).release());
810
811      ::llvm::sys::path::remove_filename(system_header_path);
812      ::llvm::sys::path::append(system_header_path, "opencl-c-base.h");
813      c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
814         ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_base_source, ARRAY_SIZE(opencl_c_base_source) - 1)).release());
815   }
816#else
817   c->getHeaderSearchOpts().UseBuiltinIncludes = true;
818   c->getHeaderSearchOpts().UseStandardSystemIncludes = true;
819   c->getHeaderSearchOpts().ResourceDir = CLANG_RESOURCE_DIR;
820
821   // Add opencl-c generic search path
822   c->getHeaderSearchOpts().AddPath(CLANG_RESOURCE_DIR,
823                                    clang::frontend::Angled,
824                                    false, false);
825   // Add opencl include
826   c->getPreprocessorOpts().Includes.push_back("opencl-c.h");
827#endif
828
829   if (args->num_headers) {
830      ::llvm::SmallString<128> tmp_header_path;
831      ::llvm::sys::path::system_temp_directory(true, tmp_header_path);
832      ::llvm::sys::path::append(tmp_header_path, "openclon12");
833
834      c->getHeaderSearchOpts().AddPath(tmp_header_path.str(),
835                                       clang::frontend::Quoted,
836                                       false, false);
837
838      for (size_t i = 0; i < args->num_headers; i++) {
839         auto path_copy = tmp_header_path;
840         ::llvm::sys::path::append(path_copy, ::llvm::sys::path::convert_to_slash(args->headers[i].name));
841         c->getPreprocessorOpts().addRemappedFile(path_copy.str(),
842            ::llvm::MemoryBuffer::getMemBufferCopy(args->headers[i].value).release());
843      }
844   }
845
846   c->getPreprocessorOpts().addRemappedFile(
847           args->source.name,
848           ::llvm::MemoryBuffer::getMemBufferCopy(std::string(args->source.value)).release());
849
850   // Compile the code
851   clang::EmitLLVMOnlyAction act(llvm_ctx.get());
852   if (!c->ExecuteAction(act)) {
853      clc_error(logger, "%sError executing LLVM compilation action.\n",
854                log.c_str());
855      return {};
856   }
857
858   return { act.takeModule(), std::move(llvm_ctx) };
859}
860
861static SPIRV::VersionNumber
862spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version)
863{
864   switch (version) {
865   case CLC_SPIRV_VERSION_MAX: return SPIRV::VersionNumber::MaximumVersion;
866   case CLC_SPIRV_VERSION_1_0: return SPIRV::VersionNumber::SPIRV_1_0;
867   case CLC_SPIRV_VERSION_1_1: return SPIRV::VersionNumber::SPIRV_1_1;
868   case CLC_SPIRV_VERSION_1_2: return SPIRV::VersionNumber::SPIRV_1_2;
869   case CLC_SPIRV_VERSION_1_3: return SPIRV::VersionNumber::SPIRV_1_3;
870#ifdef HAS_SPIRV_1_4
871   case CLC_SPIRV_VERSION_1_4: return SPIRV::VersionNumber::SPIRV_1_4;
872#endif
873   default:      return invalid_spirv_trans_version;
874   }
875}
876
877static int
878llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod,
879                  std::unique_ptr<LLVMContext> context,
880                  const struct clc_compile_args *args,
881                  const struct clc_logger *logger,
882                  struct clc_binary *out_spirv)
883{
884   std::string log;
885
886   SPIRV::VersionNumber version =
887      spirv_version_to_llvm_spirv_translator_version(args->spirv_version);
888   if (version == invalid_spirv_trans_version) {
889      clc_error(logger, "Invalid/unsupported SPIRV specified.\n");
890      return -1;
891   }
892
893   const char *const *extensions = NULL;
894   if (args)
895      extensions = args->allowed_spirv_extensions;
896   if (!extensions) {
897      /* The SPIR-V parser doesn't handle all extensions */
898      static const char *default_extensions[] = {
899         "SPV_EXT_shader_atomic_float_add",
900         "SPV_EXT_shader_atomic_float_min_max",
901         "SPV_KHR_float_controls",
902         NULL,
903      };
904      extensions = default_extensions;
905   }
906
907   SPIRV::TranslatorOpts::ExtensionsStatusMap ext_map;
908   for (int i = 0; extensions[i]; i++) {
909#define EXT(X) \
910      if (strcmp(#X, extensions[i]) == 0) \
911         ext_map.insert(std::make_pair(SPIRV::ExtensionID::X, true));
912#include "LLVMSPIRVLib/LLVMSPIRVExtensions.inc"
913#undef EXT
914   }
915   SPIRV::TranslatorOpts spirv_opts = SPIRV::TranslatorOpts(version, ext_map);
916
917#if LLVM_VERSION_MAJOR >= 13
918   /* This was the default in 12.0 and older, but currently we'll fail to parse without this */
919   spirv_opts.setPreserveOCLKernelArgTypeMetadataThroughString(true);
920#endif
921
922   std::ostringstream spv_stream;
923   if (!::llvm::writeSpirv(mod.get(), spirv_opts, spv_stream, log)) {
924      clc_error(logger, "%sTranslation from LLVM IR to SPIR-V failed.\n",
925                log.c_str());
926      return -1;
927   }
928
929   const std::string spv_out = spv_stream.str();
930   out_spirv->size = spv_out.size();
931   out_spirv->data = malloc(out_spirv->size);
932   memcpy(out_spirv->data, spv_out.data(), out_spirv->size);
933
934   return 0;
935}
936
937int
938clc_c_to_spir(const struct clc_compile_args *args,
939              const struct clc_logger *logger,
940              struct clc_binary *out_spir)
941{
942   auto pair = clc_compile_to_llvm_module(args, logger);
943   if (!pair.first)
944      return -1;
945
946   ::llvm::SmallVector<char, 0> buffer;
947   ::llvm::BitcodeWriter writer(buffer);
948   writer.writeModule(*pair.first);
949
950   out_spir->size = buffer.size_in_bytes();
951   out_spir->data = malloc(out_spir->size);
952   memcpy(out_spir->data, buffer.data(), out_spir->size);
953
954   return 0;
955}
956
957int
958clc_c_to_spirv(const struct clc_compile_args *args,
959               const struct clc_logger *logger,
960               struct clc_binary *out_spirv)
961{
962   auto pair = clc_compile_to_llvm_module(args, logger);
963   if (!pair.first)
964      return -1;
965   return llvm_mod_to_spirv(std::move(pair.first), std::move(pair.second), args, logger, out_spirv);
966}
967
968int
969clc_spir_to_spirv(const struct clc_binary *in_spir,
970                  const struct clc_logger *logger,
971                  struct clc_binary *out_spirv)
972{
973   LLVMInitializeAllTargets();
974   LLVMInitializeAllTargetInfos();
975   LLVMInitializeAllTargetMCs();
976   LLVMInitializeAllAsmPrinters();
977
978   std::unique_ptr<LLVMContext> llvm_ctx{ new LLVMContext };
979   ::llvm::StringRef spir_ref(static_cast<const char*>(in_spir->data), in_spir->size);
980   auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, "<spir>"), *llvm_ctx);
981   if (!mod)
982      return -1;
983
984   return llvm_mod_to_spirv(std::move(mod.get()), std::move(llvm_ctx), NULL, logger, out_spirv);
985}
986
987class SPIRVMessageConsumer {
988public:
989   SPIRVMessageConsumer(const struct clc_logger *logger): logger(logger) {}
990
991   void operator()(spv_message_level_t level, const char *src,
992                   const spv_position_t &pos, const char *msg)
993   {
994      switch(level) {
995      case SPV_MSG_FATAL:
996      case SPV_MSG_INTERNAL_ERROR:
997      case SPV_MSG_ERROR:
998         clc_error(logger, "(file=%s,line=%ld,column=%ld,index=%ld): %s\n",
999                   src, pos.line, pos.column, pos.index, msg);
1000         break;
1001
1002      case SPV_MSG_WARNING:
1003         clc_warning(logger, "(file=%s,line=%ld,column=%ld,index=%ld): %s\n",
1004                     src, pos.line, pos.column, pos.index, msg);
1005         break;
1006
1007      default:
1008         break;
1009      }
1010   }
1011
1012private:
1013   const struct clc_logger *logger;
1014};
1015
1016int
1017clc_link_spirv_binaries(const struct clc_linker_args *args,
1018                        const struct clc_logger *logger,
1019                        struct clc_binary *out_spirv)
1020{
1021   std::vector<std::vector<uint32_t>> binaries;
1022
1023   for (unsigned i = 0; i < args->num_in_objs; i++) {
1024      const uint32_t *data = static_cast<const uint32_t *>(args->in_objs[i]->data);
1025      std::vector<uint32_t> bin(data, data + (args->in_objs[i]->size / 4));
1026      binaries.push_back(bin);
1027   }
1028
1029   SPIRVMessageConsumer msgconsumer(logger);
1030   spvtools::Context context(spirv_target);
1031   context.SetMessageConsumer(msgconsumer);
1032   spvtools::LinkerOptions options;
1033   options.SetAllowPartialLinkage(args->create_library);
1034   options.SetCreateLibrary(args->create_library);
1035   std::vector<uint32_t> linkingResult;
1036   spv_result_t status = spvtools::Link(context, binaries, &linkingResult, options);
1037   if (status != SPV_SUCCESS) {
1038      return -1;
1039   }
1040
1041   out_spirv->size = linkingResult.size() * 4;
1042   out_spirv->data = static_cast<uint32_t *>(malloc(out_spirv->size));
1043   memcpy(out_spirv->data, linkingResult.data(), out_spirv->size);
1044
1045   return 0;
1046}
1047
1048int
1049clc_spirv_specialize(const struct clc_binary *in_spirv,
1050                     const struct clc_parsed_spirv *parsed_data,
1051                     const struct clc_spirv_specialization_consts *consts,
1052                     struct clc_binary *out_spirv)
1053{
1054   std::unordered_map<uint32_t, std::vector<uint32_t>> spec_const_map;
1055   for (unsigned i = 0; i < consts->num_specializations; ++i) {
1056      unsigned id = consts->specializations[i].id;
1057      auto parsed_spec_const = std::find_if(parsed_data->spec_constants,
1058         parsed_data->spec_constants + parsed_data->num_spec_constants,
1059         [id](const clc_parsed_spec_constant &c) { return c.id == id; });
1060      assert(parsed_spec_const != parsed_data->spec_constants + parsed_data->num_spec_constants);
1061
1062      std::vector<uint32_t> words;
1063      switch (parsed_spec_const->type) {
1064      case CLC_SPEC_CONSTANT_BOOL:
1065         words.push_back(consts->specializations[i].value.b);
1066         break;
1067      case CLC_SPEC_CONSTANT_INT32:
1068      case CLC_SPEC_CONSTANT_UINT32:
1069      case CLC_SPEC_CONSTANT_FLOAT:
1070         words.push_back(consts->specializations[i].value.u32);
1071         break;
1072      case CLC_SPEC_CONSTANT_INT16:
1073         words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i16);
1074         break;
1075      case CLC_SPEC_CONSTANT_INT8:
1076         words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i8);
1077         break;
1078      case CLC_SPEC_CONSTANT_UINT16:
1079         words.push_back((uint32_t)consts->specializations[i].value.u16);
1080         break;
1081      case CLC_SPEC_CONSTANT_UINT8:
1082         words.push_back((uint32_t)consts->specializations[i].value.u8);
1083         break;
1084      case CLC_SPEC_CONSTANT_DOUBLE:
1085      case CLC_SPEC_CONSTANT_INT64:
1086      case CLC_SPEC_CONSTANT_UINT64:
1087         words.resize(2);
1088         memcpy(words.data(), &consts->specializations[i].value.u64, 8);
1089         break;
1090      case CLC_SPEC_CONSTANT_UNKNOWN:
1091         assert(0);
1092         break;
1093      }
1094
1095      ASSERTED auto ret = spec_const_map.emplace(id, std::move(words));
1096      assert(ret.second);
1097   }
1098
1099   spvtools::Optimizer opt(spirv_target);
1100   opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(std::move(spec_const_map)));
1101
1102   std::vector<uint32_t> result;
1103   if (!opt.Run(static_cast<const uint32_t*>(in_spirv->data), in_spirv->size / 4, &result))
1104      return false;
1105
1106   out_spirv->size = result.size() * 4;
1107   out_spirv->data = malloc(out_spirv->size);
1108   memcpy(out_spirv->data, result.data(), out_spirv->size);
1109   return true;
1110}
1111
1112void
1113clc_dump_spirv(const struct clc_binary *spvbin, FILE *f)
1114{
1115   spvtools::SpirvTools tools(spirv_target);
1116   const uint32_t *data = static_cast<const uint32_t *>(spvbin->data);
1117   std::vector<uint32_t> bin(data, data + (spvbin->size / 4));
1118   std::string out;
1119   tools.Disassemble(bin, &out,
1120                     SPV_BINARY_TO_TEXT_OPTION_INDENT |
1121                     SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
1122   fwrite(out.c_str(), out.size(), 1, f);
1123}
1124
1125void
1126clc_free_spir_binary(struct clc_binary *spir)
1127{
1128   free(spir->data);
1129}
1130
1131void
1132clc_free_spirv_binary(struct clc_binary *spvbin)
1133{
1134   free(spvbin->data);
1135}
1136