71237a00411de371c628117659319255fc1e9bd3 — Martin Hafskjold Thoresen 5 months ago 9aba901 master
Clean up a little
3 files changed, 91 insertions(+), 63 deletions(-)

A Makefile
A README.md
M main.c
A Makefile => Makefile +2 -0
@@ 0,0 1,2 @@
+ CC=gcc
+ CFLAGS=-O3 -march=native -masm=intel -Wall -Wpedantic

A README.md => README.md +21 -0
@@ 0,0 1,21 @@
+ # Merge Codegen Experiment
+ 
+ See [the blog post](https://mht.technology/post/merge).
+ 
+ To run, type 
+ 
+ ```
+ make main
+ ```
+ 
+ See the `Makefile` for flags used.
+ 
+ To run the branching benchmark, run `./main branch`.
+ To run the sorting benchmark, run `./main sorting`.
+ 
+ To get the stats from a run, use the `stats.py` script,
+ for instance like this:
+ 
+ ```bash
+ make main && ./main sort | tee stats-sort | python stats.py
+ ```

M main.c => main.c +68 -63
@@ 15,10 15,7 @@ static uint64_t ZS[2 * N];
  static uint64_t TRUTH[2 * N];
  
- void asm_nb_rev(uint64_t *, size_t, uint64_t *, size_t, uint64_t *, size_t);
- void branching(uint64_t *, size_t, uint64_t *, size_t, uint64_t *, size_t);
- 
- // TODO: read and inline this into `take_time`.
+ // Get the difference between two `time`s.
  time diff(time start, time end) {
    time temp;
    if ((end.tv_nsec - start.tv_nsec) < 0) {


@@ 31,6 28,17 @@ return temp;
  }
  
+ // Time the execution of the expression `fun`, and print it to `stdout`.
+ #define TIME(fun)                                                              \
+   {                                                                            \
+     time time1, time2;                                                         \
+     clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &time1);                           \
+     fun;                                                                       \
+     clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &time2);                           \
+     time d = diff(time1, time2);                                               \
+     printf("%zu.%09zu\n", d.tv_sec, d.tv_nsec);                                \
+   }
+ 
  // Comparison for `qsort`
  int uint64_t_cmp(const void *a, const void *b) {
    return *((uint64_t *)a) - *((uint64_t *)b);


@@ 241,9 249,8 @@ "and %[u], %[j];"
        "test %[u], %[k];"
        "jnz 1b;"
-       : [minxy] "=&r"(minxy), [xi] "+&r"(xi), [yj] "+&r"(yj),
-         [t] "=&r"(t), [i] "+&r"(i), [j] "+&r"(j), [k] "+&r"(k),
-         [u] "=&r"(u), [zse] "+&r"(zse)
+       : [minxy] "=&r"(minxy), [xi] "+&r"(xi), [yj] "+&r"(yj), [t] "=&r"(t),
+         [i] "+&r"(i), [j] "+&r"(j), [k] "+&r"(k), [u] "=&r"(u), [zse] "+&r"(zse)
        : [xse] "r"(xse), [yse] "r"(yse), [one] "r"(one)
        : "memory");
  


@@ 253,16 260,6 @@ memcpy(zse + k, xse + i, -8 * k);
  }
  
- #define TIME(fun)                                                              \
-   {                                                                            \
-     time time1, time2;                                                         \
-     clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &time1);                           \
-     fun;                                                                       \
-     clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &time2);                           \
-     time d = diff(time1, time2);                                               \
-     printf("%zu.%09zu\n", d.tv_sec, d.tv_nsec);                                \
-   }
- 
  // Check that the resuling ZS is the same as TRUTH. If not, print of everything.
  void check(void) {
    for (int i = 0; i < 2 * N; i++) {


@@ 285,34 282,35 @@ }
  }
  
- #define SORT(variant)                                                \
- void sort_##variant##2(uint64_t *xs, size_t n, uint64_t *buf, int *into_buf) {\
-   switch (n) {                                                       \
-   case 1:                                                            \
-     *buf = *xs;                                                      \
-   case 0:                                                            \
-     *into_buf = 1;                                                   \
-     return;                                                          \
-   }                                                                  \
-   size_t h = n / 2;                                                  \
-   int a_ib, b_ib;                                                    \
-   sort_##variant##2(xs, h, buf, &a_ib);                              \
-   sort_##variant##2(xs + h, n - h, buf + h, &b_ib);                  \
-   *into_buf = a_ib ^ 1;                                              \
-   if (a_ib == 1) {                                                   \
-     variant(buf, h, buf + h, n - h, xs, n);                          \
-   } else {                                                           \
-     variant(xs, h, xs + h, n - h, buf, n);                           \
-   }                                                                  \
- }                                                                    \
-                                                                      \
- void sort_##variant(uint64_t *xs, size_t n, uint64_t *buf) {       \
-   int buf_contains_result;                                           \
-   sort_##variant##2(xs, n, buf, &buf_contains_result);               \
-   if (buf_contains_result) {                                         \
-     memcpy(xs, buf, 8 * n);                                          \
-   }                                                                  \
- }
+ #define SORT(variant)                                                          \
+   void sort_##variant##2(uint64_t * xs, size_t n, uint64_t * buf,              \
+                          int *into_buf) {                                      \
+     switch (n) {                                                               \
+     case 1:                                                                    \
+       *buf = *xs;                                                              \
+     case 0:                                                                    \
+       *into_buf = 1;                                                           \
+       return;                                                                  \
+     }                                                                          \
+     size_t h = n / 2;                                                          \
+     int a_ib, b_ib;                                                            \
+     sort_##variant##2(xs, h, buf, &a_ib);                                      \
+     sort_##variant##2(xs + h, n - h, buf + h, &b_ib);                          \
+     *into_buf = a_ib ^ 1;                                                      \
+     if (a_ib == 1) {                                                           \
+       variant(buf, h, buf + h, n - h, xs, n);                                  \
+     } else {                                                                   \
+       variant(xs, h, xs + h, n - h, buf, n);                                   \
+     }                                                                          \
+   }                                                                            \
+                                                                                \
+   void sort_##variant(uint64_t *xs, size_t n, uint64_t *buf) {                 \
+     int buf_contains_result;                                                   \
+     sort_##variant##2(xs, n, buf, &buf_contains_result);                       \
+     if (buf_contains_result) {                                                 \
+       memcpy(xs, buf, 8 * n);                                                  \
+     }                                                                          \
+   }
  
  SORT(branching)
  SORT(nonbranching_but_branching)


@@ 322,13 320,14 @@ SORT(nonbranching_reverse_ternary)
  SORT(asm_nb_rev)
  
- #define TIME_SORT(v)\
-   memset(ZS, 0, 8 * 2 * N);\
-   printf("%s: ", #v);\
-   TIME(v(XS, N, ZS));\
-   for (int i = 0; i < N; i++) XS[i] = YS[i];
+ #define TIME_SORT(v)                                                           \
+   memset(ZS, 0, 8 * 2 * N);                                                    \
+   printf("%s: ", #v);                                                          \
+   TIME(v(XS, N, ZS));                                                          \
+   for (int i = 0; i < N; i++)                                                  \
+     XS[i] = YS[i];
  
- void sort_time(void) {
+ void run_sort(void) {
    for (int i = 0; i < 10; i++) {
      srand(i);
      for (int i = 0; i < N; i++) {


@@ 346,23 345,15 @@ }
  }
  
- int main(void) {
- 
-   // sort_time();
-   // return 0;
- 
-   fprintf(stderr, "Generating numbers and sorting them\n");
-   prepare(0);
-   fprintf(stderr, "Computing TRUTH:\n");
-   branching(XS, N, YS, N, ZS, 2 * N);
-   memcpy(TRUTH, ZS, 8 * 2 * N);
- 
-   fprintf(stderr, "Running:\n");
+ void run_branch() {
    for (int s = 0; s < 10; s++) {
      fprintf(stderr, "  seed=%d:\n", s);
+     fprintf(stderr, "Generating numbers and sorting them\n");
      prepare(s);
      branching(XS, N, YS, N, ZS, 2 * N);
      memcpy(TRUTH, ZS, 8 * 2 * N);
+ 
+     fprintf(stderr, "Running:\n");
      for (int i = 0; i < 10; i++) {
  
        memset(ZS, 0, 8 * 2 * N);


@@ 401,6 392,20 @@ check();
      }
    }
+ }
  
+ int main(int argc, char **argv) {
+   if (argc != 2) {
+     fprintf(stderr, "Usage: `main <branch|sort>`\n");
+     return 1;
+   }
+   if (strcmp("branch", argv[1]) == 0) {
+     run_branch();
+   } else if (strcmp("sort", argv[1]) == 0) {
+     run_sort();
+   } else {
+     fprintf(stderr, "Usage: `main <branch|sort>`\n");
+     return 1;
+   }
    return 0;
  }