From 79afb7352be50581786fdb384104c178b5c6c394 Mon Sep 17 00:00:00 2001 From: mikelma Date: Sun, 1 May 2022 21:23:27 +0000 Subject: [PATCH] basic lambda function support --- src/ast/check.rs | 10 ++++++ src/ast/expr.rs | 1 + src/ast/func.rs | 78 ++++++++++++++++++++++++++++++++++++++++++++- src/ast/tree.rs | 23 ++++++++++++- src/codegen/expr.rs | 41 +++++++++++++++++++++++- src/codegen/mod.rs | 3 +- src/grammar.pest | 4 ++- src/st.rs | 14 ++++++-- 8 files changed, 167 insertions(+), 7 deletions(-) diff --git a/src/ast/check.rs b/src/ast/check.rs index efc7db2..48cc080 100644 --- a/src/ast/check.rs +++ b/src/ast/check.rs @@ -566,6 +566,16 @@ pub fn get_node_type_no_autoconv(node: &AstNode) -> Result { AstNode::MemberAccessExpr { access_types, .. } => Ok(access_types.last().unwrap().clone()), AstNode::EnumVariant { enum_name, .. } => Ok(VarType::Enum(enum_name.clone())), AstNode::String(_) => Ok(VarType::Str), + AstNode::Lambda { ret_ty, params, .. } => { + let param_ty = params + .iter() + .map(|(_, t)| t.clone()) + .collect::>(); + Ok(VarType::Fun { + param_ty, + ret_ty: ret_ty.as_ref().map(|t| Box::new(t.clone())), + }) + } AstNode::Type(ty) => Err(LogMesg::err() .name("Expected value") .cause(format!("Expected value but got type {} instead", ty))), diff --git a/src/ast/expr.rs b/src/ast/expr.rs index 1b248e6..5c07d05 100644 --- a/src/ast/expr.rs +++ b/src/ast/expr.rs @@ -357,6 +357,7 @@ pub fn parse_value(pair: Pair) -> AstNode { AstNode::String(bytes) } + Rule::lambda => func::parse_lambda(value), _ => unreachable!(), } } diff --git a/src/ast/func.rs b/src/ast/func.rs index 91f8c15..15f6d31 100644 --- a/src/ast/func.rs +++ b/src/ast/func.rs @@ -1,7 +1,7 @@ use pest::iterators::Pair; use super::{parser::*, *}; -use crate::{current_unit_st, macros, VarType}; +use crate::{current_unit_st, macros, st::SymbolTableStack, VarType}; pub fn parse_func_proto(pair: Pair) -> AstNode { let pair_str = pair.as_str(); @@ -253,3 +253,79 @@ pub fn parse_extern_func_proto(pair: Pair) -> AstNode { visibility, } } + +pub fn parse_lambda(pair: Pair) -> AstNode { + let pair_str = pair.as_str(); + let pair_loc = pair.as_span().start_pos().line_col().0; + let mut pairs = pair.into_inner(); + + // parse parameter definitions + let params = parse_params_decl(pairs.next().unwrap()); + let arg_types: Vec = params.iter().map(|x| x.1.clone()).collect(); + + // parse return type (if some) + let mut next = pairs.next().unwrap(); + let ret_ty = match next.as_rule() { + Rule::retType => match ty::parse_var_type(next.clone().into_inner().next().unwrap()) { + Ok(t) => { + next = pairs.next().unwrap(); + Some(t) + } + Err(e) => { + e.lines(pair_str).location(pair_loc).send().unwrap(); + Some(VarType::Unknown) + } + }, + Rule::stmts => None, + _ => unreachable!(), + }; + + let mut name = SymbolTableStack::gen_unique_name(); + name.push_str(".lambda"); + + let res = + current_unit_st!().record_func(&name, ret_ty.clone(), arg_types, Visibility::Pub, false); + + if let Err(e) = res { + e.lines(pair_str).location(pair_loc).send().unwrap(); + } + + // create a new table for the function's scope + current_unit_st!().push_table(); + + // register the parameters in the function's scope + params.iter().for_each(|(name, ty)| { + let res = current_unit_st!().record_var(name, ty.clone()); + if let Err(e) = res { + e.lines(pair_str).location(pair_loc).send().unwrap(); + } + }); + + // get the name of the function where the lambda is being declared. + // this never panics as lambdas must be always declared inside another function. + let parent_func_name = current_unit_st!().curr_func().unwrap().to_string(); + + // set the current function to parse in the symbol table + current_unit_st!().curr_func_set(&name); + + // parse statements block of the function + let stmts = Box::new(stmts::parse_stmts(next)); + + // pop function's scope symbol table + current_unit_st!().pop_table(); + + // restore current function's value to the parent function + current_unit_st!().curr_func_set(&parent_func_name); + + // TODO + println!("***TODO***: Captured variables in lambdas"); + let captured_vars = vec![]; + + AstNode::Lambda { + name, + ret_ty, + params, + stmts, + captured_vars, + } +} diff --git a/src/ast/tree.rs b/src/ast/tree.rs index 2b22255..75f1474 100644 --- a/src/ast/tree.rs +++ b/src/ast/tree.rs @@ -132,7 +132,7 @@ pub enum AstNode { parent_ty: VarType, // type of the parent }, - // terminals + // values Identifier(String), // TODO: Fix typo: Identifier Int8(i8), UInt8(u8), @@ -165,6 +165,13 @@ pub enum AstNode { is_const: bool, }, String(Vec), + Lambda { + name: String, + ret_ty: Option, + stmts: Box, + params: Vec<(String, VarType)>, + captured_vars: Vec<(String, VarType)>, + }, /// `Type`s are only intended to be used as builtin function parameters Type(VarType), } @@ -500,6 +507,19 @@ impl TreeItem for AstNode { STYLE_TERM.apply_to("String"), String::from_utf8_lossy(bytes) ), + AstNode::Lambda { ret_ty, params, .. } => { + let params: Vec = params.iter().map(|(_, v)| v.to_string()).collect(); + write!( + f, + "{} fun({})", + STYLE_TERM.apply_to("Lambda"), + params.join(","), + )?; + if let Some(t) = ret_ty { + write!(f, ":{}", t)?; + } + Ok(()) + } AstNode::EnumVariant { enum_name, variant_name, @@ -595,6 +615,7 @@ impl TreeItem for AstNode { .cloned() .collect::>(), ), + AstNode::Lambda { stmts, .. } => Cow::from(vec![*stmts.clone()]), _ => Cow::from(vec![]), } } diff --git a/src/codegen/expr.rs b/src/codegen/expr.rs index f442e60..aa38eb5 100644 --- a/src/codegen/expr.rs +++ b/src/codegen/expr.rs @@ -1,5 +1,5 @@ use super::*; -use inkwell::types::AnyTypeEnum; +use inkwell::types::{AnyTypeEnum, BasicTypeEnum}; use inkwell::values::{BasicMetadataValueEnum, CallableValue}; impl<'ctx> CodeGen<'ctx> { @@ -586,6 +586,8 @@ impl<'ctx> CodeGen<'ctx> { let enm = self.builder.build_load(enum_ptr, "tmp.deref"); Ok(Some(enm)) }, + AstNode::Lambda { name, ret_ty, stmts, params, captured_vars } + => self.compile_lambda_expr(name, ret_ty, stmts, params, captured_vars), _ => unreachable!("Panic caused by {:?}", node), } } @@ -1057,4 +1059,41 @@ impl<'ctx> CodeGen<'ctx> { Ok(Some(slice.as_basic_value_enum())) } + + fn compile_lambda_expr( + &mut self, + name: &str, + ret_type: &Option, + stmts: &AstNode, + params: &[(String, VarType)], + captured_vars: &[(String, VarType)], + ) -> CompRet<'ctx> { + // save the state of the codegen unit before starting to build the lambda function + let parent_insert_pos = self.builder.get_insert_block().unwrap(); + let parent_curr_func = self.curr_func.clone(); + let parent_curr_fn_ret_val = self.curr_fn_ret_val.clone(); + let parent_curr_fn_ret_bb = self.curr_fn_ret_bb.clone(); + let parent_loop_exit_bb = self.loop_exit_bb.clone(); + + // compile the lambda function's prototype (and body) as if it's a normal function + self.compile_func_proto(name, params, ret_type, false); // NOTE: inline is set to false for now + self.compile_func_decl(name, ret_type, params, stmts)?; + + // restore the state of the codegen unit before the lambda expression compilation + self.builder.position_at_end(parent_insert_pos); + self.curr_func = parent_curr_func; + self.curr_fn_ret_val = parent_curr_fn_ret_val; + self.curr_fn_ret_bb = parent_curr_fn_ret_bb; + self.loop_exit_bb = parent_loop_exit_bb; + + // return the pointer to the lambda function + // this never panics, as the function prototype is already compiled by `compile_func_proto` + let fn_val = self.module.get_function(name).unwrap(); + Ok(Some( + fn_val + .as_global_value() + .as_pointer_value() + .as_basic_value_enum(), + )) + } } diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index 6dc7a4b..c28ede3 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -3,7 +3,7 @@ use inkwell::builder::Builder; use inkwell::context::Context; use inkwell::module::Module; use inkwell::targets::TargetTriple; -use inkwell::types::{BasicType, BasicTypeEnum}; +use inkwell::types::BasicType; use inkwell::values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}; use inkwell::{AddressSpace, FloatPredicate, IntPredicate}; @@ -174,6 +174,7 @@ impl<'ctx> CodeGen<'ctx> { | AstNode::Strct { .. } | AstNode::EnumVariant { .. } | AstNode::String(_) + | AstNode::Lambda { .. } | AstNode::Boolean(_) => self.compile_value(node), _ => unreachable!("{:#?}", node), } diff --git a/src/grammar.pest b/src/grammar.pest index a8599de..02c533d 100644 --- a/src/grammar.pest +++ b/src/grammar.pest @@ -140,7 +140,7 @@ rangeVal = { expr? } // ------------ literal values ------------ // -value = { float | number | str | array | enm | strct | boolean | id } +value = { float | number | str | array | enm | strct | boolean | lambda | id } number = ${ numPart ~ intType? } numPart = @{ ("-")? ~ ASCII_DIGIT+ } @@ -172,6 +172,8 @@ str = ${ "\"" ~ strInner ~ "\"" } | "\\" ~ "u" ~ "{"~ ASCII_HEX_DIGIT{1, 6} ~ "}" } +lambda = { "fun" ~ paramsDecl ~ retType? ~ stmts } + // ------------ misc ------------ // varType = { refType | funType | simpleType | arrayType | sliceType } diff --git a/src/st.rs b/src/st.rs index 7fd648b..9abd3a6 100644 --- a/src/st.rs +++ b/src/st.rs @@ -1,11 +1,13 @@ use console::style; -// use once_cell::sync::Lazy; +use once_cell::sync::Lazy; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use super::{AstNode, LogMesg, VarType, Visibility}; +static UNIQUE_ID_COUNTER: Lazy>> = Lazy::new(|| Arc::new(Mutex::new(0))); + type SymbolTable = HashMap; #[derive(Debug)] @@ -717,6 +719,14 @@ impl SymbolTableStack { ); } } + + pub fn gen_unique_name() -> String { + { + let mut counter = UNIQUE_ID_COUNTER.lock().unwrap(); + *counter += 1; + return format!("okta.uniqueid.{}", counter); + } + } } impl Default for SymbolTableStack { -- 2.45.2