~mht/merge-asm

71237a00411de371c628117659319255fc1e9bd3 — Martin Hafskjold Thoresen 1 year, 1 month ago 9aba901
Clean up a little
3 files changed, 91 insertions(+), 63 deletions(-)

A Makefile
A README.md
M main.c
A Makefile => Makefile +2 -0
@@ 0,0 1,2 @@
CC=gcc
CFLAGS=-O3 -march=native -masm=intel -Wall -Wpedantic

A README.md => README.md +21 -0
@@ 0,0 1,21 @@
# Merge Codegen Experiment

See [the blog post](https://mht.technology/post/merge).

To run, type 

```
make main
```

See the `Makefile` for flags used.

To run the branching benchmark, run `./main branch`.
To run the sorting benchmark, run `./main sorting`.

To get the stats from a run, use the `stats.py` script,
for instance like this:

```bash
make main && ./main sort | tee stats-sort | python stats.py
```

M main.c => main.c +68 -63
@@ 15,10 15,7 @@ static uint64_t YS[N];
static uint64_t ZS[2 * N];
static uint64_t TRUTH[2 * N];

void asm_nb_rev(uint64_t *, size_t, uint64_t *, size_t, uint64_t *, size_t);
void branching(uint64_t *, size_t, uint64_t *, size_t, uint64_t *, size_t);

// TODO: read and inline this into `take_time`.
// Get the difference between two `time`s.
time diff(time start, time end) {
  time temp;
  if ((end.tv_nsec - start.tv_nsec) < 0) {


@@ 31,6 28,17 @@ time diff(time start, time end) {
  return temp;
}

// Time the execution of the expression `fun`, and print it to `stdout`.
#define TIME(fun)                                                              \
  {                                                                            \
    time time1, time2;                                                         \
    clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &time1);                           \
    fun;                                                                       \
    clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &time2);                           \
    time d = diff(time1, time2);                                               \
    printf("%zu.%09zu\n", d.tv_sec, d.tv_nsec);                                \
  }

// Comparison for `qsort`
int uint64_t_cmp(const void *a, const void *b) {
  return *((uint64_t *)a) - *((uint64_t *)b);


@@ 241,9 249,8 @@ void asm_nb_rev(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax,
      "and %[u], %[j];"
      "test %[u], %[k];"
      "jnz 1b;"
      : [minxy] "=&r"(minxy), [xi] "+&r"(xi), [yj] "+&r"(yj),
        [t] "=&r"(t), [i] "+&r"(i), [j] "+&r"(j), [k] "+&r"(k),
        [u] "=&r"(u), [zse] "+&r"(zse)
      : [minxy] "=&r"(minxy), [xi] "+&r"(xi), [yj] "+&r"(yj), [t] "=&r"(t),
        [i] "+&r"(i), [j] "+&r"(j), [k] "+&r"(k), [u] "=&r"(u), [zse] "+&r"(zse)
      : [xse] "r"(xse), [yse] "r"(yse), [one] "r"(one)
      : "memory");



@@ 253,16 260,6 @@ void asm_nb_rev(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax,
    memcpy(zse + k, xse + i, -8 * k);
}

#define TIME(fun)                                                              \
  {                                                                            \
    time time1, time2;                                                         \
    clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &time1);                           \
    fun;                                                                       \
    clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &time2);                           \
    time d = diff(time1, time2);                                               \
    printf("%zu.%09zu\n", d.tv_sec, d.tv_nsec);                                \
  }

// Check that the resuling ZS is the same as TRUTH. If not, print of everything.
void check(void) {
  for (int i = 0; i < 2 * N; i++) {


@@ 285,34 282,35 @@ void check(void) {
  }
}

#define SORT(variant)                                                \
void sort_##variant##2(uint64_t *xs, size_t n, uint64_t *buf, int *into_buf) {\
  switch (n) {                                                       \
  case 1:                                                            \
    *buf = *xs;                                                      \
  case 0:                                                            \
    *into_buf = 1;                                                   \
    return;                                                          \
  }                                                                  \
  size_t h = n / 2;                                                  \
  int a_ib, b_ib;                                                    \
  sort_##variant##2(xs, h, buf, &a_ib);                              \
  sort_##variant##2(xs + h, n - h, buf + h, &b_ib);                  \
  *into_buf = a_ib ^ 1;                                              \
  if (a_ib == 1) {                                                   \
    variant(buf, h, buf + h, n - h, xs, n);                          \
  } else {                                                           \
    variant(xs, h, xs + h, n - h, buf, n);                           \
  }                                                                  \
}                                                                    \
                                                                     \
void sort_##variant(uint64_t *xs, size_t n, uint64_t *buf) {       \
  int buf_contains_result;                                           \
  sort_##variant##2(xs, n, buf, &buf_contains_result);               \
  if (buf_contains_result) {                                         \
    memcpy(xs, buf, 8 * n);                                          \
  }                                                                  \
}
#define SORT(variant)                                                          \
  void sort_##variant##2(uint64_t * xs, size_t n, uint64_t * buf,              \
                         int *into_buf) {                                      \
    switch (n) {                                                               \
    case 1:                                                                    \
      *buf = *xs;                                                              \
    case 0:                                                                    \
      *into_buf = 1;                                                           \
      return;                                                                  \
    }                                                                          \
    size_t h = n / 2;                                                          \
    int a_ib, b_ib;                                                            \
    sort_##variant##2(xs, h, buf, &a_ib);                                      \
    sort_##variant##2(xs + h, n - h, buf + h, &b_ib);                          \
    *into_buf = a_ib ^ 1;                                                      \
    if (a_ib == 1) {                                                           \
      variant(buf, h, buf + h, n - h, xs, n);                                  \
    } else {                                                                   \
      variant(xs, h, xs + h, n - h, buf, n);                                   \
    }                                                                          \
  }                                                                            \
                                                                               \
  void sort_##variant(uint64_t *xs, size_t n, uint64_t *buf) {                 \
    int buf_contains_result;                                                   \
    sort_##variant##2(xs, n, buf, &buf_contains_result);                       \
    if (buf_contains_result) {                                                 \
      memcpy(xs, buf, 8 * n);                                                  \
    }                                                                          \
  }

SORT(branching)
SORT(nonbranching_but_branching)


@@ 322,13 320,14 @@ SORT(nonbranching_reverse)
SORT(nonbranching_reverse_ternary)
SORT(asm_nb_rev)

#define TIME_SORT(v)\
  memset(ZS, 0, 8 * 2 * N);\
  printf("%s: ", #v);\
  TIME(v(XS, N, ZS));\
  for (int i = 0; i < N; i++) XS[i] = YS[i];
#define TIME_SORT(v)                                                           \
  memset(ZS, 0, 8 * 2 * N);                                                    \
  printf("%s: ", #v);                                                          \
  TIME(v(XS, N, ZS));                                                          \
  for (int i = 0; i < N; i++)                                                  \
    XS[i] = YS[i];

void sort_time(void) {
void run_sort(void) {
  for (int i = 0; i < 10; i++) {
    srand(i);
    for (int i = 0; i < N; i++) {


@@ 346,23 345,15 @@ void sort_time(void) {
  }
}

int main(void) {

  // sort_time();
  // return 0;

  fprintf(stderr, "Generating numbers and sorting them\n");
  prepare(0);
  fprintf(stderr, "Computing TRUTH:\n");
  branching(XS, N, YS, N, ZS, 2 * N);
  memcpy(TRUTH, ZS, 8 * 2 * N);

  fprintf(stderr, "Running:\n");
void run_branch() {
  for (int s = 0; s < 10; s++) {
    fprintf(stderr, "  seed=%d:\n", s);
    fprintf(stderr, "Generating numbers and sorting them\n");
    prepare(s);
    branching(XS, N, YS, N, ZS, 2 * N);
    memcpy(TRUTH, ZS, 8 * 2 * N);

    fprintf(stderr, "Running:\n");
    for (int i = 0; i < 10; i++) {

      memset(ZS, 0, 8 * 2 * N);


@@ 401,6 392,20 @@ int main(void) {
      check();
    }
  }
}

int main(int argc, char **argv) {
  if (argc != 2) {
    fprintf(stderr, "Usage: `main <branch|sort>`\n");
    return 1;
  }
  if (strcmp("branch", argv[1]) == 0) {
    run_branch();
  } else if (strcmp("sort", argv[1]) == 0) {
    run_sort();
  } else {
    fprintf(stderr, "Usage: `main <branch|sort>`\n");
    return 1;
  }
  return 0;
}