diff --git a/core/engine/src/bytecompiler/mod.rs b/core/engine/src/bytecompiler/mod.rs index dc13a911d3..385ec317af 100644 --- a/core/engine/src/bytecompiler/mod.rs +++ b/core/engine/src/bytecompiler/mod.rs @@ -1510,6 +1510,31 @@ impl<'ctx> ByteCompiler<'ctx> { } } + fn optimize(bytecode: &mut Vec, flags: CodeBlockFlags) { + // only perform Tail Call Optimisation in strict mode + if flags.contains(CodeBlockFlags::STRICT) { + // if a sequence of Call, SetReturnValue, CheckReturn, Return is found + // replace Call with TailCall to allow frame elimination + for i in 0..bytecode.len() { + let opcode: Opcode = bytecode[i].into(); + if matches!(opcode, Opcode::Call) { + if bytecode.len() > i + 5 { + let second = bytecode[i + 2].into(); + let third = bytecode[i + 3].into(); + let fourth = bytecode[i + 4].into(); + if let (Opcode::SetReturnValue, Opcode::CheckReturn, Opcode::Return) = + (second, third, fourth) + { + bytecode[i] = Opcode::TailCall as u8; + } + } else { + return; + } + } + } + } + } + /// Finish compiling code with the [`ByteCompiler`] and return the generated [`CodeBlock`]. #[inline] #[must_use] @@ -1521,6 +1546,8 @@ impl<'ctx> ByteCompiler<'ctx> { } self.r#return(false); + Self::optimize(&mut self.bytecode, self.code_block_flags); + CodeBlock { name: self.function_name, length: self.length, diff --git a/core/engine/src/vm/code_block.rs b/core/engine/src/vm/code_block.rs index 1ceec0f8ae..f69f728335 100644 --- a/core/engine/src/vm/code_block.rs +++ b/core/engine/src/vm/code_block.rs @@ -444,6 +444,9 @@ impl CodeBlock { | Instruction::Call { argument_count: value, } + | Instruction::TailCall { + argument_count: value, + } | Instruction::New { argument_count: value, } @@ -746,8 +749,7 @@ impl CodeBlock { | Instruction::Reserved54 | Instruction::Reserved55 | Instruction::Reserved56 - | Instruction::Reserved57 - | Instruction::Reserved58 => unreachable!("Reserved opcodes are unrechable"), + | Instruction::Reserved57 => unreachable!("Reserved opcodes are unreachable"), } } } diff --git a/core/engine/src/vm/flowgraph/mod.rs b/core/engine/src/vm/flowgraph/mod.rs index b612b6115d..cdea9b5da5 100644 --- a/core/engine/src/vm/flowgraph/mod.rs +++ b/core/engine/src/vm/flowgraph/mod.rs @@ -201,6 +201,7 @@ impl CodeBlock { } Instruction::CallEval { .. } | Instruction::Call { .. } + | Instruction::TailCall { .. } | Instruction::New { .. } | Instruction::SuperCall { .. } | Instruction::ConcatToString { .. } @@ -515,8 +516,7 @@ impl CodeBlock { | Instruction::Reserved54 | Instruction::Reserved55 | Instruction::Reserved56 - | Instruction::Reserved57 - | Instruction::Reserved58 => unreachable!("Reserved opcodes are unrechable"), + | Instruction::Reserved57 => unreachable!("Reserved opcodes are unreachable"), } } diff --git a/core/engine/src/vm/opcode/call/mod.rs b/core/engine/src/vm/opcode/call/mod.rs index 500bb46721..d722f0587d 100644 --- a/core/engine/src/vm/opcode/call/mod.rs +++ b/core/engine/src/vm/opcode/call/mod.rs @@ -1,5 +1,5 @@ use crate::{ - builtins::{promise::PromiseCapability, Promise}, + builtins::{function::OrdinaryFunction, promise::PromiseCapability, Promise}, error::JsNativeError, module::{ModuleKind, Referrer}, object::FunctionObjectBuilder, @@ -200,6 +200,82 @@ impl Operation for Call { } } +/// `TailCall` implements the Opcode Operation for `Opcode::TailCall` +/// +/// Operation: +/// - Tail call a function +#[derive(Debug, Clone, Copy)] +pub(crate) struct TailCall; + +impl TailCall { + fn operation(context: &mut Context, argument_count: usize) -> JsResult { + let at = context.vm.stack.len() - argument_count; + let func = &context.vm.stack[at - 1]; + + let Some(object) = func.as_object() else { + return Err(JsNativeError::typ() + .with_message("not a callable function") + .into()); + }; + + let is_ordinary_function = object.is::(); + object.__call__(argument_count).resolve(context)?; + + // only tail call for ordinary functions + // don't tail call on the main script + // TODO: the 3 needs to be reviewed + if is_ordinary_function && context.vm.frames.len() > 3 { + // check that caller is also ordinary function + let frames = &context.vm.frames; + let caller_frame = &frames[frames.len() - 2]; + let caller_function = caller_frame + .function(&context.vm) + .expect("there must be a caller function"); + if caller_function.is::() { + // remove caller's CallFrame + let frames = &mut context.vm.frames; + let caller_frame = frames.swap_remove(frames.len() - 2); + + // remove caller's prologue from stack + // this + func + arguments + let to_remove = 1 + 1 + caller_frame.argument_count as usize; + context + .vm + .stack + .drain((caller_frame.fp as usize)..(caller_frame.fp as usize + to_remove)); + + // update invoked function's fp + let frames = &mut context.vm.frames; + let invoked_frame = frames.last_mut().expect("invoked frame must exist"); + invoked_frame.set_exit_early(caller_frame.exit_early()); + invoked_frame.fp -= to_remove as u32; + } + } + Ok(CompletionType::Normal) + } +} + +impl Operation for TailCall { + const NAME: &'static str = "TailCall"; + const INSTRUCTION: &'static str = "INST - TailCall"; + const COST: u8 = 3; + + fn execute(context: &mut Context) -> JsResult { + let argument_count = context.vm.read::(); + Self::operation(context, argument_count as usize) + } + + fn execute_with_u16_operands(context: &mut Context) -> JsResult { + let argument_count = context.vm.read::() as usize; + Self::operation(context, argument_count) + } + + fn execute_with_u32_operands(context: &mut Context) -> JsResult { + let argument_count = context.vm.read::(); + Self::operation(context, argument_count as usize) + } +} + #[derive(Debug, Clone, Copy)] pub(crate) struct CallSpread; diff --git a/core/engine/src/vm/opcode/mod.rs b/core/engine/src/vm/opcode/mod.rs index 2d0bc52e79..f004daa72d 100644 --- a/core/engine/src/vm/opcode/mod.rs +++ b/core/engine/src/vm/opcode/mod.rs @@ -1700,6 +1700,13 @@ generate_opcodes! { /// Stack: this, func, argument_1, ... argument_n **=>** result Call { argument_count: VaryingOperand }, + /// Tail call a function. + /// + /// Operands: argument_count: `u32` + /// + /// Stack: this, func, argument_1, ... argument_n **=>** + TailCall { argument_count: VaryingOperand }, + /// Call a function where the arguments contain spreads. /// /// Operands: @@ -2220,8 +2227,6 @@ generate_opcodes! { Reserved56 => Reserved, /// Reserved [`Opcode`]. Reserved57 => Reserved, - /// Reserved [`Opcode`]. - Reserved58 => Reserved, } /// Specific opcodes for bindings.