mlua/
function.rs

1use std::cell::RefCell;
2use std::mem;
3use std::os::raw::{c_int, c_void};
4use std::ptr;
5use std::slice;
6
7use crate::error::{Error, Result};
8use crate::lua::Lua;
9use crate::table::Table;
10use crate::types::{Callback, LuaRef, MaybeSend};
11use crate::util::{
12    assert_stack, check_stack, linenumber_to_usize, pop_error, ptr_to_lossy_str, ptr_to_str,
13    StackGuard,
14};
15use crate::value::{FromLuaMulti, IntoLua, IntoLuaMulti, Value};
16
17#[cfg(feature = "async")]
18use {
19    crate::types::AsyncCallback,
20    futures_util::future::{self, Future},
21};
22
23/// Handle to an internal Lua function.
24#[derive(Clone, Debug)]
25pub struct Function<'lua>(pub(crate) LuaRef<'lua>);
26
27/// Owned handle to an internal Lua function.
28///
29/// The owned handle holds a *strong* reference to the current Lua instance.
30/// Be warned, if you place it into a Lua type (eg. [`UserData`] or a Rust callback), it is *very easy*
31/// to accidentally cause reference cycles that would prevent destroying Lua instance.
32///
33/// [`UserData`]: crate::UserData
34#[cfg(feature = "unstable")]
35#[cfg_attr(docsrs, doc(cfg(feature = "unstable")))]
36#[derive(Clone, Debug)]
37pub struct OwnedFunction(pub(crate) crate::types::LuaOwnedRef);
38
39#[cfg(feature = "unstable")]
40impl OwnedFunction {
41    /// Get borrowed handle to the underlying Lua function.
42    #[cfg_attr(feature = "send", allow(unused))]
43    pub const fn to_ref(&self) -> Function {
44        Function(self.0.to_ref())
45    }
46}
47
48/// Contains information about a function.
49///
50/// Please refer to the [`Lua Debug Interface`] for more information.
51///
52/// [`Lua Debug Interface`]: https://www.lua.org/manual/5.4/manual.html#4.7
53#[derive(Clone, Debug)]
54pub struct FunctionInfo {
55    /// A (reasonable) name of the function (`None` if the name cannot be found).
56    pub name: Option<String>,
57    /// Explains the `name` field (can be `global`/`local`/`method`/`field`/`upvalue`/etc).
58    ///
59    /// Always `None` for Luau.
60    pub name_what: Option<&'static str>,
61    /// A string `Lua` if the function is a Lua function, `C` if it is a C function, `main` if it is the main part of a chunk.
62    pub what: &'static str,
63    /// Source of the chunk that created the function.
64    pub source: Option<String>,
65    /// A "printable" version of `source`, to be used in error messages.
66    pub short_src: Option<String>,
67    /// The line number where the definition of the function starts.
68    pub line_defined: Option<usize>,
69    /// The line number where the definition of the function ends (not set by Luau).
70    pub last_line_defined: Option<usize>,
71}
72
73/// Luau function coverage snapshot.
74#[cfg(any(feature = "luau", doc))]
75#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
76#[derive(Clone, Debug, PartialEq, Eq)]
77pub struct CoverageInfo {
78    pub function: Option<String>,
79    pub line_defined: i32,
80    pub depth: i32,
81    pub hits: Vec<i32>,
82}
83
84impl<'lua> Function<'lua> {
85    /// Calls the function, passing `args` as function arguments.
86    ///
87    /// The function's return values are converted to the generic type `R`.
88    ///
89    /// # Examples
90    ///
91    /// Call Lua's built-in `tostring` function:
92    ///
93    /// ```
94    /// # use mlua::{Function, Lua, Result};
95    /// # fn main() -> Result<()> {
96    /// # let lua = Lua::new();
97    /// let globals = lua.globals();
98    ///
99    /// let tostring: Function = globals.get("tostring")?;
100    ///
101    /// assert_eq!(tostring.call::<_, String>(123)?, "123");
102    ///
103    /// # Ok(())
104    /// # }
105    /// ```
106    ///
107    /// Call a function with multiple arguments:
108    ///
109    /// ```
110    /// # use mlua::{Function, Lua, Result};
111    /// # fn main() -> Result<()> {
112    /// # let lua = Lua::new();
113    /// let sum: Function = lua.load(
114    ///     r#"
115    ///         function(a, b)
116    ///             return a + b
117    ///         end
118    /// "#).eval()?;
119    ///
120    /// assert_eq!(sum.call::<_, u32>((3, 4))?, 3 + 4);
121    ///
122    /// # Ok(())
123    /// # }
124    /// ```
125    pub fn call<A: IntoLuaMulti<'lua>, R: FromLuaMulti<'lua>>(&self, args: A) -> Result<R> {
126        let lua = self.0.lua;
127        let state = lua.state();
128        unsafe {
129            let _sg = StackGuard::new(state);
130            check_stack(state, 2)?;
131
132            // Push error handler
133            lua.push_error_traceback();
134            let stack_start = ffi::lua_gettop(state);
135            // Push function and the arguments
136            lua.push_ref(&self.0);
137            let nargs = args.push_into_stack_multi(lua)?;
138            // Call the function
139            let ret = ffi::lua_pcall(state, nargs, ffi::LUA_MULTRET, stack_start);
140            if ret != ffi::LUA_OK {
141                return Err(pop_error(state, ret));
142            }
143            // Get the results
144            let nresults = ffi::lua_gettop(state) - stack_start;
145            R::from_stack_multi(nresults, lua)
146        }
147    }
148
149    /// Returns a future that, when polled, calls `self`, passing `args` as function arguments,
150    /// and drives the execution.
151    ///
152    /// Internally it wraps the function to an [`AsyncThread`].
153    ///
154    /// Requires `feature = "async"`
155    ///
156    /// # Examples
157    ///
158    /// ```
159    /// use std::time::Duration;
160    /// # use mlua::{Lua, Result};
161    /// # #[tokio::main]
162    /// # async fn main() -> Result<()> {
163    /// # let lua = Lua::new();
164    ///
165    /// let sleep = lua.create_async_function(move |_lua, n: u64| async move {
166    ///     tokio::time::sleep(Duration::from_millis(n)).await;
167    ///     Ok(())
168    /// })?;
169    ///
170    /// sleep.call_async(10).await?;
171    ///
172    /// # Ok(())
173    /// # }
174    /// ```
175    ///
176    /// [`AsyncThread`]: crate::AsyncThread
177    #[cfg(feature = "async")]
178    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
179    pub fn call_async<A, R>(&self, args: A) -> impl Future<Output = Result<R>> + 'lua
180    where
181        A: IntoLuaMulti<'lua>,
182        R: FromLuaMulti<'lua> + 'lua,
183    {
184        let lua = self.0.lua;
185        let thread_res = lua.create_recycled_thread(self).map(|th| {
186            let mut th = th.into_async(args);
187            th.set_recyclable(true);
188            th
189        });
190        async move { thread_res?.await }
191    }
192
193    /// Returns a function that, when called, calls `self`, passing `args` as the first set of
194    /// arguments.
195    ///
196    /// If any arguments are passed to the returned function, they will be passed after `args`.
197    ///
198    /// # Examples
199    ///
200    /// ```
201    /// # use mlua::{Function, Lua, Result};
202    /// # fn main() -> Result<()> {
203    /// # let lua = Lua::new();
204    /// let sum: Function = lua.load(
205    ///     r#"
206    ///         function(a, b)
207    ///             return a + b
208    ///         end
209    /// "#).eval()?;
210    ///
211    /// let bound_a = sum.bind(1)?;
212    /// assert_eq!(bound_a.call::<_, u32>(2)?, 1 + 2);
213    ///
214    /// let bound_a_and_b = sum.bind(13)?.bind(57)?;
215    /// assert_eq!(bound_a_and_b.call::<_, u32>(())?, 13 + 57);
216    ///
217    /// # Ok(())
218    /// # }
219    /// ```
220    pub fn bind<A: IntoLuaMulti<'lua>>(&self, args: A) -> Result<Function<'lua>> {
221        unsafe extern "C-unwind" fn args_wrapper_impl(state: *mut ffi::lua_State) -> c_int {
222            let nargs = ffi::lua_gettop(state);
223            let nbinds = ffi::lua_tointeger(state, ffi::lua_upvalueindex(1)) as c_int;
224            ffi::luaL_checkstack(state, nbinds, ptr::null());
225
226            for i in 0..nbinds {
227                ffi::lua_pushvalue(state, ffi::lua_upvalueindex(i + 2));
228            }
229            if nargs > 0 {
230                ffi::lua_rotate(state, 1, nbinds);
231            }
232
233            nargs + nbinds
234        }
235
236        let lua = self.0.lua;
237        let state = lua.state();
238
239        let args = args.into_lua_multi(lua)?;
240        let nargs = args.len() as c_int;
241
242        if nargs == 0 {
243            return Ok(self.clone());
244        }
245
246        if nargs + 1 > ffi::LUA_MAX_UPVALUES {
247            return Err(Error::BindError);
248        }
249
250        let args_wrapper = unsafe {
251            let _sg = StackGuard::new(state);
252            check_stack(state, nargs + 3)?;
253
254            ffi::lua_pushinteger(state, nargs as ffi::lua_Integer);
255            for arg in args {
256                lua.push_value(arg)?;
257            }
258            protect_lua!(state, nargs + 1, 1, fn(state) {
259                ffi::lua_pushcclosure(state, args_wrapper_impl, ffi::lua_gettop(state));
260            })?;
261
262            Function(lua.pop_ref())
263        };
264
265        lua.load(
266            r#"
267            local func, args_wrapper = ...
268            return function(...)
269                return func(args_wrapper(...))
270            end
271            "#,
272        )
273        .try_cache()
274        .set_name("__mlua_bind")
275        .call((self.clone(), args_wrapper))
276    }
277
278    /// Returns the environment of the Lua function.
279    ///
280    /// By default Lua functions shares a global environment.
281    ///
282    /// This function always returns `None` for Rust/C functions.
283    pub fn environment(&self) -> Option<Table> {
284        let lua = self.0.lua;
285        let state = lua.state();
286        unsafe {
287            let _sg = StackGuard::new(state);
288            assert_stack(state, 1);
289
290            lua.push_ref(&self.0);
291            if ffi::lua_iscfunction(state, -1) != 0 {
292                return None;
293            }
294
295            #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))]
296            ffi::lua_getfenv(state, -1);
297            #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
298            for i in 1..=255 {
299                // Traverse upvalues until we find the _ENV one
300                match ffi::lua_getupvalue(state, -1, i) {
301                    s if s.is_null() => break,
302                    s if std::ffi::CStr::from_ptr(s as _).to_bytes() == b"_ENV" => break,
303                    _ => ffi::lua_pop(state, 1),
304                }
305            }
306
307            if ffi::lua_type(state, -1) != ffi::LUA_TTABLE {
308                return None;
309            }
310            Some(Table(lua.pop_ref()))
311        }
312    }
313
314    /// Sets the environment of the Lua function.
315    ///
316    /// The environment is a table that is used as the global environment for the function.
317    /// Returns `true` if environment successfully changed, `false` otherwise.
318    ///
319    /// This function does nothing for Rust/C functions.
320    pub fn set_environment(&self, env: Table) -> Result<bool> {
321        let lua = self.0.lua;
322        let state = lua.state();
323        unsafe {
324            let _sg = StackGuard::new(state);
325            check_stack(state, 2)?;
326
327            lua.push_ref(&self.0);
328            if ffi::lua_iscfunction(state, -1) != 0 {
329                return Ok(false);
330            }
331
332            #[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))]
333            {
334                lua.push_ref(&env.0);
335                ffi::lua_setfenv(state, -2);
336            }
337            #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
338            for i in 1..=255 {
339                match ffi::lua_getupvalue(state, -1, i) {
340                    s if s.is_null() => return Ok(false),
341                    s if std::ffi::CStr::from_ptr(s as _).to_bytes() == b"_ENV" => {
342                        ffi::lua_pop(state, 1);
343                        // Create an anonymous function with the new environment
344                        let f_with_env = lua
345                            .load("return _ENV")
346                            .set_environment(env)
347                            .try_cache()
348                            .into_function()?;
349                        lua.push_ref(&f_with_env.0);
350                        ffi::lua_upvaluejoin(state, -2, i, -1, 1);
351                        break;
352                    }
353                    _ => ffi::lua_pop(state, 1),
354                }
355            }
356
357            Ok(true)
358        }
359    }
360
361    /// Returns information about the function.
362    ///
363    /// Corresponds to the `>Sn` what mask for [`lua_getinfo`] when applied to the function.
364    ///
365    /// [`lua_getinfo`]: https://www.lua.org/manual/5.4/manual.html#lua_getinfo
366    pub fn info(&self) -> FunctionInfo {
367        let lua = self.0.lua;
368        let state = lua.state();
369        unsafe {
370            let _sg = StackGuard::new(state);
371            assert_stack(state, 1);
372
373            let mut ar: ffi::lua_Debug = mem::zeroed();
374            lua.push_ref(&self.0);
375            #[cfg(not(feature = "luau"))]
376            let res = ffi::lua_getinfo(state, cstr!(">Sn"), &mut ar);
377            #[cfg(feature = "luau")]
378            let res = ffi::lua_getinfo(state, -1, cstr!("sn"), &mut ar);
379            mlua_assert!(res != 0, "lua_getinfo failed with `>Sn`");
380
381            FunctionInfo {
382                name: ptr_to_lossy_str(ar.name).map(|s| s.into_owned()),
383                #[cfg(not(feature = "luau"))]
384                name_what: match ptr_to_str(ar.namewhat) {
385                    Some("") => None,
386                    val => val,
387                },
388                #[cfg(feature = "luau")]
389                name_what: None,
390                what: ptr_to_str(ar.what).unwrap_or("main"),
391                source: ptr_to_lossy_str(ar.source).map(|s| s.into_owned()),
392                #[cfg(not(feature = "luau"))]
393                short_src: ptr_to_lossy_str(ar.short_src.as_ptr()).map(|s| s.into_owned()),
394                #[cfg(feature = "luau")]
395                short_src: ptr_to_lossy_str(ar.short_src).map(|s| s.into_owned()),
396                line_defined: linenumber_to_usize(ar.linedefined),
397                #[cfg(not(feature = "luau"))]
398                last_line_defined: linenumber_to_usize(ar.lastlinedefined),
399                #[cfg(feature = "luau")]
400                last_line_defined: None,
401            }
402        }
403    }
404
405    /// Dumps the function as a binary chunk.
406    ///
407    /// If `strip` is true, the binary representation may not include all debug information
408    /// about the function, to save space.
409    ///
410    /// For Luau a [Compiler] can be used to compile Lua chunks to bytecode.
411    ///
412    /// [Compiler]: crate::chunk::Compiler
413    #[cfg(not(feature = "luau"))]
414    #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
415    pub fn dump(&self, strip: bool) -> Vec<u8> {
416        unsafe extern "C-unwind" fn writer(
417            _state: *mut ffi::lua_State,
418            buf: *const c_void,
419            buf_len: usize,
420            data: *mut c_void,
421        ) -> c_int {
422            let data = &mut *(data as *mut Vec<u8>);
423            let buf = slice::from_raw_parts(buf as *const u8, buf_len);
424            data.extend_from_slice(buf);
425            0
426        }
427
428        let lua = self.0.lua;
429        let state = lua.state();
430        let mut data: Vec<u8> = Vec::new();
431        unsafe {
432            let _sg = StackGuard::new(state);
433            assert_stack(state, 1);
434
435            lua.push_ref(&self.0);
436            let data_ptr = &mut data as *mut Vec<u8> as *mut c_void;
437            ffi::lua_dump(state, writer, data_ptr, strip as i32);
438            ffi::lua_pop(state, 1);
439        }
440
441        data
442    }
443
444    /// Retrieves recorded coverage information about this Lua function including inner calls.
445    ///
446    /// This function takes a callback as an argument and calls it providing [`CoverageInfo`] snapshot
447    /// per each executed inner function.
448    ///
449    /// Recording of coverage information is controlled by [`Compiler::set_coverage_level`] option.
450    ///
451    /// Requires `feature = "luau"`
452    ///
453    /// [`Compiler::set_coverage_level`]: crate::chunk::Compiler::set_coverage_level
454    #[cfg(any(feature = "luau", doc))]
455    #[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
456    pub fn coverage<F>(&self, mut func: F)
457    where
458        F: FnMut(CoverageInfo),
459    {
460        use std::ffi::CStr;
461        use std::os::raw::c_char;
462
463        unsafe extern "C-unwind" fn callback<F: FnMut(CoverageInfo)>(
464            data: *mut c_void,
465            function: *const c_char,
466            line_defined: c_int,
467            depth: c_int,
468            hits: *const c_int,
469            size: usize,
470        ) {
471            let function = if !function.is_null() {
472                Some(CStr::from_ptr(function).to_string_lossy().to_string())
473            } else {
474                None
475            };
476            let rust_callback = &mut *(data as *mut F);
477            rust_callback(CoverageInfo {
478                function,
479                line_defined,
480                depth,
481                hits: slice::from_raw_parts(hits, size).to_vec(),
482            });
483        }
484
485        let lua = self.0.lua;
486        let state = lua.state();
487        unsafe {
488            let _sg = StackGuard::new(state);
489            assert_stack(state, 1);
490
491            lua.push_ref(&self.0);
492            let func_ptr = &mut func as *mut F as *mut c_void;
493            ffi::lua_getcoverage(state, -1, func_ptr, callback::<F>);
494        }
495    }
496
497    /// Converts this function to a generic C pointer.
498    ///
499    /// There is no way to convert the pointer back to its original value.
500    ///
501    /// Typically this function is used only for hashing and debug information.
502    #[inline]
503    pub fn to_pointer(&self) -> *const c_void {
504        self.0.to_pointer()
505    }
506
507    /// Creates a deep clone of the Lua function.
508    ///
509    /// Copies the function prototype and all its upvalues to the
510    /// newly created function.
511    ///
512    /// This function returns shallow clone (same handle) for Rust/C functions.
513    /// Requires `feature = "luau"`
514    #[cfg(feature = "luau")]
515    #[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
516    pub fn deep_clone(&self) -> Self {
517        let ref_thread = self.0.lua.ref_thread();
518        unsafe {
519            if ffi::lua_iscfunction(ref_thread, self.0.index) != 0 {
520                return self.clone();
521            }
522
523            ffi::lua_clonefunction(ref_thread, self.0.index);
524            Function(self.0.lua.pop_ref_thread())
525        }
526    }
527
528    /// Convert this handle to owned version.
529    #[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))]
530    #[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))]
531    #[inline]
532    pub fn into_owned(self) -> OwnedFunction {
533        OwnedFunction(self.0.into_owned())
534    }
535}
536
537impl<'lua> PartialEq for Function<'lua> {
538    fn eq(&self, other: &Self) -> bool {
539        self.0 == other.0
540    }
541}
542
543// Additional shortcuts
544#[cfg(feature = "unstable")]
545impl OwnedFunction {
546    /// Calls the function, passing `args` as function arguments.
547    ///
548    /// This is a shortcut for [`Function::call()`].
549    #[inline]
550    pub fn call<'lua, A, R>(&'lua self, args: A) -> Result<R>
551    where
552        A: IntoLuaMulti<'lua>,
553        R: FromLuaMulti<'lua>,
554    {
555        self.to_ref().call(args)
556    }
557
558    /// Returns a future that, when polled, calls `self`, passing `args` as function arguments,
559    /// and drives the execution.
560    ///
561    /// This is a shortcut for [`Function::call_async()`].
562    #[cfg(feature = "async")]
563    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
564    #[inline]
565    pub async fn call_async<'lua, A, R>(&'lua self, args: A) -> Result<R>
566    where
567        A: IntoLuaMulti<'lua>,
568        R: FromLuaMulti<'lua> + 'lua,
569    {
570        self.to_ref().call_async(args).await
571    }
572}
573
574pub(crate) struct WrappedFunction<'lua>(pub(crate) Callback<'lua, 'static>);
575
576#[cfg(feature = "async")]
577pub(crate) struct WrappedAsyncFunction<'lua>(pub(crate) AsyncCallback<'lua, 'static>);
578
579impl<'lua> Function<'lua> {
580    /// Wraps a Rust function or closure, returning an opaque type that implements [`IntoLua`] trait.
581    #[inline]
582    pub fn wrap<A, R, F>(func: F) -> impl IntoLua<'lua>
583    where
584        A: FromLuaMulti<'lua>,
585        R: IntoLuaMulti<'lua>,
586        F: Fn(&'lua Lua, A) -> Result<R> + MaybeSend + 'static,
587    {
588        WrappedFunction(Box::new(move |lua, nargs| unsafe {
589            let args = A::from_stack_args(nargs, 1, None, lua)?;
590            func(lua, args)?.push_into_stack_multi(lua)
591        }))
592    }
593
594    /// Wraps a Rust mutable closure, returning an opaque type that implements [`IntoLua`] trait.
595    #[inline]
596    pub fn wrap_mut<A, R, F>(func: F) -> impl IntoLua<'lua>
597    where
598        A: FromLuaMulti<'lua>,
599        R: IntoLuaMulti<'lua>,
600        F: FnMut(&'lua Lua, A) -> Result<R> + MaybeSend + 'static,
601    {
602        let func = RefCell::new(func);
603        WrappedFunction(Box::new(move |lua, nargs| unsafe {
604            let mut func = func
605                .try_borrow_mut()
606                .map_err(|_| Error::RecursiveMutCallback)?;
607            let args = A::from_stack_args(nargs, 1, None, lua)?;
608            func(lua, args)?.push_into_stack_multi(lua)
609        }))
610    }
611
612    /// Wraps a Rust async function or closure, returning an opaque type that implements [`IntoLua`] trait.
613    #[cfg(feature = "async")]
614    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
615    pub fn wrap_async<A, R, F, FR>(func: F) -> impl IntoLua<'lua>
616    where
617        A: FromLuaMulti<'lua>,
618        R: IntoLuaMulti<'lua>,
619        F: Fn(&'lua Lua, A) -> FR + MaybeSend + 'static,
620        FR: Future<Output = Result<R>> + 'lua,
621    {
622        WrappedAsyncFunction(Box::new(move |lua, args| unsafe {
623            let args = match A::from_lua_args(args, 1, None, lua) {
624                Ok(args) => args,
625                Err(e) => return Box::pin(future::err(e)),
626            };
627            let fut = func(lua, args);
628            Box::pin(async move { fut.await?.push_into_stack_multi(lua) })
629        }))
630    }
631}
632
633impl<'lua> IntoLua<'lua> for WrappedFunction<'lua> {
634    #[inline]
635    fn into_lua(self, lua: &'lua Lua) -> Result<Value<'lua>> {
636        lua.create_callback(self.0).map(Value::Function)
637    }
638}
639
640#[cfg(feature = "async")]
641impl<'lua> IntoLua<'lua> for WrappedAsyncFunction<'lua> {
642    #[inline]
643    fn into_lua(self, lua: &'lua Lua) -> Result<Value<'lua>> {
644        lua.create_async_callback(self.0).map(Value::Function)
645    }
646}
647
648#[cfg(test)]
649mod assertions {
650    use super::*;
651
652    static_assertions::assert_not_impl_any!(Function: Send);
653
654    #[cfg(all(feature = "unstable", not(feature = "send")))]
655    static_assertions::assert_not_impl_any!(OwnedFunction: Send);
656}