OpenJPH
Open-source implementation of JPEG2000 Part-15
Loading...
Searching...
No Matches
ojph_transform_avx512.cpp
Go to the documentation of this file.
1//***************************************************************************/
2// This software is released under the 2-Clause BSD license, included
3// below.
4//
5// Copyright (c) 2019-2024, Aous Naman
6// Copyright (c) 2019-2024, Kakadu Software Pty Ltd, Australia
7// Copyright (c) 2019-2024, The University of New South Wales, Australia
8//
9// Redistribution and use in source and binary forms, with or without
10// modification, are permitted provided that the following conditions are
11// met:
12//
13// 1. Redistributions of source code must retain the above copyright
14// notice, this list of conditions and the following disclaimer.
15//
16// 2. Redistributions in binary form must reproduce the above copyright
17// notice, this list of conditions and the following disclaimer in the
18// documentation and/or other materials provided with the distribution.
19//
20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
21// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22// TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
23// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
26// TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31//***************************************************************************/
32// This file is part of the OpenJPH software implementation.
33// File: ojph_transform_avx512.cpp
34// Author: Aous Naman
35// Date: 13 April 2024
36//***************************************************************************/
37
38#include "ojph_arch.h"
39#if defined(OJPH_ARCH_X86_64)
40
41#include <cstdio>
42
43#include "ojph_defs.h"
44#include "ojph_mem.h"
45#include "ojph_params.h"
47
48#include "ojph_transform.h"
50
51#include <immintrin.h>
52
53namespace ojph {
54 namespace local {
55
57 // We split multiples of 32 followed by multiples of 16, because
58 // we assume byte_alignment == 64
59 static
60 void avx512_deinterleave32(float* dpl, float* dph, float* sp, int width)
61 {
62 __m512i idx1 = _mm512_set_epi32(
63 0x1E, 0x1C, 0x1A, 0x18, 0x16, 0x14, 0x12, 0x10,
64 0x0E, 0x0C, 0x0A, 0x08, 0x06, 0x04, 0x02, 0x00
65 );
66 __m512i idx2 = _mm512_set_epi32(
67 0x1F, 0x1D, 0x1B, 0x19, 0x17, 0x15, 0x13, 0x11,
68 0x0F, 0x0D, 0x0B, 0x09, 0x07, 0x05, 0x03, 0x01
69 );
70 for (; width > 16; width -= 32, sp += 32, dpl += 16, dph += 16)
71 {
72 __m512 a = _mm512_load_ps(sp);
73 __m512 b = _mm512_load_ps(sp + 16);
74 __m512 c = _mm512_permutex2var_ps(a, idx1, b);
75 __m512 d = _mm512_permutex2var_ps(a, idx2, b);
76 _mm512_store_ps(dpl, c);
77 _mm512_store_ps(dph, d);
78 }
79 for (; width > 0; width -= 16, sp += 16, dpl += 8, dph += 8)
80 {
81 __m256 a = _mm256_load_ps(sp);
82 __m256 b = _mm256_load_ps(sp + 8);
83 __m256 c = _mm256_permute2f128_ps(a, b, (2 << 4) | (0));
84 __m256 d = _mm256_permute2f128_ps(a, b, (3 << 4) | (1));
85 __m256 e = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(2, 0, 2, 0));
86 __m256 f = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(3, 1, 3, 1));
87 _mm256_store_ps(dpl, e);
88 _mm256_store_ps(dph, f);
89 }
90 }
91
93 // We split multiples of 32 followed by multiples of 16, because
94 // we assume byte_alignment == 64
95 static
96 void avx512_interleave32(float* dp, float* spl, float* sph, int width)
97 {
98 __m512i idx1 = _mm512_set_epi32(
99 0x17, 0x7, 0x16, 0x6, 0x15, 0x5, 0x14, 0x4,
100 0x13, 0x3, 0x12, 0x2, 0x11, 0x1, 0x10, 0x0
101 );
102 __m512i idx2 = _mm512_set_epi32(
103 0x1F, 0xF, 0x1E, 0xE, 0x1D, 0xD, 0x1C, 0xC,
104 0x1B, 0xB, 0x1A, 0xA, 0x19, 0x9, 0x18, 0x8
105 );
106 for (; width > 16; width -= 32, dp += 32, spl += 16, sph += 16)
107 {
108 __m512 a = _mm512_load_ps(spl);
109 __m512 b = _mm512_load_ps(sph);
110 __m512 c = _mm512_permutex2var_ps(a, idx1, b);
111 __m512 d = _mm512_permutex2var_ps(a, idx2, b);
112 _mm512_store_ps(dp, c);
113 _mm512_store_ps(dp + 16, d);
114 }
115 for (; width > 0; width -= 16, dp += 16, spl += 8, sph += 8)
116 {
117 __m256 a = _mm256_load_ps(spl);
118 __m256 b = _mm256_load_ps(sph);
119 __m256 c = _mm256_unpacklo_ps(a, b);
120 __m256 d = _mm256_unpackhi_ps(a, b);
121 __m256 e = _mm256_permute2f128_ps(c, d, (2 << 4) | (0));
122 __m256 f = _mm256_permute2f128_ps(c, d, (3 << 4) | (1));
123 _mm256_store_ps(dp, e);
124 _mm256_store_ps(dp + 8, f);
125 }
126 }
127
129 // We split multiples of 32 followed by multiples of 16, because
130 // we assume byte_alignment == 64
131 static void avx512_deinterleave64(double* dpl, double* dph, double* sp,
132 int width)
133 {
134 __m512i idx1 = _mm512_set_epi64(
135 0x0E, 0x0C, 0x0A, 0x08, 0x06, 0x04, 0x02, 0x00
136 );
137 __m512i idx2 = _mm512_set_epi64(
138 0x0F, 0x0D, 0x0B, 0x09, 0x07, 0x05, 0x03, 0x01
139 );
140 for (; width > 8; width -= 16, sp += 16, dpl += 8, dph += 8)
141 {
142 __m512d a = _mm512_load_pd(sp);
143 __m512d b = _mm512_load_pd(sp + 16);
144 __m512d c = _mm512_permutex2var_pd(a, idx1, b);
145 __m512d d = _mm512_permutex2var_pd(a, idx2, b);
146 _mm512_store_pd(dpl, c);
147 _mm512_store_pd(dph, d);
148 }
149 for (; width > 0; width -= 8, sp += 8, dpl += 4, dph += 4)
150 {
151 __m256d a = _mm256_load_pd(sp);
152 __m256d b = _mm256_load_pd(sp + 4);
153 __m256d c = _mm256_permute2f128_pd(a, b, (2 << 4) | (0));
154 __m256d d = _mm256_permute2f128_pd(a, b, (3 << 4) | (1));
155 __m256d e = _mm256_shuffle_pd(c, d, 0x0);
156 __m256d f = _mm256_shuffle_pd(c, d, 0xF);
157 _mm256_store_pd(dpl, e);
158 _mm256_store_pd(dph, f);
159 }
160 }
161
163 // We split multiples of 32 followed by multiples of 16, because
164 // we assume byte_alignment == 64
165 static void avx512_interleave64(double* dp, double* spl, double* sph,
166 int width)
167 {
168 __m512i idx1 = _mm512_set_epi64(
169 0xB, 0x3, 0xA, 0x2, 0x9, 0x1, 0x8, 0x0
170 );
171 __m512i idx2 = _mm512_set_epi64(
172 0xF, 0x7, 0xE, 0x6, 0xD, 0x5, 0xC, 0x4
173 );
174 for (; width > 8; width -= 16, dp += 16, spl += 8, sph += 8)
175 {
176 __m512d a = _mm512_load_pd(spl);
177 __m512d b = _mm512_load_pd(sph);
178 __m512d c = _mm512_permutex2var_pd(a, idx1, b);
179 __m512d d = _mm512_permutex2var_pd(a, idx2, b);
180 _mm512_store_pd(dp, c);
181 _mm512_store_pd(dp + 16, d);
182 }
183 for (; width > 0; width -= 8, dp += 8, spl += 4, sph += 4)
184 {
185 __m256d a = _mm256_load_pd(spl);
186 __m256d b = _mm256_load_pd(sph);
187 __m256d c = _mm256_unpacklo_pd(a, b);
188 __m256d d = _mm256_unpackhi_pd(a, b);
189 __m256d e = _mm256_permute2f128_pd(c, d, (2 << 4) | (0));
190 __m256d f = _mm256_permute2f128_pd(c, d, (3 << 4) | (1));
191 _mm256_store_pd(dp, e);
192 _mm256_store_pd(dp + 4, f);
193 }
194 }
195
197 static inline void avx512_multiply_const(float* p, float f, int width)
198 {
199 __m512 factor = _mm512_set1_ps(f);
200 for (; width > 0; width -= 16, p += 16)
201 {
202 __m512 s = _mm512_load_ps(p);
203 _mm512_store_ps(p, _mm512_mul_ps(factor, s));
204 }
205 }
206
208 void avx512_irv_vert_step(const lifting_step* s, const line_buf* sig,
209 const line_buf* other, const line_buf* aug,
210 ui32 repeat, bool synthesis)
211 {
212 float a = s->irv.Aatk;
213 if (synthesis)
214 a = -a;
215
216 __m512 factor = _mm512_set1_ps(a);
217
218 float* dst = aug->f32;
219 const float* src1 = sig->f32, * src2 = other->f32;
220 int i = (int)repeat;
221 for ( ; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
222 {
223 __m512 s1 = _mm512_load_ps(src1);
224 __m512 s2 = _mm512_load_ps(src2);
225 __m512 d = _mm512_load_ps(dst);
226 d = _mm512_add_ps(d, _mm512_mul_ps(factor, _mm512_add_ps(s1, s2)));
227 _mm512_store_ps(dst, d);
228 }
229 }
230
232 void avx512_irv_vert_times_K(float K, const line_buf* aug, ui32 repeat)
233 {
234 avx512_multiply_const(aug->f32, K, (int)repeat);
235 }
236
238 void avx512_irv_horz_ana(const param_atk* atk, const line_buf* ldst,
239 const line_buf* hdst, const line_buf* src,
240 ui32 width, bool even)
241 {
242 if (width > 1)
243 {
244 // split src into ldst and hdst
245 {
246 float* dpl = even ? ldst->f32 : hdst->f32;
247 float* dph = even ? hdst->f32 : ldst->f32;
248 float* sp = src->f32;
249 int w = (int)width;
250 avx512_deinterleave32(dpl, dph, sp, w);
251 }
252
253 // the actual horizontal transform
254 float* hp = hdst->f32, * lp = ldst->f32;
255 ui32 l_width = (width + (even ? 1 : 0)) >> 1; // low pass
256 ui32 h_width = (width + (even ? 0 : 1)) >> 1; // high pass
257 ui32 num_steps = atk->get_num_steps();
258 for (ui32 j = num_steps; j > 0; --j)
259 {
260 const lifting_step* s = atk->get_step(j - 1);
261 const float a = s->irv.Aatk;
262
263 // extension
264 lp[-1] = lp[0];
265 lp[l_width] = lp[l_width - 1];
266 // lifting step
267 const float* sp = lp;
268 float* dp = hp;
269 int i = (int)h_width;
270 __m512 f = _mm512_set1_ps(a);
271 if (even)
272 {
273 for (; i > 0; i -= 16, sp += 16, dp += 16)
274 {
275 __m512 m = _mm512_load_ps(sp);
276 __m512 n = _mm512_loadu_ps(sp + 1);
277 __m512 p = _mm512_load_ps(dp);
278 p = _mm512_add_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n)));
279 _mm512_store_ps(dp, p);
280 }
281 }
282 else
283 {
284 for (; i > 0; i -= 16, sp += 16, dp += 16)
285 {
286 __m512 m = _mm512_load_ps(sp);
287 __m512 n = _mm512_loadu_ps(sp - 1);
288 __m512 p = _mm512_load_ps(dp);
289 p = _mm512_add_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n)));
290 _mm512_store_ps(dp, p);
291 }
292 }
293
294 // swap buffers
295 float* t = lp; lp = hp; hp = t;
296 even = !even;
297 ui32 w = l_width; l_width = h_width; h_width = w;
298 }
299
300 { // multiply by K or 1/K
301 float K = atk->get_K();
302 float K_inv = 1.0f / K;
303 avx512_multiply_const(lp, K_inv, (int)l_width);
304 avx512_multiply_const(hp, K, (int)h_width);
305 }
306 }
307 else {
308 if (even)
309 ldst->f32[0] = src->f32[0];
310 else
311 hdst->f32[0] = src->f32[0] * 2.0f;
312 }
313 }
314
316 void avx512_irv_horz_syn(const param_atk* atk, const line_buf* dst,
317 const line_buf* lsrc, const line_buf* hsrc,
318 ui32 width, bool even)
319 {
320 if (width > 1)
321 {
322 bool ev = even;
323 float* oth = hsrc->f32, * aug = lsrc->f32;
324 ui32 aug_width = (width + (even ? 1 : 0)) >> 1; // low pass
325 ui32 oth_width = (width + (even ? 0 : 1)) >> 1; // high pass
326
327 { // multiply by K or 1/K
328 float K = atk->get_K();
329 float K_inv = 1.0f / K;
330 avx512_multiply_const(aug, K, (int)aug_width);
331 avx512_multiply_const(oth, K_inv, (int)oth_width);
332 }
333
334 // the actual horizontal transform
335 ui32 num_steps = atk->get_num_steps();
336 for (ui32 j = 0; j < num_steps; ++j)
337 {
338 const lifting_step* s = atk->get_step(j);
339 const float a = s->irv.Aatk;
340
341 // extension
342 oth[-1] = oth[0];
343 oth[oth_width] = oth[oth_width - 1];
344 // lifting step
345 const float* sp = oth;
346 float* dp = aug;
347 int i = (int)aug_width;
348 __m512 f = _mm512_set1_ps(a);
349 if (ev)
350 {
351 for (; i > 0; i -= 16, sp += 16, dp += 16)
352 {
353 __m512 m = _mm512_load_ps(sp);
354 __m512 n = _mm512_loadu_ps(sp - 1);
355 __m512 p = _mm512_load_ps(dp);
356 p = _mm512_sub_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n)));
357 _mm512_store_ps(dp, p);
358 }
359 }
360 else
361 {
362 for (; i > 0; i -= 16, sp += 16, dp += 16)
363 {
364 __m512 m = _mm512_load_ps(sp);
365 __m512 n = _mm512_loadu_ps(sp + 1);
366 __m512 p = _mm512_load_ps(dp);
367 p = _mm512_sub_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n)));
368 _mm512_store_ps(dp, p);
369 }
370 }
371
372 // swap buffers
373 float* t = aug; aug = oth; oth = t;
374 ev = !ev;
375 ui32 w = aug_width; aug_width = oth_width; oth_width = w;
376 }
377
378 // combine both lsrc and hsrc into dst
379 {
380 float* dp = dst->f32;
381 float* spl = even ? lsrc->f32 : hsrc->f32;
382 float* sph = even ? hsrc->f32 : lsrc->f32;
383 int w = (int)width;
384 avx512_interleave32(dp, spl, sph, w);
385 }
386 }
387 else {
388 if (even)
389 dst->f32[0] = lsrc->f32[0];
390 else
391 dst->f32[0] = hsrc->f32[0] * 0.5f;
392 }
393 }
394
395
397 void avx512_rev_vert_step32(const lifting_step* s, const line_buf* sig,
398 const line_buf* other, const line_buf* aug,
399 ui32 repeat, bool synthesis)
400 {
401 const si32 a = s->rev.Aatk;
402 const si32 b = s->rev.Batk;
403 const ui8 e = s->rev.Eatk;
404 __m512i va = _mm512_set1_epi32(a);
405 __m512i vb = _mm512_set1_epi32(b);
406
407 si32* dst = aug->i32;
408 const si32* src1 = sig->i32, * src2 = other->i32;
409 // The general definition of the wavelet in Part 2 is slightly
410 // different to part 2, although they are mathematically equivalent
411 // here, we identify the simpler form from Part 1 and employ them
412 if (a == 1)
413 { // 5/3 update and any case with a == 1
414 int i = (int)repeat;
415 if (synthesis)
416 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
417 {
418 __m512i s1 = _mm512_load_si512((__m512i*)src1);
419 __m512i s2 = _mm512_load_si512((__m512i*)src2);
420 __m512i d = _mm512_load_si512((__m512i*)dst);
421 __m512i t = _mm512_add_epi32(s1, s2);
422 __m512i v = _mm512_add_epi32(vb, t);
423 __m512i w = _mm512_srai_epi32(v, e);
424 d = _mm512_sub_epi32(d, w);
425 _mm512_store_si512((__m512i*)dst, d);
426 }
427 else
428 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
429 {
430 __m512i s1 = _mm512_load_si512((__m512i*)src1);
431 __m512i s2 = _mm512_load_si512((__m512i*)src2);
432 __m512i d = _mm512_load_si512((__m512i*)dst);
433 __m512i t = _mm512_add_epi32(s1, s2);
434 __m512i v = _mm512_add_epi32(vb, t);
435 __m512i w = _mm512_srai_epi32(v, e);
436 d = _mm512_add_epi32(d, w);
437 _mm512_store_si512((__m512i*)dst, d);
438 }
439 }
440 else if (a == -1 && b == 1 && e == 1)
441 { // 5/3 predict
442 int i = (int)repeat;
443 if (synthesis)
444 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
445 {
446 __m512i s1 = _mm512_load_si512((__m512i*)src1);
447 __m512i s2 = _mm512_load_si512((__m512i*)src2);
448 __m512i d = _mm512_load_si512((__m512i*)dst);
449 __m512i t = _mm512_add_epi32(s1, s2);
450 __m512i w = _mm512_srai_epi32(t, e);
451 d = _mm512_add_epi32(d, w);
452 _mm512_store_si512((__m512i*)dst, d);
453 }
454 else
455 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
456 {
457 __m512i s1 = _mm512_load_si512((__m512i*)src1);
458 __m512i s2 = _mm512_load_si512((__m512i*)src2);
459 __m512i d = _mm512_load_si512((__m512i*)dst);
460 __m512i t = _mm512_add_epi32(s1, s2);
461 __m512i w = _mm512_srai_epi32(t, e);
462 d = _mm512_sub_epi32(d, w);
463 _mm512_store_si512((__m512i*)dst, d);
464 }
465 }
466 else if (a == -1)
467 { // any case with a == -1, which is not 5/3 predict
468 int i = (int)repeat;
469 if (synthesis)
470 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
471 {
472 __m512i s1 = _mm512_load_si512((__m512i*)src1);
473 __m512i s2 = _mm512_load_si512((__m512i*)src2);
474 __m512i d = _mm512_load_si512((__m512i*)dst);
475 __m512i t = _mm512_add_epi32(s1, s2);
476 __m512i v = _mm512_sub_epi32(vb, t);
477 __m512i w = _mm512_srai_epi32(v, e);
478 d = _mm512_sub_epi32(d, w);
479 _mm512_store_si512((__m512i*)dst, d);
480 }
481 else
482 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
483 {
484 __m512i s1 = _mm512_load_si512((__m512i*)src1);
485 __m512i s2 = _mm512_load_si512((__m512i*)src2);
486 __m512i d = _mm512_load_si512((__m512i*)dst);
487 __m512i t = _mm512_add_epi32(s1, s2);
488 __m512i v = _mm512_sub_epi32(vb, t);
489 __m512i w = _mm512_srai_epi32(v, e);
490 d = _mm512_add_epi32(d, w);
491 _mm512_store_si512((__m512i*)dst, d);
492 }
493 }
494 else { // general case
495 int i = (int)repeat;
496 if (synthesis)
497 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
498 {
499 __m512i s1 = _mm512_load_si512((__m512i*)src1);
500 __m512i s2 = _mm512_load_si512((__m512i*)src2);
501 __m512i d = _mm512_load_si512((__m512i*)dst);
502 __m512i t = _mm512_add_epi32(s1, s2);
503 __m512i u = _mm512_mullo_epi32(va, t);
504 __m512i v = _mm512_add_epi32(vb, u);
505 __m512i w = _mm512_srai_epi32(v, e);
506 d = _mm512_sub_epi32(d, w);
507 _mm512_store_si512((__m512i*)dst, d);
508 }
509 else
510 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
511 {
512 __m512i s1 = _mm512_load_si512((__m512i*)src1);
513 __m512i s2 = _mm512_load_si512((__m512i*)src2);
514 __m512i d = _mm512_load_si512((__m512i*)dst);
515 __m512i t = _mm512_add_epi32(s1, s2);
516 __m512i u = _mm512_mullo_epi32(va, t);
517 __m512i v = _mm512_add_epi32(vb, u);
518 __m512i w = _mm512_srai_epi32(v, e);
519 d = _mm512_add_epi32(d, w);
520 _mm512_store_si512((__m512i*)dst, d);
521 }
522 }
523 }
524
526 void avx512_rev_vert_step64(const lifting_step* s, const line_buf* sig,
527 const line_buf* other, const line_buf* aug,
528 ui32 repeat, bool synthesis)
529 {
530 const si32 a = s->rev.Aatk;
531 const si32 b = s->rev.Batk;
532 const ui8 e = s->rev.Eatk;
533 __m512i vb = _mm512_set1_epi64(b);
534
535 si64* dst = aug->i64;
536 const si64* src1 = sig->i64, * src2 = other->i64;
537 // The general definition of the wavelet in Part 2 is slightly
538 // different to part 2, although they are mathematically equivalent
539 // here, we identify the simpler form from Part 1 and employ them
540 if (a == 1)
541 { // 5/3 update and any case with a == 1
542 int i = (int)repeat;
543 if (synthesis)
544 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
545 {
546 __m512i s1 = _mm512_load_si512((__m512i*)src1);
547 __m512i s2 = _mm512_load_si512((__m512i*)src2);
548 __m512i d = _mm512_load_si512((__m512i*)dst);
549 __m512i t = _mm512_add_epi64(s1, s2);
550 __m512i v = _mm512_add_epi64(vb, t);
551 __m512i w = _mm512_srai_epi64(v, e);
552 d = _mm512_sub_epi64(d, w);
553 _mm512_store_si512((__m512i*)dst, d);
554 }
555 else
556 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
557 {
558 __m512i s1 = _mm512_load_si512((__m512i*)src1);
559 __m512i s2 = _mm512_load_si512((__m512i*)src2);
560 __m512i d = _mm512_load_si512((__m512i*)dst);
561 __m512i t = _mm512_add_epi64(s1, s2);
562 __m512i v = _mm512_add_epi64(vb, t);
563 __m512i w = _mm512_srai_epi64(v, e);
564 d = _mm512_add_epi64(d, w);
565 _mm512_store_si512((__m512i*)dst, d);
566 }
567 }
568 else if (a == -1 && b == 1 && e == 1)
569 { // 5/3 predict
570 int i = (int)repeat;
571 if (synthesis)
572 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
573 {
574 __m512i s1 = _mm512_load_si512((__m512i*)src1);
575 __m512i s2 = _mm512_load_si512((__m512i*)src2);
576 __m512i d = _mm512_load_si512((__m512i*)dst);
577 __m512i t = _mm512_add_epi64(s1, s2);
578 __m512i w = _mm512_srai_epi64(t, e);
579 d = _mm512_add_epi64(d, w);
580 _mm512_store_si512((__m512i*)dst, d);
581 }
582 else
583 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
584 {
585 __m512i s1 = _mm512_load_si512((__m512i*)src1);
586 __m512i s2 = _mm512_load_si512((__m512i*)src2);
587 __m512i d = _mm512_load_si512((__m512i*)dst);
588 __m512i t = _mm512_add_epi64(s1, s2);
589 __m512i w = _mm512_srai_epi64(t, e);
590 d = _mm512_sub_epi64(d, w);
591 _mm512_store_si512((__m512i*)dst, d);
592 }
593 }
594 else if (a == -1)
595 { // any case with a == -1, which is not 5/3 predict
596 int i = (int)repeat;
597 if (synthesis)
598 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
599 {
600 __m512i s1 = _mm512_load_si512((__m512i*)src1);
601 __m512i s2 = _mm512_load_si512((__m512i*)src2);
602 __m512i d = _mm512_load_si512((__m512i*)dst);
603 __m512i t = _mm512_add_epi64(s1, s2);
604 __m512i v = _mm512_sub_epi64(vb, t);
605 __m512i w = _mm512_srai_epi64(v, e);
606 d = _mm512_sub_epi64(d, w);
607 _mm512_store_si512((__m512i*)dst, d);
608 }
609 else
610 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
611 {
612 __m512i s1 = _mm512_load_si512((__m512i*)src1);
613 __m512i s2 = _mm512_load_si512((__m512i*)src2);
614 __m512i d = _mm512_load_si512((__m512i*)dst);
615 __m512i t = _mm512_add_epi64(s1, s2);
616 __m512i v = _mm512_sub_epi64(vb, t);
617 __m512i w = _mm512_srai_epi64(v, e);
618 d = _mm512_add_epi64(d, w);
619 _mm512_store_si512((__m512i*)dst, d);
620 }
621 }
622 else {
623 // general case
624 // 64bit multiplication is not supported in AVX512F + AVX512CD;
625 // in particular, _mm256_mullo_epi64.
626 if (synthesis)
627 for (ui32 i = repeat; i > 0; --i)
628 *dst++ -= (b + a * (*src1++ + *src2++)) >> e;
629 else
630 for (ui32 i = repeat; i > 0; --i)
631 *dst++ += (b + a * (*src1++ + *src2++)) >> e;
632 }
633
634 // This can only be used if you have AVX512DQ
635 // { // general case
636 // __m512i va = _mm512_set1_epi64(a);
637 // int i = (int)repeat;
638 // if (synthesis)
639 // for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
640 // {
641 // __m512i s1 = _mm512_load_si512((__m512i*)src1);
642 // __m512i s2 = _mm512_load_si512((__m512i*)src2);
643 // __m512i d = _mm512_load_si512((__m512i*)dst);
644 // __m512i t = _mm512_add_epi64(s1, s2);
645 // __m512i u = _mm512_mullo_epi64(va, t);
646 // __m512i v = _mm512_add_epi64(vb, u);
647 // __m512i w = _mm512_srai_epi64(v, e);
648 // d = _mm512_sub_epi64(d, w);
649 // _mm512_store_si512((__m512i*)dst, d);
650 // }
651 // else
652 // for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
653 // {
654 // __m512i s1 = _mm512_load_si512((__m512i*)src1);
655 // __m512i s2 = _mm512_load_si512((__m512i*)src2);
656 // __m512i d = _mm512_load_si512((__m512i*)dst);
657 // __m512i t = _mm512_add_epi64(s1, s2);
658 // __m512i u = _mm512_mullo_epi64(va, t);
659 // __m512i v = _mm512_add_epi64(vb, u);
660 // __m512i w = _mm512_srai_epi64(v, e);
661 // d = _mm512_add_epi64(d, w);
662 // _mm512_store_si512((__m512i*)dst, d);
663 // }
664 // }
665 }
666
668 void avx512_rev_vert_step(const lifting_step* s, const line_buf* sig,
669 const line_buf* other, const line_buf* aug,
670 ui32 repeat, bool synthesis)
671 {
672 if (((sig != NULL) && (sig->flags & line_buf::LFT_32BIT)) ||
673 ((aug != NULL) && (aug->flags & line_buf::LFT_32BIT)) ||
674 ((other != NULL) && (other->flags & line_buf::LFT_32BIT)))
675 {
676 assert((sig == NULL || sig->flags & line_buf::LFT_32BIT) &&
677 (other == NULL || other->flags & line_buf::LFT_32BIT) &&
678 (aug == NULL || aug->flags & line_buf::LFT_32BIT));
679 avx512_rev_vert_step32(s, sig, other, aug, repeat, synthesis);
680 }
681 else
682 {
683 assert((sig == NULL || sig->flags & line_buf::LFT_64BIT) &&
684 (other == NULL || other->flags & line_buf::LFT_64BIT) &&
685 (aug == NULL || aug->flags & line_buf::LFT_64BIT));
686 avx512_rev_vert_step64(s, sig, other, aug, repeat, synthesis);
687 }
688 }
689
691 void avx512_rev_horz_ana32(const param_atk* atk, const line_buf* ldst,
692 const line_buf* hdst, const line_buf* src,
693 ui32 width, bool even)
694 {
695 if (width > 1)
696 {
697 // split src into ldst and hdst
698 {
699 float* dpl = even ? ldst->f32 : hdst->f32;
700 float* dph = even ? hdst->f32 : ldst->f32;
701 float* sp = src->f32;
702 int w = (int)width;
703 avx512_deinterleave32(dpl, dph, sp, w);
704 }
705
706 si32* hp = hdst->i32, * lp = ldst->i32;
707 ui32 l_width = (width + (even ? 1 : 0)) >> 1; // low pass
708 ui32 h_width = (width + (even ? 0 : 1)) >> 1; // high pass
709 ui32 num_steps = atk->get_num_steps();
710 for (ui32 j = num_steps; j > 0; --j)
711 {
712 // first lifting step
713 const lifting_step* s = atk->get_step(j - 1);
714 const si32 a = s->rev.Aatk;
715 const si32 b = s->rev.Batk;
716 const ui8 e = s->rev.Eatk;
717 __m512i va = _mm512_set1_epi32(a);
718 __m512i vb = _mm512_set1_epi32(b);
719
720 // extension
721 lp[-1] = lp[0];
722 lp[l_width] = lp[l_width - 1];
723 // lifting step
724 const si32* sp = lp;
725 si32* dp = hp;
726 if (a == 1)
727 { // 5/3 update and any case with a == 1
728 int i = (int)h_width;
729 if (even)
730 {
731 for (; i > 0; i -= 16, sp += 16, dp += 16)
732 {
733 __m512i s1 = _mm512_load_si512((__m512i*)sp);
734 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
735 __m512i d = _mm512_load_si512((__m512i*)dp);
736 __m512i t = _mm512_add_epi32(s1, s2);
737 __m512i v = _mm512_add_epi32(vb, t);
738 __m512i w = _mm512_srai_epi32(v, e);
739 d = _mm512_add_epi32(d, w);
740 _mm512_store_si512((__m512i*)dp, d);
741 }
742 }
743 else
744 {
745 for (; i > 0; i -= 16, sp += 16, dp += 16)
746 {
747 __m512i s1 = _mm512_load_si512((__m512i*)sp);
748 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
749 __m512i d = _mm512_load_si512((__m512i*)dp);
750 __m512i t = _mm512_add_epi32(s1, s2);
751 __m512i v = _mm512_add_epi32(vb, t);
752 __m512i w = _mm512_srai_epi32(v, e);
753 d = _mm512_add_epi32(d, w);
754 _mm512_store_si512((__m512i*)dp, d);
755 }
756 }
757 }
758 else if (a == -1 && b == 1 && e == 1)
759 { // 5/3 predict
760 int i = (int)h_width;
761 if (even)
762 for (; i > 0; i -= 16, sp += 16, dp += 16)
763 {
764 __m512i s1 = _mm512_load_si512((__m512i*)sp);
765 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
766 __m512i d = _mm512_load_si512((__m512i*)dp);
767 __m512i t = _mm512_add_epi32(s1, s2);
768 __m512i w = _mm512_srai_epi32(t, e);
769 d = _mm512_sub_epi32(d, w);
770 _mm512_store_si512((__m512i*)dp, d);
771 }
772 else
773 for (; i > 0; i -= 16, sp += 16, dp += 16)
774 {
775 __m512i s1 = _mm512_load_si512((__m512i*)sp);
776 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
777 __m512i d = _mm512_load_si512((__m512i*)dp);
778 __m512i t = _mm512_add_epi32(s1, s2);
779 __m512i w = _mm512_srai_epi32(t, e);
780 d = _mm512_sub_epi32(d, w);
781 _mm512_store_si512((__m512i*)dp, d);
782 }
783 }
784 else if (a == -1)
785 { // any case with a == -1, which is not 5/3 predict
786 int i = (int)h_width;
787 if (even)
788 for (; i > 0; i -= 16, sp += 16, dp += 16)
789 {
790 __m512i s1 = _mm512_load_si512((__m512i*)sp);
791 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
792 __m512i d = _mm512_load_si512((__m512i*)dp);
793 __m512i t = _mm512_add_epi32(s1, s2);
794 __m512i v = _mm512_sub_epi32(vb, t);
795 __m512i w = _mm512_srai_epi32(v, e);
796 d = _mm512_add_epi32(d, w);
797 _mm512_store_si512((__m512i*)dp, d);
798 }
799 else
800 for (; i > 0; i -= 16, sp += 16, dp += 16)
801 {
802 __m512i s1 = _mm512_load_si512((__m512i*)sp);
803 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
804 __m512i d = _mm512_load_si512((__m512i*)dp);
805 __m512i t = _mm512_add_epi32(s1, s2);
806 __m512i v = _mm512_sub_epi32(vb, t);
807 __m512i w = _mm512_srai_epi32(v, e);
808 d = _mm512_add_epi32(d, w);
809 _mm512_store_si512((__m512i*)dp, d);
810 }
811 }
812 else {
813 // general case
814 int i = (int)h_width;
815 if (even)
816 for (; i > 0; i -= 16, sp += 16, dp += 16)
817 {
818 __m512i s1 = _mm512_load_si512((__m512i*)sp);
819 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
820 __m512i d = _mm512_load_si512((__m512i*)dp);
821 __m512i t = _mm512_add_epi32(s1, s2);
822 __m512i u = _mm512_mullo_epi32(va, t);
823 __m512i v = _mm512_add_epi32(vb, u);
824 __m512i w = _mm512_srai_epi32(v, e);
825 d = _mm512_add_epi32(d, w);
826 _mm512_store_si512((__m512i*)dp, d);
827 }
828 else
829 for (; i > 0; i -= 16, sp += 16, dp += 16)
830 {
831 __m512i s1 = _mm512_load_si512((__m512i*)sp);
832 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
833 __m512i d = _mm512_load_si512((__m512i*)dp);
834 __m512i t = _mm512_add_epi32(s1, s2);
835 __m512i u = _mm512_mullo_epi32(va, t);
836 __m512i v = _mm512_add_epi32(vb, u);
837 __m512i w = _mm512_srai_epi32(v, e);
838 d = _mm512_add_epi32(d, w);
839 _mm512_store_si512((__m512i*)dp, d);
840 }
841 }
842
843 // swap buffers
844 si32* t = lp; lp = hp; hp = t;
845 even = !even;
846 ui32 w = l_width; l_width = h_width; h_width = w;
847 }
848 }
849 else {
850 if (even)
851 ldst->i32[0] = src->i32[0];
852 else
853 hdst->i32[0] = src->i32[0] << 1;
854 }
855 }
856
858 void avx512_rev_horz_ana64(const param_atk* atk, const line_buf* ldst,
859 const line_buf* hdst, const line_buf* src,
860 ui32 width, bool even)
861 {
862 if (width > 1)
863 {
864 // split src into ldst and hdst
865 {
866 double* dpl = (double*)(even ? ldst->p : hdst->p);
867 double* dph = (double*)(even ? hdst->p : ldst->p);
868 double* sp = (double*)(src->p);
869 int w = (int)width;
870 avx512_deinterleave64(dpl, dph, sp, w);
871 }
872
873 si64* hp = hdst->i64, * lp = ldst->i64;
874 ui32 l_width = (width + (even ? 1 : 0)) >> 1; // low pass
875 ui32 h_width = (width + (even ? 0 : 1)) >> 1; // high pass
876 ui32 num_steps = atk->get_num_steps();
877 for (ui32 j = num_steps; j > 0; --j)
878 {
879 // first lifting step
880 const lifting_step* s = atk->get_step(j - 1);
881 const si32 a = s->rev.Aatk;
882 const si32 b = s->rev.Batk;
883 const ui8 e = s->rev.Eatk;
884 __m512i vb = _mm512_set1_epi64(b);
885
886 // extension
887 lp[-1] = lp[0];
888 lp[l_width] = lp[l_width - 1];
889 // lifting step
890 const si64* sp = lp;
891 si64* dp = hp;
892 if (a == 1)
893 { // 5/3 update and any case with a == 1
894 int i = (int)h_width;
895 if (even)
896 {
897 for (; i > 0; i -= 8, sp += 8, dp += 8)
898 {
899 __m512i s1 = _mm512_load_si512((__m512i*)sp);
900 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
901 __m512i d = _mm512_load_si512((__m512i*)dp);
902 __m512i t = _mm512_add_epi64(s1, s2);
903 __m512i v = _mm512_add_epi64(vb, t);
904 __m512i w = _mm512_srai_epi64(v, e);
905 d = _mm512_add_epi64(d, w);
906 _mm512_store_si512((__m512i*)dp, d);
907 }
908 }
909 else
910 {
911 for (; i > 0; i -= 8, sp += 8, dp += 8)
912 {
913 __m512i s1 = _mm512_load_si512((__m512i*)sp);
914 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
915 __m512i d = _mm512_load_si512((__m512i*)dp);
916 __m512i t = _mm512_add_epi64(s1, s2);
917 __m512i v = _mm512_add_epi64(vb, t);
918 __m512i w = _mm512_srai_epi64(v, e);
919 d = _mm512_add_epi64(d, w);
920 _mm512_store_si512((__m512i*)dp, d);
921 }
922 }
923 }
924 else if (a == -1 && b == 1 && e == 1)
925 { // 5/3 predict
926 int i = (int)h_width;
927 if (even)
928 for (; i > 0; i -= 8, sp += 8, dp += 8)
929 {
930 __m512i s1 = _mm512_load_si512((__m512i*)sp);
931 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
932 __m512i d = _mm512_load_si512((__m512i*)dp);
933 __m512i t = _mm512_add_epi64(s1, s2);
934 __m512i w = _mm512_srai_epi64(t, e);
935 d = _mm512_sub_epi64(d, w);
936 _mm512_store_si512((__m512i*)dp, d);
937 }
938 else
939 for (; i > 0; i -= 8, sp += 8, dp += 8)
940 {
941 __m512i s1 = _mm512_load_si512((__m512i*)sp);
942 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
943 __m512i d = _mm512_load_si512((__m512i*)dp);
944 __m512i t = _mm512_add_epi64(s1, s2);
945 __m512i w = _mm512_srai_epi64(t, e);
946 d = _mm512_sub_epi64(d, w);
947 _mm512_store_si512((__m512i*)dp, d);
948 }
949 }
950 else if (a == -1)
951 { // any case with a == -1, which is not 5/3 predict
952 int i = (int)h_width;
953 if (even)
954 for (; i > 0; i -= 8, sp += 8, dp += 8)
955 {
956 __m512i s1 = _mm512_load_si512((__m512i*)sp);
957 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
958 __m512i d = _mm512_load_si512((__m512i*)dp);
959 __m512i t = _mm512_add_epi64(s1, s2);
960 __m512i v = _mm512_sub_epi64(vb, t);
961 __m512i w = _mm512_srai_epi64(v, e);
962 d = _mm512_add_epi64(d, w);
963 _mm512_store_si512((__m512i*)dp, d);
964 }
965 else
966 for (; i > 0; i -= 8, sp += 8, dp += 8)
967 {
968 __m512i s1 = _mm512_load_si512((__m512i*)sp);
969 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
970 __m512i d = _mm512_load_si512((__m512i*)dp);
971 __m512i t = _mm512_add_epi64(s1, s2);
972 __m512i v = _mm512_sub_epi64(vb, t);
973 __m512i w = _mm512_srai_epi64(v, e);
974 d = _mm512_add_epi64(d, w);
975 _mm512_store_si512((__m512i*)dp, d);
976 }
977 }
978 else
979 {
980 // general case
981 // 64bit multiplication is not supported in AVX512F + AVX512CD;
982 // in particular, _mm256_mullo_epi64.
983 if (even)
984 for (ui32 i = h_width; i > 0; --i, sp++, dp++)
985 *dp += (b + a * (sp[0] + sp[1])) >> e;
986 else
987 for (ui32 i = h_width; i > 0; --i, sp++, dp++)
988 *dp += (b + a * (sp[-1] + sp[0])) >> e;
989 }
990
991 // This can only be used if you have AVX512DQ
992 // {
993 // // general case
994 // __m512i va = _mm512_set1_epi64(a);
995 // int i = (int)h_width;
996 // if (even)
997 // for (; i > 0; i -= 8, sp += 8, dp += 8)
998 // {
999 // __m512i s1 = _mm512_load_si512((__m512i*)sp);
1000 // __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1001 // __m512i d = _mm512_load_si512((__m512i*)dp);
1002 // __m512i t = _mm512_add_epi64(s1, s2);
1003 // __m512i u = _mm512_mullo_epi64(va, t);
1004 // __m512i v = _mm512_add_epi64(vb, u);
1005 // __m512i w = _mm512_srai_epi64(v, e);
1006 // d = _mm512_add_epi64(d, w);
1007 // _mm512_store_si512((__m512i*)dp, d);
1008 // }
1009 // else
1010 // for (; i > 0; i -= 8, sp += 8, dp += 8)
1011 // {
1012 // __m512i s1 = _mm512_load_si512((__m512i*)sp);
1013 // __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1014 // __m512i d = _mm512_load_si512((__m512i*)dp);
1015 // __m512i t = _mm512_add_epi64(s1, s2);
1016 // __m512i u = _mm512_mullo_epi64(va, t);
1017 // __m512i v = _mm512_add_epi64(vb, u);
1018 // __m512i w = _mm512_srai_epi64(v, e);
1019 // d = _mm512_add_epi64(d, w);
1020 // _mm512_store_si512((__m512i*)dp, d);
1021 // }
1022 // }
1023
1024 // swap buffers
1025 si64* t = lp; lp = hp; hp = t;
1026 even = !even;
1027 ui32 w = l_width; l_width = h_width; h_width = w;
1028 }
1029 }
1030 else {
1031 if (even)
1032 ldst->i64[0] = src->i64[0];
1033 else
1034 hdst->i64[0] = src->i64[0] << 1;
1035 }
1036 }
1037
1039 void avx512_rev_horz_ana(const param_atk* atk, const line_buf* ldst,
1040 const line_buf* hdst, const line_buf* src,
1041 ui32 width, bool even)
1042 {
1043 if (src->flags & line_buf::LFT_32BIT)
1044 {
1045 assert((ldst == NULL || ldst->flags & line_buf::LFT_32BIT) &&
1046 (hdst == NULL || hdst->flags & line_buf::LFT_32BIT));
1047 avx512_rev_horz_ana32(atk, ldst, hdst, src, width, even);
1048 }
1049 else
1050 {
1051 assert((ldst == NULL || ldst->flags & line_buf::LFT_64BIT) &&
1052 (hdst == NULL || hdst->flags & line_buf::LFT_64BIT) &&
1053 (src == NULL || src->flags & line_buf::LFT_64BIT));
1054 avx512_rev_horz_ana64(atk, ldst, hdst, src, width, even);
1055 }
1056 }
1057
1059 void avx512_rev_horz_syn32(const param_atk* atk, const line_buf* dst,
1060 const line_buf* lsrc, const line_buf* hsrc,
1061 ui32 width, bool even)
1062 {
1063 if (width > 1)
1064 {
1065 bool ev = even;
1066 si32* oth = hsrc->i32, * aug = lsrc->i32;
1067 ui32 aug_width = (width + (even ? 1 : 0)) >> 1; // low pass
1068 ui32 oth_width = (width + (even ? 0 : 1)) >> 1; // high pass
1069 ui32 num_steps = atk->get_num_steps();
1070 for (ui32 j = 0; j < num_steps; ++j)
1071 {
1072 const lifting_step* s = atk->get_step(j);
1073 const si32 a = s->rev.Aatk;
1074 const si32 b = s->rev.Batk;
1075 const ui8 e = s->rev.Eatk;
1076 __m512i va = _mm512_set1_epi32(a);
1077 __m512i vb = _mm512_set1_epi32(b);
1078
1079 // extension
1080 oth[-1] = oth[0];
1081 oth[oth_width] = oth[oth_width - 1];
1082 // lifting step
1083 const si32* sp = oth;
1084 si32* dp = aug;
1085 if (a == 1)
1086 { // 5/3 update and any case with a == 1
1087 int i = (int)aug_width;
1088 if (ev)
1089 {
1090 for (; i > 0; i -= 16, sp += 16, dp += 16)
1091 {
1092 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1093 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1094 __m512i d = _mm512_load_si512((__m512i*)dp);
1095 __m512i t = _mm512_add_epi32(s1, s2);
1096 __m512i v = _mm512_add_epi32(vb, t);
1097 __m512i w = _mm512_srai_epi32(v, e);
1098 d = _mm512_sub_epi32(d, w);
1099 _mm512_store_si512((__m512i*)dp, d);
1100 }
1101 }
1102 else
1103 {
1104 for (; i > 0; i -= 16, sp += 16, dp += 16)
1105 {
1106 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1107 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1108 __m512i d = _mm512_load_si512((__m512i*)dp);
1109 __m512i t = _mm512_add_epi32(s1, s2);
1110 __m512i v = _mm512_add_epi32(vb, t);
1111 __m512i w = _mm512_srai_epi32(v, e);
1112 d = _mm512_sub_epi32(d, w);
1113 _mm512_store_si512((__m512i*)dp, d);
1114 }
1115 }
1116 }
1117 else if (a == -1 && b == 1 && e == 1)
1118 { // 5/3 predict
1119 int i = (int)aug_width;
1120 if (ev)
1121 for (; i > 0; i -= 16, sp += 16, dp += 16)
1122 {
1123 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1124 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1125 __m512i d = _mm512_load_si512((__m512i*)dp);
1126 __m512i t = _mm512_add_epi32(s1, s2);
1127 __m512i w = _mm512_srai_epi32(t, e);
1128 d = _mm512_add_epi32(d, w);
1129 _mm512_store_si512((__m512i*)dp, d);
1130 }
1131 else
1132 for (; i > 0; i -= 16, sp += 16, dp += 16)
1133 {
1134 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1135 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1136 __m512i d = _mm512_load_si512((__m512i*)dp);
1137 __m512i t = _mm512_add_epi32(s1, s2);
1138 __m512i w = _mm512_srai_epi32(t, e);
1139 d = _mm512_add_epi32(d, w);
1140 _mm512_store_si512((__m512i*)dp, d);
1141 }
1142 }
1143 else if (a == -1)
1144 { // any case with a == -1, which is not 5/3 predict
1145 int i = (int)aug_width;
1146 if (ev)
1147 for (; i > 0; i -= 16, sp += 16, dp += 16)
1148 {
1149 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1150 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1151 __m512i d = _mm512_load_si512((__m512i*)dp);
1152 __m512i t = _mm512_add_epi32(s1, s2);
1153 __m512i v = _mm512_sub_epi32(vb, t);
1154 __m512i w = _mm512_srai_epi32(v, e);
1155 d = _mm512_sub_epi32(d, w);
1156 _mm512_store_si512((__m512i*)dp, d);
1157 }
1158 else
1159 for (; i > 0; i -= 16, sp += 16, dp += 16)
1160 {
1161 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1162 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1163 __m512i d = _mm512_load_si512((__m512i*)dp);
1164 __m512i t = _mm512_add_epi32(s1, s2);
1165 __m512i v = _mm512_sub_epi32(vb, t);
1166 __m512i w = _mm512_srai_epi32(v, e);
1167 d = _mm512_sub_epi32(d, w);
1168 _mm512_store_si512((__m512i*)dp, d);
1169 }
1170 }
1171 else {
1172 // general case
1173 int i = (int)aug_width;
1174 if (ev)
1175 for (; i > 0; i -= 16, sp += 16, dp += 16)
1176 {
1177 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1178 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1179 __m512i d = _mm512_load_si512((__m512i*)dp);
1180 __m512i t = _mm512_add_epi32(s1, s2);
1181 __m512i u = _mm512_mullo_epi32(va, t);
1182 __m512i v = _mm512_add_epi32(vb, u);
1183 __m512i w = _mm512_srai_epi32(v, e);
1184 d = _mm512_sub_epi32(d, w);
1185 _mm512_store_si512((__m512i*)dp, d);
1186 }
1187 else
1188 for (; i > 0; i -= 16, sp += 16, dp += 16)
1189 {
1190 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1191 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1192 __m512i d = _mm512_load_si512((__m512i*)dp);
1193 __m512i t = _mm512_add_epi32(s1, s2);
1194 __m512i u = _mm512_mullo_epi32(va, t);
1195 __m512i v = _mm512_add_epi32(vb, u);
1196 __m512i w = _mm512_srai_epi32(v, e);
1197 d = _mm512_sub_epi32(d, w);
1198 _mm512_store_si512((__m512i*)dp, d);
1199 }
1200 }
1201
1202 // swap buffers
1203 si32* t = aug; aug = oth; oth = t;
1204 ev = !ev;
1205 ui32 w = aug_width; aug_width = oth_width; oth_width = w;
1206 }
1207
1208 // combine both lsrc and hsrc into dst
1209 {
1210 float* dp = dst->f32;
1211 float* spl = even ? lsrc->f32 : hsrc->f32;
1212 float* sph = even ? hsrc->f32 : lsrc->f32;
1213 int w = (int)width;
1214 avx512_interleave32(dp, spl, sph, w);
1215 }
1216 }
1217 else {
1218 if (even)
1219 dst->i32[0] = lsrc->i32[0];
1220 else
1221 dst->i32[0] = hsrc->i32[0] >> 1;
1222 }
1223 }
1224
1226 void avx512_rev_horz_syn64(const param_atk* atk, const line_buf* dst,
1227 const line_buf* lsrc, const line_buf* hsrc,
1228 ui32 width, bool even)
1229 {
1230 if (width > 1)
1231 {
1232 bool ev = even;
1233 si64* oth = hsrc->i64, * aug = lsrc->i64;
1234 ui32 aug_width = (width + (even ? 1 : 0)) >> 1; // low pass
1235 ui32 oth_width = (width + (even ? 0 : 1)) >> 1; // high pass
1236 ui32 num_steps = atk->get_num_steps();
1237 for (ui32 j = 0; j < num_steps; ++j)
1238 {
1239 const lifting_step* s = atk->get_step(j);
1240 const si32 a = s->rev.Aatk;
1241 const si32 b = s->rev.Batk;
1242 const ui8 e = s->rev.Eatk;
1243 __m512i vb = _mm512_set1_epi64(b);
1244
1245 // extension
1246 oth[-1] = oth[0];
1247 oth[oth_width] = oth[oth_width - 1];
1248 // lifting step
1249 const si64* sp = oth;
1250 si64* dp = aug;
1251 if (a == 1)
1252 { // 5/3 update and any case with a == 1
1253 int i = (int)aug_width;
1254 if (ev)
1255 {
1256 for (; i > 0; i -= 8, sp += 8, dp += 8)
1257 {
1258 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1259 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1260 __m512i d = _mm512_load_si512((__m512i*)dp);
1261 __m512i t = _mm512_add_epi64(s1, s2);
1262 __m512i v = _mm512_add_epi64(vb, t);
1263 __m512i w = _mm512_srai_epi64(v, e);
1264 d = _mm512_sub_epi64(d, w);
1265 _mm512_store_si512((__m512i*)dp, d);
1266 }
1267 }
1268 else
1269 {
1270 for (; i > 0; i -= 8, sp += 8, dp += 8)
1271 {
1272 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1273 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1274 __m512i d = _mm512_load_si512((__m512i*)dp);
1275 __m512i t = _mm512_add_epi64(s1, s2);
1276 __m512i v = _mm512_add_epi64(vb, t);
1277 __m512i w = _mm512_srai_epi64(v, e);
1278 d = _mm512_sub_epi64(d, w);
1279 _mm512_store_si512((__m512i*)dp, d);
1280 }
1281 }
1282 }
1283 else if (a == -1 && b == 1 && e == 1)
1284 { // 5/3 predict
1285 int i = (int)aug_width;
1286 if (ev)
1287 for (; i > 0; i -= 8, sp += 8, dp += 8)
1288 {
1289 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1290 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1291 __m512i d = _mm512_load_si512((__m512i*)dp);
1292 __m512i t = _mm512_add_epi64(s1, s2);
1293 __m512i w = _mm512_srai_epi64(t, e);
1294 d = _mm512_add_epi64(d, w);
1295 _mm512_store_si512((__m512i*)dp, d);
1296 }
1297 else
1298 for (; i > 0; i -= 8, sp += 8, dp += 8)
1299 {
1300 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1301 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1302 __m512i d = _mm512_load_si512((__m512i*)dp);
1303 __m512i t = _mm512_add_epi64(s1, s2);
1304 __m512i w = _mm512_srai_epi64(t, e);
1305 d = _mm512_add_epi64(d, w);
1306 _mm512_store_si512((__m512i*)dp, d);
1307 }
1308 }
1309 else if (a == -1)
1310 { // any case with a == -1, which is not 5/3 predict
1311 int i = (int)aug_width;
1312 if (ev)
1313 for (; i > 0; i -= 8, sp += 8, dp += 8)
1314 {
1315 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1316 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1317 __m512i d = _mm512_load_si512((__m512i*)dp);
1318 __m512i t = _mm512_add_epi64(s1, s2);
1319 __m512i v = _mm512_sub_epi64(vb, t);
1320 __m512i w = _mm512_srai_epi64(v, e);
1321 d = _mm512_sub_epi64(d, w);
1322 _mm512_store_si512((__m512i*)dp, d);
1323 }
1324 else
1325 for (; i > 0; i -= 8, sp += 8, dp += 8)
1326 {
1327 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1328 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1329 __m512i d = _mm512_load_si512((__m512i*)dp);
1330 __m512i t = _mm512_add_epi64(s1, s2);
1331 __m512i v = _mm512_sub_epi64(vb, t);
1332 __m512i w = _mm512_srai_epi64(v, e);
1333 d = _mm512_sub_epi64(d, w);
1334 _mm512_store_si512((__m512i*)dp, d);
1335 }
1336 }
1337 else
1338 {
1339 // general case
1340 // 64bit multiplication is not supported in AVX512F + AVX512CD;
1341 // in particular, _mm256_mullo_epi64.
1342 if (ev)
1343 for (ui32 i = aug_width; i > 0; --i, sp++, dp++)
1344 *dp -= (b + a * (sp[-1] + sp[0])) >> e;
1345 else
1346 for (ui32 i = aug_width; i > 0; --i, sp++, dp++)
1347 *dp -= (b + a * (sp[0] + sp[1])) >> e;
1348 }
1349
1350 // This can only be used if you have AVX512DQ
1351 // {
1352 // // general case
1353 // __m512i va = _mm512_set1_epi64(a);
1354 // int i = (int)aug_width;
1355 // if (ev)
1356 // for (; i > 0; i -= 8, sp += 8, dp += 8)
1357 // {
1358 // __m512i s1 = _mm512_load_si512((__m512i*)sp);
1359 // __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1360 // __m512i d = _mm512_load_si512((__m512i*)dp);
1361 // __m512i t = _mm512_add_epi64(s1, s2);
1362 // __m512i u = _mm512_mullo_epi64(va, t);
1363 // __m512i v = _mm512_add_epi64(vb, u);
1364 // __m512i w = _mm512_srai_epi64(v, e);
1365 // d = _mm512_sub_epi64(d, w);
1366 // _mm512_store_si512((__m512i*)dp, d);
1367 // }
1368 // else
1369 // for (; i > 0; i -= 8, sp += 8, dp += 8)
1370 // {
1371 // __m512i s1 = _mm512_load_si512((__m512i*)sp);
1372 // __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1373 // __m512i d = _mm512_load_si512((__m512i*)dp);
1374 // __m512i t = _mm512_add_epi64(s1, s2);
1375 // __m512i u = _mm512_mullo_epi64(va, t);
1376 // __m512i v = _mm512_add_epi64(vb, u);
1377 // __m512i w = _mm512_srai_epi64(v, e);
1378 // d = _mm512_sub_epi64(d, w);
1379 // _mm512_store_si512((__m512i*)dp, d);
1380 // }
1381 // }
1382
1383 // swap buffers
1384 si64* t = aug; aug = oth; oth = t;
1385 ev = !ev;
1386 ui32 w = aug_width; aug_width = oth_width; oth_width = w;
1387 }
1388
1389 // combine both lsrc and hsrc into dst
1390 {
1391 double* dp = (double*)(dst->p);
1392 double* spl = (double*)(even ? lsrc->p : hsrc->p);
1393 double* sph = (double*)(even ? hsrc->p : lsrc->p);
1394 int w = (int)width;
1395 avx512_interleave64(dp, spl, sph, w);
1396 }
1397 }
1398 else {
1399 if (even)
1400 dst->i64[0] = lsrc->i64[0];
1401 else
1402 dst->i64[0] = hsrc->i64[0] >> 1;
1403 }
1404 }
1405
1407 void avx512_rev_horz_syn(const param_atk* atk, const line_buf* dst,
1408 const line_buf* lsrc, const line_buf* hsrc,
1409 ui32 width, bool even)
1410 {
1411 if (dst->flags & line_buf::LFT_32BIT)
1412 {
1413 assert((lsrc == NULL || lsrc->flags & line_buf::LFT_32BIT) &&
1414 (hsrc == NULL || hsrc->flags & line_buf::LFT_32BIT));
1415 avx512_rev_horz_syn32(atk, dst, lsrc, hsrc, width, even);
1416 }
1417 else
1418 {
1419 assert((dst == NULL || dst->flags & line_buf::LFT_64BIT) &&
1420 (lsrc == NULL || lsrc->flags & line_buf::LFT_64BIT) &&
1421 (hsrc == NULL || hsrc->flags & line_buf::LFT_64BIT));
1422 avx512_rev_horz_syn64(atk, dst, lsrc, hsrc, width, even);
1423 }
1424 }
1425
1426 } // !local
1427} // !ojph
1428
1429#endif
void avx512_irv_vert_step(const lifting_step *s, const line_buf *sig, const line_buf *other, const line_buf *aug, ui32 repeat, bool synthesis)
void avx512_rev_horz_syn(const param_atk *atk, const line_buf *dst, const line_buf *lsrc, const line_buf *hsrc, ui32 width, bool even)
void avx512_rev_vert_step(const lifting_step *s, const line_buf *sig, const line_buf *other, const line_buf *aug, ui32 repeat, bool synthesis)
void avx512_irv_horz_ana(const param_atk *atk, const line_buf *ldst, const line_buf *hdst, const line_buf *src, ui32 width, bool even)
void avx512_irv_vert_times_K(float K, const line_buf *aug, ui32 repeat)
void avx512_irv_horz_syn(const param_atk *atk, const line_buf *dst, const line_buf *lsrc, const line_buf *hsrc, ui32 width, bool even)
void avx512_rev_horz_ana(const param_atk *atk, const line_buf *ldst, const line_buf *hdst, const line_buf *src, ui32 width, bool even)
int64_t si64
Definition ojph_defs.h:57
int32_t si32
Definition ojph_defs.h:55
uint32_t ui32
Definition ojph_defs.h:54
uint8_t ui8
Definition ojph_defs.h:50