mlua/serde/
de.rs

1use std::cell::RefCell;
2use std::os::raw::c_void;
3use std::rc::Rc;
4use std::result::Result as StdResult;
5use std::string::String as StdString;
6
7use rustc_hash::FxHashSet;
8use serde::de::{self, IntoDeserializer};
9
10use crate::error::{Error, Result};
11use crate::table::{Table, TablePairs, TableSequence};
12use crate::userdata::AnyUserData;
13use crate::value::Value;
14
15/// A struct for deserializing Lua values into Rust values.
16#[derive(Debug)]
17pub struct Deserializer<'lua> {
18    value: Value<'lua>,
19    options: Options,
20    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
21}
22
23/// A struct with options to change default deserializer behavior.
24#[derive(Debug, Clone, Copy)]
25#[non_exhaustive]
26pub struct Options {
27    /// If true, an attempt to serialize types such as [`Function`], [`Thread`], [`LightUserData`]
28    /// and [`Error`] will cause an error.
29    /// Otherwise these types skipped when iterating or serialized as unit type.
30    ///
31    /// Default: **true**
32    ///
33    /// [`Function`]: crate::Function
34    /// [`Thread`]: crate::Thread
35    /// [`LightUserData`]: crate::LightUserData
36    /// [`Error`]: crate::Error
37    pub deny_unsupported_types: bool,
38
39    /// If true, an attempt to serialize a recursive table (table that refers to itself)
40    /// will cause an error.
41    /// Otherwise subsequent attempts to serialize the same table will be ignored.
42    ///
43    /// Default: **true**
44    pub deny_recursive_tables: bool,
45
46    /// If true, keys in tables will be iterated in sorted order.
47    ///
48    /// Default: **false**
49    pub sort_keys: bool,
50}
51
52impl Default for Options {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl Options {
59    /// Returns a new instance of `Options` with default parameters.
60    pub const fn new() -> Self {
61        Options {
62            deny_unsupported_types: true,
63            deny_recursive_tables: true,
64            sort_keys: false,
65        }
66    }
67
68    /// Sets [`deny_unsupported_types`] option.
69    ///
70    /// [`deny_unsupported_types`]: #structfield.deny_unsupported_types
71    #[must_use]
72    pub const fn deny_unsupported_types(mut self, enabled: bool) -> Self {
73        self.deny_unsupported_types = enabled;
74        self
75    }
76
77    /// Sets [`deny_recursive_tables`] option.
78    ///
79    /// [`deny_recursive_tables`]: #structfield.deny_recursive_tables
80    #[must_use]
81    pub const fn deny_recursive_tables(mut self, enabled: bool) -> Self {
82        self.deny_recursive_tables = enabled;
83        self
84    }
85
86    /// Sets [`sort_keys`] option.
87    ///
88    /// [`sort_keys`]: #structfield.sort_keys
89    #[must_use]
90    pub const fn sort_keys(mut self, enabled: bool) -> Self {
91        self.sort_keys = enabled;
92        self
93    }
94}
95
96impl<'lua> Deserializer<'lua> {
97    /// Creates a new Lua Deserializer for the `Value`.
98    pub fn new(value: Value<'lua>) -> Self {
99        Self::new_with_options(value, Options::default())
100    }
101
102    /// Creates a new Lua Deserializer for the `Value` with custom options.
103    pub fn new_with_options(value: Value<'lua>, options: Options) -> Self {
104        Deserializer {
105            value,
106            options,
107            visited: Rc::new(RefCell::new(FxHashSet::default())),
108        }
109    }
110
111    fn from_parts(
112        value: Value<'lua>,
113        options: Options,
114        visited: Rc<RefCell<FxHashSet<*const c_void>>>,
115    ) -> Self {
116        Deserializer {
117            value,
118            options,
119            visited,
120        }
121    }
122}
123
124impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
125    type Error = Error;
126
127    #[inline]
128    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
129    where
130        V: de::Visitor<'de>,
131    {
132        match self.value {
133            Value::Nil => visitor.visit_unit(),
134            Value::Boolean(b) => visitor.visit_bool(b),
135            #[allow(clippy::useless_conversion)]
136            Value::Integer(i) => visitor.visit_i64(i.into()),
137            #[allow(clippy::useless_conversion)]
138            Value::Number(n) => visitor.visit_f64(n.into()),
139            #[cfg(feature = "luau")]
140            Value::Vector(_) => self.deserialize_seq(visitor),
141            Value::String(s) => match s.to_str() {
142                Ok(s) => visitor.visit_str(s),
143                Err(_) => visitor.visit_bytes(s.as_bytes()),
144            },
145            Value::Table(ref t) if t.raw_len() > 0 || t.is_array() => self.deserialize_seq(visitor),
146            Value::Table(_) => self.deserialize_map(visitor),
147            Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
148            Value::UserData(ud) if ud.is_serializable() => {
149                serde_userdata(ud, |value| value.deserialize_any(visitor))
150            }
151            #[cfg(feature = "luau")]
152            Value::UserData(ud) if ud.1 == crate::types::SubtypeId::Buffer => unsafe {
153                let mut size = 0usize;
154                let buf = ffi::lua_tobuffer(ud.0.lua.ref_thread(), ud.0.index, &mut size);
155                mlua_assert!(!buf.is_null(), "invalid Luau buffer");
156                let buf = std::slice::from_raw_parts(buf as *const u8, size);
157                visitor.visit_bytes(buf)
158            },
159            Value::Function(_)
160            | Value::Thread(_)
161            | Value::UserData(_)
162            | Value::LightUserData(_)
163            | Value::Error(_) => {
164                if self.options.deny_unsupported_types {
165                    let msg = format!("unsupported value type `{}`", self.value.type_name());
166                    Err(de::Error::custom(msg))
167                } else {
168                    visitor.visit_unit()
169                }
170            }
171        }
172    }
173
174    #[inline]
175    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
176    where
177        V: de::Visitor<'de>,
178    {
179        match self.value {
180            Value::Nil => visitor.visit_none(),
181            Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
182            _ => visitor.visit_some(self),
183        }
184    }
185
186    #[inline]
187    fn deserialize_enum<V>(
188        self,
189        name: &'static str,
190        variants: &'static [&'static str],
191        visitor: V,
192    ) -> Result<V::Value>
193    where
194        V: de::Visitor<'de>,
195    {
196        let (variant, value, _guard) = match self.value {
197            Value::Table(table) => {
198                let _guard = RecursionGuard::new(&table, &self.visited);
199
200                let mut iter = table.pairs::<StdString, Value>();
201                let (variant, value) = match iter.next() {
202                    Some(v) => v?,
203                    None => {
204                        return Err(de::Error::invalid_value(
205                            de::Unexpected::Map,
206                            &"map with a single key",
207                        ))
208                    }
209                };
210
211                if iter.next().is_some() {
212                    return Err(de::Error::invalid_value(
213                        de::Unexpected::Map,
214                        &"map with a single key",
215                    ));
216                }
217                let skip = check_value_for_skip(&value, self.options, &self.visited)
218                    .map_err(|err| Error::DeserializeError(err.to_string()))?;
219                if skip {
220                    return Err(de::Error::custom("bad enum value"));
221                }
222
223                (variant, Some(value), Some(_guard))
224            }
225            Value::String(variant) => (variant.to_str()?.to_owned(), None, None),
226            Value::UserData(ud) if ud.is_serializable() => {
227                return serde_userdata(ud, |value| value.deserialize_enum(name, variants, visitor));
228            }
229            _ => return Err(de::Error::custom("bad enum value")),
230        };
231
232        visitor.visit_enum(EnumDeserializer {
233            variant,
234            value,
235            options: self.options,
236            visited: self.visited,
237        })
238    }
239
240    #[inline]
241    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
242    where
243        V: de::Visitor<'de>,
244    {
245        match self.value {
246            #[cfg(feature = "luau")]
247            Value::Vector(vec) => {
248                let mut deserializer = VecDeserializer {
249                    vec,
250                    next: 0,
251                    options: self.options,
252                    visited: self.visited,
253                };
254                visitor.visit_seq(&mut deserializer)
255            }
256            Value::Table(t) => {
257                let _guard = RecursionGuard::new(&t, &self.visited);
258
259                let len = t.raw_len();
260                let mut deserializer = SeqDeserializer {
261                    seq: t.sequence_values(),
262                    options: self.options,
263                    visited: self.visited,
264                };
265                let seq = visitor.visit_seq(&mut deserializer)?;
266                if deserializer.seq.count() == 0 {
267                    Ok(seq)
268                } else {
269                    Err(de::Error::invalid_length(
270                        len,
271                        &"fewer elements in the table",
272                    ))
273                }
274            }
275            Value::UserData(ud) if ud.is_serializable() => {
276                serde_userdata(ud, |value| value.deserialize_seq(visitor))
277            }
278            value => Err(de::Error::invalid_type(
279                de::Unexpected::Other(value.type_name()),
280                &"table",
281            )),
282        }
283    }
284
285    #[inline]
286    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
287    where
288        V: de::Visitor<'de>,
289    {
290        self.deserialize_seq(visitor)
291    }
292
293    #[inline]
294    fn deserialize_tuple_struct<V>(
295        self,
296        _name: &'static str,
297        _len: usize,
298        visitor: V,
299    ) -> Result<V::Value>
300    where
301        V: de::Visitor<'de>,
302    {
303        self.deserialize_seq(visitor)
304    }
305
306    #[inline]
307    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
308    where
309        V: de::Visitor<'de>,
310    {
311        match self.value {
312            Value::Table(t) => {
313                let _guard = RecursionGuard::new(&t, &self.visited);
314
315                let mut deserializer = MapDeserializer {
316                    pairs: MapPairs::new(t, self.options.sort_keys)?,
317                    value: None,
318                    options: self.options,
319                    visited: self.visited,
320                    processed: 0,
321                };
322                let map = visitor.visit_map(&mut deserializer)?;
323                let count = deserializer.pairs.count();
324                if count == 0 {
325                    Ok(map)
326                } else {
327                    Err(de::Error::invalid_length(
328                        deserializer.processed + count,
329                        &"fewer elements in the table",
330                    ))
331                }
332            }
333            Value::UserData(ud) if ud.is_serializable() => {
334                serde_userdata(ud, |value| value.deserialize_map(visitor))
335            }
336            value => Err(de::Error::invalid_type(
337                de::Unexpected::Other(value.type_name()),
338                &"table",
339            )),
340        }
341    }
342
343    #[inline]
344    fn deserialize_struct<V>(
345        self,
346        _name: &'static str,
347        _fields: &'static [&'static str],
348        visitor: V,
349    ) -> Result<V::Value>
350    where
351        V: de::Visitor<'de>,
352    {
353        self.deserialize_map(visitor)
354    }
355
356    #[inline]
357    fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
358    where
359        V: de::Visitor<'de>,
360    {
361        match self.value {
362            Value::UserData(ud) if ud.is_serializable() => {
363                serde_userdata(ud, |value| value.deserialize_newtype_struct(name, visitor))
364            }
365            _ => visitor.visit_newtype_struct(self),
366        }
367    }
368
369    #[inline]
370    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
371    where
372        V: de::Visitor<'de>,
373    {
374        match self.value {
375            Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_unit(),
376            _ => self.deserialize_any(visitor),
377        }
378    }
379
380    #[inline]
381    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
382    where
383        V: de::Visitor<'de>,
384    {
385        match self.value {
386            Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_unit(),
387            _ => self.deserialize_any(visitor),
388        }
389    }
390
391    serde::forward_to_deserialize_any! {
392        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string bytes
393        byte_buf identifier ignored_any
394    }
395}
396
397struct SeqDeserializer<'lua> {
398    seq: TableSequence<'lua, Value<'lua>>,
399    options: Options,
400    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
401}
402
403impl<'lua, 'de> de::SeqAccess<'de> for SeqDeserializer<'lua> {
404    type Error = Error;
405
406    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
407    where
408        T: de::DeserializeSeed<'de>,
409    {
410        loop {
411            match self.seq.next() {
412                Some(value) => {
413                    let value = value?;
414                    let skip = check_value_for_skip(&value, self.options, &self.visited)
415                        .map_err(|err| Error::DeserializeError(err.to_string()))?;
416                    if skip {
417                        continue;
418                    }
419                    let visited = Rc::clone(&self.visited);
420                    let deserializer = Deserializer::from_parts(value, self.options, visited);
421                    return seed.deserialize(deserializer).map(Some);
422                }
423                None => return Ok(None),
424            }
425        }
426    }
427
428    fn size_hint(&self) -> Option<usize> {
429        match self.seq.size_hint() {
430            (lower, Some(upper)) if lower == upper => Some(upper),
431            _ => None,
432        }
433    }
434}
435
436#[cfg(feature = "luau")]
437struct VecDeserializer {
438    vec: crate::types::Vector,
439    next: usize,
440    options: Options,
441    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
442}
443
444#[cfg(feature = "luau")]
445impl<'de> de::SeqAccess<'de> for VecDeserializer {
446    type Error = Error;
447
448    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
449    where
450        T: de::DeserializeSeed<'de>,
451    {
452        match self.vec.0.get(self.next) {
453            Some(&n) => {
454                self.next += 1;
455                let visited = Rc::clone(&self.visited);
456                let deserializer =
457                    Deserializer::from_parts(Value::Number(n as _), self.options, visited);
458                seed.deserialize(deserializer).map(Some)
459            }
460            None => Ok(None),
461        }
462    }
463
464    fn size_hint(&self) -> Option<usize> {
465        Some(crate::types::Vector::SIZE)
466    }
467}
468
469pub(crate) enum MapPairs<'lua> {
470    Iter(TablePairs<'lua, Value<'lua>, Value<'lua>>),
471    Vec(Vec<(Value<'lua>, Value<'lua>)>),
472}
473
474impl<'lua> MapPairs<'lua> {
475    pub(crate) fn new(t: Table<'lua>, sort_keys: bool) -> Result<Self> {
476        if sort_keys {
477            let mut pairs = t.pairs::<Value, Value>().collect::<Result<Vec<_>>>()?;
478            pairs.sort_by(|(a, _), (b, _)| b.cmp(a)); // reverse order as we pop values from the end
479            Ok(MapPairs::Vec(pairs))
480        } else {
481            Ok(MapPairs::Iter(t.pairs::<Value, Value>()))
482        }
483    }
484
485    pub(crate) fn count(self) -> usize {
486        match self {
487            MapPairs::Iter(iter) => iter.count(),
488            MapPairs::Vec(vec) => vec.len(),
489        }
490    }
491
492    pub(crate) fn size_hint(&self) -> (usize, Option<usize>) {
493        match self {
494            MapPairs::Iter(iter) => iter.size_hint(),
495            MapPairs::Vec(vec) => (vec.len(), Some(vec.len())),
496        }
497    }
498}
499
500impl<'lua> Iterator for MapPairs<'lua> {
501    type Item = Result<(Value<'lua>, Value<'lua>)>;
502
503    fn next(&mut self) -> Option<Self::Item> {
504        match self {
505            MapPairs::Iter(iter) => iter.next(),
506            MapPairs::Vec(vec) => vec.pop().map(Ok),
507        }
508    }
509}
510
511struct MapDeserializer<'lua> {
512    pairs: MapPairs<'lua>,
513    value: Option<Value<'lua>>,
514    options: Options,
515    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
516    processed: usize,
517}
518
519impl<'lua, 'de> de::MapAccess<'de> for MapDeserializer<'lua> {
520    type Error = Error;
521
522    fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
523    where
524        T: de::DeserializeSeed<'de>,
525    {
526        loop {
527            match self.pairs.next() {
528                Some(item) => {
529                    let (key, value) = item?;
530                    let skip_key = check_value_for_skip(&key, self.options, &self.visited)
531                        .map_err(|err| Error::DeserializeError(err.to_string()))?;
532                    let skip_value = check_value_for_skip(&value, self.options, &self.visited)
533                        .map_err(|err| Error::DeserializeError(err.to_string()))?;
534                    if skip_key || skip_value {
535                        continue;
536                    }
537                    self.processed += 1;
538                    self.value = Some(value);
539                    let visited = Rc::clone(&self.visited);
540                    let key_de = Deserializer::from_parts(key, self.options, visited);
541                    return seed.deserialize(key_de).map(Some);
542                }
543                None => return Ok(None),
544            }
545        }
546    }
547
548    fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
549    where
550        T: de::DeserializeSeed<'de>,
551    {
552        match self.value.take() {
553            Some(value) => {
554                let visited = Rc::clone(&self.visited);
555                seed.deserialize(Deserializer::from_parts(value, self.options, visited))
556            }
557            None => Err(de::Error::custom("value is missing")),
558        }
559    }
560
561    fn size_hint(&self) -> Option<usize> {
562        match self.pairs.size_hint() {
563            (lower, Some(upper)) if lower == upper => Some(upper),
564            _ => None,
565        }
566    }
567}
568
569struct EnumDeserializer<'lua> {
570    variant: StdString,
571    value: Option<Value<'lua>>,
572    options: Options,
573    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
574}
575
576impl<'lua, 'de> de::EnumAccess<'de> for EnumDeserializer<'lua> {
577    type Error = Error;
578    type Variant = VariantDeserializer<'lua>;
579
580    fn variant_seed<T>(self, seed: T) -> Result<(T::Value, Self::Variant)>
581    where
582        T: de::DeserializeSeed<'de>,
583    {
584        let variant = self.variant.into_deserializer();
585        let variant_access = VariantDeserializer {
586            value: self.value,
587            options: self.options,
588            visited: self.visited,
589        };
590        seed.deserialize(variant).map(|v| (v, variant_access))
591    }
592}
593
594struct VariantDeserializer<'lua> {
595    value: Option<Value<'lua>>,
596    options: Options,
597    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
598}
599
600impl<'lua, 'de> de::VariantAccess<'de> for VariantDeserializer<'lua> {
601    type Error = Error;
602
603    fn unit_variant(self) -> Result<()> {
604        match self.value {
605            Some(_) => Err(de::Error::invalid_type(
606                de::Unexpected::NewtypeVariant,
607                &"unit variant",
608            )),
609            None => Ok(()),
610        }
611    }
612
613    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
614    where
615        T: de::DeserializeSeed<'de>,
616    {
617        match self.value {
618            Some(value) => {
619                seed.deserialize(Deserializer::from_parts(value, self.options, self.visited))
620            }
621            None => Err(de::Error::invalid_type(
622                de::Unexpected::UnitVariant,
623                &"newtype variant",
624            )),
625        }
626    }
627
628    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
629    where
630        V: de::Visitor<'de>,
631    {
632        match self.value {
633            Some(value) => serde::Deserializer::deserialize_seq(
634                Deserializer::from_parts(value, self.options, self.visited),
635                visitor,
636            ),
637            None => Err(de::Error::invalid_type(
638                de::Unexpected::UnitVariant,
639                &"tuple variant",
640            )),
641        }
642    }
643
644    fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
645    where
646        V: de::Visitor<'de>,
647    {
648        match self.value {
649            Some(value) => serde::Deserializer::deserialize_map(
650                Deserializer::from_parts(value, self.options, self.visited),
651                visitor,
652            ),
653            None => Err(de::Error::invalid_type(
654                de::Unexpected::UnitVariant,
655                &"struct variant",
656            )),
657        }
658    }
659}
660
661// Adds `ptr` to the `visited` map and removes on drop
662// Used to track recursive tables but allow to traverse same tables multiple times
663pub(crate) struct RecursionGuard {
664    ptr: *const c_void,
665    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
666}
667
668impl RecursionGuard {
669    #[inline]
670    pub(crate) fn new(table: &Table, visited: &Rc<RefCell<FxHashSet<*const c_void>>>) -> Self {
671        let visited = Rc::clone(visited);
672        let ptr = table.to_pointer();
673        visited.borrow_mut().insert(ptr);
674        RecursionGuard { ptr, visited }
675    }
676}
677
678impl Drop for RecursionGuard {
679    fn drop(&mut self) {
680        self.visited.borrow_mut().remove(&self.ptr);
681    }
682}
683
684// Checks `options` and decides should we emit an error or skip next element
685pub(crate) fn check_value_for_skip(
686    value: &Value,
687    options: Options,
688    visited: &RefCell<FxHashSet<*const c_void>>,
689) -> StdResult<bool, &'static str> {
690    match value {
691        Value::Table(table) => {
692            let ptr = table.to_pointer();
693            if visited.borrow().contains(&ptr) {
694                if options.deny_recursive_tables {
695                    return Err("recursive table detected");
696                }
697                return Ok(true); // skip
698            }
699        }
700        Value::UserData(ud) if ud.is_serializable() => {}
701        Value::Function(_)
702        | Value::Thread(_)
703        | Value::UserData(_)
704        | Value::LightUserData(_)
705        | Value::Error(_)
706            if !options.deny_unsupported_types =>
707        {
708            return Ok(true); // skip
709        }
710        _ => {}
711    }
712    Ok(false) // do not skip
713}
714
715fn serde_userdata<V>(
716    ud: AnyUserData,
717    f: impl FnOnce(serde_value::Value) -> std::result::Result<V, serde_value::DeserializerError>,
718) -> Result<V> {
719    let value = serde_value::to_value(ud).map_err(|err| Error::SerializeError(err.to_string()))?;
720    f(value).map_err(|err| Error::DeserializeError(err.to_string()))
721}