[FFmpeg-devel] [PATCH v2] mdct15: add inverse transform postrotation SIMD

James Almer jamrial at gmail.com
Sun Jul 30 04:30:44 EEST 2017


On 7/29/2017 9:48 PM, Rostislav Pehlivanov wrote:
> Speeds up decoding by 8% in total in the avx2 case.
> 
> 20ms frames:
> Before   (c):  17774 decicycles in postrotate,  262065 runs,     79 skips
> After (sse3):   9624 decicycles in postrotate,  262113 runs,     31 skips
> After (avx2):   7169 decicycles in postrotate,  262104 runs,     40 skips
> 
> 10ms frames:
> Before   (c):   9058 decicycles in postrotate,  524209 runs,     79 skips
> After (sse3):   4964 decicycles in postrotate,  524236 runs,     52 skips
> After (avx2):   3915 decicycles in postrotate,  524236 runs,     52 skips
> 
> 5ms frames:
> Before   (c):   4764 decicycles in postrotate, 1048466 runs,    110 skips
> After (sse3):   2670 decicycles in postrotate, 1048507 runs,     69 skips
> After (avx2):   2161 decicycles in postrotate, 1048515 runs,     61 skips
> 
> 2.5ms frames:
> Before   (c):   2608 decicycles in postrotate, 2097030 runs,    122 skips
> After (sse3):   1507 decicycles in postrotate, 2097089 runs,     63 skips
> After (avx2):   1377 decicycles in postrotate, 2097097 runs,     55 skips
> 
> Needs to overwrite the start of some buffers as well as the
> end of them, hence the OVERALLOC stuff.
> 
> Signed-off-by: Rostislav Pehlivanov <atomnuker at gmail.com>
> ---
>  libavcodec/mdct15.c          | 74 ++++++++++++++++++++++++-----------
>  libavcodec/mdct15.h          |  3 ++
>  libavcodec/x86/mdct15.asm    | 93 +++++++++++++++++++++++++++++++++++++++++++-
>  libavcodec/x86/mdct15_init.c |  9 +++++
>  4 files changed, 155 insertions(+), 24 deletions(-)
> 
> diff --git a/libavcodec/mdct15.c b/libavcodec/mdct15.c
> index d68372c344..9838082c7e 100644
> --- a/libavcodec/mdct15.c
> +++ b/libavcodec/mdct15.c
> @@ -28,6 +28,7 @@
>  #include <math.h>
>  #include <stddef.h>
>  
> +#include "avcodec.h"
>  #include "config.h"
>  
>  #include "libavutil/attributes.h"
> @@ -40,6 +41,25 @@
>  
>  #define CMUL3(c, a, b) CMUL((c).re, (c).im, (a).re, (a).im, (b).re, (b).im)
>  
> +#define OVERALLOC(val, len, size)                                           \
> +    {                                                                       \
> +        const int pad = AV_INPUT_BUFFER_PADDING_SIZE/size;                  \
> +        (val) = NULL;                                                       \
> +        uint8_t *temp = av_mallocz_array(len + pad, size);                  \
> +        if (temp)                                                           \
> +            (val) = (void *)(temp + AV_INPUT_BUFFER_PADDING_SIZE);          \
> +    }
> +
> +#define OVERFREEP(val)                                                      \
> +    {                                                                       \
> +        uint8_t *temp = (uint8_t *)(val);                                   \
> +        if (temp) {                                                         \
> +            temp -= AV_INPUT_BUFFER_PADDING_SIZE;                           \
> +            av_free(temp);                                                  \
> +        }                                                                   \
> +        val = NULL;                                                         \
> +    }
> +
>  av_cold void ff_mdct15_uninit(MDCT15Context **ps)
>  {
>      MDCT15Context *s = *ps;
> @@ -50,9 +70,9 @@ av_cold void ff_mdct15_uninit(MDCT15Context **ps)
>      ff_fft_end(&s->ptwo_fft);
>  
>      av_freep(&s->pfa_prereindex);
> -    av_freep(&s->pfa_postreindex);
> -    av_freep(&s->twiddle_exptab);
> -    av_freep(&s->tmp);
> +    OVERFREEP(s->pfa_postreindex);
> +    OVERFREEP(s->twiddle_exptab);
> +    OVERFREEP(s->tmp);
>  
>      av_freep(ps);
>  }
> @@ -65,11 +85,11 @@ static inline int init_pfa_reindex_tabs(MDCT15Context *s)
>      const int inv_1 = l_ptwo << ((4 - b_ptwo) & 3); /* (2^b_ptwo)^-1 mod 15 */
>      const int inv_2 = 0xeeeeeeef & ((1U << b_ptwo) - 1); /* 15^-1 mod 2^b_ptwo */
>  
> -    s->pfa_prereindex = av_malloc(15 * l_ptwo * sizeof(*s->pfa_prereindex));
> +    s->pfa_prereindex = av_malloc_array(15 * l_ptwo, sizeof(*s->pfa_prereindex));
>      if (!s->pfa_prereindex)
>          return 1;
>  
> -    s->pfa_postreindex = av_malloc(15 * l_ptwo * sizeof(*s->pfa_postreindex));
> +    OVERALLOC(s->pfa_postreindex, 15 * l_ptwo, sizeof(*s->pfa_postreindex));
>      if (!s->pfa_postreindex)
>          return 1;
>  
> @@ -203,6 +223,21 @@ static void mdct15(MDCT15Context *s, float *dst, const float *src, ptrdiff_t str
>      }
>  }
>  
> +static void postrotate_c(FFTComplex *out, FFTComplex *in, FFTComplex *exp,
> +                         int *lut, ptrdiff_t len8)
> +{
> +    int i;
> +
> +    /* Reindex again, apply twiddles and output */
> +    for (i = 0; i < len8; i++) {
> +        const int i0 = len8 + i, i1 = len8 - i - 1;
> +        const int s0 = lut[i0], s1 = lut[i1];
> +
> +        CMUL(out[i1].re, out[i0].im, in[s1].im, in[s1].re, exp[i1].im, exp[i1].re);
> +        CMUL(out[i0].re, out[i1].im, in[s0].im, in[s0].re, exp[i0].im, exp[i0].re);
> +    }
> +}
> +
>  static void imdct15_half(MDCT15Context *s, float *dst, const float *src,
>                           ptrdiff_t stride)
>  {
> @@ -226,15 +261,7 @@ static void imdct15_half(MDCT15Context *s, float *dst, const float *src,
>          s->ptwo_fft.fft_calc(&s->ptwo_fft, s->tmp + l_ptwo*i);
>  
>      /* Reindex again, apply twiddles and output */
> -    for (i = 0; i < len8; i++) {
> -        const int i0 = len8 + i, i1 = len8 - i - 1;
> -        const int s0 = s->pfa_postreindex[i0], s1 = s->pfa_postreindex[i1];
> -
> -        CMUL(z[i1].re, z[i0].im, s->tmp[s1].im, s->tmp[s1].re,
> -             s->twiddle_exptab[i1].im, s->twiddle_exptab[i1].re);
> -        CMUL(z[i0].re, z[i1].im, s->tmp[s0].im, s->tmp[s0].re,
> -             s->twiddle_exptab[i0].im, s->twiddle_exptab[i0].re);
> -    }
> +    s->postreindex(z, s->tmp, s->twiddle_exptab, s->pfa_postreindex, len8);
>  }
>  
>  av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
> @@ -253,13 +280,14 @@ av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
>      if (!s)
>          return AVERROR(ENOMEM);
>  
> -    s->fft_n      = N - 1;
> -    s->len4       = len2 / 2;
> -    s->len2       = len2;
> -    s->inverse    = inverse;
> -    s->fft15      = fft15_c;
> -    s->mdct       = mdct15;
> -    s->imdct_half = imdct15_half;
> +    s->fft_n       = N - 1;
> +    s->len4        = len2 / 2;
> +    s->len2        = len2;
> +    s->inverse     = inverse;
> +    s->fft15       = fft15_c;
> +    s->mdct        = mdct15;
> +    s->imdct_half  = imdct15_half;
> +    s->postreindex = postrotate_c;
>  
>      if (ff_fft_init(&s->ptwo_fft, N - 1, s->inverse) < 0)
>          goto fail;
> @@ -267,11 +295,11 @@ av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
>      if (init_pfa_reindex_tabs(s))
>          goto fail;
>  
> -    s->tmp  = av_malloc_array(len, 2 * sizeof(*s->tmp));
> +    OVERALLOC(s->tmp, 2*len, sizeof(*s->tmp));
>      if (!s->tmp)
>          goto fail;
>  
> -    s->twiddle_exptab  = av_malloc_array(s->len4, sizeof(*s->twiddle_exptab));
> +    OVERALLOC(s->twiddle_exptab, s->len4, sizeof(*s->twiddle_exptab));
>      if (!s->twiddle_exptab)
>          goto fail;
>  
> diff --git a/libavcodec/mdct15.h b/libavcodec/mdct15.h
> index 1c2149d436..42e60f3e10 100644
> --- a/libavcodec/mdct15.h
> +++ b/libavcodec/mdct15.h
> @@ -42,6 +42,9 @@ typedef struct MDCT15Context {
>      /* 15-point FFT */
>      void (*fft15)(FFTComplex *out, FFTComplex *in, FFTComplex *exptab, ptrdiff_t stride);
>  
> +    /* PFA postrotate and exptab */
> +    void (*postreindex)(FFTComplex *out, FFTComplex *in, FFTComplex *exp, int *lut, ptrdiff_t len8);
> +
>      /* Calculate a full 2N -> N MDCT */
>      void (*mdct)(struct MDCT15Context *s, float *dst, const float *src, ptrdiff_t stride);
>  
> diff --git a/libavcodec/x86/mdct15.asm b/libavcodec/x86/mdct15.asm
> index f8b895944d..b42adb4aa9 100644
> --- a/libavcodec/x86/mdct15.asm
> +++ b/libavcodec/x86/mdct15.asm
> @@ -24,7 +24,11 @@
>  
>  %if ARCH_X86_64
>  
> -SECTION_RODATA
> +SECTION_RODATA 32
> +
> +perm_neg: dd 2, 5, 3, 4, 6, 1, 7, 0
> +perm_pos: dd 0, 7, 1, 6, 4, 3, 5, 2
> +sign_adjust_r: times 4 dd 0x80000000, 0x00000000
>  
>  sign_adjust_5: dd 0x00000000, 0x80000000, 0x80000000, 0x00000000
>  
> @@ -138,4 +142,91 @@ cglobal fft15, 4, 6, 14, out, in, exptab, stride, stride3, stride5
>  
>      RET
>  
> +%macro LUT_LOAD_4D 3
> +    mov      r7d, [lutq + %3q*4 +  0]
> +    movsd  xmm%1, [inq +  r7q*8]
> +    mov      r7d, [lutq + %3q*4 +  4]
> +    movhps xmm%1, [inq +  r7q*8]
> +%if cpuflag(avx2)
> +    mov      r7d, [lutq + %3q*4 +  8]
> +    movsd     %2, [inq +  r7q*8]
> +    mov      r7d, [lutq + %3q*4 + 12]
> +    movhps    %2, [inq +  r7q*8]
> +    vinsertf128 %1, %1, %2, 1
> +%endif
> +%endmacro
> +
> +%macro POSTROTATE_FN 0
> +;**********************************************************************************************************
> +;void ff_mdct15_postreindex(FFTComplex *out, FFTComplex *in, FFTComplex *exp, uint32_t *lut, int64_t len8);
> +;**********************************************************************************************************

Nit: Move this above the LUT_LOAD_4D macro, so it's clear where all the
postreindex stuff starts.
Also, you forgot to replace the uint32_t and int64_t here.

> +cglobal mdct15_postreindex, 5, 8, 12, out, in, exp, lut, len8, offset_p, offset_n
> +%if cpuflag(avx2)
> +    %define INCREMENT 4
> +%else
> +    %define INCREMENT 2

You could make this a POSTROTATE_FN macro argument instead.

> +%endif
> +
> +    mova m7, [perm_pos]
> +    mova m8, [perm_neg]
> +    mova m9, [sign_adjust_r]

Change these three to movaps, since initializing the functions with sse3
and avx2 makes mova/u aliases of movdqa/u.

> +
> +    mov offset_pq, len8q
> +    lea offset_nq, [len8q - INCREMENT]
> +
> +    shl len8q, 1
> +
> +    movups m10, [outq - mmsize]         ; backup from start - mmsize to start
> +    movups m11, [outq + len8q*8]        ; backup from end to end + mmsize
> +
> +.loop:
> +    movups m0, [expq + offset_pq*8]     ; exp[p0].re, exp[p0].im, exp[p1].re, exp[p1].im, exp[p2].re, exp[p2].im, exp[p3].re, exp[p3].im
> +    movups m1, [expq + offset_nq*8]     ; exp[n3].re, exp[n3].im, exp[n2].re, exp[n2].im, exp[n1].re, exp[n1].im, exp[n0].re, exp[n0].im
> +
> +    LUT_LOAD_4D m3, xmm4, offset_p      ; in[p0].re, in[p0].im, in[p1].re, in[p1].im, in[p2].re, in[p2].im, in[p3].re, in[p3].im
> +    LUT_LOAD_4D m4, xmm5, offset_n      ; in[n3].re, in[n3].im, in[n2].re, in[n2].im, in[n1].re, in[n1].im, in[n0].re, in[n0].im

Nit: xm4 and xm5

> +
> +    mulps m5, m3, m0                    ; in[p].reim * exp[p].reim
> +    mulps m6, m4, m1                    ; in[n].reim * exp[n].reim
> +
> +    xorps m5, m9                        ; in[p].re *= -1, in[p].im *= 1
> +    xorps m6, m9                        ; in[n].re *= -1, in[n].im *= 1
> +
> +    shufps m3, m3, m3, q2301            ; in[p].imre
> +    shufps m4, m4, m4, q2301            ; in[n].imre
> +
> +    mulps m3, m0                        ; in[p].imre * exp[p].reim
> +    mulps m4, m1                        ; in[n].imre * exp[n].reim
> +
> +    haddps m5, m4                       ; out[p0].re, out[p1].re, out[p3].im, out[p2].im, out[p2].re, out[p3].re, out[p1].im, out[p0].im
> +    haddps m3, m6                       ; out[n0].im, out[n1].im, out[n3].re, out[n2].re, out[n2].im, out[n3].im, out[n1].re, out[n0].re
> +
> +%if cpuflag(avx2)
> +    vpermps m5, m7, m5                  ; out[p0].re, out[p0].im, out[p1].re, out[p1].im, out[p2].re, out[p2].im, out[p3].re, out[p3].im
> +    vpermps m3, m8, m3                  ; out[n3].im, out[n3].re, out[n2].im, out[n2].re, out[n1].im, out[n1].re, out[n0].im, out[n0].re
> +%else
> +    shufps m5, m5, m5, q2130
> +    shufps m3, m3, m3, q0312
> +%endif
> +
> +    movups [outq + offset_pq*8], m5
> +    movups [outq + offset_nq*8], m3
> +
> +    sub offset_nq, INCREMENT
> +    add offset_pq, INCREMENT
> +
> +    cmp offset_pq, len8q
> +    jl .loop
> +
> +    movups [outq - mmsize],  m10
> +    movups [outq + len8q*8], m11
> +
> +    RET
> +%endmacro
> +
> +INIT_XMM sse3
> +POSTROTATE_FN
> +INIT_YMM avx2
> +POSTROTATE_FN

Wrap the two avx2 lines in a HAVE_AVX2_EXTERNAL check or it will fail to
assemble with Yasm 1.1.0 and older.

> +
>  %endif
> diff --git a/libavcodec/x86/mdct15_init.c b/libavcodec/x86/mdct15_init.c
> index ba3d94c2ec..ec4ff42bb6 100644
> --- a/libavcodec/x86/mdct15_init.c
> +++ b/libavcodec/x86/mdct15_init.c
> @@ -25,6 +25,9 @@
>  #include "libavutil/x86/cpu.h"
>  #include "libavcodec/mdct15.h"
>  
> +void ff_mdct15_postreindex_sse3(FFTComplex *out, FFTComplex *in, FFTComplex *exp, int *lut, ptrdiff_t len8);
> +void ff_mdct15_postreindex_avx2(FFTComplex *out, FFTComplex *in, FFTComplex *exp, int *lut, ptrdiff_t len8);
> +
>  void ff_fft15_avx(FFTComplex *out, FFTComplex *in, FFTComplex *exptab, ptrdiff_t stride);
>  
>  static void perm_twiddles(MDCT15Context *s)
> @@ -90,6 +93,12 @@ av_cold void ff_mdct15_init_x86(MDCT15Context *s)
>          adjust_twiddles = 1;
>      }
>  
> +    if (ARCH_X86_64 && EXTERNAL_SSE3(cpu_flags))
> +        s->postreindex = ff_mdct15_postreindex_sse3;

SSE3 goes before AVX.

> +
> +    if (ARCH_X86_64 && EXTERNAL_AVX2(cpu_flags))

EXTERNAL_AVX2_FAST(cpu_flags)

> +        s->postreindex = ff_mdct15_postreindex_avx2;
> +
>      if (adjust_twiddles)
>          perm_twiddles(s);
>  }

Maybe poke Hendrik for his opinion, but it seems to work, so LGTM.


More information about the ffmpeg-devel mailing list