1use std::cell::RefCell;
2use std::os::raw::c_void;
3use std::rc::Rc;
4use std::result::Result as StdResult;
5use std::string::String as StdString;
6
7use rustc_hash::FxHashSet;
8use serde::de::{self, IntoDeserializer};
9
10use crate::error::{Error, Result};
11use crate::table::{Table, TablePairs, TableSequence};
12use crate::userdata::AnyUserData;
13use crate::value::Value;
14
15#[derive(Debug)]
17pub struct Deserializer<'lua> {
18 value: Value<'lua>,
19 options: Options,
20 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
21}
22
23#[derive(Debug, Clone, Copy)]
25#[non_exhaustive]
26pub struct Options {
27 pub deny_unsupported_types: bool,
38
39 pub deny_recursive_tables: bool,
45
46 pub sort_keys: bool,
50}
51
52impl Default for Options {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl Options {
59 pub const fn new() -> Self {
61 Options {
62 deny_unsupported_types: true,
63 deny_recursive_tables: true,
64 sort_keys: false,
65 }
66 }
67
68 #[must_use]
72 pub const fn deny_unsupported_types(mut self, enabled: bool) -> Self {
73 self.deny_unsupported_types = enabled;
74 self
75 }
76
77 #[must_use]
81 pub const fn deny_recursive_tables(mut self, enabled: bool) -> Self {
82 self.deny_recursive_tables = enabled;
83 self
84 }
85
86 #[must_use]
90 pub const fn sort_keys(mut self, enabled: bool) -> Self {
91 self.sort_keys = enabled;
92 self
93 }
94}
95
96impl<'lua> Deserializer<'lua> {
97 pub fn new(value: Value<'lua>) -> Self {
99 Self::new_with_options(value, Options::default())
100 }
101
102 pub fn new_with_options(value: Value<'lua>, options: Options) -> Self {
104 Deserializer {
105 value,
106 options,
107 visited: Rc::new(RefCell::new(FxHashSet::default())),
108 }
109 }
110
111 fn from_parts(
112 value: Value<'lua>,
113 options: Options,
114 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
115 ) -> Self {
116 Deserializer {
117 value,
118 options,
119 visited,
120 }
121 }
122}
123
124impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
125 type Error = Error;
126
127 #[inline]
128 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
129 where
130 V: de::Visitor<'de>,
131 {
132 match self.value {
133 Value::Nil => visitor.visit_unit(),
134 Value::Boolean(b) => visitor.visit_bool(b),
135 #[allow(clippy::useless_conversion)]
136 Value::Integer(i) => visitor.visit_i64(i.into()),
137 #[allow(clippy::useless_conversion)]
138 Value::Number(n) => visitor.visit_f64(n.into()),
139 #[cfg(feature = "luau")]
140 Value::Vector(_) => self.deserialize_seq(visitor),
141 Value::String(s) => match s.to_str() {
142 Ok(s) => visitor.visit_str(s),
143 Err(_) => visitor.visit_bytes(s.as_bytes()),
144 },
145 Value::Table(ref t) if t.raw_len() > 0 || t.is_array() => self.deserialize_seq(visitor),
146 Value::Table(_) => self.deserialize_map(visitor),
147 Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
148 Value::UserData(ud) if ud.is_serializable() => {
149 serde_userdata(ud, |value| value.deserialize_any(visitor))
150 }
151 #[cfg(feature = "luau")]
152 Value::UserData(ud) if ud.1 == crate::types::SubtypeId::Buffer => unsafe {
153 let mut size = 0usize;
154 let buf = ffi::lua_tobuffer(ud.0.lua.ref_thread(), ud.0.index, &mut size);
155 mlua_assert!(!buf.is_null(), "invalid Luau buffer");
156 let buf = std::slice::from_raw_parts(buf as *const u8, size);
157 visitor.visit_bytes(buf)
158 },
159 Value::Function(_)
160 | Value::Thread(_)
161 | Value::UserData(_)
162 | Value::LightUserData(_)
163 | Value::Error(_) => {
164 if self.options.deny_unsupported_types {
165 let msg = format!("unsupported value type `{}`", self.value.type_name());
166 Err(de::Error::custom(msg))
167 } else {
168 visitor.visit_unit()
169 }
170 }
171 }
172 }
173
174 #[inline]
175 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
176 where
177 V: de::Visitor<'de>,
178 {
179 match self.value {
180 Value::Nil => visitor.visit_none(),
181 Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
182 _ => visitor.visit_some(self),
183 }
184 }
185
186 #[inline]
187 fn deserialize_enum<V>(
188 self,
189 name: &'static str,
190 variants: &'static [&'static str],
191 visitor: V,
192 ) -> Result<V::Value>
193 where
194 V: de::Visitor<'de>,
195 {
196 let (variant, value, _guard) = match self.value {
197 Value::Table(table) => {
198 let _guard = RecursionGuard::new(&table, &self.visited);
199
200 let mut iter = table.pairs::<StdString, Value>();
201 let (variant, value) = match iter.next() {
202 Some(v) => v?,
203 None => {
204 return Err(de::Error::invalid_value(
205 de::Unexpected::Map,
206 &"map with a single key",
207 ))
208 }
209 };
210
211 if iter.next().is_some() {
212 return Err(de::Error::invalid_value(
213 de::Unexpected::Map,
214 &"map with a single key",
215 ));
216 }
217 let skip = check_value_for_skip(&value, self.options, &self.visited)
218 .map_err(|err| Error::DeserializeError(err.to_string()))?;
219 if skip {
220 return Err(de::Error::custom("bad enum value"));
221 }
222
223 (variant, Some(value), Some(_guard))
224 }
225 Value::String(variant) => (variant.to_str()?.to_owned(), None, None),
226 Value::UserData(ud) if ud.is_serializable() => {
227 return serde_userdata(ud, |value| value.deserialize_enum(name, variants, visitor));
228 }
229 _ => return Err(de::Error::custom("bad enum value")),
230 };
231
232 visitor.visit_enum(EnumDeserializer {
233 variant,
234 value,
235 options: self.options,
236 visited: self.visited,
237 })
238 }
239
240 #[inline]
241 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
242 where
243 V: de::Visitor<'de>,
244 {
245 match self.value {
246 #[cfg(feature = "luau")]
247 Value::Vector(vec) => {
248 let mut deserializer = VecDeserializer {
249 vec,
250 next: 0,
251 options: self.options,
252 visited: self.visited,
253 };
254 visitor.visit_seq(&mut deserializer)
255 }
256 Value::Table(t) => {
257 let _guard = RecursionGuard::new(&t, &self.visited);
258
259 let len = t.raw_len();
260 let mut deserializer = SeqDeserializer {
261 seq: t.sequence_values(),
262 options: self.options,
263 visited: self.visited,
264 };
265 let seq = visitor.visit_seq(&mut deserializer)?;
266 if deserializer.seq.count() == 0 {
267 Ok(seq)
268 } else {
269 Err(de::Error::invalid_length(
270 len,
271 &"fewer elements in the table",
272 ))
273 }
274 }
275 Value::UserData(ud) if ud.is_serializable() => {
276 serde_userdata(ud, |value| value.deserialize_seq(visitor))
277 }
278 value => Err(de::Error::invalid_type(
279 de::Unexpected::Other(value.type_name()),
280 &"table",
281 )),
282 }
283 }
284
285 #[inline]
286 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
287 where
288 V: de::Visitor<'de>,
289 {
290 self.deserialize_seq(visitor)
291 }
292
293 #[inline]
294 fn deserialize_tuple_struct<V>(
295 self,
296 _name: &'static str,
297 _len: usize,
298 visitor: V,
299 ) -> Result<V::Value>
300 where
301 V: de::Visitor<'de>,
302 {
303 self.deserialize_seq(visitor)
304 }
305
306 #[inline]
307 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
308 where
309 V: de::Visitor<'de>,
310 {
311 match self.value {
312 Value::Table(t) => {
313 let _guard = RecursionGuard::new(&t, &self.visited);
314
315 let mut deserializer = MapDeserializer {
316 pairs: MapPairs::new(t, self.options.sort_keys)?,
317 value: None,
318 options: self.options,
319 visited: self.visited,
320 processed: 0,
321 };
322 let map = visitor.visit_map(&mut deserializer)?;
323 let count = deserializer.pairs.count();
324 if count == 0 {
325 Ok(map)
326 } else {
327 Err(de::Error::invalid_length(
328 deserializer.processed + count,
329 &"fewer elements in the table",
330 ))
331 }
332 }
333 Value::UserData(ud) if ud.is_serializable() => {
334 serde_userdata(ud, |value| value.deserialize_map(visitor))
335 }
336 value => Err(de::Error::invalid_type(
337 de::Unexpected::Other(value.type_name()),
338 &"table",
339 )),
340 }
341 }
342
343 #[inline]
344 fn deserialize_struct<V>(
345 self,
346 _name: &'static str,
347 _fields: &'static [&'static str],
348 visitor: V,
349 ) -> Result<V::Value>
350 where
351 V: de::Visitor<'de>,
352 {
353 self.deserialize_map(visitor)
354 }
355
356 #[inline]
357 fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
358 where
359 V: de::Visitor<'de>,
360 {
361 match self.value {
362 Value::UserData(ud) if ud.is_serializable() => {
363 serde_userdata(ud, |value| value.deserialize_newtype_struct(name, visitor))
364 }
365 _ => visitor.visit_newtype_struct(self),
366 }
367 }
368
369 #[inline]
370 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
371 where
372 V: de::Visitor<'de>,
373 {
374 match self.value {
375 Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_unit(),
376 _ => self.deserialize_any(visitor),
377 }
378 }
379
380 #[inline]
381 fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
382 where
383 V: de::Visitor<'de>,
384 {
385 match self.value {
386 Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_unit(),
387 _ => self.deserialize_any(visitor),
388 }
389 }
390
391 serde::forward_to_deserialize_any! {
392 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string bytes
393 byte_buf identifier ignored_any
394 }
395}
396
397struct SeqDeserializer<'lua> {
398 seq: TableSequence<'lua, Value<'lua>>,
399 options: Options,
400 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
401}
402
403impl<'lua, 'de> de::SeqAccess<'de> for SeqDeserializer<'lua> {
404 type Error = Error;
405
406 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
407 where
408 T: de::DeserializeSeed<'de>,
409 {
410 loop {
411 match self.seq.next() {
412 Some(value) => {
413 let value = value?;
414 let skip = check_value_for_skip(&value, self.options, &self.visited)
415 .map_err(|err| Error::DeserializeError(err.to_string()))?;
416 if skip {
417 continue;
418 }
419 let visited = Rc::clone(&self.visited);
420 let deserializer = Deserializer::from_parts(value, self.options, visited);
421 return seed.deserialize(deserializer).map(Some);
422 }
423 None => return Ok(None),
424 }
425 }
426 }
427
428 fn size_hint(&self) -> Option<usize> {
429 match self.seq.size_hint() {
430 (lower, Some(upper)) if lower == upper => Some(upper),
431 _ => None,
432 }
433 }
434}
435
436#[cfg(feature = "luau")]
437struct VecDeserializer {
438 vec: crate::types::Vector,
439 next: usize,
440 options: Options,
441 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
442}
443
444#[cfg(feature = "luau")]
445impl<'de> de::SeqAccess<'de> for VecDeserializer {
446 type Error = Error;
447
448 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
449 where
450 T: de::DeserializeSeed<'de>,
451 {
452 match self.vec.0.get(self.next) {
453 Some(&n) => {
454 self.next += 1;
455 let visited = Rc::clone(&self.visited);
456 let deserializer =
457 Deserializer::from_parts(Value::Number(n as _), self.options, visited);
458 seed.deserialize(deserializer).map(Some)
459 }
460 None => Ok(None),
461 }
462 }
463
464 fn size_hint(&self) -> Option<usize> {
465 Some(crate::types::Vector::SIZE)
466 }
467}
468
469pub(crate) enum MapPairs<'lua> {
470 Iter(TablePairs<'lua, Value<'lua>, Value<'lua>>),
471 Vec(Vec<(Value<'lua>, Value<'lua>)>),
472}
473
474impl<'lua> MapPairs<'lua> {
475 pub(crate) fn new(t: Table<'lua>, sort_keys: bool) -> Result<Self> {
476 if sort_keys {
477 let mut pairs = t.pairs::<Value, Value>().collect::<Result<Vec<_>>>()?;
478 pairs.sort_by(|(a, _), (b, _)| b.cmp(a)); Ok(MapPairs::Vec(pairs))
480 } else {
481 Ok(MapPairs::Iter(t.pairs::<Value, Value>()))
482 }
483 }
484
485 pub(crate) fn count(self) -> usize {
486 match self {
487 MapPairs::Iter(iter) => iter.count(),
488 MapPairs::Vec(vec) => vec.len(),
489 }
490 }
491
492 pub(crate) fn size_hint(&self) -> (usize, Option<usize>) {
493 match self {
494 MapPairs::Iter(iter) => iter.size_hint(),
495 MapPairs::Vec(vec) => (vec.len(), Some(vec.len())),
496 }
497 }
498}
499
500impl<'lua> Iterator for MapPairs<'lua> {
501 type Item = Result<(Value<'lua>, Value<'lua>)>;
502
503 fn next(&mut self) -> Option<Self::Item> {
504 match self {
505 MapPairs::Iter(iter) => iter.next(),
506 MapPairs::Vec(vec) => vec.pop().map(Ok),
507 }
508 }
509}
510
511struct MapDeserializer<'lua> {
512 pairs: MapPairs<'lua>,
513 value: Option<Value<'lua>>,
514 options: Options,
515 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
516 processed: usize,
517}
518
519impl<'lua, 'de> de::MapAccess<'de> for MapDeserializer<'lua> {
520 type Error = Error;
521
522 fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
523 where
524 T: de::DeserializeSeed<'de>,
525 {
526 loop {
527 match self.pairs.next() {
528 Some(item) => {
529 let (key, value) = item?;
530 let skip_key = check_value_for_skip(&key, self.options, &self.visited)
531 .map_err(|err| Error::DeserializeError(err.to_string()))?;
532 let skip_value = check_value_for_skip(&value, self.options, &self.visited)
533 .map_err(|err| Error::DeserializeError(err.to_string()))?;
534 if skip_key || skip_value {
535 continue;
536 }
537 self.processed += 1;
538 self.value = Some(value);
539 let visited = Rc::clone(&self.visited);
540 let key_de = Deserializer::from_parts(key, self.options, visited);
541 return seed.deserialize(key_de).map(Some);
542 }
543 None => return Ok(None),
544 }
545 }
546 }
547
548 fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
549 where
550 T: de::DeserializeSeed<'de>,
551 {
552 match self.value.take() {
553 Some(value) => {
554 let visited = Rc::clone(&self.visited);
555 seed.deserialize(Deserializer::from_parts(value, self.options, visited))
556 }
557 None => Err(de::Error::custom("value is missing")),
558 }
559 }
560
561 fn size_hint(&self) -> Option<usize> {
562 match self.pairs.size_hint() {
563 (lower, Some(upper)) if lower == upper => Some(upper),
564 _ => None,
565 }
566 }
567}
568
569struct EnumDeserializer<'lua> {
570 variant: StdString,
571 value: Option<Value<'lua>>,
572 options: Options,
573 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
574}
575
576impl<'lua, 'de> de::EnumAccess<'de> for EnumDeserializer<'lua> {
577 type Error = Error;
578 type Variant = VariantDeserializer<'lua>;
579
580 fn variant_seed<T>(self, seed: T) -> Result<(T::Value, Self::Variant)>
581 where
582 T: de::DeserializeSeed<'de>,
583 {
584 let variant = self.variant.into_deserializer();
585 let variant_access = VariantDeserializer {
586 value: self.value,
587 options: self.options,
588 visited: self.visited,
589 };
590 seed.deserialize(variant).map(|v| (v, variant_access))
591 }
592}
593
594struct VariantDeserializer<'lua> {
595 value: Option<Value<'lua>>,
596 options: Options,
597 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
598}
599
600impl<'lua, 'de> de::VariantAccess<'de> for VariantDeserializer<'lua> {
601 type Error = Error;
602
603 fn unit_variant(self) -> Result<()> {
604 match self.value {
605 Some(_) => Err(de::Error::invalid_type(
606 de::Unexpected::NewtypeVariant,
607 &"unit variant",
608 )),
609 None => Ok(()),
610 }
611 }
612
613 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
614 where
615 T: de::DeserializeSeed<'de>,
616 {
617 match self.value {
618 Some(value) => {
619 seed.deserialize(Deserializer::from_parts(value, self.options, self.visited))
620 }
621 None => Err(de::Error::invalid_type(
622 de::Unexpected::UnitVariant,
623 &"newtype variant",
624 )),
625 }
626 }
627
628 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
629 where
630 V: de::Visitor<'de>,
631 {
632 match self.value {
633 Some(value) => serde::Deserializer::deserialize_seq(
634 Deserializer::from_parts(value, self.options, self.visited),
635 visitor,
636 ),
637 None => Err(de::Error::invalid_type(
638 de::Unexpected::UnitVariant,
639 &"tuple variant",
640 )),
641 }
642 }
643
644 fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
645 where
646 V: de::Visitor<'de>,
647 {
648 match self.value {
649 Some(value) => serde::Deserializer::deserialize_map(
650 Deserializer::from_parts(value, self.options, self.visited),
651 visitor,
652 ),
653 None => Err(de::Error::invalid_type(
654 de::Unexpected::UnitVariant,
655 &"struct variant",
656 )),
657 }
658 }
659}
660
661pub(crate) struct RecursionGuard {
664 ptr: *const c_void,
665 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
666}
667
668impl RecursionGuard {
669 #[inline]
670 pub(crate) fn new(table: &Table, visited: &Rc<RefCell<FxHashSet<*const c_void>>>) -> Self {
671 let visited = Rc::clone(visited);
672 let ptr = table.to_pointer();
673 visited.borrow_mut().insert(ptr);
674 RecursionGuard { ptr, visited }
675 }
676}
677
678impl Drop for RecursionGuard {
679 fn drop(&mut self) {
680 self.visited.borrow_mut().remove(&self.ptr);
681 }
682}
683
684pub(crate) fn check_value_for_skip(
686 value: &Value,
687 options: Options,
688 visited: &RefCell<FxHashSet<*const c_void>>,
689) -> StdResult<bool, &'static str> {
690 match value {
691 Value::Table(table) => {
692 let ptr = table.to_pointer();
693 if visited.borrow().contains(&ptr) {
694 if options.deny_recursive_tables {
695 return Err("recursive table detected");
696 }
697 return Ok(true); }
699 }
700 Value::UserData(ud) if ud.is_serializable() => {}
701 Value::Function(_)
702 | Value::Thread(_)
703 | Value::UserData(_)
704 | Value::LightUserData(_)
705 | Value::Error(_)
706 if !options.deny_unsupported_types =>
707 {
708 return Ok(true); }
710 _ => {}
711 }
712 Ok(false) }
714
715fn serde_userdata<V>(
716 ud: AnyUserData,
717 f: impl FnOnce(serde_value::Value) -> std::result::Result<V, serde_value::DeserializerError>,
718) -> Result<V> {
719 let value = serde_value::to_value(ud).map_err(|err| Error::SerializeError(err.to_string()))?;
720 f(value).map_err(|err| Error::DeserializeError(err.to_string()))
721}