// gcc swscale.c -W -Wall -std=c99 -lavutil -lswscale

#include <inttypes.h>
#include <stdbool.h>
#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <libswscale/swscale.h>
#include <libavutil/opt.h>
#include <libavutil/pixfmt.h>

#define W 128
#define H 128
#define STRIDE_YUV W
#define STRIDE_RGBA (W * 4)

void scale(uint8_t *dst[3], uint8_t *src[3], int s_yuv, int s_range,
           int d_yuv, int d_range, int scale)
{
    struct SwsContext *sws = sws_alloc_context();

    int exact = SWS_FULL_CHR_H_INT | SWS_FULL_CHR_H_INP |
                SWS_ACCURATE_RND | SWS_BITEXACT;
    av_opt_set_int(sws, "sws_flags", exact | SWS_POINT /* | SWS_PRINT_INFO*/, 0);

    av_opt_set_int(sws, "srcw", W, 0);
    av_opt_set_int(sws, "srch", H, 0);
    av_opt_set_int(sws, "src_format", s_yuv ? PIX_FMT_YUV444P : PIX_FMT_RGBA, 0);

    av_opt_set_int(sws, "dstw", scale ? W/2 : W, 0);
    av_opt_set_int(sws, "dsth", scale ? H/2 : H, 0);
    av_opt_set_int(sws, "dst_format", d_yuv ? PIX_FMT_YUV444P : PIX_FMT_RGBA, 0);

    // needed? maybe not
    //av_opt_set_int(sws, "src_range", !!s_range, 0);
    //av_opt_set_int(sws, "dst_range", !!d_range, 0);

    sws_setColorspaceDetails(sws, sws_getCoefficients(SWS_CS_ITU709), !!s_range,
                             sws_getCoefficients(SWS_CS_ITU709), !!d_range,
                             0, 1 << 16, 1 << 16);

    int res = sws_init_context(sws, NULL, NULL);
    assert(res >= 0);

    int s_stride = s_yuv ? STRIDE_YUV : STRIDE_RGBA;
    int d_stride = d_yuv ? STRIDE_YUV : STRIDE_RGBA;
    int s_stride_a[3] = {s_stride, s_stride, s_stride};
    int d_stride_a[3] = {d_stride, d_stride, d_stride};
    sws_scale(sws, (const unsigned char *const *)src, s_stride_a, 0, H, dst, d_stride_a);
    sws_freeContext(sws);
}

void set_white(uint8_t *c[3], bool yuv, int val)
{
    if (yuv) {
        memset(c[0], val, STRIDE_YUV * H);
        memset(c[1], 128, STRIDE_YUV * H);
        memset(c[2], 128, STRIDE_YUV * H);
    } else {
        memset(c[0], val, STRIDE_RGBA * H);
    }
}

int get_white(uint8_t *c[3], bool yuv)
{
    if (yuv) {
        return c[0][W/4 + STRIDE_YUV * (H/4)];
    } else {
        return c[0][W/4 * 4 + STRIDE_RGBA * (H/4) + 1];
    }
}

uint8_t *a[3], *b[3];

void test(unsigned int mode)
{
    bool d_range = mode & 1;
    bool d_yuv = mode & 2;
    bool s_range = mode & 4;
    bool s_yuv = mode & 8;
    bool doscale = mode & 16;
    bool lower = mode & 32;
    int ival = lower ? 255 : 200;
    set_white(a, s_yuv, ival);
    scale(b, a, s_yuv, s_range, d_yuv, d_range, doscale);
    int oval = get_white(b, d_yuv);
    printf("%-9s", doscale ? "scale" : "no-scale");
    printf("%s", s_yuv ? "yuv" : "rgb");
    printf("-%s ", s_range ? "ful" : "lim");
    printf("%s", d_yuv ? "yuv" : "rgb");
    printf("-%s ", d_range ? "ful" : "lim");
    printf(" %d -> %d\n", ival, oval);
}

int main(int argc, char **argv)
{
    for (int i = 0; i < 3; i++) {
        a[i] = av_malloc(STRIDE_RGBA * H);
        b[i] = av_malloc(STRIDE_RGBA * H);
    }
    for (int i = 0; i < 64; i++) {
        test(i);
        if (i % 16 == 15)
            printf("\n");
    }
}
