[FFmpeg-cvslog] avfilter/vf_nnedi: rewrite and cleanup code

Paul B Mahol git at videolan.org
Mon Jan 18 15:14:10 EET 2021


ffmpeg | branch: master | Paul B Mahol <onemda at gmail.com> | Sun Jan 17 17:39:28 2021 +0100| [117bf7394f7d5c47104bd30d141466decd01dda1] | committer: Paul B Mahol

avfilter/vf_nnedi: rewrite and cleanup code

Also add slice threading support.
Also add support for >8 depth formats.
Also add support for commands.

> http://git.videolan.org/gitweb.cgi/ffmpeg.git/?a=commit;h=117bf7394f7d5c47104bd30d141466decd01dda1
---

 doc/filters.texi       |   12 +-
 libavfilter/vf_nnedi.c | 1535 ++++++++++++++++++++++++------------------------
 2 files changed, 775 insertions(+), 772 deletions(-)

diff --git a/doc/filters.texi b/doc/filters.texi
index aa1389e407..3ce6699d7c 100644
--- a/doc/filters.texi
+++ b/doc/filters.texi
@@ -14649,9 +14649,9 @@ Set which set of weights to use in the predictor.
 Can be one of the following:
 
 @table @samp
- at item a
+ at item a, abs
 weights trained to minimize absolute error
- at item s
+ at item s, mse
 weights trained to minimize squared error
 @end table
 
@@ -14673,14 +14673,16 @@ Can be one of the following:
 @item none
 @item original
 @item new
+ at item new2
+ at item new3
 @end table
 
 Default is @code{new}.
-
- at item fapprox
-Set various debugging flags.
 @end table
 
+ at subsection Commands
+This filter supports same @ref{commands} as options, excluding @var{weights} option.
+
 @section noformat
 
 Force libavfilter not to use any of the specified pixel formats for the
diff --git a/libavfilter/vf_nnedi.c b/libavfilter/vf_nnedi.c
index 33ff503d92..7f209cb68c 100644
--- a/libavfilter/vf_nnedi.c
+++ b/libavfilter/vf_nnedi.c
@@ -24,6 +24,7 @@
 #include "libavutil/common.h"
 #include "libavutil/float_dsp.h"
 #include "libavutil/imgutils.h"
+#include "libavutil/mem_internal.h"
 #include "libavutil/opt.h"
 #include "libavutil/pixdesc.h"
 #include "avfilter.h"
@@ -31,21 +32,45 @@
 #include "internal.h"
 #include "video.h"
 
-typedef struct FrameData {
-    uint8_t *paddedp[3];
-    int padded_stride[3];
-    int padded_width[3];
-    int padded_height[3];
-
-    uint8_t *dstp[3];
-    int dst_stride[3];
-
-    int field[3];
-
-    int32_t *lcount[3];
-    float *input;
-    float *temp;
-} FrameData;
+static const size_t NNEDI_WEIGHTS_SIZE = 13574928;
+static const uint8_t NNEDI_XDIM[] = { 8, 16, 32, 48, 8, 16, 32 };
+static const uint8_t NNEDI_YDIM[] = { 6, 6, 6, 6, 4, 4, 4 };
+static const uint16_t NNEDI_NNS[] = { 16, 32, 64, 128, 256 };
+
+static const unsigned NNEDI_DIMS0 = 49 * 4 + 5 * 4 + 9 * 4;
+static const unsigned NNEDI_DIMS0_NEW = 4 * 65 + 4 * 5;
+
+typedef struct PrescreenerOldCoefficients {
+    DECLARE_ALIGNED(32, float, kernel_l0)[4][14 * 4];
+    float bias_l0[4];
+
+    DECLARE_ALIGNED(32, float, kernel_l1)[4][4];
+    float bias_l1[4];
+
+    DECLARE_ALIGNED(32, float, kernel_l2)[4][8];
+    float bias_l2[4];
+} PrescreenerOldCoefficients;
+
+typedef struct PrescreenerNewCoefficients {
+    DECLARE_ALIGNED(32, float, kernel_l0)[4][16 * 4];
+    float bias_l0[4];
+
+    DECLARE_ALIGNED(32, float, kernel_l1)[4][4];
+    float bias_l1[4];
+} PrescreenerNewCoefficients;
+
+typedef struct PredictorCoefficients {
+    int xdim, ydim, nns;
+    float *data;
+    float *softmax_q1;
+    float *elliott_q1;
+    float *softmax_bias_q1;
+    float *elliott_bias_q1;
+    float *softmax_q2;
+    float *elliott_q2;
+    float *softmax_bias_q2;
+    float *elliott_bias_q2;
+} PredictorCoefficients;
 
 typedef struct NNEDIContext {
     const AVClass *class;
@@ -59,16 +84,21 @@ typedef struct NNEDIContext {
     int64_t cur_pts;
 
     AVFloatDSPContext *fdsp;
+    int depth;
     int nb_planes;
+    int nb_threads;
     int linesize[4];
+    int planewidth[4];
     int planeheight[4];
+    int field_n;
+
+    PrescreenerOldCoefficients prescreener_old;
+    PrescreenerNewCoefficients prescreener_new[3];
+    PredictorCoefficients coeffs[2][5][7];
 
-    float *weights0;
-    float *weights1[2];
-    int asize;
-    int nns;
-    int xdia;
-    int ydia;
+    float half;
+    float in_scale;
+    float out_scale;
 
     // Parameters
     int deint;
@@ -79,104 +109,84 @@ typedef struct NNEDIContext {
     int qual;
     int etype;
     int pscrn;
-    int fapprox;
-
-    int max_value;
-
-    void (*copy_pad)(const AVFrame *, FrameData *, struct NNEDIContext *, int);
-    void (*evalfunc_0)(struct NNEDIContext *, FrameData *);
-    void (*evalfunc_1)(struct NNEDIContext *, FrameData *);
-
-    // Functions used in evalfunc_0
-    void (*readpixels)(const uint8_t *, const int, float *);
-    void (*compute_network0)(struct NNEDIContext *s, const float *, const float *, uint8_t *);
-    int32_t (*process_line0)(const uint8_t *, int, uint8_t *, const uint8_t *, const int, const int, const int);
-
-    // Functions used in evalfunc_1
-    void (*extract)(const uint8_t *, const int, const int, const int, float *, float *);
-    void (*dot_prod)(struct NNEDIContext *, const float *, const float *, float *, const int, const int, const float *);
-    void (*expfunc)(float *, const int);
-    void (*wae5)(const float *, const int, float *);
 
-    FrameData frame_data;
+    int input_size;
+    uint8_t *prescreen_buf;
+    float *input_buf;
+    float *output_buf;
+
+    void (*read)(const uint8_t *src, float *dst,
+                 int src_stride, int dst_stride,
+                 int width, int height, float scale);
+    void (*write)(const float *src, uint8_t *dst,
+                  int src_stride, int dst_stride,
+                  int width, int height, int depth, float scale);
+    void (*prescreen[2])(AVFilterContext *ctx,
+                         const void *src, ptrdiff_t src_stride,
+                         uint8_t *prescreen, int N, void *data);
 } NNEDIContext;
 
 #define OFFSET(x) offsetof(NNEDIContext, x)
+#define RFLAGS AV_OPT_FLAG_VIDEO_PARAM|AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_RUNTIME_PARAM
 #define FLAGS AV_OPT_FLAG_VIDEO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
 
 static const AVOption nnedi_options[] = {
     {"weights",  "set weights file", OFFSET(weights_file),  AV_OPT_TYPE_STRING, {.str="nnedi3_weights.bin"}, 0, 0, FLAGS },
-    {"deint",         "set which frames to deinterlace", OFFSET(deint),         AV_OPT_TYPE_INT, {.i64=0}, 0, 1, FLAGS, "deint" },
-        {"all",        "deinterlace all frames",                       0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "deint" },
-        {"interlaced", "only deinterlace frames marked as interlaced", 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "deint" },
-    {"field",  "set mode of operation", OFFSET(field),         AV_OPT_TYPE_INT, {.i64=-1}, -2, 3, FLAGS, "field" },
-        {"af", "use frame flags, both fields",  0, AV_OPT_TYPE_CONST, {.i64=-2}, 0, 0, FLAGS, "field" },
-        {"a",  "use frame flags, single field", 0, AV_OPT_TYPE_CONST, {.i64=-1}, 0, 0, FLAGS, "field" },
-        {"t",  "use top field only",            0, AV_OPT_TYPE_CONST, {.i64=0},  0, 0, FLAGS, "field" },
-        {"b",  "use bottom field only",         0, AV_OPT_TYPE_CONST, {.i64=1},  0, 0, FLAGS, "field" },
-        {"tf", "use both fields, top first",    0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "field" },
-        {"bf", "use both fields, bottom first", 0, AV_OPT_TYPE_CONST, {.i64=3}, 0, 0, FLAGS, "field" },
-    {"planes", "set which planes to process", OFFSET(process_plane), AV_OPT_TYPE_INT, {.i64=7}, 0, 7, FLAGS },
-    {"nsize",  "set size of local neighborhood around each pixel, used by the predictor neural network", OFFSET(nsize), AV_OPT_TYPE_INT, {.i64=6}, 0, 6, FLAGS, "nsize" },
-        {"s8x6",     NULL, 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "nsize" },
-        {"s16x6",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "nsize" },
-        {"s32x6",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "nsize" },
-        {"s48x6",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=3}, 0, 0, FLAGS, "nsize" },
-        {"s8x4",     NULL, 0, AV_OPT_TYPE_CONST, {.i64=4}, 0, 0, FLAGS, "nsize" },
-        {"s16x4",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=5}, 0, 0, FLAGS, "nsize" },
-        {"s32x4",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=6}, 0, 0, FLAGS, "nsize" },
-    {"nns",    "set number of neurons in predictor neural network", OFFSET(nnsparam), AV_OPT_TYPE_INT, {.i64=1}, 0, 4, FLAGS, "nns" },
-        {"n16",       NULL, 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "nns" },
-        {"n32",       NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "nns" },
-        {"n64",       NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "nns" },
-        {"n128",      NULL, 0, AV_OPT_TYPE_CONST, {.i64=3}, 0, 0, FLAGS, "nns" },
-        {"n256",      NULL, 0, AV_OPT_TYPE_CONST, {.i64=4}, 0, 0, FLAGS, "nns" },
-    {"qual",  "set quality", OFFSET(qual), AV_OPT_TYPE_INT, {.i64=1}, 1, 2, FLAGS, "qual" },
-        {"fast", NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "qual" },
-        {"slow", NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "qual" },
-    {"etype", "set which set of weights to use in the predictor", OFFSET(etype), AV_OPT_TYPE_INT, {.i64=0}, 0, 1, FLAGS, "etype" },
-        {"a",  "weights trained to minimize absolute error", 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "etype" },
-        {"s",  "weights trained to minimize squared error",  0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "etype" },
-    {"pscrn", "set prescreening", OFFSET(pscrn), AV_OPT_TYPE_INT, {.i64=2}, 0, 2, FLAGS, "pscrn" },
-        {"none",      NULL, 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, FLAGS, "pscrn" },
-        {"original",  NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, FLAGS, "pscrn" },
-        {"new",       NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, FLAGS, "pscrn" },
-    {"fapprox",       NULL, OFFSET(fapprox),       AV_OPT_TYPE_INT, {.i64=0}, 0, 3, FLAGS },
+    {"deint",         "set which frames to deinterlace", OFFSET(deint),         AV_OPT_TYPE_INT, {.i64=0}, 0, 1, RFLAGS, "deint" },
+        {"all",        "deinterlace all frames",                       0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, RFLAGS, "deint" },
+        {"interlaced", "only deinterlace frames marked as interlaced", 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, RFLAGS, "deint" },
+    {"field",  "set mode of operation", OFFSET(field),         AV_OPT_TYPE_INT, {.i64=-1}, -2, 3, RFLAGS, "field" },
+        {"af", "use frame flags, both fields",  0, AV_OPT_TYPE_CONST, {.i64=-2}, 0, 0, RFLAGS, "field" },
+        {"a",  "use frame flags, single field", 0, AV_OPT_TYPE_CONST, {.i64=-1}, 0, 0, RFLAGS, "field" },
+        {"t",  "use top field only",            0, AV_OPT_TYPE_CONST, {.i64=0},  0, 0, RFLAGS, "field" },
+        {"b",  "use bottom field only",         0, AV_OPT_TYPE_CONST, {.i64=1},  0, 0, RFLAGS, "field" },
+        {"tf", "use both fields, top first",    0, AV_OPT_TYPE_CONST, {.i64=2},  0, 0, RFLAGS, "field" },
+        {"bf", "use both fields, bottom first", 0, AV_OPT_TYPE_CONST, {.i64=3},  0, 0, RFLAGS, "field" },
+    {"planes", "set which planes to process", OFFSET(process_plane), AV_OPT_TYPE_INT, {.i64=7}, 0, 15, RFLAGS },
+    {"nsize",  "set size of local neighborhood around each pixel, used by the predictor neural network", OFFSET(nsize), AV_OPT_TYPE_INT, {.i64=6}, 0, 6, RFLAGS, "nsize" },
+        {"s8x6",     NULL, 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, RFLAGS, "nsize" },
+        {"s16x6",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, RFLAGS, "nsize" },
+        {"s32x6",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, RFLAGS, "nsize" },
+        {"s48x6",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=3}, 0, 0, RFLAGS, "nsize" },
+        {"s8x4",     NULL, 0, AV_OPT_TYPE_CONST, {.i64=4}, 0, 0, RFLAGS, "nsize" },
+        {"s16x4",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=5}, 0, 0, RFLAGS, "nsize" },
+        {"s32x4",    NULL, 0, AV_OPT_TYPE_CONST, {.i64=6}, 0, 0, RFLAGS, "nsize" },
+    {"nns",    "set number of neurons in predictor neural network", OFFSET(nnsparam), AV_OPT_TYPE_INT, {.i64=1}, 0, 4, RFLAGS, "nns" },
+        {"n16",       NULL, 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, RFLAGS, "nns" },
+        {"n32",       NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, RFLAGS, "nns" },
+        {"n64",       NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, RFLAGS, "nns" },
+        {"n128",      NULL, 0, AV_OPT_TYPE_CONST, {.i64=3}, 0, 0, RFLAGS, "nns" },
+        {"n256",      NULL, 0, AV_OPT_TYPE_CONST, {.i64=4}, 0, 0, RFLAGS, "nns" },
+    {"qual",  "set quality", OFFSET(qual), AV_OPT_TYPE_INT, {.i64=1}, 1, 2, RFLAGS, "qual" },
+        {"fast", NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, RFLAGS, "qual" },
+        {"slow", NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, RFLAGS, "qual" },
+    {"etype", "set which set of weights to use in the predictor", OFFSET(etype), AV_OPT_TYPE_INT, {.i64=0}, 0, 1, RFLAGS, "etype" },
+        {"a",  "weights trained to minimize absolute error", 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, RFLAGS, "etype" },
+        {"abs","weights trained to minimize absolute error", 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, RFLAGS, "etype" },
+        {"s",  "weights trained to minimize squared error",  0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, RFLAGS, "etype" },
+        {"mse","weights trained to minimize squared error",  0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, RFLAGS, "etype" },
+    {"pscrn", "set prescreening", OFFSET(pscrn), AV_OPT_TYPE_INT, {.i64=2}, 0, 4, RFLAGS, "pscrn" },
+        {"none",      NULL, 0, AV_OPT_TYPE_CONST, {.i64=0}, 0, 0, RFLAGS, "pscrn" },
+        {"original",  NULL, 0, AV_OPT_TYPE_CONST, {.i64=1}, 0, 0, RFLAGS, "pscrn" },
+        {"new",       NULL, 0, AV_OPT_TYPE_CONST, {.i64=2}, 0, 0, RFLAGS, "pscrn" },
+        {"new2",      NULL, 0, AV_OPT_TYPE_CONST, {.i64=3}, 0, 0, RFLAGS, "pscrn" },
+        {"new3",      NULL, 0, AV_OPT_TYPE_CONST, {.i64=4}, 0, 0, RFLAGS, "pscrn" },
     { NULL }
 };
 
 AVFILTER_DEFINE_CLASS(nnedi);
 
-static int config_input(AVFilterLink *inlink)
-{
-    AVFilterContext *ctx = inlink->dst;
-    NNEDIContext *s = ctx->priv;
-    const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format);
-    int ret;
-
-    s->nb_planes = av_pix_fmt_count_planes(inlink->format);
-    if ((ret = av_image_fill_linesizes(s->linesize, inlink->format, inlink->w)) < 0)
-        return ret;
-
-    s->planeheight[1] = s->planeheight[2] = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h);
-    s->planeheight[0] = s->planeheight[3] = inlink->h;
-
-    return 0;
-}
-
 static int config_output(AVFilterLink *outlink)
 {
     AVFilterContext *ctx = outlink->src;
-    NNEDIContext *s = ctx->priv;
 
     outlink->time_base.num = ctx->inputs[0]->time_base.num;
     outlink->time_base.den = ctx->inputs[0]->time_base.den * 2;
     outlink->w             = ctx->inputs[0]->w;
     outlink->h             = ctx->inputs[0]->h;
 
-    if (s->field > 1 || s->field == -2)
-        outlink->frame_rate = av_mul_q(ctx->inputs[0]->frame_rate,
-                                       (AVRational){2, 1});
+    outlink->frame_rate = av_mul_q(ctx->inputs[0]->frame_rate,
+                                   (AVRational){2, 1});
 
     return 0;
 }
@@ -184,14 +194,28 @@ static int config_output(AVFilterLink *outlink)
 static int query_formats(AVFilterContext *ctx)
 {
     static const enum AVPixelFormat pix_fmts[] = {
+        AV_PIX_FMT_GRAY8,
+        AV_PIX_FMT_GRAY9, AV_PIX_FMT_GRAY10, AV_PIX_FMT_GRAY12, AV_PIX_FMT_GRAY14, AV_PIX_FMT_GRAY16,
         AV_PIX_FMT_YUV410P, AV_PIX_FMT_YUV411P,
         AV_PIX_FMT_YUV420P, AV_PIX_FMT_YUV422P,
         AV_PIX_FMT_YUV440P, AV_PIX_FMT_YUV444P,
         AV_PIX_FMT_YUVJ444P, AV_PIX_FMT_YUVJ440P,
         AV_PIX_FMT_YUVJ422P, AV_PIX_FMT_YUVJ420P,
         AV_PIX_FMT_YUVJ411P,
-        AV_PIX_FMT_GBRP,
-        AV_PIX_FMT_GRAY8,
+        AV_PIX_FMT_YUVA420P, AV_PIX_FMT_YUVA422P, AV_PIX_FMT_YUVA444P,
+        AV_PIX_FMT_GBRP, AV_PIX_FMT_GBRAP,
+        AV_PIX_FMT_YUV420P9, AV_PIX_FMT_YUV422P9, AV_PIX_FMT_YUV444P9,
+        AV_PIX_FMT_YUV420P10, AV_PIX_FMT_YUV422P10, AV_PIX_FMT_YUV444P10,
+        AV_PIX_FMT_YUV440P10,
+        AV_PIX_FMT_YUV420P12, AV_PIX_FMT_YUV422P12, AV_PIX_FMT_YUV444P12,
+        AV_PIX_FMT_YUV440P12,
+        AV_PIX_FMT_YUV420P14, AV_PIX_FMT_YUV422P14, AV_PIX_FMT_YUV444P14,
+        AV_PIX_FMT_YUV420P16, AV_PIX_FMT_YUV422P16, AV_PIX_FMT_YUV444P16,
+        AV_PIX_FMT_GBRP9, AV_PIX_FMT_GBRP10, AV_PIX_FMT_GBRP12, AV_PIX_FMT_GBRP14, AV_PIX_FMT_GBRP16,
+        AV_PIX_FMT_YUVA444P9, AV_PIX_FMT_YUVA444P10, AV_PIX_FMT_YUVA444P12, AV_PIX_FMT_YUVA444P16,
+        AV_PIX_FMT_YUVA422P9, AV_PIX_FMT_YUVA422P10, AV_PIX_FMT_YUVA422P12, AV_PIX_FMT_YUVA422P16,
+        AV_PIX_FMT_YUVA420P9, AV_PIX_FMT_YUVA420P10, AV_PIX_FMT_YUVA420P16,
+        AV_PIX_FMT_GBRAP10,   AV_PIX_FMT_GBRAP12,    AV_PIX_FMT_GBRAP16,
         AV_PIX_FMT_NONE
     };
 
@@ -201,592 +225,480 @@ static int query_formats(AVFilterContext *ctx)
     return ff_set_common_formats(ctx, fmts_list);
 }
 
-static void copy_pad(const AVFrame *src, FrameData *frame_data, NNEDIContext *s, int fn)
+static float dot_dsp(NNEDIContext *s, const float *kernel, const float *input,
+                     unsigned n, float scale, float bias)
 {
-    const int off = 1 - fn;
-    int plane, y, x;
+    float sum;
 
-    for (plane = 0; plane < s->nb_planes; plane++) {
-        const uint8_t *srcp = (const uint8_t *)src->data[plane];
-        uint8_t *dstp = (uint8_t *)frame_data->paddedp[plane];
+    sum = s->fdsp->scalarproduct_float(kernel, input, n);
 
-        const int src_stride = src->linesize[plane];
-        const int dst_stride = frame_data->padded_stride[plane];
+    return sum * scale + bias;
+}
 
-        const int src_height = s->planeheight[plane];
-        const int dst_height = frame_data->padded_height[plane];
+static float dot_product(const float *kernel, const float *input,
+                         unsigned n, float scale, float bias)
+{
+    float sum = 0.0f;
 
-        const int src_width = s->linesize[plane];
-        const int dst_width = frame_data->padded_width[plane];
+    for (int i = 0; i < n; i++)
+        sum += kernel[i] * input[i];
 
-        int c = 4;
+    return sum * scale + bias;
+}
 
-        if (!(s->process_plane & (1 << plane)))
-            continue;
+static float elliott(float x)
+{
+    return x / (1.0f + fabsf(x));
+}
 
-        // Copy.
-        for (y = off; y < src_height; y += 2)
-            memcpy(dstp + 32 + (6 + y) * dst_stride,
-                   srcp + y * src_stride,
-                   src_width * sizeof(uint8_t));
+static void transform_elliott(float *input, int size)
+{
+    for (int i = 0; i < size; i++)
+        input[i] = elliott(input[i]);
+}
 
-        // And pad.
-        dstp += (6 + off) * dst_stride;
-        for (y = 6 + off; y < dst_height - 6; y += 2) {
-            int c = 2;
+static void process_old(AVFilterContext *ctx,
+                        const void *src, ptrdiff_t src_stride,
+                        uint8_t *prescreen, int N,
+                        void *data)
+{
+    NNEDIContext *s = ctx->priv;
+    PrescreenerOldCoefficients *m_data = data;
+    const float *src_p = src;
 
-            for (x = 0; x < 32; x++)
-                dstp[x] = dstp[64 - x];
+    // Adjust source pointer to point to top-left of filter window.
+    const float *window = src_p - 2 * src_stride - 5;
 
-            for (x = dst_width - 32; x < dst_width; x++, c += 2)
-                dstp[x] = dstp[x - c];
+    for (int j = 0; j < N; j++) {
+        LOCAL_ALIGNED_32(float, input, [48]);
+        float state[12];
 
-            dstp += dst_stride * 2;
-        }
+        for (int i = 0; i < 4; i++)
+            memcpy(input + i * 12, window + i * src_stride + j, 12 * sizeof(float));
 
-        dstp = (uint8_t *)frame_data->paddedp[plane];
-        for (y = off; y < 6; y += 2)
-            memcpy(dstp + y * dst_stride,
-                   dstp + (12 + 2 * off - y) * dst_stride,
-                   dst_width * sizeof(uint8_t));
+        // Layer 0.
+        for (int n = 0; n < 4; n++)
+            state[n] = dot_dsp(s, m_data->kernel_l0[n], input, 48, 1.0f, m_data->bias_l0[n]);
+        transform_elliott(state + 1, 3);
 
-        for (y = dst_height - 6 + off; y < dst_height; y += 2, c += 4)
-            memcpy(dstp + y * dst_stride,
-                   dstp + (y - c) * dst_stride,
-                   dst_width * sizeof(uint8_t));
-    }
-}
+        // Layer 1.
+        for (int n = 0; n < 4; n++)
+            state[n + 4] = dot_product(m_data->kernel_l1[n], state, 4, 1.0f, m_data->bias_l1[n]);
+        transform_elliott(state + 4, 3);
 
-static void elliott(float *data, const int n)
-{
-    int i;
+        // Layer 2.
+        for (int n = 0; n < 4; n++)
+            state[n + 8] = dot_product(m_data->kernel_l2[n], state, 8, 1.0f, m_data->bias_l2[n]);
 
-    for (i = 0; i < n; i++)
-        data[i] = data[i] / (1.0f + FFABS(data[i]));
+        prescreen[j] = FFMAX(state[10], state[11]) <= FFMAX(state[8], state[9]) ? 255 : 0;
+    }
 }
 
-static void dot_prod(NNEDIContext *s, const float *data, const float *weights, float *vals, const int n, const int len, const float *scale)
+static void process_new(AVFilterContext *ctx,
+                        const void *src, ptrdiff_t src_stride,
+                        uint8_t *prescreen, int N,
+                        void *data)
 {
-    int i;
+    NNEDIContext *s = ctx->priv;
+    PrescreenerNewCoefficients *m_data = data;
+    const float *src_p = src;
 
-    for (i = 0; i < n; i++) {
-        float sum;
+    // Adjust source pointer to point to top-left of filter window.
+    const float *window = src_p - 2 * src_stride - 6;
 
-        sum = s->fdsp->scalarproduct_float(data, &weights[i * len], len);
+    for (int j = 0; j < N; j += 4) {
+        LOCAL_ALIGNED_32(float, input, [64]);
+        float state[8];
 
-        vals[i] = sum * scale[0] + weights[n * len + i];
-    }
-}
+        for (int i = 0; i < 4; i++)
+            memcpy(input + i * 16, window + i * src_stride + j, 16 * sizeof(float));
 
-static void dot_prods(NNEDIContext *s, const float *dataf, const float *weightsf, float *vals, const int n, const int len, const float *scale)
-{
-    const int16_t *data = (int16_t *)dataf;
-    const int16_t *weights = (int16_t *)weightsf;
-    const float *wf = (float *)&weights[n * len];
-    int i, j;
+        for (int n = 0; n < 4; n++)
+            state[n] = dot_dsp(s, m_data->kernel_l0[n], input, 64, 1.0f, m_data->bias_l0[n]);
+        transform_elliott(state, 4);
 
-    for (i = 0; i < n; i++) {
-        int sum = 0, off = ((i >> 2) << 3) + (i & 3);
-        for (j = 0; j < len; j++)
-            sum += data[j] * weights[i * len + j];
+        for (int n = 0; n < 4; n++)
+            state[n + 4] = dot_product(m_data->kernel_l1[n], state, 4, 1.0f, m_data->bias_l1[n]);
 
-        vals[i] = sum * wf[off] * scale[0] + wf[off + 4];
+        for (int n = 0; n < 4; n++)
+            prescreen[j + n] = state[n + 4] > 0.f;
     }
 }
 
-static void compute_network0(NNEDIContext *s, const float *input, const float *weights, uint8_t *d)
+static size_t filter_offset(unsigned nn, PredictorCoefficients *model)
 {
-    float t, temp[12], scale = 1.0f;
-
-    dot_prod(s, input, weights, temp, 4, 48, &scale);
-    t = temp[0];
-    elliott(temp, 4);
-    temp[0] = t;
-    dot_prod(s, temp, weights + 4 * 49, temp + 4, 4, 4, &scale);
-    elliott(temp + 4, 4);
-    dot_prod(s, temp, weights + 4 * 49 + 4 * 5, temp + 8, 4, 8, &scale);
-    if (FFMAX(temp[10], temp[11]) <= FFMAX(temp[8], temp[9]))
-        d[0] = 1;
-    else
-        d[0] = 0;
+    return nn * model->xdim * model->ydim;
 }
 
-static void compute_network0_i16(NNEDIContext *s, const float *inputf, const float *weightsf, uint8_t *d)
+static const float *softmax_q1_filter(unsigned nn, PredictorCoefficients *model)
 {
-    const float *wf = weightsf + 2 * 48;
-    float t, temp[12], scale = 1.0f;
-
-    dot_prods(s, inputf, weightsf, temp, 4, 48, &scale);
-    t = temp[0];
-    elliott(temp, 4);
-    temp[0] = t;
-    dot_prod(s, temp, wf + 8, temp + 4, 4, 4, &scale);
-    elliott(temp + 4, 4);
-    dot_prod(s, temp, wf + 8 + 4 * 5, temp + 8, 4, 8, &scale);
-    if (FFMAX(temp[10], temp[11]) <= FFMAX(temp[8], temp[9]))
-        d[0] = 1;
-    else
-        d[0] = 0;
+    return model->softmax_q1 + filter_offset(nn, model);
 }
 
-static void pixel2float48(const uint8_t *t8, const int pitch, float *p)
+static const float *elliott_q1_filter(unsigned nn, PredictorCoefficients *model)
 {
-    const uint8_t *t = (const uint8_t *)t8;
-    int y, x;
-
-    for (y = 0; y < 4; y++)
-        for (x = 0; x < 12; x++)
-            p[y * 12 + x] = t[y * pitch * 2 + x];
+    return model->elliott_q1 + filter_offset(nn, model);
 }
 
-static void byte2word48(const uint8_t *t, const int pitch, float *pf)
+static const float *softmax_q2_filter(unsigned nn, PredictorCoefficients *model)
 {
-    int16_t *p = (int16_t *)pf;
-    int y, x;
-
-    for (y = 0; y < 4; y++)
-        for (x = 0; x < 12; x++)
-            p[y * 12 + x] = t[y * pitch * 2 + x];
+    return model->softmax_q2 + filter_offset(nn, model);
 }
 
-static int32_t process_line0(const uint8_t *tempu, int width, uint8_t *dstp8, const uint8_t *src3p8, const int src_pitch, const int max_value, const int chroma)
+static const float *elliott_q2_filter(unsigned nn, PredictorCoefficients *model)
 {
-    uint8_t *dstp = (uint8_t *)dstp8;
-    const uint8_t *src3p = (const uint8_t *)src3p8;
-    int minimum = 0;
-    int maximum = max_value - 1; // Technically the -1 is only needed for 8 and 16 bit input.
-    int count = 0, x;
-    for (x = 0; x < width; x++) {
-        if (tempu[x]) {
-            int tmp = 19 * (src3p[x + src_pitch * 2] + src3p[x + src_pitch * 4]) - 3 * (src3p[x] + src3p[x + src_pitch * 6]);
-            tmp /= 32;
-            dstp[x] = FFMAX(FFMIN(tmp, maximum), minimum);
-        } else {
-            dstp[x] = 255;
-            count++;
-        }
-    }
-    return count;
-}
-
-// new prescreener functions
-static void byte2word64(const uint8_t *t, const int pitch, float *p)
-{
-    int16_t *ps = (int16_t *)p;
-    int y, x;
-
-    for (y = 0; y < 4; y++)
-        for (x = 0; x < 16; x++)
-            ps[y * 16 + x] = t[y * pitch * 2 + x];
+    return model->elliott_q2 + filter_offset(nn, model);
 }
 
-static void compute_network0new(NNEDIContext *s, const float *datai, const float *weights, uint8_t *d)
+static void gather_input(const float *src, ptrdiff_t src_stride,
+                         float *buf, float mstd[4],
+                         PredictorCoefficients *model)
 {
-    int16_t *data = (int16_t *)datai;
-    int16_t *ws = (int16_t *)weights;
-    float *wf = (float *)&ws[4 * 64];
-    float vals[8];
-    int mask, i, j;
-
-    for (i = 0; i < 4; i++) {
-        int sum = 0;
-        float t;
-
-        for (j = 0; j < 64; j++)
-            sum += data[j] * ws[(i << 3) + ((j >> 3) << 5) + (j & 7)];
-        t = sum * wf[i] + wf[4 + i];
-        vals[i] = t / (1.0f + FFABS(t));
-    }
-
-    for (i = 0; i < 4; i++) {
-        float sum = 0.0f;
-
-        for (j = 0; j < 4; j++)
-            sum += vals[j] * wf[8 + i + (j << 2)];
-        vals[4 + i] = sum + wf[8 + 16 + i];
-    }
-
-    mask = 0;
-    for (i = 0; i < 4; i++) {
-        if (vals[4 + i] > 0.0f)
-            mask |= (0x1 << (i << 3));
-    }
-
-    ((int *)d)[0] = mask;
-}
-
-static void evalfunc_0(NNEDIContext *s, FrameData *frame_data)
-{
-    float *input = frame_data->input;
-    const float *weights0 = s->weights0;
-    float *temp = frame_data->temp;
-    uint8_t *tempu = (uint8_t *)temp;
-    int plane, x, y;
-
-    // And now the actual work.
-    for (plane = 0; plane < s->nb_planes; plane++) {
-        const uint8_t *srcp = (const uint8_t *)frame_data->paddedp[plane];
-        const int src_stride = frame_data->padded_stride[plane] / sizeof(uint8_t);
-
-        const int width = frame_data->padded_width[plane];
-        const int height = frame_data->padded_height[plane];
-
-        uint8_t *dstp = (uint8_t *)frame_data->dstp[plane];
-        const int dst_stride = frame_data->dst_stride[plane] / sizeof(uint8_t);
-        const uint8_t *src3p;
-        int ystart, ystop;
-        int32_t *lcount;
-
-        if (!(s->process_plane & (1 << plane)))
-            continue;
-
-        for (y = 1 - frame_data->field[plane]; y < height - 12; y += 2) {
-            memcpy(dstp + y * dst_stride,
-                   srcp + 32 + (6 + y) * src_stride,
-                   (width - 64) * sizeof(uint8_t));
+    float sum = 0;
+    float sum_sq = 0;
+    float tmp;
 
-        }
+    for (int i = 0; i < model->ydim; i++) {
+        for (int j = 0; j < model->xdim; j++) {
+            float val = src[i * src_stride + j];
 
-        ystart = 6 + frame_data->field[plane];
-        ystop = height - 6;
-        srcp += ystart * src_stride;
-        dstp += (ystart - 6) * dst_stride - 32;
-        src3p = srcp - src_stride * 3;
-        lcount = frame_data->lcount[plane] - 6;
-
-        if (s->pscrn == 1) { // original
-            for (y = ystart; y < ystop; y += 2) {
-                for (x = 32; x < width - 32; x++) {
-                    s->readpixels((const uint8_t *)(src3p + x - 5), src_stride, input);
-                    s->compute_network0(s, input, weights0, tempu+x);
-                }
-                lcount[y] += s->process_line0(tempu + 32, width - 64, (uint8_t *)(dstp + 32), (const uint8_t *)(src3p + 32), src_stride, s->max_value, plane);
-                src3p += src_stride * 2;
-                dstp += dst_stride * 2;
-            }
-        } else if (s->pscrn > 1) { // new
-            for (y = ystart; y < ystop; y += 2) {
-                for (x = 32; x < width - 32; x += 4) {
-                    s->readpixels((const uint8_t *)(src3p + x - 6), src_stride, input);
-                    s->compute_network0(s, input, weights0, tempu + x);
-                }
-                lcount[y] += s->process_line0(tempu + 32, width - 64, (uint8_t *)(dstp + 32), (const uint8_t *)(src3p + 32), src_stride, s->max_value, plane);
-                src3p += src_stride * 2;
-                dstp += dst_stride * 2;
-            }
-        } else { // no prescreening
-            for (y = ystart; y < ystop; y += 2) {
-                memset(dstp + 32, 255, (width - 64) * sizeof(uint8_t));
-                lcount[y] += width - 64;
-                dstp += dst_stride * 2;
-            }
+            buf[i * model->xdim + j] = val;
+            sum += val;
+            sum_sq += val * val;
         }
     }
-}
 
-static void extract_m8(const uint8_t *srcp8, const int stride, const int xdia, const int ydia, float *mstd, float *input)
-{
-    // uint8_t or uint16_t or float
-    const uint8_t *srcp = (const uint8_t *)srcp8;
-    float scale;
-    double tmp;
-
-    // int32_t or int64_t or double
-    int64_t sum = 0, sumsq = 0;
-    int y, x;
-
-    for (y = 0; y < ydia; y++) {
-        const uint8_t *srcpT = srcp + y * stride * 2;
-
-        for (x = 0; x < xdia; x++) {
-            sum += srcpT[x];
-            sumsq += (uint32_t)srcpT[x] * (uint32_t)srcpT[x];
-            input[x] = srcpT[x];
-        }
-        input += xdia;
-    }
-    scale = 1.0f / (xdia * ydia);
-    mstd[0] = sum * scale;
-    tmp = (double)sumsq * scale - (double)mstd[0] * mstd[0];
-    mstd[3] = 0.0f;
-    if (tmp <= FLT_EPSILON)
-        mstd[1] = mstd[2] = 0.0f;
-    else {
-        mstd[1] = sqrt(tmp);
+    mstd[0] = sum / (model->xdim * model->ydim);
+    mstd[3] = 0.f;
+
+    tmp = sum_sq / (model->xdim * model->ydim) - mstd[0] * mstd[0];
+    if (tmp < FLT_EPSILON) {
+        mstd[1] = 0.0f;
+        mstd[2] = 0.0f;
+    } else {
+        mstd[1] = sqrtf(tmp);
         mstd[2] = 1.0f / mstd[1];
     }
 }
 
-static void extract_m8_i16(const uint8_t *srcp, const int stride, const int xdia, const int ydia, float *mstd, float *inputf)
+static float softmax_exp(float x)
 {
-    int16_t *input = (int16_t *)inputf;
-    float scale;
-    int sum = 0, sumsq = 0;
-    int y, x;
-
-    for (y = 0; y < ydia; y++) {
-        const uint8_t *srcpT = srcp + y * stride * 2;
-        for (x = 0; x < xdia; x++) {
-            sum += srcpT[x];
-            sumsq += srcpT[x] * srcpT[x];
-            input[x] = srcpT[x];
-        }
-        input += xdia;
-    }
-    scale = 1.0f / (float)(xdia * ydia);
-    mstd[0] = sum * scale;
-    mstd[1] = sumsq * scale - mstd[0] * mstd[0];
-    mstd[3] = 0.0f;
-    if (mstd[1] <= FLT_EPSILON)
-        mstd[1] = mstd[2] = 0.0f;
-    else {
-        mstd[1] = sqrt(mstd[1]);
-        mstd[2] = 1.0f / mstd[1];
-    }
+    return expf(av_clipf(x, -80.f, 80.f));
 }
 
-
-static const float exp_lo = -80.0f;
-static const float exp_hi = +80.0f;
-
-static void e2_m16(float *s, const int n)
+static void transform_softmax_exp(float *input, int size)
 {
-    int i;
-
-    for (i = 0; i < n; i++)
-        s[i] = exp(av_clipf(s[i], exp_lo, exp_hi));
+    for (int i = 0; i < size; i++)
+        input[i] = softmax_exp(input[i]);
 }
 
-const float min_weight_sum = 1e-10f;
-
-static void weighted_avg_elliott_mul5_m16(const float *w, const int n, float *mstd)
+static void wae5(const float *softmax, const float *el,
+                 unsigned n, float mstd[4])
 {
     float vsum = 0.0f, wsum = 0.0f;
-    int i;
 
-    for (i = 0; i < n; i++) {
-        vsum += w[i] * (w[n + i] / (1.0f + FFABS(w[n + i])));
-        wsum += w[i];
+    for (int i = 0; i < n; i++) {
+        vsum += softmax[i] * elliott(el[i]);
+        wsum += softmax[i];
     }
-    if (wsum > min_weight_sum)
-        mstd[3] += ((5.0f * vsum) / wsum) * mstd[1] + mstd[0];
+
+    if (wsum > 1e-10f)
+        mstd[3] += (5.0f * vsum) / wsum * mstd[1] + mstd[0];
     else
         mstd[3] += mstd[0];
 }
 
-
-static void evalfunc_1(NNEDIContext *s, FrameData *frame_data)
+static void predictor(AVFilterContext *ctx,
+                      const void *src, ptrdiff_t src_stride, void *dst,
+                      const uint8_t *prescreen, int N,
+                      void *data, int use_q2)
 {
-    float *input = frame_data->input;
-    float *temp = frame_data->temp;
-    float **weights1 = s->weights1;
-    const int qual = s->qual;
-    const int asize = s->asize;
-    const int nns = s->nns;
-    const int xdia = s->xdia;
-    const int xdiad2m1 = (xdia / 2) - 1;
-    const int ydia = s->ydia;
-    const float scale = 1.0f / (float)qual;
-    int plane, y, x, i;
-
-    for (plane = 0; plane < s->nb_planes; plane++) {
-        const uint8_t *srcp = (const uint8_t *)frame_data->paddedp[plane];
-        const int src_stride = frame_data->padded_stride[plane] / sizeof(uint8_t);
-
-        const int width = frame_data->padded_width[plane];
-        const int height = frame_data->padded_height[plane];
-
-        uint8_t *dstp = (uint8_t *)frame_data->dstp[plane];
-        const int dst_stride = frame_data->dst_stride[plane] / sizeof(uint8_t);
-
-        const int ystart = frame_data->field[plane];
-        const int ystop = height - 12;
-        const uint8_t *srcpp;
-
-        if (!(s->process_plane & (1 << plane)))
+    NNEDIContext *s = ctx->priv;
+    PredictorCoefficients *model = data;
+    const float *src_p = src;
+    float *dst_p = dst;
+
+    // Adjust source pointer to point to top-left of filter window.
+    const float *window = src_p - (model->ydim / 2) * src_stride - (model->xdim / 2 - 1);
+    unsigned filter_size = model->xdim * model->ydim;
+    unsigned nns = model->nns;
+
+    for (int i = 0; i < N; i++) {
+        LOCAL_ALIGNED_32(float, input, [48 * 6]);
+        float activation[256 * 2];
+        float mstd[4];
+        float scale;
+
+        if (prescreen[i])
             continue;
 
-        srcp += (ystart + 6) * src_stride;
-        dstp += ystart * dst_stride - 32;
-        srcpp = srcp - (ydia - 1) * src_stride - xdiad2m1;
+        gather_input(window + i, src_stride, input, mstd, model);
+        scale = mstd[2];
 
-        for (y = ystart; y < ystop; y += 2) {
-            for (x = 32; x < width - 32; x++) {
-                float mstd[4];
+        for (int nn = 0; nn < nns; nn++)
+            activation[nn] = dot_dsp(s, softmax_q1_filter(nn, model), input, filter_size, scale, model->softmax_bias_q1[nn]);
 
-                if (dstp[x] != 255)
-                    continue;
+        for (int nn = 0; nn < nns; nn++)
+            activation[model->nns + nn] = dot_dsp(s, elliott_q1_filter(nn, model), input, filter_size, scale, model->elliott_bias_q1[nn]);
 
-                s->extract((const uint8_t *)(srcpp + x), src_stride, xdia, ydia, mstd, input);
-                for (i = 0; i < qual; i++) {
-                    s->dot_prod(s, input, weights1[i], temp, nns * 2, asize, mstd + 2);
-                    s->expfunc(temp, nns);
-                    s->wae5(temp, nns, mstd);
-                }
+        transform_softmax_exp(activation, nns);
+        wae5(activation, activation + nns, nns, mstd);
 
-                dstp[x] = FFMIN(FFMAX((int)(mstd[3] * scale + 0.5f), 0), s->max_value);
-            }
-            srcpp += src_stride * 2;
-            dstp += dst_stride * 2;
+        if (use_q2) {
+            for (int nn = 0; nn < nns; nn++)
+                activation[nn] = dot_dsp(s, softmax_q2_filter(nn, model), input, filter_size, scale, model->softmax_bias_q2[nn]);
+
+            for (int nn = 0; nn < nns; nn++)
+                activation[nns + nn] = dot_dsp(s, elliott_q2_filter(nn, model), input, filter_size, scale, model->elliott_bias_q2[nn]);
+
+            transform_softmax_exp(activation, nns);
+            wae5(activation, activation + nns, nns, mstd);
         }
+
+        dst_p[i] = mstd[3] / (use_q2 ? 2 : 1);
     }
 }
 
-#define NUM_NSIZE 7
-#define NUM_NNS 5
-
-static int roundds(const double f)
+static void read_bytes(const uint8_t *src, float *dst,
+                       int src_stride, int dst_stride,
+                       int width, int height, float scale)
 {
-    if (f - floor(f) >= 0.5)
-        return FFMIN((int)ceil(f), 32767);
-    return FFMAX((int)floor(f), -32768);
+    for (int y = 0; y < height; y++) {
+        for (int x = 0; x < 32; x++)
+            dst[-x - 1] = src[x];
+
+        for (int x = 0; x < width; x++)
+            dst[x] = src[x];
+
+        for (int x = 0; x < 32; x++)
+            dst[width + x] = src[width - x - 1];
+
+        dst += dst_stride;
+        src += src_stride;
+    }
 }
 
-static void select_functions(NNEDIContext *s)
+static void read_words(const uint8_t *srcp, float *dst,
+                       int src_stride, int dst_stride,
+                       int width, int height, float scale)
 {
-    s->copy_pad = copy_pad;
-    s->evalfunc_0 = evalfunc_0;
-    s->evalfunc_1 = evalfunc_1;
+    const uint16_t *src = (const uint16_t *)srcp;
 
-    // evalfunc_0
-    s->process_line0 = process_line0;
+    src_stride /= 2;
 
-    if (s->pscrn < 2) { // original prescreener
-        if (s->fapprox & 1) { // int16 dot products
-            s->readpixels = byte2word48;
-            s->compute_network0 = compute_network0_i16;
-        } else {
-            s->readpixels = pixel2float48;
-            s->compute_network0 = compute_network0;
-        }
-    } else { // new prescreener
-        // only int16 dot products
-        s->readpixels = byte2word64;
-        s->compute_network0 = compute_network0new;
-    }
+    for (int y = 0; y < height; y++) {
+        for (int x = 0; x < 32; x++)
+            dst[-x - 1] = src[x] * scale;
 
-    // evalfunc_1
-    s->wae5 = weighted_avg_elliott_mul5_m16;
+        for (int x = 0; x < width; x++)
+            dst[x] = src[x] * scale;
 
-    if (s->fapprox & 2) { // use int16 dot products
-        s->extract = extract_m8_i16;
-        s->dot_prod = dot_prods;
-    } else { // use float dot products
-        s->extract = extract_m8;
-        s->dot_prod = dot_prod;
-    }
+        for (int x = 0; x < 32; x++)
+            dst[width + x] = src[width - x - 1] * scale;
 
-    s->expfunc = e2_m16;
+        dst += dst_stride;
+        src += src_stride;
+    }
 }
 
-static int modnpf(const int m, const int n)
+static void write_bytes(const float *src, uint8_t *dst,
+                        int src_stride, int dst_stride,
+                        int width, int height, int depth,
+                        float scale)
 {
-    if ((m % n) == 0)
-        return m;
-    return m + n - (m % n);
+    for (int y = 0; y < height; y++) {
+        for (int x = 0; x < width; x++)
+            dst[x] = av_clip_uint8(src[x]);
+
+        dst += dst_stride;
+        src += src_stride;
+    }
 }
 
-static int get_frame(AVFilterContext *ctx, int is_second)
+static void write_words(const float *src, uint8_t *dstp,
+                        int src_stride, int dst_stride,
+                        int width, int height, int depth,
+                        float scale)
 {
-    NNEDIContext *s = ctx->priv;
-    AVFilterLink *outlink = ctx->outputs[0];
-    AVFrame *src = s->src;
-    FrameData *frame_data;
-    int effective_field = s->field;
-    size_t temp_size;
-    int field_n;
-    int plane;
+    uint16_t *dst = (uint16_t *)dstp;
 
-    if (effective_field > 1)
-        effective_field -= 2;
-    else if (effective_field < 0)
-        effective_field += 2;
+    dst_stride /= 2;
 
-    if (s->field < 0 && src->interlaced_frame && src->top_field_first == 0)
-        effective_field = 0;
-    else if (s->field < 0 && src->interlaced_frame && src->top_field_first == 1)
-        effective_field = 1;
-    else
-        effective_field = !effective_field;
+    for (int y = 0; y < height; y++) {
+        for (int x = 0; x < width; x++)
+            dst[x] = av_clip_uintp2_c(src[x] * scale, depth);
 
-    if (s->field > 1 || s->field == -2) {
-        if (is_second) {
-            field_n = (effective_field == 0);
-        } else {
-            field_n = (effective_field == 1);
-        }
-    } else {
-        field_n = effective_field;
+        dst += dst_stride;
+        src += src_stride;
     }
+}
 
-    s->dst = ff_get_video_buffer(outlink, outlink->w, outlink->h);
-    if (!s->dst)
-        return AVERROR(ENOMEM);
-    av_frame_copy_props(s->dst, src);
-    s->dst->interlaced_frame = 0;
+static void interpolation(const void *src, ptrdiff_t src_stride,
+                          void *dst, const uint8_t *prescreen, unsigned n)
+{
+    const float *src_p = src;
+    float *dst_p = dst;
+    const float *window = src_p - 2 * src_stride;
 
-    frame_data = &s->frame_data;
+    for (int i = 0; i < n; i++) {
+        float accum = 0.0f;
 
-    for (plane = 0; plane < s->nb_planes; plane++) {
-        int dst_height = s->planeheight[plane];
-        int dst_width = s->linesize[plane];
+        if (!prescreen[i])
+            continue;
 
-        const int min_alignment = 16;
-        const int min_pad = 10;
+        accum += (-3.0f / 32.0f) * window[0 * src_stride + i];
+        accum += (19.0f / 32.0f) * window[1 * src_stride + i];
+        accum += (19.0f / 32.0f) * window[2 * src_stride + i];
+        accum += (-3.0f / 32.0f) * window[3 * src_stride + i];
 
-        if (!(s->process_plane & (1 << plane))) {
-            av_image_copy_plane(s->dst->data[plane], s->dst->linesize[plane],
-                                src->data[plane], src->linesize[plane],
-                                s->linesize[plane],
-                                s->planeheight[plane]);
+        dst_p[i] = accum;
+    }
+}
+
+static int filter_slice(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
+{
+    NNEDIContext *s = ctx->priv;
+    AVFrame *out = s->dst;
+    AVFrame *in = s->src;
+    const float in_scale = s->in_scale;
+    const float out_scale = s->out_scale;
+    const int depth = s->depth;
+    const int interlaced = in->interlaced_frame;
+    const int tff = s->field_n == (s->field < 0 ? interlaced ? in->top_field_first : 1 :
+                                  (s->field & 1) ^ 1);
+
+
+    for (int p = 0; p < s->nb_planes; p++) {
+        const int height = s->planeheight[p];
+        const int width = s->planewidth[p];
+        const int slice_start = 2 * ((height / 2 * jobnr) / nb_jobs);
+        const int slice_end = 2 * ((height / 2 * (jobnr+1)) / nb_jobs);
+        const uint8_t *src_data = in->data[p];
+        uint8_t *dst_data = out->data[p];
+        uint8_t *dst = out->data[p] + slice_start * out->linesize[p];
+        const int src_linesize = in->linesize[p];
+        const int dst_linesize = out->linesize[p];
+        uint8_t *prescreen_buf = s->prescreen_buf + s->planewidth[0] * jobnr;
+        float *srcbuf = s->input_buf + s->input_size * jobnr;
+        const int srcbuf_stride = width + 64;
+        float *dstbuf = s->output_buf + s->input_size * jobnr;
+        const int dstbuf_stride = width;
+        const int slice_height = (slice_end - slice_start) / 2;
+        const int last_slice = slice_end == height;
+        const uint8_t *in_line;
+        uint8_t *out_line;
+        int y_out;
+
+        if (!(s->process_plane & (1 << p))) {
+            av_image_copy_plane(dst, out->linesize[p],
+                                in->data[p] + slice_start * in->linesize[p],
+                                in->linesize[p],
+                                s->linesize[p], slice_end - slice_start);
             continue;
         }
 
-        frame_data->padded_width[plane]  = dst_width + 64;
-        frame_data->padded_height[plane] = dst_height + 12;
-        frame_data->padded_stride[plane] = modnpf(frame_data->padded_width[plane] + min_pad, min_alignment); // TODO: maybe min_pad is in pixels too?
-        if (!frame_data->paddedp[plane]) {
-            frame_data->paddedp[plane] = av_malloc_array(frame_data->padded_stride[plane], frame_data->padded_height[plane]);
-            if (!frame_data->paddedp[plane])
-                return AVERROR(ENOMEM);
+        y_out    = slice_start + (tff ^ (slice_start & 1));
+        in_line  = src_data + (y_out * src_linesize);
+        out_line = dst_data + (y_out * dst_linesize);
+
+        while (y_out < slice_end) {
+            memcpy(out_line, in_line, s->linesize[p]);
+            y_out += 2;
+            in_line  += src_linesize * 2;
+            out_line += dst_linesize * 2;
         }
 
-        frame_data->dstp[plane] = s->dst->data[plane];
-        frame_data->dst_stride[plane] = s->dst->linesize[plane];
+        y_out = slice_start + ((!tff) ^ (slice_start & 1));
+
+        s->read(src_data + FFMAX(y_out - 5, tff) * src_linesize,
+                srcbuf + 32,
+                src_linesize * 2, srcbuf_stride,
+                width, 1, in_scale);
+        srcbuf += srcbuf_stride;
+
+        s->read(src_data + FFMAX(y_out - 3, tff) * src_linesize,
+                srcbuf + 32,
+                src_linesize * 2, srcbuf_stride,
+                width, 1, in_scale);
+        srcbuf += srcbuf_stride;
+
+        s->read(src_data + FFMAX(y_out - 1, tff) * src_linesize,
+                srcbuf + 32,
+                src_linesize * 2, srcbuf_stride,
+                width, 1, in_scale);
+        srcbuf += srcbuf_stride;
+
+        in_line  = src_data + FFMIN(y_out + 1, height - 1 - !tff) * src_linesize;
+        out_line = dst_data + (y_out * dst_linesize);
+
+        s->read(in_line, srcbuf + 32, src_linesize * 2, srcbuf_stride,
+                width, slice_height - last_slice, in_scale);
+
+        y_out += (slice_height - last_slice) * 2;
+
+        s->read(src_data + FFMIN(y_out + 1, height - 1 - !tff) * src_linesize,
+                srcbuf + 32 + srcbuf_stride * (slice_height - last_slice),
+                src_linesize * 2, srcbuf_stride,
+                width, 1, in_scale);
+
+        s->read(src_data + FFMIN(y_out + 3, height - 1 - !tff) * src_linesize,
+                srcbuf + 32 + srcbuf_stride * (slice_height + 1 - last_slice),
+                src_linesize * 2, srcbuf_stride,
+                width, 1, in_scale);
+
+        s->read(src_data + FFMIN(y_out + 5, height - 1 - !tff) * src_linesize,
+                srcbuf + 32 + srcbuf_stride * (slice_height + 2 - last_slice),
+                src_linesize * 2, srcbuf_stride,
+                width, 1, in_scale);
+
+        for (int y = 0; y < slice_end - slice_start; y += 2) {
+            if (s->pscrn > 1) {
+                s->prescreen[1](ctx, srcbuf + (y / 2) * srcbuf_stride + 32,
+                                srcbuf_stride, prescreen_buf, width,
+                                &s->prescreener_new[s->pscrn - 2]);
+            } else if (s->pscrn == 1) {
+                s->prescreen[0](ctx, srcbuf + (y / 2) * srcbuf_stride + 32,
+                                srcbuf_stride, prescreen_buf, width,
+                                &s->prescreener_old);
+            }
 
-        if (!frame_data->lcount[plane]) {
-            frame_data->lcount[plane] = av_calloc(dst_height, sizeof(int32_t) * 16);
-            if (!frame_data->lcount[plane])
-                return AVERROR(ENOMEM);
-        } else {
-            memset(frame_data->lcount[plane], 0, dst_height * sizeof(int32_t) * 16);
+            predictor(ctx,
+                      srcbuf + (y / 2) * srcbuf_stride + 32,
+                      srcbuf_stride,
+                      dstbuf + (y / 2) * dstbuf_stride,
+                      prescreen_buf, width,
+                      &s->coeffs[s->etype][s->nnsparam][s->nsize], s->qual == 2);
+
+            if (s->prescreen > 0)
+                interpolation(srcbuf + (y / 2) * srcbuf_stride + 32,
+                              srcbuf_stride,
+                              dstbuf + (y / 2) * dstbuf_stride,
+                              prescreen_buf, width);
         }
 
-        frame_data->field[plane] = field_n;
+        s->write(dstbuf, out_line, dstbuf_stride, dst_linesize * 2,
+                 width, slice_height, depth, out_scale);
     }
 
-    if (!frame_data->input) {
-        frame_data->input = av_malloc(512 * sizeof(float));
-        if (!frame_data->input)
-            return AVERROR(ENOMEM);
-    }
-    // evalfunc_0 requires at least padded_width[0] bytes.
-    // evalfunc_1 requires at least 512 floats.
-    if (!frame_data->temp) {
-        temp_size = FFMAX(frame_data->padded_width[0], 512 * sizeof(float));
-        frame_data->temp = av_malloc(temp_size);
-        if (!frame_data->temp)
-            return AVERROR(ENOMEM);
-    }
+    return 0;
+}
+
+static int get_frame(AVFilterContext *ctx, int is_second)
+{
+    NNEDIContext *s = ctx->priv;
+    AVFilterLink *outlink = ctx->outputs[0];
+    AVFrame *src = s->src;
 
-    // Copy src to a padded "frame" in frame_data and mirror the edges.
-    s->copy_pad(src, frame_data, s, field_n);
+    s->dst = ff_get_video_buffer(outlink, outlink->w, outlink->h);
+    if (!s->dst)
+        return AVERROR(ENOMEM);
+    av_frame_copy_props(s->dst, src);
+    s->dst->interlaced_frame = 0;
 
-    // Handles prescreening and the cubic interpolation.
-    s->evalfunc_0(s, frame_data);
+    ctx->internal->execute(ctx, filter_slice, NULL, NULL, FFMIN(s->planeheight[1] / 2, s->nb_threads));
 
-    // The rest.
-    s->evalfunc_1(s, frame_data);
+    if (s->field == -2 || s->field > 1)
+        s->field_n = !s->field_n;
 
     return 0;
 }
@@ -904,23 +816,221 @@ static int request_frame(AVFilterLink *link)
     return 0;
 }
 
+static void read(float *dst, size_t n, const float **data)
+{
+    memcpy(dst, *data, n * sizeof(float));
+    *data += n;
+}
+
+static float *allocate(float **ptr, size_t size)
+{
+    float *ret = *ptr;
+
+    *ptr += size;
+
+    return ret;
+}
+
+static int allocate_model(PredictorCoefficients *coeffs, int xdim, int ydim, int nns)
+{
+    size_t filter_size = nns * xdim * ydim;
+    size_t bias_size = nns;
+    float *data;
+
+    data = av_malloc_array(filter_size + bias_size, 4 * sizeof(float));
+    if (!data)
+        return AVERROR(ENOMEM);
+
+    coeffs->data = data;
+    coeffs->xdim = xdim;
+    coeffs->ydim = ydim;
+    coeffs->nns  = nns;
+
+    coeffs->softmax_q1 = allocate(&data, filter_size);
+    coeffs->elliott_q1 = allocate(&data, filter_size);
+    coeffs->softmax_bias_q1 = allocate(&data, bias_size);
+    coeffs->elliott_bias_q1 = allocate(&data, bias_size);
+
+    coeffs->softmax_q2 = allocate(&data, filter_size);
+    coeffs->elliott_q2 = allocate(&data, filter_size);
+    coeffs->softmax_bias_q2 = allocate(&data, bias_size);
+    coeffs->elliott_bias_q2 = allocate(&data, bias_size);
+
+    return 0;
+}
+
+static int read_weights(AVFilterContext *ctx, const float *bdata)
+{
+    NNEDIContext *s = ctx->priv;
+    int ret;
+
+    read(&s->prescreener_old.kernel_l0[0][0], 4 * 48, &bdata);
+    read(s->prescreener_old.bias_l0, 4, &bdata);
+
+    read(&s->prescreener_old.kernel_l1[0][0], 4 * 4, &bdata);
+    read(s->prescreener_old.bias_l1, 4, &bdata);
+
+    read(&s->prescreener_old.kernel_l2[0][0], 4 * 8, &bdata);
+    read(s->prescreener_old.bias_l2, 4, &bdata);
+
+    for (int i = 0; i < 3; i++) {
+        PrescreenerNewCoefficients *data = &s->prescreener_new[i];
+        float kernel_l0_shuffled[4 * 64];
+        float kernel_l1_shuffled[4 * 4];
+
+        read(kernel_l0_shuffled, 4 * 64, &bdata);
+        read(data->bias_l0, 4, &bdata);
+
+        read(kernel_l1_shuffled, 4 * 4, &bdata);
+        read(data->bias_l1, 4, &bdata);
+
+        for (int n = 0; n < 4; n++) {
+            for (int k = 0; k < 64; k++)
+                data->kernel_l0[n][k] = kernel_l0_shuffled[(k / 8) * 32 + n * 8 + k % 8];
+            for (int k = 0; k < 4; k++)
+                data->kernel_l1[n][k] = kernel_l1_shuffled[k * 4 + n];
+        }
+    }
+
+    for (int m = 0; m < 2; m++) {
+        // Grouping by neuron count.
+        for (int i = 0; i < 5; i++) {
+            int nns = NNEDI_NNS[i];
+
+            // Grouping by window size.
+            for (int j = 0; j < 7; j++) {
+                PredictorCoefficients *model = &s->coeffs[m][i][j];
+                int xdim = NNEDI_XDIM[j];
+                int ydim = NNEDI_YDIM[j];
+                size_t filter_size = xdim * ydim;
+
+                ret = allocate_model(model, xdim, ydim, nns);
+                if (ret < 0)
+                    return ret;
+
+                // Quality 1 model. NNS[i] * (XDIM[j] * YDIM[j]) * 2 coefficients.
+                read(model->softmax_q1, nns * filter_size, &bdata);
+                read(model->elliott_q1, nns * filter_size, &bdata);
+
+                // Quality 1 model bias. NNS[i] * 2 coefficients.
+                read(model->softmax_bias_q1, nns, &bdata);
+                read(model->elliott_bias_q1, nns, &bdata);
+
+                // Quality 2 model. NNS[i] * (XDIM[j] * YDIM[j]) * 2 coefficients.
+                read(model->softmax_q2, nns * filter_size, &bdata);
+                read(model->elliott_q2, nns * filter_size, &bdata);
+
+                // Quality 2 model bias. NNS[i] * 2 coefficients.
+                read(model->softmax_bias_q2, nns, &bdata);
+                read(model->elliott_bias_q2, nns, &bdata);
+            }
+        }
+    }
+
+    return 0;
+}
+
+static float mean(const float *input, int size)
+{
+    float sum = 0.;
+
+    for (int i = 0; i < size; i++)
+        sum += input[i];
+
+    return sum / size;
+}
+
+static void transform(float *input, int size, float mean, float half)
+{
+    for (int i = 0; i < size; i++)
+        input[i] = (input[i] - mean) / half;
+}
+
+static void subtract_mean_old(PrescreenerOldCoefficients *coeffs, float half)
+{
+    for (int n = 0; n < 4; n++) {
+        float m = mean(coeffs->kernel_l0[n], 48);
+
+        transform(coeffs->kernel_l0[n], 48, m, half);
+    }
+}
+
+static void subtract_mean_new(PrescreenerNewCoefficients *coeffs, float half)
+{
+    for (int n = 0; n < 4; n++) {
+        float m = mean(coeffs->kernel_l0[n], 64);
+
+        transform(coeffs->kernel_l0[n], 64, m, half);
+    }
+}
+
+static void subtract_mean_predictor(PredictorCoefficients *model)
+{
+    size_t filter_size = model->xdim * model->ydim;
+    int nns = model->nns;
+
+    float softmax_means[256]; // Average of individual softmax filters.
+    float elliott_means[256]; // Average of individual elliott filters.
+    float mean_filter[48 * 6]; // Pointwise average of all softmax filters.
+    float mean_bias;
+
+    // Quality 1.
+    for (int nn = 0; nn < nns; nn++) {
+        softmax_means[nn] = mean(model->softmax_q1 + nn * filter_size, filter_size);
+        elliott_means[nn] = mean(model->elliott_q1 + nn * filter_size, filter_size);
+
+        for (int k = 0; k < filter_size; k++)
+            mean_filter[k] += model->softmax_q1[nn * filter_size + k] - softmax_means[nn];
+    }
+
+    for (int k = 0; k < filter_size; k++)
+        mean_filter[k] /= nns;
+
+    mean_bias = mean(model->softmax_bias_q1, nns);
+
+    for (int nn = 0; nn < nns; nn++) {
+        for (int k = 0; k < filter_size; k++) {
+            model->softmax_q1[nn * filter_size + k] -= softmax_means[nn] + mean_filter[k];
+            model->elliott_q1[nn * filter_size + k] -= elliott_means[nn];
+        }
+        model->softmax_bias_q1[nn] -= mean_bias;
+    }
+
+    // Quality 2.
+    memset(mean_filter, 0, 48 * 6 * sizeof(float));
+
+    for (int nn = 0; nn < nns; nn++) {
+        softmax_means[nn] = mean(model->softmax_q2 + nn * filter_size, filter_size);
+        elliott_means[nn] = mean(model->elliott_q2 + nn * filter_size, filter_size);
+
+        for (int k = 0; k < filter_size; k++) {
+            mean_filter[k] += model->softmax_q2[nn * filter_size + k] - softmax_means[nn];
+        }
+    }
+
+    for (int k = 0; k < filter_size; k++)
+        mean_filter[k] /= nns;
+
+    mean_bias = mean(model->softmax_bias_q2, nns);
+
+    for (unsigned nn = 0; nn < nns; nn++) {
+        for (unsigned k = 0; k < filter_size; k++) {
+            model->softmax_q2[nn * filter_size + k] -= softmax_means[nn] + mean_filter[k];
+            model->elliott_q2[nn * filter_size + k] -= elliott_means[nn];
+        }
+
+        model->softmax_bias_q2[nn] -= mean_bias;
+    }
+}
+
 static av_cold int init(AVFilterContext *ctx)
 {
     NNEDIContext *s = ctx->priv;
     FILE *weights_file = NULL;
-    int64_t expected_size = 13574928;
     int64_t weights_size;
     float *bdata;
     size_t bytes_read;
-    const int xdia_table[NUM_NSIZE] = { 8, 16, 32, 48, 8, 16, 32 };
-    const int ydia_table[NUM_NSIZE] = { 6, 6, 6, 6, 4, 4, 4 };
-    const int nns_table[NUM_NNS] = { 16, 32, 64, 128, 256 };
-    const int dims0 = 49 * 4 + 5 * 4 + 9 * 4;
-    const int dims0new = 4 * 65 + 4 * 5;
-    const int dims1 = nns_table[s->nnsparam] * 2 * (xdia_table[s->nsize] * ydia_table[s->nsize] + 1);
-    int dims1tsize = 0;
-    int dims1offset = 0;
-    int ret = 0, i, j, k;
+    int ret = 0;
 
     weights_file = av_fopen_utf8(s->weights_file, "rb");
     if (!weights_file) {
@@ -940,7 +1050,7 @@ static av_cold int init(AVFilterContext *ctx)
         fclose(weights_file);
         av_log(ctx, AV_LOG_ERROR, "Couldn't get size of weights file.\n");
         return AVERROR(EINVAL);
-    } else if (weights_size != expected_size) {
+    } else if (weights_size != NNEDI_WEIGHTS_SIZE) {
         fclose(weights_file);
         av_log(ctx, AV_LOG_ERROR, "Unexpected weights file size.\n");
         return AVERROR(EINVAL);
@@ -952,15 +1062,14 @@ static av_cold int init(AVFilterContext *ctx)
         return AVERROR(EINVAL);
     }
 
-    bdata = (float *)av_malloc(expected_size);
+    bdata = av_malloc(NNEDI_WEIGHTS_SIZE);
     if (!bdata) {
         fclose(weights_file);
         return AVERROR(ENOMEM);
     }
 
-    bytes_read = fread(bdata, 1, expected_size, weights_file);
-
-    if (bytes_read != (size_t)expected_size) {
+    bytes_read = fread(bdata, 1, NNEDI_WEIGHTS_SIZE, weights_file);
+    if (bytes_read != NNEDI_WEIGHTS_SIZE) {
         fclose(weights_file);
         ret = AVERROR_INVALIDDATA;
         av_log(ctx, AV_LOG_ERROR, "Couldn't read weights file.\n");
@@ -969,211 +1078,102 @@ static av_cold int init(AVFilterContext *ctx)
 
     fclose(weights_file);
 
-    for (j = 0; j < NUM_NNS; j++) {
-        for (i = 0; i < NUM_NSIZE; i++) {
-            if (i == s->nsize && j == s->nnsparam)
-                dims1offset = dims1tsize;
-            dims1tsize += nns_table[j] * 2 * (xdia_table[i] * ydia_table[i] + 1) * 2;
-        }
-    }
-
-    s->weights0 = av_malloc_array(FFMAX(dims0, dims0new), sizeof(float));
-    if (!s->weights0) {
+    s->fdsp = avpriv_float_dsp_alloc(0);
+    if (!s->fdsp) {
         ret = AVERROR(ENOMEM);
         goto fail;
     }
 
-    for (i = 0; i < 2; i++) {
-        s->weights1[i] = av_malloc_array(dims1, sizeof(float));
-        if (!s->weights1[i]) {
-            ret = AVERROR(ENOMEM);
-            goto fail;
-        }
-    }
+    ret = read_weights(ctx, bdata);
+    if (ret < 0)
+        goto fail;
 
-    // Adjust prescreener weights
-    if (s->pscrn >= 2) {// using new prescreener
-        const float *bdw;
-        int16_t *ws;
-        float *wf;
-        double mean[4] = { 0.0, 0.0, 0.0, 0.0 };
-        int *offt = av_calloc(4 * 64, sizeof(int));
-
-        if (!offt) {
-            ret = AVERROR(ENOMEM);
-            goto fail;
-        }
+fail:
+    av_free(bdata);
+    return ret;
+}
 
-        for (j = 0; j < 4; j++)
-            for (k = 0; k < 64; k++)
-                offt[j * 64 + k] = ((k >> 3) << 5) + ((j & 3) << 3) + (k & 7);
-
-        bdw = bdata + dims0 + dims0new * (s->pscrn - 2);
-        ws = (int16_t *)s->weights0;
-        wf = (float *)&ws[4 * 64];
-        // Calculate mean weight of each first layer neuron
-        for (j = 0; j < 4; j++) {
-            double cmean = 0.0;
-            for (k = 0; k < 64; k++)
-                cmean += bdw[offt[j * 64 + k]];
-            mean[j] = cmean / 64.0;
-        }
-        // Factor mean removal and 1.0/127.5 scaling
-        // into first layer weights. scale to int16 range
-        for (j = 0; j < 4; j++) {
-            double scale, mval = 0.0;
-
-            for (k = 0; k < 64; k++)
-                mval = FFMAX(mval, FFABS((bdw[offt[j * 64 + k]] - mean[j]) / 127.5));
-            scale = 32767.0 / mval;
-            for (k = 0; k < 64; k++)
-                ws[offt[j * 64 + k]] = roundds(((bdw[offt[j * 64 + k]] - mean[j]) / 127.5) * scale);
-            wf[j] = (float)(mval / 32767.0);
-        }
-        memcpy(wf + 4, bdw + 4 * 64, (dims0new - 4 * 64) * sizeof(float));
-        av_free(offt);
-    } else { // using old prescreener
-        double mean[4] = { 0.0, 0.0, 0.0, 0.0 };
-        // Calculate mean weight of each first layer neuron
-        for (j = 0; j < 4; j++) {
-            double cmean = 0.0;
-            for (k = 0; k < 48; k++)
-                cmean += bdata[j * 48 + k];
-            mean[j] = cmean / 48.0;
-        }
-        if (s->fapprox & 1) {// use int16 dot products in first layer
-            int16_t *ws = (int16_t *)s->weights0;
-            float *wf = (float *)&ws[4 * 48];
-            // Factor mean removal and 1.0/127.5 scaling
-            // into first layer weights. scale to int16 range
-            for (j = 0; j < 4; j++) {
-                double scale, mval = 0.0;
-                for (k = 0; k < 48; k++)
-                    mval = FFMAX(mval, FFABS((bdata[j * 48 + k] - mean[j]) / 127.5));
-                scale = 32767.0 / mval;
-                for (k = 0; k < 48; k++)
-                    ws[j * 48 + k] = roundds(((bdata[j * 48 + k] - mean[j]) / 127.5) * scale);
-                wf[j] = (float)(mval / 32767.0);
-            }
-            memcpy(wf + 4, bdata + 4 * 48, (dims0 - 4 * 48) * sizeof(float));
-        } else {// use float dot products in first layer
-            double half = (1 << 8) - 1;
-
-            half /= 2;
-
-            // Factor mean removal and 1.0/half scaling
-            // into first layer weights.
-            for (j = 0; j < 4; j++)
-                for (k = 0; k < 48; k++)
-                    s->weights0[j * 48 + k] = (float)((bdata[j * 48 + k] - mean[j]) / half);
-            memcpy(s->weights0 + 4 * 48, bdata + 4 * 48, (dims0 - 4 * 48) * sizeof(float));
-        }
+static int config_input(AVFilterLink *inlink)
+{
+    AVFilterContext *ctx = inlink->dst;
+    NNEDIContext *s = ctx->priv;
+    const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format);
+    int ret;
+
+    s->depth = desc->comp[0].depth;
+    s->nb_threads = ff_filter_get_nb_threads(ctx);
+    s->nb_planes = av_pix_fmt_count_planes(inlink->format);
+    if ((ret = av_image_fill_linesizes(s->linesize, inlink->format, inlink->w)) < 0)
+        return ret;
+
+    s->planewidth[1] = s->planewidth[2] = AV_CEIL_RSHIFT(inlink->w, desc->log2_chroma_w);
+    s->planewidth[0] = s->planewidth[3] = inlink->w;
+    s->planeheight[1] = s->planeheight[2] = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h);
+    s->planeheight[0] = s->planeheight[3] = inlink->h;
+
+    s->half = ((1 << 8) - 1) / 2.f;
+    s->out_scale = 1 << (s->depth - 8);
+    s->in_scale = 1.f / s->out_scale;
+
+    switch (s->depth) {
+    case 8:
+        s->read  = read_bytes;
+        s->write = write_bytes;
+        break;
+    default:
+        s->read  = read_words;
+        s->write = write_words;
+        break;
     }
 
-    // Adjust prediction weights
-    for (i = 0; i < 2; i++) {
-        const float *bdataT = bdata + dims0 + dims0new * 3 + dims1tsize * s->etype + dims1offset + i * dims1;
-        const int nnst = nns_table[s->nnsparam];
-        const int asize = xdia_table[s->nsize] * ydia_table[s->nsize];
-        const int boff = nnst * 2 * asize;
-        double *mean = (double *)av_calloc(asize + 1 + nnst * 2, sizeof(double));
-
-        if (!mean) {
-            ret = AVERROR(ENOMEM);
-            goto fail;
-        }
+    subtract_mean_old(&s->prescreener_old, s->half);
+    subtract_mean_new(&s->prescreener_new[0], s->half);
+    subtract_mean_new(&s->prescreener_new[1], s->half);
+    subtract_mean_new(&s->prescreener_new[2], s->half);
 
-        // Calculate mean weight of each neuron (ignore bias)
-        for (j = 0; j < nnst * 2; j++) {
-            double cmean = 0.0;
-            for (k = 0; k < asize; k++)
-                cmean += bdataT[j * asize + k];
-            mean[asize + 1 + j] = cmean / (double)asize;
-        }
-        // Calculate mean softmax neuron
-        for (j = 0; j < nnst; j++) {
-            for (k = 0; k < asize; k++)
-                mean[k] += bdataT[j * asize + k] - mean[asize + 1 + j];
-            mean[asize] += bdataT[boff + j];
-        }
-        for (j = 0; j < asize + 1; j++)
-            mean[j] /= (double)(nnst);
-
-        if (s->fapprox & 2) { // use int16 dot products
-            int16_t *ws = (int16_t *)s->weights1[i];
-            float *wf = (float *)&ws[nnst * 2 * asize];
-            // Factor mean removal into weights, remove global offset from
-            // softmax neurons, and scale weights to int16 range.
-            for (j = 0; j < nnst; j++) { // softmax neurons
-                double scale, mval = 0.0;
-                for (k = 0; k < asize; k++)
-                    mval = FFMAX(mval, FFABS(bdataT[j * asize + k] - mean[asize + 1 + j] - mean[k]));
-                scale = 32767.0 / mval;
-                for (k = 0; k < asize; k++)
-                    ws[j * asize + k] = roundds((bdataT[j * asize + k] - mean[asize + 1 + j] - mean[k]) * scale);
-                wf[(j >> 2) * 8 + (j & 3)] = (float)(mval / 32767.0);
-                wf[(j >> 2) * 8 + (j & 3) + 4] = (float)(bdataT[boff + j] - mean[asize]);
-            }
-            for (j = nnst; j < nnst * 2; j++) { // elliott neurons
-                double scale, mval = 0.0;
-                for (k = 0; k < asize; k++)
-                    mval = FFMAX(mval, FFABS(bdataT[j * asize + k] - mean[asize + 1 + j]));
-                scale = 32767.0 / mval;
-                for (k = 0; k < asize; k++)
-                    ws[j * asize + k] = roundds((bdataT[j * asize + k] - mean[asize + 1 + j]) * scale);
-                wf[(j >> 2) * 8 + (j & 3)] = (float)(mval / 32767.0);
-                wf[(j >> 2) * 8 + (j & 3) + 4] = bdataT[boff + j];
-            }
-        } else { // use float dot products
-            // Factor mean removal into weights, and remove global
-            // offset from softmax neurons.
-            for (j = 0; j < nnst * 2; j++) {
-                for (k = 0; k < asize; k++) {
-                    const double q = j < nnst ? mean[k] : 0.0;
-                    s->weights1[i][j * asize + k] = (float)(bdataT[j * asize + k] - mean[asize + 1 + j] - q);
-                }
-                s->weights1[i][boff + j] = (float)(bdataT[boff + j] - (j < nnst ? mean[asize] : 0.0));
-            }
+    s->prescreen[0] = process_old;
+    s->prescreen[1] = process_new;
+
+    for (int i = 0; i < 2; i++) {
+        for (int j = 0; j < 5; j++) {
+            for (int k = 0; k < 7; k++)
+                subtract_mean_predictor(&s->coeffs[i][j][k]);
         }
-        av_free(mean);
     }
 
-    s->nns = nns_table[s->nnsparam];
-    s->xdia = xdia_table[s->nsize];
-    s->ydia = ydia_table[s->nsize];
-    s->asize = xdia_table[s->nsize] * ydia_table[s->nsize];
-
-    s->max_value = 65535 >> 8;
+    s->prescreen_buf = av_calloc(s->nb_threads * s->planewidth[0], sizeof(*s->prescreen_buf));
+    if (!s->prescreen_buf)
+        return AVERROR(ENOMEM);
 
-    select_functions(s);
+    s->input_size = (s->planewidth[0] + 64) * (s->planeheight[0] + 6);
+    s->input_buf = av_calloc(s->nb_threads * s->input_size, sizeof(*s->input_buf));
+    if (!s->input_buf)
+        return AVERROR(ENOMEM);
 
-    s->fdsp = avpriv_float_dsp_alloc(0);
-    if (!s->fdsp)
-        ret = AVERROR(ENOMEM);
+    s->output_buf = av_calloc(s->nb_threads * s->input_size, sizeof(*s->output_buf));
+    if (!s->output_buf)
+        return AVERROR(ENOMEM);
 
-fail:
-    av_free(bdata);
-    return ret;
+    return 0;
 }
 
 static av_cold void uninit(AVFilterContext *ctx)
 {
     NNEDIContext *s = ctx->priv;
-    int i;
-
-    av_freep(&s->weights0);
 
-    for (i = 0; i < 2; i++)
-        av_freep(&s->weights1[i]);
+    av_freep(&s->prescreen_buf);
+    av_freep(&s->input_buf);
+    av_freep(&s->output_buf);
+    av_freep(&s->fdsp);
 
-    for (i = 0; i < s->nb_planes; i++) {
-        av_freep(&s->frame_data.paddedp[i]);
-        av_freep(&s->frame_data.lcount[i]);
+    for (int i = 0; i < 2; i++) {
+        for (int j = 0; j < 5; j++) {
+            for (int k = 0; k < 7; k++) {
+                av_freep(&s->coeffs[i][j][k].data);
+            }
+        }
     }
 
-    av_freep(&s->frame_data.input);
-    av_freep(&s->frame_data.temp);
-    av_freep(&s->fdsp);
     av_frame_free(&s->second);
 }
 
@@ -1207,5 +1207,6 @@ AVFilter ff_vf_nnedi = {
     .query_formats = query_formats,
     .inputs        = inputs,
     .outputs       = outputs,
-    .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL,
+    .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL | AVFILTER_FLAG_SLICE_THREADS,
+    .process_command = ff_filter_process_command,
 };



More information about the ffmpeg-cvslog mailing list