#include <assert.h>
#include <stdarg.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include "env.h"
#include "eval.h"
#include "expr.h"
#include "utils.h"
static bool
is_truthy(struct expression *expr)
{
if (expr->kind == EXPR_LIST && expr->list != NULL) {
return true;
}
if (expr->kind == EXPR_INTEGER && expr->integer != 0) {
return true;
}
/* TODO: more rules? */
return false;
}
static struct expression *
env_add(struct location loc, struct list_expression *list,
struct binding **bindings)
{
intmax_t res = 0;
while (list) {
struct expression *value = eval(list->expr, bindings);
if (value->kind != EXPR_INTEGER) {
error(value->loc, "Cannot perform addition on non-integers");
}
res += value->integer;
list = list->next;
}
struct expression *new = mkexpression();
new->kind = EXPR_INTEGER;
new->integer = res;
return new;
}
static struct expression *
env_sub(struct location loc, struct list_expression *list,
struct binding **bindings)
{
if (!list) {
error(loc, "Expected at least one argument to subtract");
}
struct expression *first = eval(list->expr, bindings);
list = list->next;
intmax_t res = first->integer;
while (list) {
struct expression *value = eval(list->expr, bindings);
if (value->kind != EXPR_INTEGER) {
error(value->loc, "Cannot perform subtraction on non-integers");
}
res -= value->integer;
list = list->next;
}
struct expression *new = mkexpression();
new->kind = EXPR_INTEGER;
new->integer = res;
return new;
}
static struct expression *
env_mul(struct location loc, struct list_expression *list,
struct binding **bindings)
{
intmax_t res = 1;
while (list) {
struct expression *value = eval(list->expr, bindings);
if (value->kind != EXPR_INTEGER) {
error(value->loc, "Cannot perform multiplication on non-integers");
}
res *= value->integer;
list = list->next;
}
struct expression *new = mkexpression();
new->kind = EXPR_INTEGER;
new->integer = res;
return new;
}
static struct expression *
env_div(struct location loc, struct list_expression *list,
struct binding **bindings)
{
if (!list) {
error(loc, "Expected at least one argument to divide");
}
struct expression *first = eval(list->expr, bindings);
list = list->next;
intmax_t res = first->integer;
while (list) {
struct expression *value = eval(list->expr, bindings);
if (value->kind != EXPR_INTEGER) {
error(value->loc, "Cannot perform division on non-integers");
}
if (value->integer == 0) {
error(value->loc, "Cannot divide by zero");
}
res /= value->integer;
list = list->next;
}
struct expression *new = mkexpression();
new->kind = EXPR_INTEGER;
new->integer = res;
return new;
}
enum cmp_kind {
CMP_EQ, /* = */
CMP_GE, /* >= */
CMP_GT, /* > */
CMP_LE, /* <= */
CMP_LT, /* < */
};
static struct expression *
cmp(struct location loc, struct list_expression *list,
struct binding **bindings, enum cmp_kind cmp)
{
if (!list || !list->next) {
error(loc, "Not enough arguments");
}
struct expression *first = eval(list->expr, bindings);
struct expression *second = eval(list->next->expr, bindings);
if (first->kind != EXPR_INTEGER || second->kind != EXPR_INTEGER) {
error(loc, "Cannot compare non-integers");
}
bool res;
switch (cmp) {
case CMP_EQ:
res = first->integer == second->integer;
break;
case CMP_GE:
res = first->integer >= second->integer;
break;
case CMP_GT:
res = first->integer > second->integer;
break;
case CMP_LE:
res = first->integer <= second->integer;
break;
case CMP_LT:
res = first->integer < second->integer;
break;
}
struct expression *new = mkexpression();
new->kind = EXPR_INTEGER;
new->integer = res ? 1 : 0;
return new;
}
static struct expression *
env_eq(struct location loc, struct list_expression *list,
struct binding **bindings)
{
return cmp(loc, list, bindings, CMP_EQ);
}
static struct expression *
env_ge(struct location loc, struct list_expression *list,
struct binding **bindings)
{
return cmp(loc, list, bindings, CMP_GE);
}
static struct expression *
env_gt(struct location loc, struct list_expression *list,
struct binding **bindings)
{
return cmp(loc, list, bindings, CMP_GT);
}
static struct expression *
env_le(struct location loc, struct list_expression *list,
struct binding **bindings)
{
return cmp(loc, list, bindings, CMP_LE);
}
static struct expression *
env_lt(struct location loc, struct list_expression *list,
struct binding **bindings)
{
return cmp(loc, list, bindings, CMP_LT);
}
static struct expression *
env_and(struct location loc, struct list_expression *list,
struct binding **bindings)
{
bool res = true;
while (list) {
struct expression *value = eval(list->expr, bindings);
if (!is_truthy(value)) {
res = false;
break;
}
list = list->next;
}
struct expression *new = mkexpression();
new->kind = EXPR_INTEGER;
new->integer = res ? 1 : 0;
return new;
}
static struct expression *
env_not(struct location loc, struct list_expression *list,
struct binding **bindings)
{
if (!list) {
error(loc, "Expected an argument");
}
struct expression *value = eval(list->expr, bindings);
struct expression *new = mkexpression();
new->kind = EXPR_INTEGER;
new->integer = !is_truthy(value) ? 1 : 0;
return new;
}
static struct expression *
env_or(struct location loc, struct list_expression *list,
struct binding **bindings)
{
bool res = false;
while (list) {
struct expression *value = eval(list->expr, bindings);
if (is_truthy(value)) {
res = true;
break;
}
list = list->next;
}
struct expression *new = mkexpression();
new->kind = EXPR_INTEGER;
new->integer = res ? 1 : 0;
return new;
}
static struct expression *
env_define(struct location loc, struct list_expression *list,
struct binding **bindings)
{
if (!list || !list->next) {
error(loc, "Not enough arguments");
}
struct expression *handle = list->expr;
if (handle->kind != EXPR_SYMBOL) {
error(loc, "Invalid define; expected a symbol");
}
struct expression *value = eval(list->next->expr, bindings);
if (binding_lookup(*bindings, handle->symbol)) {
error(loc, "'%s' is already defined", handle->symbol);
}
binding_insert(bindings, handle->symbol, value);
return mknil();
}
static struct expression *
env_set(struct location loc, struct list_expression *list,
struct binding **bindings)
{
if (!list || !list->next) {
error(loc, "Not enough arguments");
}
struct expression *handle = list->expr;
if (handle->kind != EXPR_SYMBOL) {
error(loc, "Invalid set!; expected a symbol");
}
struct expression *value = eval(list->next->expr, bindings);
struct binding *bind = binding_lookup(*bindings, handle->symbol);
if (!bind) {
error(loc, "Undefined symbol '%s'", handle->symbol);
}
expression_free(bind->expr);
bind->expr = value;
return mknil();
}
static struct expression *
env_if(struct location loc, struct list_expression *list,
struct binding **bindings)
{
if (!list || !list->next || !list->next->next) {
error(loc, "Not enough arguments");
}
struct expression *condition = eval(list->expr, bindings);
struct expression *iftrue = list->next->expr;
struct expression *iffalse = list->next->next->expr;
if (is_truthy(condition)) {
return eval(iftrue, bindings);
} else {
return eval(iffalse, bindings);
}
}
static struct expression *
env_do(struct location loc, struct list_expression *list,
struct binding **bindings)
{
while (list) {
expression_free(eval(list->expr, bindings));
list = list->next;
}
return mknil();
}
static struct expression *
env_cons(struct location loc, struct list_expression *list,
struct binding **bindings)
{
if (!list || !list->next) {
error(loc, "Not enough arguments");
}
struct expression *new = mkexpression();
new->kind = EXPR_PAIR;
new->pair[0] = eval(list->expr, bindings);
new->pair[1] = eval(list->next->expr, bindings);
return new;
}
static struct expression *
env_car(struct location loc, struct list_expression *list,
struct binding **bindings)
{
if (!list) {
error(loc, "Not enough arguments");
}
struct expression *pair = eval(list->expr, bindings);
if (pair->kind != EXPR_PAIR) {
error(loc, "car must be used on a pair");
}
return expression_dup(pair->pair[0]);
}
static struct expression *
env_cdr(struct location loc, struct list_expression *list,
struct binding **bindings)
{
if (!list) {
error(loc, "Not enough arguments");
}
struct expression *pair = eval(list->expr, bindings);
if (pair->kind != EXPR_PAIR) {
error(loc, "cdr must be used on a pair");
}
return expression_dup(pair->pair[1]);
}
static struct expression *
env_string_concat(struct location loc, struct list_expression *list,
struct binding **bindings)
{
/* TODO: varargs */
if (!list || !list->next) {
error(loc, "Not enough arguments");
}
struct expression *a = eval(list->expr, bindings);
struct expression *b = eval(list->next->expr, bindings);
if (a->kind != EXPR_STRING || b->kind != EXPR_STRING) {
error(loc, "Cannot concatenate non-strings");
}
size_t alen = strlen(a->string);
size_t blen = strlen(b->string);
char *new = xmalloc(alen + blen + 1);
strcat(new, a->string);
strcat(new, b->string);
struct expression *expr = mkexpression();
expr->kind = EXPR_STRING;
expr->string = new;
return expr;
}
static struct expression *
env_string_sub(struct location loc, struct list_expression *list,
struct binding **bindings)
{
if (!list || !list->next) {
error(loc, "Not enough arguments");
}
struct expression *str = eval(list->expr, bindings);
if (str->kind != EXPR_STRING) {
error(loc, "Expected first argument to be a string");
}
size_t slen = strlen(str->string);
struct expression *estart = eval(list->next->expr, bindings);
if (estart->kind != EXPR_INTEGER) {
error(loc, "Index must be a number");
}
size_t start = estart->integer;
size_t end = slen - 1;
if (list->next->next) {
struct expression *eend = eval(list->next->next->expr, bindings);
if (eend->kind != EXPR_INTEGER) {
error(loc, "Index must be a number");
}
end = eend->integer;
}
if (start > end) {
error(loc, "Invalid slice");
}
if (end > slen) {
error(loc, "End index outside of string");
}
assert(end - start <= slen);
char *new = xmalloc(end - start + 1);
strncpy(new, str->string + start, end - start);
struct expression *expr = mkexpression();
expr->kind = EXPR_STRING;
expr->string = new;
return expr;
}
static struct expression *
env_string_len(struct location loc, struct list_expression *list,
struct binding **bindings)
{
if (!list) {
error(loc, "Not enough arguments");
}
struct expression *str = eval(list->expr, bindings);
if (str->kind != EXPR_STRING) {
error(loc, "Expected argument to be a string");
}
struct expression *expr = mkexpression();
expr->kind = EXPR_INTEGER;
expr->integer = strlen(str->string);
return expr;
}
static struct expression *
env_display(struct location loc, struct list_expression *list,
struct binding **bindings)
{
while (list) {
struct expression *value = eval(list->expr, bindings);
switch (value->kind) {
case EXPR_STRING:
fputs(value->string, stdout);
break;
default:
expression_print(stdout, value);
}
list = list->next;
}
return mknil();
}
static struct expression *
env_newline(struct location loc, struct list_expression *list,
struct binding **bindings)
{
puts("");
return mknil();
}
struct env_function
environment[] = {
{ .name = "+", .func = &env_add },
{ .name = "-", .func = &env_sub },
{ .name = "*", .func = &env_mul },
{ .name = "/", .func = &env_div },
{ .name = "=", .func = &env_eq },
{ .name = ">=", .func = &env_ge },
{ .name = ">", .func = &env_gt },
{ .name = "<=", .func = &env_le },
{ .name = "<", .func = &env_lt },
{ .name = "and", .func = &env_and },
{ .name = "not", .func = &env_not },
{ .name = "or", .func = &env_or },
{ .name = "define", .func = &env_define },
{ .name = "set!", .func = &env_set },
{ .name = "if", .func = &env_if },
{ .name = "do", .func = &env_do },
{ .name = "cons", .func = &env_cons },
{ .name = "car", .func = &env_car },
{ .name = "cdr", .func = &env_cdr },
{ .name = "string-concat", .func = &env_string_concat },
{ .name = "string-sub", .func = &env_string_sub },
{ .name = "string-len", .func = &env_string_len },
{ .name = "display", .func = &env_display },
{ .name = "newline", .func = &env_newline },
{ .name = NULL, .func = NULL },
};