HIP: Heterogenous-computing Interface for Portability
Loading...
Searching...
No Matches
macro_based_grid_launch.hpp
1/*
2Copyright (c) 2015 - 2021 Advanced Micro Devices, Inc. All rights reserved.
3
4Permission is hereby granted, free of charge, to any person obtaining a copy
5of this software and associated documentation files (the "Software"), to deal
6in the Software without restriction, including without limitation the rights
7to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8copies of the Software, and to permit persons to whom the Software is
9furnished to do so, subject to the following conditions:
10
11The above copyright notice and this permission notice shall be included in
12all copies or substantial portions of the Software.
13
14THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20THE SOFTWARE.
21*/
22
23#pragma once
24
25#include "concepts.hpp"
26#include "helpers.hpp"
27
28#include "hc.hpp"
29#include "hip/hip_ext.h"
30#include "hip_runtime.h"
31
32#include <functional>
33#include <iostream>
34#include <stdexcept>
35#include <type_traits>
36#include <utility>
37
38namespace hip_impl {
39namespace {
40struct New_grid_launch_tag {};
41struct Old_grid_launch_tag {};
42
43template <typename C, typename D>
44class RAII_guard {
45 D dtor_;
46
47 public:
48 RAII_guard() = default;
49
50 RAII_guard(const C& ctor, D dtor) : dtor_{std::move(dtor)} { ctor(); }
51
52 RAII_guard(const RAII_guard&) = default;
53 RAII_guard(RAII_guard&&) = default;
54
55 RAII_guard& operator=(const RAII_guard&) = default;
56 RAII_guard& operator=(RAII_guard&&) = default;
57
58 ~RAII_guard() { dtor_(); }
59};
60
61template <typename C, typename D>
62RAII_guard<C, D> make_RAII_guard(const C& ctor, D dtor) {
63 return RAII_guard<C, D>{ctor, std::move(dtor)};
64}
65
66template <FunctionalProcedure F, typename... Ts>
67using is_new_grid_launch_t = typename std::conditional<is_callable<F(Ts...)>{}, New_grid_launch_tag,
68 Old_grid_launch_tag>::type;
69} // namespace
70
71// TODO: - dispatch rank should be derived from the domain dimensions passed
72// in, and not always assumed to be 3;
73
74template <FunctionalProcedure K, typename... Ts>
75requires(Domain<K> ==
76 {Ts...}) inline void grid_launch_hip_impl_(New_grid_launch_tag, dim3 num_blocks,
77 dim3 dim_blocks, int group_mem_bytes,
78 const hc::accelerator_view& acc_v, K k) {
79 const auto d =
80 hc::extent<3>{num_blocks.z * dim_blocks.z, num_blocks.y * dim_blocks.y,
81 num_blocks.x * dim_blocks.x}
82 .tile_with_dynamic(dim_blocks.z, dim_blocks.y, dim_blocks.x, group_mem_bytes);
83
84 try {
85 hc::parallel_for_each(acc_v, d, k);
86 } catch (std::exception& ex) {
87 std::cerr << "Failed in " << __func__ << ", with exception: " << ex.what() << std::endl;
88 hip_throw(ex);
89 }
90}
91
92// TODO: these are workarounds, they should be removed.
93
94hc::accelerator_view lock_stream_hip_(hipStream_t&, void*&);
95void print_prelaunch_trace_(const char*, dim3, dim3, int, hipStream_t);
96void unlock_stream_hip_(hipStream_t, void*, const char*, hc::accelerator_view*);
97
98template <FunctionalProcedure K, typename... Ts>
99requires(Domain<K> == {Ts...}) inline void grid_launch_hip_impl_(New_grid_launch_tag,
100 dim3 num_blocks, dim3 dim_blocks,
101 int group_mem_bytes,
102 hipStream_t stream,
103 const char* kernel_name, K k) {
104 void* lck_stream = nullptr;
105 auto acc_v = lock_stream_hip_(stream, lck_stream);
106 auto stream_guard =
107 make_RAII_guard(std::bind(print_prelaunch_trace_, kernel_name, num_blocks, dim_blocks,
108 group_mem_bytes, stream),
109 std::bind(unlock_stream_hip_, stream, lck_stream, kernel_name, &acc_v));
110
111 try {
112 grid_launch_hip_impl_(New_grid_launch_tag{}, std::move(num_blocks), std::move(dim_blocks),
113 group_mem_bytes, acc_v, std::move(k));
114 } catch (std::exception& ex) {
115 std::cerr << "Failed in " << __func__ << ", with exception: " << ex.what() << std::endl;
116 hip_throw(ex);
117 }
118}
119
120template <FunctionalProcedure K, typename... Ts>
121requires(Domain<K> ==
122 {hipLaunchParm, Ts...}) inline void grid_launch_hip_impl_(Old_grid_launch_tag,
123 dim3 num_blocks, dim3 dim_blocks,
124 int group_mem_bytes,
125 hipStream_t stream, K k) {
126 grid_launch_hip_impl_(New_grid_launch_tag{}, std::move(num_blocks), std::move(dim_blocks),
127 group_mem_bytes, std::move(stream), std::move(k));
128}
129
130template <FunctionalProcedure K, typename... Ts>
131requires(Domain<K> == {hipLaunchParm, Ts...}) inline void grid_launch_hip_impl_(
132 Old_grid_launch_tag, dim3 num_blocks, dim3 dim_blocks, int group_mem_bytes, hipStream_t stream,
133 const char* kernel_name, K k) {
134 grid_launch_hip_impl_(New_grid_launch_tag{}, std::move(num_blocks), std::move(dim_blocks),
135 group_mem_bytes, std::move(stream), kernel_name, std::move(k));
136}
137
138template <FunctionalProcedure K, typename... Ts>
139requires(Domain<K> == {Ts...}) inline std::enable_if_t<
140 !std::is_function<K>::value> grid_launch_hip_(dim3 num_blocks, dim3 dim_blocks,
141 int group_mem_bytes, hipStream_t stream,
142 const char* kernel_name, K k) {
143 grid_launch_hip_impl_(is_new_grid_launch_t<K, Ts...>{}, std::move(num_blocks),
144 std::move(dim_blocks), group_mem_bytes, std::move(stream), kernel_name,
145 std::move(k));
146}
147
148template <FunctionalProcedure K, typename... Ts>
149requires(Domain<K> == {Ts...}) inline std::enable_if_t<
150 !std::is_function<K>::value> grid_launch_hip_(dim3 num_blocks, dim3 dim_blocks,
151 int group_mem_bytes, hipStream_t stream, K k) {
152 grid_launch_hip_impl_(is_new_grid_launch_t<K, Ts...>{}, std::move(num_blocks),
153 std::move(dim_blocks), group_mem_bytes, std::move(stream), std::move(k));
154}
155
156// TODO: these are temporary and purposefully noisy and disruptive.
157#define make_kernel_name_hip(k, n) \
158 HIP_kernel_functor_name_begin##_##k##_##HIP_kernel_functor_name_end##_##n
159
160#define make_kernel_functor_hip_30(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
161 p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, \
162 p22, p23, p24, p25, p26, p27) \
163 struct make_kernel_name_hip(function_name, 28) { \
164 std::decay_t<decltype(p0)> _p0_; \
165 std::decay_t<decltype(p1)> _p1_; \
166 std::decay_t<decltype(p2)> _p2_; \
167 std::decay_t<decltype(p3)> _p3_; \
168 std::decay_t<decltype(p4)> _p4_; \
169 std::decay_t<decltype(p5)> _p5_; \
170 std::decay_t<decltype(p6)> _p6_; \
171 std::decay_t<decltype(p7)> _p7_; \
172 std::decay_t<decltype(p8)> _p8_; \
173 std::decay_t<decltype(p9)> _p9_; \
174 std::decay_t<decltype(p10)> _p10_; \
175 std::decay_t<decltype(p11)> _p11_; \
176 std::decay_t<decltype(p12)> _p12_; \
177 std::decay_t<decltype(p13)> _p13_; \
178 std::decay_t<decltype(p14)> _p14_; \
179 std::decay_t<decltype(p15)> _p15_; \
180 std::decay_t<decltype(p16)> _p16_; \
181 std::decay_t<decltype(p17)> _p17_; \
182 std::decay_t<decltype(p18)> _p18_; \
183 std::decay_t<decltype(p19)> _p19_; \
184 std::decay_t<decltype(p20)> _p20_; \
185 std::decay_t<decltype(p21)> _p21_; \
186 std::decay_t<decltype(p22)> _p22_; \
187 std::decay_t<decltype(p23)> _p23_; \
188 std::decay_t<decltype(p24)> _p24_; \
189 std::decay_t<decltype(p25)> _p25_; \
190 std::decay_t<decltype(p26)> _p26_; \
191 std::decay_t<decltype(p27)> _p27_; \
192 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
193 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
194 _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, _p18_, _p19_, _p20_, _p21_, \
195 _p22_, _p23_, _p24_, _p25_, _p26_, _p27_); \
196 } \
197 }
198#define make_kernel_functor_hip_29(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
199 p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, \
200 p22, p23, p24, p25, p26) \
201 struct make_kernel_name_hip(function_name, 27) { \
202 std::decay_t<decltype(p0)> _p0_; \
203 std::decay_t<decltype(p1)> _p1_; \
204 std::decay_t<decltype(p2)> _p2_; \
205 std::decay_t<decltype(p3)> _p3_; \
206 std::decay_t<decltype(p4)> _p4_; \
207 std::decay_t<decltype(p5)> _p5_; \
208 std::decay_t<decltype(p6)> _p6_; \
209 std::decay_t<decltype(p7)> _p7_; \
210 std::decay_t<decltype(p8)> _p8_; \
211 std::decay_t<decltype(p9)> _p9_; \
212 std::decay_t<decltype(p10)> _p10_; \
213 std::decay_t<decltype(p11)> _p11_; \
214 std::decay_t<decltype(p12)> _p12_; \
215 std::decay_t<decltype(p13)> _p13_; \
216 std::decay_t<decltype(p14)> _p14_; \
217 std::decay_t<decltype(p15)> _p15_; \
218 std::decay_t<decltype(p16)> _p16_; \
219 std::decay_t<decltype(p17)> _p17_; \
220 std::decay_t<decltype(p18)> _p18_; \
221 std::decay_t<decltype(p19)> _p19_; \
222 std::decay_t<decltype(p20)> _p20_; \
223 std::decay_t<decltype(p21)> _p21_; \
224 std::decay_t<decltype(p22)> _p22_; \
225 std::decay_t<decltype(p23)> _p23_; \
226 std::decay_t<decltype(p24)> _p24_; \
227 std::decay_t<decltype(p25)> _p25_; \
228 std::decay_t<decltype(p26)> _p26_; \
229 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
230 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
231 _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, _p18_, _p19_, _p20_, _p21_, \
232 _p22_, _p23_, _p24_, _p25_, _p26_); \
233 } \
234 }
235#define make_kernel_functor_hip_28(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
236 p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, \
237 p22, p23, p24, p25) \
238 struct make_kernel_name_hip(function_name, 26) { \
239 std::decay_t<decltype(p0)> _p0_; \
240 std::decay_t<decltype(p1)> _p1_; \
241 std::decay_t<decltype(p2)> _p2_; \
242 std::decay_t<decltype(p3)> _p3_; \
243 std::decay_t<decltype(p4)> _p4_; \
244 std::decay_t<decltype(p5)> _p5_; \
245 std::decay_t<decltype(p6)> _p6_; \
246 std::decay_t<decltype(p7)> _p7_; \
247 std::decay_t<decltype(p8)> _p8_; \
248 std::decay_t<decltype(p9)> _p9_; \
249 std::decay_t<decltype(p10)> _p10_; \
250 std::decay_t<decltype(p11)> _p11_; \
251 std::decay_t<decltype(p12)> _p12_; \
252 std::decay_t<decltype(p13)> _p13_; \
253 std::decay_t<decltype(p14)> _p14_; \
254 std::decay_t<decltype(p15)> _p15_; \
255 std::decay_t<decltype(p16)> _p16_; \
256 std::decay_t<decltype(p17)> _p17_; \
257 std::decay_t<decltype(p18)> _p18_; \
258 std::decay_t<decltype(p19)> _p19_; \
259 std::decay_t<decltype(p20)> _p20_; \
260 std::decay_t<decltype(p21)> _p21_; \
261 std::decay_t<decltype(p22)> _p22_; \
262 std::decay_t<decltype(p23)> _p23_; \
263 std::decay_t<decltype(p24)> _p24_; \
264 std::decay_t<decltype(p25)> _p25_; \
265 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
266 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
267 _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, _p18_, _p19_, _p20_, _p21_, \
268 _p22_, _p23_, _p24_, _p25_); \
269 } \
270 }
271#define make_kernel_functor_hip_27(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
272 p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, \
273 p22, p23, p24) \
274 struct make_kernel_name_hip(function_name, 25) { \
275 std::decay_t<decltype(p0)> _p0_; \
276 std::decay_t<decltype(p1)> _p1_; \
277 std::decay_t<decltype(p2)> _p2_; \
278 std::decay_t<decltype(p3)> _p3_; \
279 std::decay_t<decltype(p4)> _p4_; \
280 std::decay_t<decltype(p5)> _p5_; \
281 std::decay_t<decltype(p6)> _p6_; \
282 std::decay_t<decltype(p7)> _p7_; \
283 std::decay_t<decltype(p8)> _p8_; \
284 std::decay_t<decltype(p9)> _p9_; \
285 std::decay_t<decltype(p10)> _p10_; \
286 std::decay_t<decltype(p11)> _p11_; \
287 std::decay_t<decltype(p12)> _p12_; \
288 std::decay_t<decltype(p13)> _p13_; \
289 std::decay_t<decltype(p14)> _p14_; \
290 std::decay_t<decltype(p15)> _p15_; \
291 std::decay_t<decltype(p16)> _p16_; \
292 std::decay_t<decltype(p17)> _p17_; \
293 std::decay_t<decltype(p18)> _p18_; \
294 std::decay_t<decltype(p19)> _p19_; \
295 std::decay_t<decltype(p20)> _p20_; \
296 std::decay_t<decltype(p21)> _p21_; \
297 std::decay_t<decltype(p22)> _p22_; \
298 std::decay_t<decltype(p23)> _p23_; \
299 std::decay_t<decltype(p24)> _p24_; \
300 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
301 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
302 _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, _p18_, _p19_, _p20_, _p21_, \
303 _p22_, _p23_, _p24_); \
304 } \
305 }
306#define make_kernel_functor_hip_26(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
307 p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, \
308 p22, p23) \
309 struct make_kernel_name_hip(function_name, 24) { \
310 std::decay_t<decltype(p0)> _p0_; \
311 std::decay_t<decltype(p1)> _p1_; \
312 std::decay_t<decltype(p2)> _p2_; \
313 std::decay_t<decltype(p3)> _p3_; \
314 std::decay_t<decltype(p4)> _p4_; \
315 std::decay_t<decltype(p5)> _p5_; \
316 std::decay_t<decltype(p6)> _p6_; \
317 std::decay_t<decltype(p7)> _p7_; \
318 std::decay_t<decltype(p8)> _p8_; \
319 std::decay_t<decltype(p9)> _p9_; \
320 std::decay_t<decltype(p10)> _p10_; \
321 std::decay_t<decltype(p11)> _p11_; \
322 std::decay_t<decltype(p12)> _p12_; \
323 std::decay_t<decltype(p13)> _p13_; \
324 std::decay_t<decltype(p14)> _p14_; \
325 std::decay_t<decltype(p15)> _p15_; \
326 std::decay_t<decltype(p16)> _p16_; \
327 std::decay_t<decltype(p17)> _p17_; \
328 std::decay_t<decltype(p18)> _p18_; \
329 std::decay_t<decltype(p19)> _p19_; \
330 std::decay_t<decltype(p20)> _p20_; \
331 std::decay_t<decltype(p21)> _p21_; \
332 std::decay_t<decltype(p22)> _p22_; \
333 std::decay_t<decltype(p23)> _p23_; \
334 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
335 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
336 _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, _p18_, _p19_, _p20_, _p21_, \
337 _p22_, _p23_); \
338 } \
339 }
340#define make_kernel_functor_hip_25(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
341 p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, \
342 p22) \
343 struct make_kernel_name_hip(function_name, 23) { \
344 std::decay_t<decltype(p0)> _p0_; \
345 std::decay_t<decltype(p1)> _p1_; \
346 std::decay_t<decltype(p2)> _p2_; \
347 std::decay_t<decltype(p3)> _p3_; \
348 std::decay_t<decltype(p4)> _p4_; \
349 std::decay_t<decltype(p5)> _p5_; \
350 std::decay_t<decltype(p6)> _p6_; \
351 std::decay_t<decltype(p7)> _p7_; \
352 std::decay_t<decltype(p8)> _p8_; \
353 std::decay_t<decltype(p9)> _p9_; \
354 std::decay_t<decltype(p10)> _p10_; \
355 std::decay_t<decltype(p11)> _p11_; \
356 std::decay_t<decltype(p12)> _p12_; \
357 std::decay_t<decltype(p13)> _p13_; \
358 std::decay_t<decltype(p14)> _p14_; \
359 std::decay_t<decltype(p15)> _p15_; \
360 std::decay_t<decltype(p16)> _p16_; \
361 std::decay_t<decltype(p17)> _p17_; \
362 std::decay_t<decltype(p18)> _p18_; \
363 std::decay_t<decltype(p19)> _p19_; \
364 std::decay_t<decltype(p20)> _p20_; \
365 std::decay_t<decltype(p21)> _p21_; \
366 std::decay_t<decltype(p22)> _p22_; \
367 __attribute__((used, flatten)) void operator()(const hc::tiled_index<3>&) const [[hc]] { \
368 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
369 _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, _p18_, _p19_, _p20_, _p21_, \
370 _p22_); \
371 } \
372 }
373#define make_kernel_functor_hip_24(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
374 p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21) \
375 struct make_kernel_name_hip(function_name, 22) { \
376 std::decay_t<decltype(p0)> _p0_; \
377 std::decay_t<decltype(p1)> _p1_; \
378 std::decay_t<decltype(p2)> _p2_; \
379 std::decay_t<decltype(p3)> _p3_; \
380 std::decay_t<decltype(p4)> _p4_; \
381 std::decay_t<decltype(p5)> _p5_; \
382 std::decay_t<decltype(p6)> _p6_; \
383 std::decay_t<decltype(p7)> _p7_; \
384 std::decay_t<decltype(p8)> _p8_; \
385 std::decay_t<decltype(p9)> _p9_; \
386 std::decay_t<decltype(p10)> _p10_; \
387 std::decay_t<decltype(p11)> _p11_; \
388 std::decay_t<decltype(p12)> _p12_; \
389 std::decay_t<decltype(p13)> _p13_; \
390 std::decay_t<decltype(p14)> _p14_; \
391 std::decay_t<decltype(p15)> _p15_; \
392 std::decay_t<decltype(p16)> _p16_; \
393 std::decay_t<decltype(p17)> _p17_; \
394 std::decay_t<decltype(p18)> _p18_; \
395 std::decay_t<decltype(p19)> _p19_; \
396 std::decay_t<decltype(p20)> _p20_; \
397 std::decay_t<decltype(p21)> _p21_; \
398 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
399 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
400 _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, _p18_, _p19_, _p20_, _p21_); \
401 } \
402 }
403#define make_kernel_functor_hip_23(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
404 p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20) \
405 struct make_kernel_name_hip(function_name, 21) { \
406 std::decay_t<decltype(p0)> _p0_; \
407 std::decay_t<decltype(p1)> _p1_; \
408 std::decay_t<decltype(p2)> _p2_; \
409 std::decay_t<decltype(p3)> _p3_; \
410 std::decay_t<decltype(p4)> _p4_; \
411 std::decay_t<decltype(p5)> _p5_; \
412 std::decay_t<decltype(p6)> _p6_; \
413 std::decay_t<decltype(p7)> _p7_; \
414 std::decay_t<decltype(p8)> _p8_; \
415 std::decay_t<decltype(p9)> _p9_; \
416 std::decay_t<decltype(p10)> _p10_; \
417 std::decay_t<decltype(p11)> _p11_; \
418 std::decay_t<decltype(p12)> _p12_; \
419 std::decay_t<decltype(p13)> _p13_; \
420 std::decay_t<decltype(p14)> _p14_; \
421 std::decay_t<decltype(p15)> _p15_; \
422 std::decay_t<decltype(p16)> _p16_; \
423 std::decay_t<decltype(p17)> _p17_; \
424 std::decay_t<decltype(p18)> _p18_; \
425 std::decay_t<decltype(p19)> _p19_; \
426 std::decay_t<decltype(p20)> _p20_; \
427 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
428 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
429 _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, _p18_, _p19_, _p20_); \
430 } \
431 }
432#define make_kernel_functor_hip_22(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
433 p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19) \
434 struct make_kernel_name_hip(function_name, 20) { \
435 std::decay_t<decltype(p0)> _p0_; \
436 std::decay_t<decltype(p1)> _p1_; \
437 std::decay_t<decltype(p2)> _p2_; \
438 std::decay_t<decltype(p3)> _p3_; \
439 std::decay_t<decltype(p4)> _p4_; \
440 std::decay_t<decltype(p5)> _p5_; \
441 std::decay_t<decltype(p6)> _p6_; \
442 std::decay_t<decltype(p7)> _p7_; \
443 std::decay_t<decltype(p8)> _p8_; \
444 std::decay_t<decltype(p9)> _p9_; \
445 std::decay_t<decltype(p10)> _p10_; \
446 std::decay_t<decltype(p11)> _p11_; \
447 std::decay_t<decltype(p12)> _p12_; \
448 std::decay_t<decltype(p13)> _p13_; \
449 std::decay_t<decltype(p14)> _p14_; \
450 std::decay_t<decltype(p15)> _p15_; \
451 std::decay_t<decltype(p16)> _p16_; \
452 std::decay_t<decltype(p17)> _p17_; \
453 std::decay_t<decltype(p18)> _p18_; \
454 std::decay_t<decltype(p19)> _p19_; \
455 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
456 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
457 _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, _p18_, _p19_); \
458 } \
459 }
460#define make_kernel_functor_hip_21(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
461 p9, p10, p11, p12, p13, p14, p15, p16, p17, p18) \
462 struct make_kernel_name_hip(function_name, 19) { \
463 std::decay_t<decltype(p0)> _p0_; \
464 std::decay_t<decltype(p1)> _p1_; \
465 std::decay_t<decltype(p2)> _p2_; \
466 std::decay_t<decltype(p3)> _p3_; \
467 std::decay_t<decltype(p4)> _p4_; \
468 std::decay_t<decltype(p5)> _p5_; \
469 std::decay_t<decltype(p6)> _p6_; \
470 std::decay_t<decltype(p7)> _p7_; \
471 std::decay_t<decltype(p8)> _p8_; \
472 std::decay_t<decltype(p9)> _p9_; \
473 std::decay_t<decltype(p10)> _p10_; \
474 std::decay_t<decltype(p11)> _p11_; \
475 std::decay_t<decltype(p12)> _p12_; \
476 std::decay_t<decltype(p13)> _p13_; \
477 std::decay_t<decltype(p14)> _p14_; \
478 std::decay_t<decltype(p15)> _p15_; \
479 std::decay_t<decltype(p16)> _p16_; \
480 std::decay_t<decltype(p17)> _p17_; \
481 std::decay_t<decltype(p18)> _p18_; \
482 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
483 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
484 _p12_, _p13_, _p14_, _p15_, _p16_, _p17_, _p18_); \
485 } \
486 }
487#define make_kernel_functor_hip_20(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
488 p9, p10, p11, p12, p13, p14, p15, p16, p17) \
489 struct make_kernel_name_hip(function_name, 18) { \
490 std::decay_t<decltype(p0)> _p0_; \
491 std::decay_t<decltype(p1)> _p1_; \
492 std::decay_t<decltype(p2)> _p2_; \
493 std::decay_t<decltype(p3)> _p3_; \
494 std::decay_t<decltype(p4)> _p4_; \
495 std::decay_t<decltype(p5)> _p5_; \
496 std::decay_t<decltype(p6)> _p6_; \
497 std::decay_t<decltype(p7)> _p7_; \
498 std::decay_t<decltype(p8)> _p8_; \
499 std::decay_t<decltype(p9)> _p9_; \
500 std::decay_t<decltype(p10)> _p10_; \
501 std::decay_t<decltype(p11)> _p11_; \
502 std::decay_t<decltype(p12)> _p12_; \
503 std::decay_t<decltype(p13)> _p13_; \
504 std::decay_t<decltype(p14)> _p14_; \
505 std::decay_t<decltype(p15)> _p15_; \
506 std::decay_t<decltype(p16)> _p16_; \
507 std::decay_t<decltype(p17)> _p17_; \
508 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
509 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
510 _p12_, _p13_, _p14_, _p15_, _p16_, _p17_); \
511 } \
512 }
513#define make_kernel_functor_hip_19(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
514 p9, p10, p11, p12, p13, p14, p15, p16) \
515 struct make_kernel_name_hip(function_name, 17) { \
516 std::decay_t<decltype(p0)> _p0_; \
517 std::decay_t<decltype(p1)> _p1_; \
518 std::decay_t<decltype(p2)> _p2_; \
519 std::decay_t<decltype(p3)> _p3_; \
520 std::decay_t<decltype(p4)> _p4_; \
521 std::decay_t<decltype(p5)> _p5_; \
522 std::decay_t<decltype(p6)> _p6_; \
523 std::decay_t<decltype(p7)> _p7_; \
524 std::decay_t<decltype(p8)> _p8_; \
525 std::decay_t<decltype(p9)> _p9_; \
526 std::decay_t<decltype(p10)> _p10_; \
527 std::decay_t<decltype(p11)> _p11_; \
528 std::decay_t<decltype(p12)> _p12_; \
529 std::decay_t<decltype(p13)> _p13_; \
530 std::decay_t<decltype(p14)> _p14_; \
531 std::decay_t<decltype(p15)> _p15_; \
532 std::decay_t<decltype(p16)> _p16_; \
533 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
534 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
535 _p12_, _p13_, _p14_, _p15_, _p16_); \
536 } \
537 }
538#define make_kernel_functor_hip_18(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
539 p9, p10, p11, p12, p13, p14, p15) \
540 struct make_kernel_name_hip(function_name, 16) { \
541 std::decay_t<decltype(p0)> _p0_; \
542 std::decay_t<decltype(p1)> _p1_; \
543 std::decay_t<decltype(p2)> _p2_; \
544 std::decay_t<decltype(p3)> _p3_; \
545 std::decay_t<decltype(p4)> _p4_; \
546 std::decay_t<decltype(p5)> _p5_; \
547 std::decay_t<decltype(p6)> _p6_; \
548 std::decay_t<decltype(p7)> _p7_; \
549 std::decay_t<decltype(p8)> _p8_; \
550 std::decay_t<decltype(p9)> _p9_; \
551 std::decay_t<decltype(p10)> _p10_; \
552 std::decay_t<decltype(p11)> _p11_; \
553 std::decay_t<decltype(p12)> _p12_; \
554 std::decay_t<decltype(p13)> _p13_; \
555 std::decay_t<decltype(p14)> _p14_; \
556 std::decay_t<decltype(p15)> _p15_; \
557 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
558 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
559 _p12_, _p13_, _p14_, _p15_); \
560 } \
561 }
562#define make_kernel_functor_hip_17(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
563 p9, p10, p11, p12, p13, p14) \
564 struct make_kernel_name_hip(function_name, 15) { \
565 std::decay_t<decltype(p0)> _p0_; \
566 std::decay_t<decltype(p1)> _p1_; \
567 std::decay_t<decltype(p2)> _p2_; \
568 std::decay_t<decltype(p3)> _p3_; \
569 std::decay_t<decltype(p4)> _p4_; \
570 std::decay_t<decltype(p5)> _p5_; \
571 std::decay_t<decltype(p6)> _p6_; \
572 std::decay_t<decltype(p7)> _p7_; \
573 std::decay_t<decltype(p8)> _p8_; \
574 std::decay_t<decltype(p9)> _p9_; \
575 std::decay_t<decltype(p10)> _p10_; \
576 std::decay_t<decltype(p11)> _p11_; \
577 std::decay_t<decltype(p12)> _p12_; \
578 std::decay_t<decltype(p13)> _p13_; \
579 std::decay_t<decltype(p14)> _p14_; \
580 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
581 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
582 _p12_, _p13_, _p14_); \
583 } \
584 }
585#define make_kernel_functor_hip_16(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
586 p9, p10, p11, p12, p13) \
587 struct make_kernel_name_hip(function_name, 14) { \
588 std::decay_t<decltype(p0)> _p0_; \
589 std::decay_t<decltype(p1)> _p1_; \
590 std::decay_t<decltype(p2)> _p2_; \
591 std::decay_t<decltype(p3)> _p3_; \
592 std::decay_t<decltype(p4)> _p4_; \
593 std::decay_t<decltype(p5)> _p5_; \
594 std::decay_t<decltype(p6)> _p6_; \
595 std::decay_t<decltype(p7)> _p7_; \
596 std::decay_t<decltype(p8)> _p8_; \
597 std::decay_t<decltype(p9)> _p9_; \
598 std::decay_t<decltype(p10)> _p10_; \
599 std::decay_t<decltype(p11)> _p11_; \
600 std::decay_t<decltype(p12)> _p12_; \
601 std::decay_t<decltype(p13)> _p13_; \
602 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
603 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
604 _p12_, _p13_); \
605 } \
606 }
607#define make_kernel_functor_hip_15(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
608 p9, p10, p11, p12) \
609 struct make_kernel_name_hip(function_name, 13) { \
610 std::decay_t<decltype(p0)> _p0_; \
611 std::decay_t<decltype(p1)> _p1_; \
612 std::decay_t<decltype(p2)> _p2_; \
613 std::decay_t<decltype(p3)> _p3_; \
614 std::decay_t<decltype(p4)> _p4_; \
615 std::decay_t<decltype(p5)> _p5_; \
616 std::decay_t<decltype(p6)> _p6_; \
617 std::decay_t<decltype(p7)> _p7_; \
618 std::decay_t<decltype(p8)> _p8_; \
619 std::decay_t<decltype(p9)> _p9_; \
620 std::decay_t<decltype(p10)> _p10_; \
621 std::decay_t<decltype(p11)> _p11_; \
622 std::decay_t<decltype(p12)> _p12_; \
623 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
624 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_, \
625 _p12_); \
626 } \
627 }
628#define make_kernel_functor_hip_14(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
629 p9, p10, p11) \
630 struct make_kernel_name_hip(function_name, 12) { \
631 std::decay_t<decltype(p0)> _p0_; \
632 std::decay_t<decltype(p1)> _p1_; \
633 std::decay_t<decltype(p2)> _p2_; \
634 std::decay_t<decltype(p3)> _p3_; \
635 std::decay_t<decltype(p4)> _p4_; \
636 std::decay_t<decltype(p5)> _p5_; \
637 std::decay_t<decltype(p6)> _p6_; \
638 std::decay_t<decltype(p7)> _p7_; \
639 std::decay_t<decltype(p8)> _p8_; \
640 std::decay_t<decltype(p9)> _p9_; \
641 std::decay_t<decltype(p10)> _p10_; \
642 std::decay_t<decltype(p11)> _p11_; \
643 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
644 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_, _p11_); \
645 } \
646 }
647#define make_kernel_functor_hip_13(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
648 p9, p10) \
649 struct make_kernel_name_hip(function_name, 11) { \
650 std::decay_t<decltype(p0)> _p0_; \
651 std::decay_t<decltype(p1)> _p1_; \
652 std::decay_t<decltype(p2)> _p2_; \
653 std::decay_t<decltype(p3)> _p3_; \
654 std::decay_t<decltype(p4)> _p4_; \
655 std::decay_t<decltype(p5)> _p5_; \
656 std::decay_t<decltype(p6)> _p6_; \
657 std::decay_t<decltype(p7)> _p7_; \
658 std::decay_t<decltype(p8)> _p8_; \
659 std::decay_t<decltype(p9)> _p9_; \
660 std::decay_t<decltype(p10)> _p10_; \
661 void operator()(const hc::tiled_index<3>&) const [[hc]] { \
662 kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_, _p10_); \
663 } \
664 }
665#define make_kernel_functor_hip_12(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8, \
666 p9) \
667 struct make_kernel_name_hip(function_name, 10) { \
668 std::decay_t<decltype(p0)> _p0_; \
669 std::decay_t<decltype(p1)> _p1_; \
670 std::decay_t<decltype(p2)> _p2_; \
671 std::decay_t<decltype(p3)> _p3_; \
672 std::decay_t<decltype(p4)> _p4_; \
673 std::decay_t<decltype(p5)> _p5_; \
674 std::decay_t<decltype(p6)> _p6_; \
675 std::decay_t<decltype(p7)> _p7_; \
676 std::decay_t<decltype(p8)> _p8_; \
677 std::decay_t<decltype(p9)> _p9_; \
678 void operator()(const hc::tiled_index<3>&) const \
679 [[hc]] { kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_, _p9_); } \
680 }
681#define make_kernel_functor_hip_11(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7, p8) \
682 struct make_kernel_name_hip(function_name, 9) { \
683 std::decay_t<decltype(p0)> _p0_; \
684 std::decay_t<decltype(p1)> _p1_; \
685 std::decay_t<decltype(p2)> _p2_; \
686 std::decay_t<decltype(p3)> _p3_; \
687 std::decay_t<decltype(p4)> _p4_; \
688 std::decay_t<decltype(p5)> _p5_; \
689 std::decay_t<decltype(p6)> _p6_; \
690 std::decay_t<decltype(p7)> _p7_; \
691 std::decay_t<decltype(p8)> _p8_; \
692 void operator()(const hc::tiled_index<3>&) const \
693 [[hc]] { kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_, _p8_); } \
694 }
695#define make_kernel_functor_hip_10(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6, p7) \
696 struct make_kernel_name_hip(function_name, 8) { \
697 std::decay_t<decltype(p0)> _p0_; \
698 std::decay_t<decltype(p1)> _p1_; \
699 std::decay_t<decltype(p2)> _p2_; \
700 std::decay_t<decltype(p3)> _p3_; \
701 std::decay_t<decltype(p4)> _p4_; \
702 std::decay_t<decltype(p5)> _p5_; \
703 std::decay_t<decltype(p6)> _p6_; \
704 std::decay_t<decltype(p7)> _p7_; \
705 void operator()(const hc::tiled_index<3>&) const \
706 [[hc]] { kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_, _p7_); } \
707 }
708#define make_kernel_functor_hip_9(function_name, kernel_name, p0, p1, p2, p3, p4, p5, p6) \
709 struct make_kernel_name_hip(function_name, 7) { \
710 std::decay_t<decltype(p0)> _p0_; \
711 std::decay_t<decltype(p1)> _p1_; \
712 std::decay_t<decltype(p2)> _p2_; \
713 std::decay_t<decltype(p3)> _p3_; \
714 std::decay_t<decltype(p4)> _p4_; \
715 std::decay_t<decltype(p5)> _p5_; \
716 std::decay_t<decltype(p6)> _p6_; \
717 void operator()(const hc::tiled_index<3>&) const \
718 [[hc]] { kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_, _p6_); } \
719 }
720#define make_kernel_functor_hip_8(function_name, kernel_name, p0, p1, p2, p3, p4, p5) \
721 struct make_kernel_name_hip(function_name, 6) { \
722 std::decay_t<decltype(p0)> _p0_; \
723 std::decay_t<decltype(p1)> _p1_; \
724 std::decay_t<decltype(p2)> _p2_; \
725 std::decay_t<decltype(p3)> _p3_; \
726 std::decay_t<decltype(p4)> _p4_; \
727 std::decay_t<decltype(p5)> _p5_; \
728 void operator()(const hc::tiled_index<3>&) const \
729 [[hc]] { kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_, _p5_); } \
730 }
731#define make_kernel_functor_hip_7(function_name, kernel_name, p0, p1, p2, p3, p4) \
732 struct make_kernel_name_hip(function_name, 5) { \
733 std::decay_t<decltype(p0)> _p0_; \
734 std::decay_t<decltype(p1)> _p1_; \
735 std::decay_t<decltype(p2)> _p2_; \
736 std::decay_t<decltype(p3)> _p3_; \
737 std::decay_t<decltype(p4)> _p4_; \
738 void operator()(const hc::tiled_index<3>&) const \
739 [[hc]] { kernel_name(_p0_, _p1_, _p2_, _p3_, _p4_); } \
740 }
741#define make_kernel_functor_hip_6(function_name, kernel_name, p0, p1, p2, p3) \
742 struct make_kernel_name_hip(function_name, 4) { \
743 std::decay_t<decltype(p0)> _p0_; \
744 std::decay_t<decltype(p1)> _p1_; \
745 std::decay_t<decltype(p2)> _p2_; \
746 std::decay_t<decltype(p3)> _p3_; \
747 void operator()(const hc::tiled_index<3>&) const \
748 [[hc]] { kernel_name(_p0_, _p1_, _p2_, _p3_); } \
749 }
750#define make_kernel_functor_hip_5(function_name, kernel_name, p0, p1, p2) \
751 struct make_kernel_name_hip(function_name, 3) { \
752 std::decay_t<decltype(p0)> _p0_; \
753 std::decay_t<decltype(p1)> _p1_; \
754 std::decay_t<decltype(p2)> _p2_; \
755 void operator()(const hc::tiled_index<3>&) const [[hc]] { kernel_name(_p0_, _p1_, _p2_); } \
756 }
757#define make_kernel_functor_hip_4(function_name, kernel_name, p0, p1) \
758 struct make_kernel_name_hip(function_name, 2) { \
759 std::decay_t<decltype(p0)> _p0_; \
760 std::decay_t<decltype(p1)> _p1_; \
761 void operator()(const hc::tiled_index<3>&) const [[hc]] { kernel_name(_p0_, _p1_); } \
762 }
763#define fofo(f, n) kernel_prefix_hip##f##kernel_suffix_hip##n
764#define make_kernel_functor_hip_3(function_name, kernel_name, p0) \
765 struct make_kernel_name_hip(function_name, 1) { \
766 std::decay_t<decltype(p0)> _p0_; \
767 void operator()(const hc::tiled_index<3>&) const [[hc]] { kernel_name(_p0_); } \
768 }
769#define make_kernel_functor_hip_2(function_name, kernel_name) \
770 struct make_kernel_name_hip(function_name, 0) { \
771 void operator()(const hc::tiled_index<3>&)[[hc]] { return kernel_name(hipLaunchParm{}); } \
772 }
773#define make_kernel_functor_hip_1(...)
774#define make_kernel_functor_hip_0(...)
775#define make_kernel_functor_hip_(...) overload_macro_hip_(make_kernel_functor_hip_, __VA_ARGS__)
776
777
778#define hipLaunchNamedKernelGGL(function_name, kernel_name, num_blocks, dim_blocks, \
779 group_mem_bytes, stream, ...) \
780 do { \
781 make_kernel_functor_hip_(function_name, kernel_name, __VA_ARGS__) \
782 hip_kernel_functor_impl_{__VA_ARGS__}; \
783 hip_impl::grid_launch_hip_(num_blocks, dim_blocks, group_mem_bytes, stream, #kernel_name, \
784 hip_kernel_functor_impl_); \
785 } while (0)
786
787#define hipLaunchKernelGGL(kernel_name, num_blocks, dim_blocks, group_mem_bytes, stream, ...) \
788 do { \
789 hipLaunchNamedKernelGGL(unnamed, kernel_name, num_blocks, dim_blocks, group_mem_bytes, \
790 stream, ##__VA_ARGS__); \
791 } while (0)
792
793#define hipLaunchKernel(kernel_name, num_blocks, dim_blocks, group_mem_bytes, stream, ...) \
794 do { \
795 hipLaunchKernelGGL(kernel_name, num_blocks, dim_blocks, group_mem_bytes, stream, \
796 hipLaunchParm{}, ##__VA_ARGS__); \
797 } while (0)
798} // namespace hip_impl