~mikelma/oktac

79afb7352be50581786fdb384104c178b5c6c394 — mikelma 5 months ago 4b86137 main
basic lambda function support
M src/ast/check.rs => src/ast/check.rs +10 -0
@@ 566,6 566,16 @@ pub fn get_node_type_no_autoconv(node: &AstNode) -> Result<VarType, LogMesg> {
        AstNode::MemberAccessExpr { access_types, .. } => Ok(access_types.last().unwrap().clone()),
        AstNode::EnumVariant { enum_name, .. } => Ok(VarType::Enum(enum_name.clone())),
        AstNode::String(_) => Ok(VarType::Str),
        AstNode::Lambda { ret_ty, params, .. } => {
            let param_ty = params
                .iter()
                .map(|(_, t)| t.clone())
                .collect::<Vec<VarType>>();
            Ok(VarType::Fun {
                param_ty,
                ret_ty: ret_ty.as_ref().map(|t| Box::new(t.clone())),
            })
        }
        AstNode::Type(ty) => Err(LogMesg::err()
            .name("Expected value")
            .cause(format!("Expected value but got type {} instead", ty))),

M src/ast/expr.rs => src/ast/expr.rs +1 -0
@@ 357,6 357,7 @@ pub fn parse_value(pair: Pair<Rule>) -> AstNode {

            AstNode::String(bytes)
        }
        Rule::lambda => func::parse_lambda(value),
        _ => unreachable!(),
    }
}

M src/ast/func.rs => src/ast/func.rs +77 -1
@@ 1,7 1,7 @@
use pest::iterators::Pair;

use super::{parser::*, *};
use crate::{current_unit_st, macros, VarType};
use crate::{current_unit_st, macros, st::SymbolTableStack, VarType};

pub fn parse_func_proto(pair: Pair<Rule>) -> AstNode {
    let pair_str = pair.as_str();


@@ 253,3 253,79 @@ pub fn parse_extern_func_proto(pair: Pair<Rule>) -> AstNode {
        visibility,
    }
}

pub fn parse_lambda(pair: Pair<Rule>) -> AstNode {
    let pair_str = pair.as_str();
    let pair_loc = pair.as_span().start_pos().line_col().0;
    let mut pairs = pair.into_inner();

    // parse parameter definitions
    let params = parse_params_decl(pairs.next().unwrap());
    let arg_types: Vec<VarType> = params.iter().map(|x| x.1.clone()).collect();

    // parse return type (if some)
    let mut next = pairs.next().unwrap();
    let ret_ty = match next.as_rule() {
        Rule::retType => match ty::parse_var_type(next.clone().into_inner().next().unwrap()) {
            Ok(t) => {
                next = pairs.next().unwrap();
                Some(t)
            }
            Err(e) => {
                e.lines(pair_str).location(pair_loc).send().unwrap();
                Some(VarType::Unknown)
            }
        },
        Rule::stmts => None,
        _ => unreachable!(),
    };

    let mut name = SymbolTableStack::gen_unique_name();
    name.push_str(".lambda");

    let res =
        current_unit_st!().record_func(&name, ret_ty.clone(), arg_types, Visibility::Pub, false);

    if let Err(e) = res {
        e.lines(pair_str).location(pair_loc).send().unwrap();
    }

    // create a new table for the function's scope
    current_unit_st!().push_table();

    // register the parameters in the function's scope
    params.iter().for_each(|(name, ty)| {
        let res = current_unit_st!().record_var(name, ty.clone());
        if let Err(e) = res {
            e.lines(pair_str).location(pair_loc).send().unwrap();
        }
    });

    // get the name of the function where the lambda is being declared.
    // this never panics as lambdas must be always declared inside another function.
    let parent_func_name = current_unit_st!().curr_func().unwrap().to_string();

    // set the current function to parse in the symbol table
    current_unit_st!().curr_func_set(&name);

    // parse statements block of the function
    let stmts = Box::new(stmts::parse_stmts(next));

    // pop function's scope symbol table
    current_unit_st!().pop_table();

    // restore current function's value to the parent function
    current_unit_st!().curr_func_set(&parent_func_name);

    // TODO
    println!("***TODO***: Captured variables in lambdas");
    let captured_vars = vec![];

    AstNode::Lambda {
        name,
        ret_ty,
        params,
        stmts,
        captured_vars,
    }
}

M src/ast/tree.rs => src/ast/tree.rs +22 -1
@@ 132,7 132,7 @@ pub enum AstNode {
        parent_ty: VarType, // type of the parent
    },

    // terminals
    // values
    Identifier(String), // TODO: Fix typo: Identifier
    Int8(i8),
    UInt8(u8),


@@ 165,6 165,13 @@ pub enum AstNode {
        is_const: bool,
    },
    String(Vec<u8>),
    Lambda {
        name: String,
        ret_ty: Option<VarType>,
        stmts: Box<AstNode>,
        params: Vec<(String, VarType)>,
        captured_vars: Vec<(String, VarType)>,
    },
    /// `Type`s are only intended to be used as builtin function parameters
    Type(VarType),
}


@@ 500,6 507,19 @@ impl TreeItem for AstNode {
                STYLE_TERM.apply_to("String"),
                String::from_utf8_lossy(bytes)
            ),
            AstNode::Lambda { ret_ty, params, .. } => {
                let params: Vec<String> = params.iter().map(|(_, v)| v.to_string()).collect();
                write!(
                    f,
                    "{} fun({})",
                    STYLE_TERM.apply_to("Lambda"),
                    params.join(","),
                )?;
                if let Some(t) = ret_ty {
                    write!(f, ":{}", t)?;
                }
                Ok(())
            }
            AstNode::EnumVariant {
                enum_name,
                variant_name,


@@ 595,6 615,7 @@ impl TreeItem for AstNode {
                    .cloned()
                    .collect::<Vec<AstNode>>(),
            ),
            AstNode::Lambda { stmts, .. } => Cow::from(vec![*stmts.clone()]),
            _ => Cow::from(vec![]),
        }
    }

M src/codegen/expr.rs => src/codegen/expr.rs +40 -1
@@ 1,5 1,5 @@
use super::*;
use inkwell::types::AnyTypeEnum;
use inkwell::types::{AnyTypeEnum, BasicTypeEnum};
use inkwell::values::{BasicMetadataValueEnum, CallableValue};

impl<'ctx> CodeGen<'ctx> {


@@ 586,6 586,8 @@ impl<'ctx> CodeGen<'ctx> {
                let enm = self.builder.build_load(enum_ptr, "tmp.deref");
                Ok(Some(enm))
            },
            AstNode::Lambda { name, ret_ty, stmts, params, captured_vars }
                => self.compile_lambda_expr(name, ret_ty, stmts, params, captured_vars),
            _ => unreachable!("Panic caused by {:?}", node),
        }
    }


@@ 1057,4 1059,41 @@ impl<'ctx> CodeGen<'ctx> {

        Ok(Some(slice.as_basic_value_enum()))
    }

    fn compile_lambda_expr(
        &mut self,
        name: &str,
        ret_type: &Option<VarType>,
        stmts: &AstNode,
        params: &[(String, VarType)],
        captured_vars: &[(String, VarType)],
    ) -> CompRet<'ctx> {
        // save the state of the codegen unit before starting to build the lambda function
        let parent_insert_pos = self.builder.get_insert_block().unwrap();
        let parent_curr_func = self.curr_func.clone();
        let parent_curr_fn_ret_val = self.curr_fn_ret_val.clone();
        let parent_curr_fn_ret_bb = self.curr_fn_ret_bb.clone();
        let parent_loop_exit_bb = self.loop_exit_bb.clone();

        // compile the lambda function's prototype (and body) as if it's a normal function
        self.compile_func_proto(name, params, ret_type, false); // NOTE: inline is set to false for now
        self.compile_func_decl(name, ret_type, params, stmts)?;

        // restore the state of the codegen unit before the lambda expression compilation
        self.builder.position_at_end(parent_insert_pos);
        self.curr_func = parent_curr_func;
        self.curr_fn_ret_val = parent_curr_fn_ret_val;
        self.curr_fn_ret_bb = parent_curr_fn_ret_bb;
        self.loop_exit_bb = parent_loop_exit_bb;

        // return the pointer to the lambda function
        // this never panics, as the function prototype is already compiled by `compile_func_proto`
        let fn_val = self.module.get_function(name).unwrap();
        Ok(Some(
            fn_val
                .as_global_value()
                .as_pointer_value()
                .as_basic_value_enum(),
        ))
    }
}

M src/codegen/mod.rs => src/codegen/mod.rs +2 -1
@@ 3,7 3,7 @@ use inkwell::builder::Builder;
use inkwell::context::Context;
use inkwell::module::Module;
use inkwell::targets::TargetTriple;
use inkwell::types::{BasicType, BasicTypeEnum};
use inkwell::types::BasicType;
use inkwell::values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue};
use inkwell::{AddressSpace, FloatPredicate, IntPredicate};



@@ 174,6 174,7 @@ impl<'ctx> CodeGen<'ctx> {
            | AstNode::Strct { .. }
            | AstNode::EnumVariant { .. }
            | AstNode::String(_)
            | AstNode::Lambda { .. }
            | AstNode::Boolean(_) => self.compile_value(node),
            _ => unreachable!("{:#?}", node),
        }

M src/grammar.pest => src/grammar.pest +3 -1
@@ 140,7 140,7 @@ rangeVal = { expr? }

// ------------ literal values  ------------ //

value = { float | number | str | array | enm | strct | boolean | id }
value = { float | number | str | array | enm | strct | boolean | lambda | id }

number = ${ numPart ~ intType? }
    numPart = @{ ("-")? ~ ASCII_DIGIT+ }


@@ 172,6 172,8 @@ str = ${ "\"" ~ strInner ~ "\"" }
        | "\\" ~ "u" ~ "{"~ ASCII_HEX_DIGIT{1, 6} ~ "}"
    }

lambda = { "fun" ~ paramsDecl ~ retType? ~ stmts }

// ------------ misc ------------ //

varType = { refType | funType | simpleType | arrayType | sliceType }

M src/st.rs => src/st.rs +12 -2
@@ 1,11 1,13 @@
use console::style;
// use once_cell::sync::Lazy;
use once_cell::sync::Lazy;

use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use super::{AstNode, LogMesg, VarType, Visibility};

static UNIQUE_ID_COUNTER: Lazy<Arc<Mutex<usize>>> = Lazy::new(|| Arc::new(Mutex::new(0)));

type SymbolTable = HashMap<String, (SymbolInfo, SymbolType)>;

#[derive(Debug)]


@@ 717,6 719,14 @@ impl SymbolTableStack {
            );
        }
    }

    pub fn gen_unique_name() -> String {
        {
            let mut counter = UNIQUE_ID_COUNTER.lock().unwrap();
            *counter += 1;
            return format!("okta.uniqueid.{}", counter);
        }
    }
}

impl Default for SymbolTableStack {