mlua/
thread.rs

1use std::os::raw::{c_int, c_void};
2
3use crate::error::{Error, Result};
4#[allow(unused)]
5use crate::lua::Lua;
6use crate::types::LuaRef;
7use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard};
8use crate::value::{FromLuaMulti, IntoLuaMulti};
9
10#[cfg(not(feature = "luau"))]
11use crate::{
12    hook::{Debug, HookTriggers},
13    types::MaybeSend,
14};
15
16#[cfg(feature = "async")]
17use {
18    crate::value::MultiValue,
19    futures_util::stream::Stream,
20    std::{
21        future::Future,
22        marker::PhantomData,
23        pin::Pin,
24        ptr::NonNull,
25        task::{Context, Poll, Waker},
26    },
27};
28
29/// Status of a Lua thread (coroutine).
30#[derive(Debug, Copy, Clone, Eq, PartialEq)]
31pub enum ThreadStatus {
32    /// The thread was just created, or is suspended because it has called `coroutine.yield`.
33    ///
34    /// If a thread is in this state, it can be resumed by calling [`Thread::resume`].
35    ///
36    /// [`Thread::resume`]: crate::Thread::resume
37    Resumable,
38    /// Either the thread has finished executing, or the thread is currently running.
39    Unresumable,
40    /// The thread has raised a Lua error during execution.
41    Error,
42}
43
44/// Handle to an internal Lua thread (coroutine).
45#[derive(Clone, Debug)]
46pub struct Thread<'lua>(pub(crate) LuaRef<'lua>, pub(crate) *mut ffi::lua_State);
47
48/// Owned handle to an internal Lua thread (coroutine).
49///
50/// The owned handle holds a *strong* reference to the current Lua instance.
51/// Be warned, if you place it into a Lua type (eg. [`UserData`] or a Rust callback), it is *very easy*
52/// to accidentally cause reference cycles that would prevent destroying Lua instance.
53///
54/// [`UserData`]: crate::UserData
55#[cfg(feature = "unstable")]
56#[cfg_attr(docsrs, doc(cfg(feature = "unstable")))]
57#[derive(Clone, Debug)]
58pub struct OwnedThread(
59    pub(crate) crate::types::LuaOwnedRef,
60    pub(crate) *mut ffi::lua_State,
61);
62
63#[cfg(feature = "unstable")]
64impl OwnedThread {
65    /// Get borrowed handle to the underlying Lua table.
66    #[cfg_attr(feature = "send", allow(unused))]
67    pub const fn to_ref(&self) -> Thread {
68        Thread(self.0.to_ref(), self.1)
69    }
70}
71
72/// Thread (coroutine) representation as an async [`Future`] or [`Stream`].
73///
74/// Requires `feature = "async"`
75///
76/// [`Future`]: std::future::Future
77/// [`Stream`]: futures_util::stream::Stream
78#[cfg(feature = "async")]
79#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
80#[must_use = "futures do nothing unless you `.await` or poll them"]
81pub struct AsyncThread<'lua, R> {
82    thread: Thread<'lua>,
83    init_args: Option<Result<MultiValue<'lua>>>,
84    ret: PhantomData<R>,
85    recycle: bool,
86}
87
88impl<'lua> Thread<'lua> {
89    #[inline(always)]
90    pub(crate) fn new(r#ref: LuaRef<'lua>) -> Self {
91        let state = unsafe { ffi::lua_tothread(r#ref.lua.ref_thread(), r#ref.index) };
92        Thread(r#ref, state)
93    }
94
95    const fn state(&self) -> *mut ffi::lua_State {
96        self.1
97    }
98
99    /// Resumes execution of this thread.
100    ///
101    /// Equivalent to `coroutine.resume`.
102    ///
103    /// Passes `args` as arguments to the thread. If the coroutine has called `coroutine.yield`, it
104    /// will return these arguments. Otherwise, the coroutine wasn't yet started, so the arguments
105    /// are passed to its main function.
106    ///
107    /// If the thread is no longer in `Active` state (meaning it has finished execution or
108    /// encountered an error), this will return `Err(CoroutineInactive)`, otherwise will return `Ok`
109    /// as follows:
110    ///
111    /// If the thread calls `coroutine.yield`, returns the values passed to `yield`. If the thread
112    /// `return`s values from its main function, returns those.
113    ///
114    /// # Examples
115    ///
116    /// ```
117    /// # use mlua::{Error, Lua, Result, Thread};
118    /// # fn main() -> Result<()> {
119    /// # let lua = Lua::new();
120    /// let thread: Thread = lua.load(r#"
121    ///     coroutine.create(function(arg)
122    ///         assert(arg == 42)
123    ///         local yieldarg = coroutine.yield(123)
124    ///         assert(yieldarg == 43)
125    ///         return 987
126    ///     end)
127    /// "#).eval()?;
128    ///
129    /// assert_eq!(thread.resume::<_, u32>(42)?, 123);
130    /// assert_eq!(thread.resume::<_, u32>(43)?, 987);
131    ///
132    /// // The coroutine has now returned, so `resume` will fail
133    /// match thread.resume::<_, u32>(()) {
134    ///     Err(Error::CoroutineInactive) => {},
135    ///     unexpected => panic!("unexpected result {:?}", unexpected),
136    /// }
137    /// # Ok(())
138    /// # }
139    /// ```
140    pub fn resume<A, R>(&self, args: A) -> Result<R>
141    where
142        A: IntoLuaMulti<'lua>,
143        R: FromLuaMulti<'lua>,
144    {
145        if self.status() != ThreadStatus::Resumable {
146            return Err(Error::CoroutineInactive);
147        }
148
149        let lua = self.0.lua;
150        let state = lua.state();
151        let thread_state = self.state();
152        unsafe {
153            let _sg = StackGuard::new(state);
154            let _thread_sg = StackGuard::with_top(thread_state, 0);
155
156            let nresults = self.resume_inner(args)?;
157            check_stack(state, nresults + 1)?;
158            ffi::lua_xmove(thread_state, state, nresults);
159
160            R::from_stack_multi(nresults, lua)
161        }
162    }
163
164    /// Resumes execution of this thread.
165    ///
166    /// It's similar to `resume()` but leaves `nresults` values on the thread stack.
167    unsafe fn resume_inner<A: IntoLuaMulti<'lua>>(&self, args: A) -> Result<c_int> {
168        let lua = self.0.lua;
169        let state = lua.state();
170        let thread_state = self.state();
171
172        let nargs = args.push_into_stack_multi(lua)?;
173        if nargs > 0 {
174            check_stack(thread_state, nargs)?;
175            ffi::lua_xmove(state, thread_state, nargs);
176        }
177
178        let mut nresults = 0;
179        let ret = ffi::lua_resume(thread_state, state, nargs, &mut nresults as *mut c_int);
180        if ret != ffi::LUA_OK && ret != ffi::LUA_YIELD {
181            if ret == ffi::LUA_ERRMEM {
182                // Don't call error handler for memory errors
183                return Err(pop_error(thread_state, ret));
184            }
185            check_stack(state, 3)?;
186            protect_lua!(state, 0, 1, |state| error_traceback_thread(
187                state,
188                thread_state
189            ))?;
190            return Err(pop_error(state, ret));
191        }
192
193        Ok(nresults)
194    }
195
196    /// Gets the status of the thread.
197    pub fn status(&self) -> ThreadStatus {
198        let thread_state = self.state();
199        if thread_state == self.0.lua.state() {
200            // The coroutine is currently running
201            return ThreadStatus::Unresumable;
202        }
203        unsafe {
204            let status = ffi::lua_status(thread_state);
205            if status != ffi::LUA_OK && status != ffi::LUA_YIELD {
206                ThreadStatus::Error
207            } else if status == ffi::LUA_YIELD || ffi::lua_gettop(thread_state) > 0 {
208                ThreadStatus::Resumable
209            } else {
210                ThreadStatus::Unresumable
211            }
212        }
213    }
214
215    /// Sets a 'hook' function that will periodically be called as Lua code executes.
216    ///
217    /// This function is similar or [`Lua::set_hook()`] except that it sets for the thread.
218    /// To remove a hook call [`Lua::remove_hook()`].
219    #[cfg(not(feature = "luau"))]
220    #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
221    pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F)
222    where
223        F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static,
224    {
225        let lua = self.0.lua;
226        unsafe {
227            lua.set_thread_hook(self.state(), triggers, callback);
228        }
229    }
230
231    /// Resets a thread
232    ///
233    /// In [Lua 5.4]: cleans its call stack and closes all pending to-be-closed variables.
234    /// Returns a error in case of either the original error that stopped the thread or errors
235    /// in closing methods.
236    ///
237    /// In Luau: resets to the initial state of a newly created Lua thread.
238    /// Lua threads in arbitrary states (like yielded or errored) can be reset properly.
239    ///
240    /// Sets a Lua function for the thread afterwards.
241    ///
242    /// Requires `feature = "lua54"` OR `feature = "luau"`.
243    ///
244    /// [Lua 5.4]: https://www.lua.org/manual/5.4/manual.html#lua_closethread
245    #[cfg(any(feature = "lua54", feature = "luau"))]
246    #[cfg_attr(docsrs, doc(cfg(any(feature = "lua54", feature = "luau"))))]
247    pub fn reset(&self, func: crate::function::Function<'lua>) -> Result<()> {
248        let lua = self.0.lua;
249        let thread_state = self.state();
250        if thread_state == lua.state() {
251            return Err(Error::runtime("cannot reset a running thread"));
252        }
253        unsafe {
254            #[cfg(all(feature = "lua54", not(feature = "vendored")))]
255            let status = ffi::lua_resetthread(thread_state);
256            #[cfg(all(feature = "lua54", feature = "vendored"))]
257            let status = ffi::lua_closethread(thread_state, lua.state());
258            #[cfg(feature = "lua54")]
259            if status != ffi::LUA_OK {
260                return Err(pop_error(thread_state, status));
261            }
262            #[cfg(feature = "luau")]
263            ffi::lua_resetthread(thread_state);
264
265            // Push function to the top of the thread stack
266            ffi::lua_xpush(lua.ref_thread(), thread_state, func.0.index);
267
268            #[cfg(feature = "luau")]
269            {
270                // Inherit `LUA_GLOBALSINDEX` from the main thread
271                ffi::lua_xpush(lua.main_state(), thread_state, ffi::LUA_GLOBALSINDEX);
272                ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX);
273            }
274
275            Ok(())
276        }
277    }
278
279    /// Converts Thread to an AsyncThread which implements [`Future`] and [`Stream`] traits.
280    ///
281    /// `args` are passed as arguments to the thread function for first call.
282    /// The object calls [`resume()`] while polling and also allows to run rust futures
283    /// to completion using an executor.
284    ///
285    /// Using AsyncThread as a Stream allows to iterate through `coroutine.yield()`
286    /// values whereas Future version discards that values and poll until the final
287    /// one (returned from the thread function).
288    ///
289    /// Requires `feature = "async"`
290    ///
291    /// [`Future`]: std::future::Future
292    /// [`Stream`]: futures_util::stream::Stream
293    /// [`resume()`]: https://www.lua.org/manual/5.4/manual.html#lua_resume
294    ///
295    /// # Examples
296    ///
297    /// ```
298    /// # use mlua::{Lua, Result, Thread};
299    /// use futures::stream::TryStreamExt;
300    /// # #[tokio::main]
301    /// # async fn main() -> Result<()> {
302    /// # let lua = Lua::new();
303    /// let thread: Thread = lua.load(r#"
304    ///     coroutine.create(function (sum)
305    ///         for i = 1,10 do
306    ///             sum = sum + i
307    ///             coroutine.yield(sum)
308    ///         end
309    ///         return sum
310    ///     end)
311    /// "#).eval()?;
312    ///
313    /// let mut stream = thread.into_async::<_, i64>(1);
314    /// let mut sum = 0;
315    /// while let Some(n) = stream.try_next().await? {
316    ///     sum += n;
317    /// }
318    ///
319    /// assert_eq!(sum, 286);
320    ///
321    /// # Ok(())
322    /// # }
323    /// ```
324    #[cfg(feature = "async")]
325    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
326    pub fn into_async<A, R>(self, args: A) -> AsyncThread<'lua, R>
327    where
328        A: IntoLuaMulti<'lua>,
329        R: FromLuaMulti<'lua>,
330    {
331        let args = args.into_lua_multi(self.0.lua);
332        AsyncThread {
333            thread: self,
334            init_args: Some(args),
335            ret: PhantomData,
336            recycle: false,
337        }
338    }
339
340    /// Enables sandbox mode on this thread.
341    ///
342    /// Under the hood replaces the global environment table with a new table,
343    /// that performs writes locally and proxies reads to caller's global environment.
344    ///
345    /// This mode ideally should be used together with the global sandbox mode [`Lua::sandbox()`].
346    ///
347    /// Please note that Luau links environment table with chunk when loading it into Lua state.
348    /// Therefore you need to load chunks into a thread to link with the thread environment.
349    ///
350    /// # Examples
351    ///
352    /// ```
353    /// # use mlua::{Lua, Result};
354    /// # fn main() -> Result<()> {
355    /// let lua = Lua::new();
356    /// let thread = lua.create_thread(lua.create_function(|lua2, ()| {
357    ///     lua2.load("var = 123").exec()?;
358    ///     assert_eq!(lua2.globals().get::<_, u32>("var")?, 123);
359    ///     Ok(())
360    /// })?)?;
361    /// thread.sandbox()?;
362    /// thread.resume(())?;
363    ///
364    /// // The global environment should be unchanged
365    /// assert_eq!(lua.globals().get::<_, Option<u32>>("var")?, None);
366    /// # Ok(())
367    /// # }
368    /// ```
369    ///
370    /// Requires `feature = "luau"`
371    #[cfg(any(feature = "luau", docsrs))]
372    #[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
373    #[doc(hidden)]
374    pub fn sandbox(&self) -> Result<()> {
375        let lua = self.0.lua;
376        let state = lua.state();
377        let thread_state = self.state();
378        unsafe {
379            check_stack(thread_state, 3)?;
380            check_stack(state, 3)?;
381            protect_lua!(state, 0, 0, |_| ffi::luaL_sandboxthread(thread_state))
382        }
383    }
384
385    /// Converts this thread to a generic C pointer.
386    ///
387    /// There is no way to convert the pointer back to its original value.
388    ///
389    /// Typically this function is used only for hashing and debug information.
390    #[inline]
391    pub fn to_pointer(&self) -> *const c_void {
392        self.0.to_pointer()
393    }
394
395    /// Convert this handle to owned version.
396    #[cfg(all(feature = "unstable", any(not(feature = "send"), doc)))]
397    #[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", not(feature = "send")))))]
398    #[inline]
399    pub fn into_owned(self) -> OwnedThread {
400        OwnedThread(self.0.into_owned(), self.1)
401    }
402}
403
404impl<'lua> PartialEq for Thread<'lua> {
405    fn eq(&self, other: &Self) -> bool {
406        self.0 == other.0
407    }
408}
409
410// Additional shortcuts
411#[cfg(feature = "unstable")]
412impl OwnedThread {
413    /// Resumes execution of this thread.
414    ///
415    /// See [`Thread::resume()`] for more details.
416    pub fn resume<'lua, A, R>(&'lua self, args: A) -> Result<R>
417    where
418        A: IntoLuaMulti<'lua>,
419        R: FromLuaMulti<'lua>,
420    {
421        self.to_ref().resume(args)
422    }
423
424    /// Gets the status of the thread.
425    pub fn status(&self) -> ThreadStatus {
426        self.to_ref().status()
427    }
428}
429
430#[cfg(feature = "async")]
431impl<'lua, R> AsyncThread<'lua, R> {
432    #[inline]
433    pub(crate) fn set_recyclable(&mut self, recyclable: bool) {
434        self.recycle = recyclable;
435    }
436}
437
438#[cfg(feature = "async")]
439#[cfg(any(feature = "lua54", feature = "luau"))]
440impl<'lua, R> Drop for AsyncThread<'lua, R> {
441    fn drop(&mut self) {
442        if self.recycle {
443            unsafe {
444                let lua = self.thread.0.lua;
445                // For Lua 5.4 this also closes all pending to-be-closed variables
446                if !lua.recycle_thread(&mut self.thread) {
447                    #[cfg(feature = "lua54")]
448                    if self.thread.status() == ThreadStatus::Error {
449                        #[cfg(not(feature = "vendored"))]
450                        ffi::lua_resetthread(self.thread.state());
451                        #[cfg(feature = "vendored")]
452                        ffi::lua_closethread(self.thread.state(), lua.state());
453                    }
454                }
455            }
456        }
457    }
458}
459
460#[cfg(feature = "async")]
461impl<'lua, R> Stream for AsyncThread<'lua, R>
462where
463    R: FromLuaMulti<'lua>,
464{
465    type Item = Result<R>;
466
467    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
468        if self.thread.status() != ThreadStatus::Resumable {
469            return Poll::Ready(None);
470        }
471
472        let lua = self.thread.0.lua;
473        let state = lua.state();
474        let thread_state = self.thread.state();
475        unsafe {
476            let _sg = StackGuard::new(state);
477            let _thread_sg = StackGuard::with_top(thread_state, 0);
478            let _wg = WakerGuard::new(lua, cx.waker());
479
480            // This is safe as we are not moving the whole struct
481            let this = self.get_unchecked_mut();
482            let nresults = if let Some(args) = this.init_args.take() {
483                this.thread.resume_inner(args?)?
484            } else {
485                this.thread.resume_inner(())?
486            };
487
488            if nresults == 1 && is_poll_pending(thread_state) {
489                return Poll::Pending;
490            }
491
492            check_stack(state, nresults + 1)?;
493            ffi::lua_xmove(thread_state, state, nresults);
494
495            cx.waker().wake_by_ref();
496            Poll::Ready(Some(R::from_stack_multi(nresults, lua)))
497        }
498    }
499}
500
501#[cfg(feature = "async")]
502impl<'lua, R> Future for AsyncThread<'lua, R>
503where
504    R: FromLuaMulti<'lua>,
505{
506    type Output = Result<R>;
507
508    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
509        if self.thread.status() != ThreadStatus::Resumable {
510            return Poll::Ready(Err(Error::CoroutineInactive));
511        }
512
513        let lua = self.thread.0.lua;
514        let state = lua.state();
515        let thread_state = self.thread.state();
516        unsafe {
517            let _sg = StackGuard::new(state);
518            let _thread_sg = StackGuard::with_top(thread_state, 0);
519            let _wg = WakerGuard::new(lua, cx.waker());
520
521            // This is safe as we are not moving the whole struct
522            let this = self.get_unchecked_mut();
523            let nresults = if let Some(args) = this.init_args.take() {
524                this.thread.resume_inner(args?)?
525            } else {
526                this.thread.resume_inner(())?
527            };
528
529            if nresults == 1 && is_poll_pending(thread_state) {
530                return Poll::Pending;
531            }
532
533            if ffi::lua_status(thread_state) == ffi::LUA_YIELD {
534                // Ignore value returned via yield()
535                cx.waker().wake_by_ref();
536                return Poll::Pending;
537            }
538
539            check_stack(state, nresults + 1)?;
540            ffi::lua_xmove(thread_state, state, nresults);
541
542            Poll::Ready(R::from_stack_multi(nresults, lua))
543        }
544    }
545}
546
547#[cfg(feature = "async")]
548#[inline(always)]
549unsafe fn is_poll_pending(state: *mut ffi::lua_State) -> bool {
550    ffi::lua_tolightuserdata(state, -1) == Lua::poll_pending().0
551}
552
553#[cfg(feature = "async")]
554struct WakerGuard<'lua, 'a> {
555    lua: &'lua Lua,
556    prev: NonNull<Waker>,
557    _phantom: PhantomData<&'a ()>,
558}
559
560#[cfg(feature = "async")]
561impl<'lua, 'a> WakerGuard<'lua, 'a> {
562    #[inline]
563    pub fn new(lua: &'lua Lua, waker: &'a Waker) -> Result<WakerGuard<'lua, 'a>> {
564        let prev = unsafe { lua.set_waker(NonNull::from(waker)) };
565        Ok(WakerGuard {
566            lua,
567            prev,
568            _phantom: PhantomData,
569        })
570    }
571}
572
573#[cfg(feature = "async")]
574impl<'lua, 'a> Drop for WakerGuard<'lua, 'a> {
575    fn drop(&mut self) {
576        unsafe { self.lua.set_waker(self.prev) };
577    }
578}
579
580#[cfg(test)]
581mod assertions {
582    use super::*;
583
584    static_assertions::assert_not_impl_any!(Thread: Send);
585}