ruby_prism/
lib.rs

1//! # ruby-prism
2//!
3//! Rustified version of Ruby's prism parser.
4//!
5#![warn(clippy::all, clippy::nursery, clippy::pedantic, future_incompatible, missing_docs, nonstandard_style, rust_2018_idioms, trivial_casts, trivial_numeric_casts, unreachable_pub, unused_qualifications)]
6
7// Most of the code in this file is generated, so sometimes it generates code
8// that doesn't follow the clippy rules. We don't want to see those warnings.
9#[allow(clippy::too_many_lines, clippy::use_self)]
10mod bindings {
11    // In `build.rs`, we generate bindings based on the config.yml file. Here is
12    // where we pull in those bindings and make them part of our library.
13    include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
14}
15
16use std::ffi::{c_char, CStr};
17use std::marker::PhantomData;
18use std::mem::MaybeUninit;
19use std::ptr::NonNull;
20
21pub use self::bindings::*;
22use ruby_prism_sys::{pm_comment_t, pm_constant_id_list_t, pm_constant_id_t, pm_diagnostic_t, pm_integer_t, pm_location_t, pm_magic_comment_t, pm_node_destroy, pm_node_list, pm_node_t, pm_parse, pm_parser_free, pm_parser_init, pm_parser_t};
23
24/// A range in the source file.
25pub struct Location<'pr> {
26    parser: NonNull<pm_parser_t>,
27    pub(crate) start: *const u8,
28    pub(crate) end: *const u8,
29    marker: PhantomData<&'pr [u8]>,
30}
31
32impl<'pr> Location<'pr> {
33    /// Returns a byte slice for the range.
34    /// # Panics
35    /// Panics if the end offset is not greater than the start offset.
36    #[must_use]
37    pub fn as_slice(&self) -> &'pr [u8] {
38        unsafe {
39            let len = usize::try_from(self.end.offset_from(self.start)).expect("end should point to memory after start");
40            std::slice::from_raw_parts(self.start, len)
41        }
42    }
43
44    /// Return a Location from the given `pm_location_t`.
45    #[must_use]
46    pub(crate) const fn new(parser: NonNull<pm_parser_t>, loc: &'pr pm_location_t) -> Self {
47        Location {
48            parser,
49            start: loc.start,
50            end: loc.end,
51            marker: PhantomData,
52        }
53    }
54
55    /// Return a Location starting at self and ending at the end of other.
56    /// Returns None if both locations did not originate from the same parser,
57    /// or if self starts after other.
58    #[must_use]
59    pub fn join(&self, other: &Self) -> Option<Self> {
60        if self.parser != other.parser || self.start > other.start {
61            None
62        } else {
63            Some(Location {
64                parser: self.parser,
65                start: self.start,
66                end: other.end,
67                marker: PhantomData,
68            })
69        }
70    }
71
72    /// Return the start offset from the beginning of the parsed source.
73    /// # Panics
74    /// Panics if the start offset is not greater than the parser's start.
75    #[must_use]
76    pub fn start_offset(&self) -> usize {
77        unsafe {
78            let parser_start = (*self.parser.as_ptr()).start;
79            usize::try_from(self.start.offset_from(parser_start)).expect("start should point to memory after the parser's start")
80        }
81    }
82
83    /// Return the end offset from the beginning of the parsed source.
84    /// # Panics
85    /// Panics if the end offset is not greater than the parser's start.
86    #[must_use]
87    pub fn end_offset(&self) -> usize {
88        unsafe {
89            let parser_start = (*self.parser.as_ptr()).start;
90            usize::try_from(self.end.offset_from(parser_start)).expect("end should point to memory after the parser's start")
91        }
92    }
93}
94
95impl std::fmt::Debug for Location<'_> {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        let slice: &[u8] = self.as_slice();
98
99        let mut visible = String::new();
100        visible.push('"');
101
102        for &byte in slice {
103            let part: Vec<u8> = std::ascii::escape_default(byte).collect();
104            visible.push_str(std::str::from_utf8(&part).unwrap());
105        }
106
107        visible.push('"');
108        write!(f, "{visible}")
109    }
110}
111
112/// An iterator over the nodes in a list.
113pub struct NodeListIter<'pr> {
114    parser: NonNull<pm_parser_t>,
115    pointer: NonNull<pm_node_list>,
116    index: usize,
117    marker: PhantomData<&'pr mut pm_node_list>,
118}
119
120impl<'pr> Iterator for NodeListIter<'pr> {
121    type Item = Node<'pr>;
122
123    fn next(&mut self) -> Option<Self::Item> {
124        if self.index >= unsafe { self.pointer.as_ref().size } {
125            None
126        } else {
127            let node: *mut pm_node_t = unsafe { *(self.pointer.as_ref().nodes.add(self.index)) };
128            self.index += 1;
129            Some(Node::new(self.parser, node))
130        }
131    }
132}
133
134/// A list of nodes.
135pub struct NodeList<'pr> {
136    parser: NonNull<pm_parser_t>,
137    pointer: NonNull<pm_node_list>,
138    marker: PhantomData<&'pr mut pm_node_list>,
139}
140
141impl<'pr> NodeList<'pr> {
142    /// Returns an iterator over the nodes.
143    #[must_use]
144    pub const fn iter(&self) -> NodeListIter<'pr> {
145        NodeListIter {
146            parser: self.parser,
147            pointer: self.pointer,
148            index: 0,
149            marker: PhantomData,
150        }
151    }
152}
153
154impl<'pr> IntoIterator for &NodeList<'pr> {
155    type Item = Node<'pr>;
156    type IntoIter = NodeListIter<'pr>;
157    fn into_iter(self) -> Self::IntoIter {
158        self.iter()
159    }
160}
161
162impl std::fmt::Debug for NodeList<'_> {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        write!(f, "{:?}", self.iter().collect::<Vec<_>>())
165    }
166}
167
168/// A handle for a constant ID.
169pub struct ConstantId<'pr> {
170    parser: NonNull<pm_parser_t>,
171    id: pm_constant_id_t,
172    marker: PhantomData<&'pr mut pm_constant_id_t>,
173}
174
175impl<'pr> ConstantId<'pr> {
176    const fn new(parser: NonNull<pm_parser_t>, id: pm_constant_id_t) -> Self {
177        ConstantId { parser, id, marker: PhantomData }
178    }
179
180    /// Returns a byte slice for the constant ID.
181    ///
182    /// # Panics
183    ///
184    /// Panics if the constant ID is not found in the constant pool.
185    #[must_use]
186    pub fn as_slice(&self) -> &'pr [u8] {
187        unsafe {
188            let pool = &(*self.parser.as_ptr()).constant_pool;
189            let constant = &(*pool.constants.add((self.id - 1).try_into().unwrap()));
190            std::slice::from_raw_parts(constant.start, constant.length)
191        }
192    }
193}
194
195impl std::fmt::Debug for ConstantId<'_> {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        write!(f, "{:?}", self.id)
198    }
199}
200
201/// An iterator over the constants in a list.
202pub struct ConstantListIter<'pr> {
203    parser: NonNull<pm_parser_t>,
204    pointer: NonNull<pm_constant_id_list_t>,
205    index: usize,
206    marker: PhantomData<&'pr mut pm_constant_id_list_t>,
207}
208
209impl<'pr> Iterator for ConstantListIter<'pr> {
210    type Item = ConstantId<'pr>;
211
212    fn next(&mut self) -> Option<Self::Item> {
213        if self.index >= unsafe { self.pointer.as_ref().size } {
214            None
215        } else {
216            let constant_id: pm_constant_id_t = unsafe { *(self.pointer.as_ref().ids.add(self.index)) };
217            self.index += 1;
218            Some(ConstantId::new(self.parser, constant_id))
219        }
220    }
221}
222
223/// A list of constants.
224pub struct ConstantList<'pr> {
225    /// The raw pointer to the parser where this list came from.
226    parser: NonNull<pm_parser_t>,
227
228    /// The raw pointer to the list allocated by prism.
229    pointer: NonNull<pm_constant_id_list_t>,
230
231    /// The marker to indicate the lifetime of the pointer.
232    marker: PhantomData<&'pr mut pm_constant_id_list_t>,
233}
234
235impl<'pr> ConstantList<'pr> {
236    /// Returns an iterator over the constants in the list.
237    #[must_use]
238    pub const fn iter(&self) -> ConstantListIter<'pr> {
239        ConstantListIter {
240            parser: self.parser,
241            pointer: self.pointer,
242            index: 0,
243            marker: PhantomData,
244        }
245    }
246}
247
248impl<'pr> IntoIterator for &ConstantList<'pr> {
249    type Item = ConstantId<'pr>;
250    type IntoIter = ConstantListIter<'pr>;
251    fn into_iter(self) -> Self::IntoIter {
252        self.iter()
253    }
254}
255
256impl std::fmt::Debug for ConstantList<'_> {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        write!(f, "{:?}", self.iter().collect::<Vec<_>>())
259    }
260}
261
262/// A handle for an arbitarily-sized integer.
263pub struct Integer<'pr> {
264    /// The raw pointer to the integer allocated by prism.
265    pointer: *const pm_integer_t,
266
267    /// The marker to indicate the lifetime of the pointer.
268    marker: PhantomData<&'pr mut pm_constant_id_t>,
269}
270
271impl Integer<'_> {
272    const fn new(pointer: *const pm_integer_t) -> Self {
273        Integer { pointer, marker: PhantomData }
274    }
275
276    /// Returns the sign and the u32 digits representation of the integer,
277    /// ordered least significant digit first.
278    #[must_use]
279    pub const fn to_u32_digits(&self) -> (bool, &[u32]) {
280        let negative = unsafe { (*self.pointer).negative };
281        let length = unsafe { (*self.pointer).length };
282        let values = unsafe { (*self.pointer).values };
283
284        if values.is_null() {
285            let value_ptr = unsafe { std::ptr::addr_of!((*self.pointer).value) };
286            let slice = unsafe { std::slice::from_raw_parts(value_ptr, 1) };
287            (negative, slice)
288        } else {
289            let slice = unsafe { std::slice::from_raw_parts(values, length) };
290            (negative, slice)
291        }
292    }
293}
294
295impl std::fmt::Debug for Integer<'_> {
296    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        write!(f, "{:?}", self.pointer)
298    }
299}
300
301impl TryInto<i32> for Integer<'_> {
302    type Error = ();
303
304    fn try_into(self) -> Result<i32, Self::Error> {
305        let negative = unsafe { (*self.pointer).negative };
306        let length = unsafe { (*self.pointer).length };
307
308        if length == 0 {
309            i32::try_from(unsafe { (*self.pointer).value }).map_or(Err(()), |value| if negative { Ok(-value) } else { Ok(value) })
310        } else {
311            Err(())
312        }
313    }
314}
315
316/// A diagnostic message that came back from the parser.
317#[derive(Debug)]
318pub struct Diagnostic<'pr> {
319    diag: NonNull<pm_diagnostic_t>,
320    parser: NonNull<pm_parser_t>,
321    marker: PhantomData<&'pr pm_diagnostic_t>,
322}
323
324impl<'pr> Diagnostic<'pr> {
325    /// Returns the message associated with the diagnostic.
326    ///
327    /// # Panics
328    ///
329    /// Panics if the message is not able to be converted into a `CStr`.
330    ///
331    #[must_use]
332    pub fn message(&self) -> &str {
333        unsafe {
334            let message: *mut c_char = self.diag.as_ref().message.cast_mut();
335            CStr::from_ptr(message).to_str().expect("prism allows only UTF-8 for diagnostics.")
336        }
337    }
338
339    /// The location of the diagnostic in the source.
340    #[must_use]
341    pub const fn location(&self) -> Location<'pr> {
342        Location::new(self.parser, unsafe { &self.diag.as_ref().location })
343    }
344}
345
346/// A comment that was found during parsing.
347#[derive(Debug)]
348pub struct Comment<'pr> {
349    content: NonNull<pm_comment_t>,
350    parser: NonNull<pm_parser_t>,
351    marker: PhantomData<&'pr pm_comment_t>,
352}
353
354impl<'pr> Comment<'pr> {
355    /// Returns the text of the comment.
356    ///
357    /// # Panics
358    /// Panics if the end offset is not greater than the start offset.
359    #[must_use]
360    pub fn text(&self) -> &[u8] {
361        self.location().as_slice()
362    }
363
364    /// The location of the comment in the source.
365    #[must_use]
366    pub const fn location(&self) -> Location<'pr> {
367        Location::new(self.parser, unsafe { &self.content.as_ref().location })
368    }
369}
370
371/// A magic comment that was found during parsing.
372#[derive(Debug)]
373pub struct MagicComment<'pr> {
374    comment: NonNull<pm_magic_comment_t>,
375    marker: PhantomData<&'pr pm_magic_comment_t>,
376}
377
378impl MagicComment<'_> {
379    /// Returns the text of the comment's key.
380    #[must_use]
381    pub const fn key(&self) -> &[u8] {
382        unsafe {
383            let start = self.comment.as_ref().key_start;
384            let len = self.comment.as_ref().key_length as usize;
385            std::slice::from_raw_parts(start, len)
386        }
387    }
388
389    /// Returns the text of the comment's value.
390    #[must_use]
391    pub const fn value(&self) -> &[u8] {
392        unsafe {
393            let start = self.comment.as_ref().value_start;
394            let len = self.comment.as_ref().value_length as usize;
395            std::slice::from_raw_parts(start, len)
396        }
397    }
398}
399
400/// A struct created by the `errors` or `warnings` methods on `ParseResult`. It
401/// can be used to iterate over the diagnostics in the parse result.
402pub struct Diagnostics<'pr> {
403    diagnostic: *mut pm_diagnostic_t,
404    parser: NonNull<pm_parser_t>,
405    marker: PhantomData<&'pr pm_diagnostic_t>,
406}
407
408impl<'pr> Iterator for Diagnostics<'pr> {
409    type Item = Diagnostic<'pr>;
410
411    fn next(&mut self) -> Option<Self::Item> {
412        if let Some(diagnostic) = NonNull::new(self.diagnostic) {
413            let current = Diagnostic {
414                diag: diagnostic,
415                parser: self.parser,
416                marker: PhantomData,
417            };
418            self.diagnostic = unsafe { diagnostic.as_ref().node.next.cast::<pm_diagnostic_t>() };
419            Some(current)
420        } else {
421            None
422        }
423    }
424}
425
426/// A struct created by the `comments` method on `ParseResult`. It can be used
427/// to iterate over the comments in the parse result.
428pub struct Comments<'pr> {
429    comment: *mut pm_comment_t,
430    parser: NonNull<pm_parser_t>,
431    marker: PhantomData<&'pr pm_comment_t>,
432}
433
434impl<'pr> Iterator for Comments<'pr> {
435    type Item = Comment<'pr>;
436
437    fn next(&mut self) -> Option<Self::Item> {
438        if let Some(comment) = NonNull::new(self.comment) {
439            let current = Comment {
440                content: comment,
441                parser: self.parser,
442                marker: PhantomData,
443            };
444            self.comment = unsafe { comment.as_ref().node.next.cast::<pm_comment_t>() };
445            Some(current)
446        } else {
447            None
448        }
449    }
450}
451
452/// A struct created by the `magic_comments` method on `ParseResult`. It can be used
453/// to iterate over the magic comments in the parse result.
454pub struct MagicComments<'pr> {
455    comment: *mut pm_magic_comment_t,
456    marker: PhantomData<&'pr pm_magic_comment_t>,
457}
458
459impl<'pr> Iterator for MagicComments<'pr> {
460    type Item = MagicComment<'pr>;
461
462    fn next(&mut self) -> Option<Self::Item> {
463        if let Some(comment) = NonNull::new(self.comment) {
464            let current = MagicComment { comment, marker: PhantomData };
465            self.comment = unsafe { comment.as_ref().node.next.cast::<pm_magic_comment_t>() };
466            Some(current)
467        } else {
468            None
469        }
470    }
471}
472
473/// The result of parsing a source string.
474#[derive(Debug)]
475pub struct ParseResult<'pr> {
476    source: &'pr [u8],
477    parser: NonNull<pm_parser_t>,
478    node: NonNull<pm_node_t>,
479}
480
481impl<'pr> ParseResult<'pr> {
482    /// Returns the source string that was parsed.
483    #[must_use]
484    pub const fn source(&self) -> &'pr [u8] {
485        self.source
486    }
487
488    /// Returns whether we found a `frozen_string_literal` magic comment with a true value.
489    #[must_use]
490    pub fn frozen_string_literals(&self) -> bool {
491        unsafe { (*self.parser.as_ptr()).frozen_string_literal == 1 }
492    }
493
494    /// Returns a slice of the source string that was parsed using the given
495    /// location range.
496    ///
497    /// # Panics
498    /// Panics if start offset or end offset are not valid offsets from the root.
499    #[must_use]
500    pub fn as_slice(&self, location: &Location<'pr>) -> &'pr [u8] {
501        let root = self.source.as_ptr();
502
503        let start = usize::try_from(unsafe { location.start.offset_from(root) }).expect("start should point to memory after root");
504        let end = usize::try_from(unsafe { location.end.offset_from(root) }).expect("end should point to memory after root");
505
506        &self.source[start..end]
507    }
508
509    /// Returns an iterator that can be used to iterate over the errors in the
510    /// parse result.
511    #[must_use]
512    pub fn errors(&self) -> Diagnostics<'_> {
513        unsafe {
514            let list = &mut (*self.parser.as_ptr()).error_list;
515            Diagnostics {
516                diagnostic: list.head.cast::<pm_diagnostic_t>(),
517                parser: self.parser,
518                marker: PhantomData,
519            }
520        }
521    }
522
523    /// Returns an iterator that can be used to iterate over the warnings in the
524    /// parse result.
525    #[must_use]
526    pub fn warnings(&self) -> Diagnostics<'_> {
527        unsafe {
528            let list = &mut (*self.parser.as_ptr()).warning_list;
529            Diagnostics {
530                diagnostic: list.head.cast::<pm_diagnostic_t>(),
531                parser: self.parser,
532                marker: PhantomData,
533            }
534        }
535    }
536
537    /// Returns an iterator that can be used to iterate over the comments in the
538    /// parse result.
539    #[must_use]
540    pub fn comments(&self) -> Comments<'_> {
541        unsafe {
542            let list = &mut (*self.parser.as_ptr()).comment_list;
543            Comments {
544                comment: list.head.cast::<pm_comment_t>(),
545                parser: self.parser,
546                marker: PhantomData,
547            }
548        }
549    }
550
551    /// Returns an iterator that can be used to iterate over the magic comments in the
552    /// parse result.
553    #[must_use]
554    pub fn magic_comments(&self) -> MagicComments<'_> {
555        unsafe {
556            let list = &mut (*self.parser.as_ptr()).magic_comment_list;
557            MagicComments {
558                comment: list.head.cast::<pm_magic_comment_t>(),
559                marker: PhantomData,
560            }
561        }
562    }
563
564    /// Returns an optional location of the __END__ marker and the rest of the content of the file.
565    #[must_use]
566    pub fn data_loc(&self) -> Option<Location<'_>> {
567        let location = unsafe { &(*self.parser.as_ptr()).data_loc };
568        if location.start.is_null() {
569            None
570        } else {
571            Some(Location::new(self.parser, location))
572        }
573    }
574
575    /// Returns the root node of the parse result.
576    #[must_use]
577    pub fn node(&self) -> Node<'_> {
578        Node::new(self.parser, self.node.as_ptr())
579    }
580}
581
582impl Drop for ParseResult<'_> {
583    fn drop(&mut self) {
584        unsafe {
585            pm_node_destroy(self.parser.as_ptr(), self.node.as_ptr());
586            pm_parser_free(self.parser.as_ptr());
587            drop(Box::from_raw(self.parser.as_ptr()));
588        }
589    }
590}
591
592/// Parses the given source string and returns a parse result.
593///
594/// # Panics
595///
596/// Panics if the parser fails to initialize.
597///
598#[must_use]
599pub fn parse(source: &[u8]) -> ParseResult<'_> {
600    unsafe {
601        let uninit = Box::new(MaybeUninit::<pm_parser_t>::uninit());
602        let uninit = Box::into_raw(uninit);
603
604        pm_parser_init((*uninit).as_mut_ptr(), source.as_ptr(), source.len(), std::ptr::null());
605
606        let parser = (*uninit).assume_init_mut();
607        let parser = NonNull::new_unchecked(parser);
608
609        let node = pm_parse(parser.as_ptr());
610        let node = NonNull::new_unchecked(node);
611
612        ParseResult { source, parser, node }
613    }
614}
615
616#[cfg(test)]
617mod tests {
618    use super::parse;
619
620    #[test]
621    fn comments_test() {
622        let source = "# comment 1\n# comment 2\n# comment 3\n";
623        let result = parse(source.as_ref());
624
625        for comment in result.comments() {
626            let text = std::str::from_utf8(comment.text()).unwrap();
627            assert!(text.starts_with("# comment"));
628        }
629    }
630
631    #[test]
632    fn magic_comments_test() {
633        use crate::MagicComment;
634
635        let source = "# typed: ignore\n# typed:true\n#typed: strict\n";
636        let result = parse(source.as_ref());
637
638        let comments: Vec<MagicComment<'_>> = result.magic_comments().collect();
639        assert_eq!(3, comments.len());
640
641        assert_eq!(b"typed", comments[0].key());
642        assert_eq!(b"ignore", comments[0].value());
643
644        assert_eq!(b"typed", comments[1].key());
645        assert_eq!(b"true", comments[1].value());
646
647        assert_eq!(b"typed", comments[2].key());
648        assert_eq!(b"strict", comments[2].value());
649    }
650
651    #[test]
652    fn data_loc_test() {
653        let source = "1";
654        let result = parse(source.as_ref());
655        let data_loc = result.data_loc();
656        assert!(data_loc.is_none());
657
658        let source = "__END__\nabc\n";
659        let result = parse(source.as_ref());
660        let data_loc = result.data_loc().unwrap();
661        let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
662        assert_eq!(slice, "__END__\nabc\n");
663
664        let source = "1\n2\n3\n__END__\nabc\ndef\n";
665        let result = parse(source.as_ref());
666        let data_loc = result.data_loc().unwrap();
667        let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
668        assert_eq!(slice, "__END__\nabc\ndef\n");
669    }
670
671    #[test]
672    fn location_test() {
673        let source = "111 + 222 + 333";
674        let result = parse(source.as_ref());
675
676        let node = result.node();
677        let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
678        let node = node.as_call_node().unwrap().receiver().unwrap();
679        let plus = node.as_call_node().unwrap();
680        let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
681
682        let location = node.as_integer_node().unwrap().location();
683        let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
684
685        assert_eq!(slice, "222");
686        assert_eq!(6, location.start_offset());
687        assert_eq!(9, location.end_offset());
688
689        let recv_loc = plus.receiver().unwrap().location();
690        assert_eq!(recv_loc.as_slice(), b"111");
691        assert_eq!(0, recv_loc.start_offset());
692        assert_eq!(3, recv_loc.end_offset());
693
694        let joined = recv_loc.join(&location).unwrap();
695        assert_eq!(joined.as_slice(), b"111 + 222");
696
697        let not_joined = location.join(&recv_loc);
698        assert!(not_joined.is_none());
699
700        {
701            let result = parse(source.as_ref());
702            let node = result.node();
703            let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
704            let node = node.as_call_node().unwrap().receiver().unwrap();
705            let plus = node.as_call_node().unwrap();
706            let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
707
708            let location = node.as_integer_node().unwrap().location();
709            let not_joined = recv_loc.join(&location);
710            assert!(not_joined.is_none());
711
712            let not_joined = location.join(&recv_loc);
713            assert!(not_joined.is_none());
714        }
715
716        let location = node.location();
717        let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
718
719        assert_eq!(slice, "222");
720
721        let slice = std::str::from_utf8(location.as_slice()).unwrap();
722
723        assert_eq!(slice, "222");
724    }
725
726    #[test]
727    fn visitor_test() {
728        use super::{visit_interpolated_regular_expression_node, visit_regular_expression_node, InterpolatedRegularExpressionNode, RegularExpressionNode, Visit};
729
730        struct RegularExpressionVisitor {
731            count: usize,
732        }
733
734        impl Visit<'_> for RegularExpressionVisitor {
735            fn visit_interpolated_regular_expression_node(&mut self, node: &InterpolatedRegularExpressionNode<'_>) {
736                self.count += 1;
737                visit_interpolated_regular_expression_node(self, node);
738            }
739
740            fn visit_regular_expression_node(&mut self, node: &RegularExpressionNode<'_>) {
741                self.count += 1;
742                visit_regular_expression_node(self, node);
743            }
744        }
745
746        let source = "# comment 1\n# comment 2\nmodule Foo; class Bar; /abc #{/def/}/; end; end";
747        let result = parse(source.as_ref());
748
749        let mut visitor = RegularExpressionVisitor { count: 0 };
750        visitor.visit(&result.node());
751
752        assert_eq!(visitor.count, 2);
753    }
754
755    #[test]
756    fn node_upcast_test() {
757        use super::Node;
758
759        let source = "module Foo; end";
760        let result = parse(source.as_ref());
761
762        let node = result.node();
763        let upcast_node = node.as_program_node().unwrap().as_node();
764        assert!(matches!(upcast_node, Node::ProgramNode { .. }));
765
766        let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
767        let upcast_node = node.as_module_node().unwrap().as_node();
768        assert!(matches!(upcast_node, Node::ModuleNode { .. }));
769    }
770
771    #[test]
772    fn constant_id_test() {
773        let source = "module Foo; x = 1; end";
774        let result = parse(source.as_ref());
775
776        let node = result.node();
777        let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
778        let module = module.as_module_node().unwrap();
779        let locals = module.locals().iter().collect::<Vec<_>>();
780
781        assert_eq!(locals.len(), 1);
782
783        assert_eq!(locals[0].as_slice(), b"x");
784    }
785
786    #[test]
787    fn optional_loc_test() {
788        let source = r"
789module Example
790  x = call_func(3, 4)
791  y = x.call_func 5, 6
792end
793";
794        let result = parse(source.as_ref());
795
796        let node = result.node();
797        let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
798        let module = module.as_module_node().unwrap();
799        let body = module.body();
800        let writes = body.iter().next().unwrap().as_statements_node().unwrap().body().iter().collect::<Vec<_>>();
801        assert_eq!(writes.len(), 2);
802
803        let asgn = &writes[0];
804        let call = asgn.as_local_variable_write_node().unwrap().value();
805        let call = call.as_call_node().unwrap();
806
807        let call_operator_loc = call.call_operator_loc();
808        assert!(call_operator_loc.is_none());
809        let closing_loc = call.closing_loc();
810        assert!(closing_loc.is_some());
811
812        let asgn = &writes[1];
813        let call = asgn.as_local_variable_write_node().unwrap().value();
814        let call = call.as_call_node().unwrap();
815
816        let call_operator_loc = call.call_operator_loc();
817        assert!(call_operator_loc.is_some());
818        let closing_loc = call.closing_loc();
819        assert!(closing_loc.is_none());
820    }
821
822    #[test]
823    fn frozen_strings_test() {
824        let source = r#"
825# frozen_string_literal: true
826"foo"
827"#;
828        let result = parse(source.as_ref());
829        assert!(result.frozen_string_literals());
830
831        let source = "3";
832        let result = parse(source.as_ref());
833        assert!(!result.frozen_string_literals());
834    }
835
836    #[test]
837    fn string_flags_test() {
838        let source = r#"
839# frozen_string_literal: true
840"foo"
841"#;
842        let result = parse(source.as_ref());
843
844        let node = result.node();
845        let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
846        let string = string.as_string_node().unwrap();
847        assert!(string.is_frozen());
848
849        let source = r#"
850"foo"
851"#;
852        let result = parse(source.as_ref());
853
854        let node = result.node();
855        let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
856        let string = string.as_string_node().unwrap();
857        assert!(!string.is_frozen());
858    }
859
860    #[test]
861    fn call_flags_test() {
862        let source = r"
863x
864";
865        let result = parse(source.as_ref());
866
867        let node = result.node();
868        let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
869        let call = call.as_call_node().unwrap();
870        assert!(call.is_variable_call());
871
872        let source = r"
873x&.foo
874";
875        let result = parse(source.as_ref());
876
877        let node = result.node();
878        let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
879        let call = call.as_call_node().unwrap();
880        assert!(call.is_safe_navigation());
881    }
882
883    #[test]
884    fn integer_flags_test() {
885        let source = r"
8860b1
887";
888        let result = parse(source.as_ref());
889
890        let node = result.node();
891        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
892        let i = i.as_integer_node().unwrap();
893        assert!(i.is_binary());
894        assert!(!i.is_decimal());
895        assert!(!i.is_octal());
896        assert!(!i.is_hexadecimal());
897
898        let source = r"
8991
900";
901        let result = parse(source.as_ref());
902
903        let node = result.node();
904        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
905        let i = i.as_integer_node().unwrap();
906        assert!(!i.is_binary());
907        assert!(i.is_decimal());
908        assert!(!i.is_octal());
909        assert!(!i.is_hexadecimal());
910
911        let source = r"
9120o1
913";
914        let result = parse(source.as_ref());
915
916        let node = result.node();
917        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
918        let i = i.as_integer_node().unwrap();
919        assert!(!i.is_binary());
920        assert!(!i.is_decimal());
921        assert!(i.is_octal());
922        assert!(!i.is_hexadecimal());
923
924        let source = r"
9250x1
926";
927        let result = parse(source.as_ref());
928
929        let node = result.node();
930        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
931        let i = i.as_integer_node().unwrap();
932        assert!(!i.is_binary());
933        assert!(!i.is_decimal());
934        assert!(!i.is_octal());
935        assert!(i.is_hexadecimal());
936    }
937
938    #[test]
939    fn range_flags_test() {
940        let source = r"
9410..1
942";
943        let result = parse(source.as_ref());
944
945        let node = result.node();
946        let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
947        let range = range.as_range_node().unwrap();
948        assert!(!range.is_exclude_end());
949
950        let source = r"
9510...1
952";
953        let result = parse(source.as_ref());
954
955        let node = result.node();
956        let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
957        let range = range.as_range_node().unwrap();
958        assert!(range.is_exclude_end());
959    }
960
961    #[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
962    #[test]
963    fn regex_flags_test() {
964        let source = r"
965/a/i
966";
967        let result = parse(source.as_ref());
968
969        let node = result.node();
970        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
971        let regex = regex.as_regular_expression_node().unwrap();
972        assert!(regex.is_ignore_case());
973        assert!(!regex.is_extended());
974        assert!(!regex.is_multi_line());
975        assert!(!regex.is_euc_jp());
976        assert!(!regex.is_ascii_8bit());
977        assert!(!regex.is_windows_31j());
978        assert!(!regex.is_utf_8());
979        assert!(!regex.is_once());
980
981        let source = r"
982/a/x
983";
984        let result = parse(source.as_ref());
985
986        let node = result.node();
987        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
988        let regex = regex.as_regular_expression_node().unwrap();
989        assert!(!regex.is_ignore_case());
990        assert!(regex.is_extended());
991        assert!(!regex.is_multi_line());
992        assert!(!regex.is_euc_jp());
993        assert!(!regex.is_ascii_8bit());
994        assert!(!regex.is_windows_31j());
995        assert!(!regex.is_utf_8());
996        assert!(!regex.is_once());
997
998        let source = r"
999/a/m
1000";
1001        let result = parse(source.as_ref());
1002
1003        let node = result.node();
1004        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1005        let regex = regex.as_regular_expression_node().unwrap();
1006        assert!(!regex.is_ignore_case());
1007        assert!(!regex.is_extended());
1008        assert!(regex.is_multi_line());
1009        assert!(!regex.is_euc_jp());
1010        assert!(!regex.is_ascii_8bit());
1011        assert!(!regex.is_windows_31j());
1012        assert!(!regex.is_utf_8());
1013        assert!(!regex.is_once());
1014
1015        let source = r"
1016/a/e
1017";
1018        let result = parse(source.as_ref());
1019
1020        let node = result.node();
1021        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1022        let regex = regex.as_regular_expression_node().unwrap();
1023        assert!(!regex.is_ignore_case());
1024        assert!(!regex.is_extended());
1025        assert!(!regex.is_multi_line());
1026        assert!(regex.is_euc_jp());
1027        assert!(!regex.is_ascii_8bit());
1028        assert!(!regex.is_windows_31j());
1029        assert!(!regex.is_utf_8());
1030        assert!(!regex.is_once());
1031
1032        let source = r"
1033/a/n
1034";
1035        let result = parse(source.as_ref());
1036
1037        let node = result.node();
1038        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1039        let regex = regex.as_regular_expression_node().unwrap();
1040        assert!(!regex.is_ignore_case());
1041        assert!(!regex.is_extended());
1042        assert!(!regex.is_multi_line());
1043        assert!(!regex.is_euc_jp());
1044        assert!(regex.is_ascii_8bit());
1045        assert!(!regex.is_windows_31j());
1046        assert!(!regex.is_utf_8());
1047        assert!(!regex.is_once());
1048
1049        let source = r"
1050/a/s
1051";
1052        let result = parse(source.as_ref());
1053
1054        let node = result.node();
1055        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1056        let regex = regex.as_regular_expression_node().unwrap();
1057        assert!(!regex.is_ignore_case());
1058        assert!(!regex.is_extended());
1059        assert!(!regex.is_multi_line());
1060        assert!(!regex.is_euc_jp());
1061        assert!(!regex.is_ascii_8bit());
1062        assert!(regex.is_windows_31j());
1063        assert!(!regex.is_utf_8());
1064        assert!(!regex.is_once());
1065
1066        let source = r"
1067/a/u
1068";
1069        let result = parse(source.as_ref());
1070
1071        let node = result.node();
1072        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1073        let regex = regex.as_regular_expression_node().unwrap();
1074        assert!(!regex.is_ignore_case());
1075        assert!(!regex.is_extended());
1076        assert!(!regex.is_multi_line());
1077        assert!(!regex.is_euc_jp());
1078        assert!(!regex.is_ascii_8bit());
1079        assert!(!regex.is_windows_31j());
1080        assert!(regex.is_utf_8());
1081        assert!(!regex.is_once());
1082
1083        let source = r"
1084/a/o
1085";
1086        let result = parse(source.as_ref());
1087
1088        let node = result.node();
1089        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1090        let regex = regex.as_regular_expression_node().unwrap();
1091        assert!(!regex.is_ignore_case());
1092        assert!(!regex.is_extended());
1093        assert!(!regex.is_multi_line());
1094        assert!(!regex.is_euc_jp());
1095        assert!(!regex.is_ascii_8bit());
1096        assert!(!regex.is_windows_31j());
1097        assert!(!regex.is_utf_8());
1098        assert!(regex.is_once());
1099    }
1100
1101    #[test]
1102    fn visitor_traversal_test() {
1103        use crate::{Node, Visit};
1104
1105        #[derive(Default)]
1106        struct NodeCounts {
1107            pre_parent: usize,
1108            post_parent: usize,
1109            pre_leaf: usize,
1110            post_leaf: usize,
1111        }
1112
1113        #[derive(Default)]
1114        struct CountingVisitor {
1115            counts: NodeCounts,
1116        }
1117
1118        impl Visit<'_> for CountingVisitor {
1119            fn visit_branch_node_enter(&mut self, _node: Node<'_>) {
1120                self.counts.pre_parent += 1;
1121            }
1122
1123            fn visit_branch_node_leave(&mut self) {
1124                self.counts.post_parent += 1;
1125            }
1126
1127            fn visit_leaf_node_enter(&mut self, _node: Node<'_>) {
1128                self.counts.pre_leaf += 1;
1129            }
1130
1131            fn visit_leaf_node_leave(&mut self) {
1132                self.counts.post_leaf += 1;
1133            }
1134        }
1135
1136        let source = r"
1137module Example
1138  x = call_func(3, 4)
1139  y = x.call_func 5, 6
1140end
1141";
1142        let result = parse(source.as_ref());
1143        let node = result.node();
1144        let mut visitor = CountingVisitor::default();
1145        visitor.visit(&node);
1146
1147        assert_eq!(7, visitor.counts.pre_parent);
1148        assert_eq!(7, visitor.counts.post_parent);
1149        assert_eq!(6, visitor.counts.pre_leaf);
1150        assert_eq!(6, visitor.counts.post_leaf);
1151    }
1152
1153    #[test]
1154    fn visitor_lifetime_test() {
1155        use crate::{Node, Visit};
1156
1157        #[derive(Default)]
1158        struct StackingNodeVisitor<'a> {
1159            stack: Vec<Node<'a>>,
1160            max_depth: usize,
1161        }
1162
1163        impl<'pr> Visit<'pr> for StackingNodeVisitor<'pr> {
1164            fn visit_branch_node_enter(&mut self, node: Node<'pr>) {
1165                self.stack.push(node);
1166            }
1167
1168            fn visit_branch_node_leave(&mut self) {
1169                self.stack.pop();
1170            }
1171
1172            fn visit_leaf_node_leave(&mut self) {
1173                self.max_depth = self.max_depth.max(self.stack.len());
1174            }
1175        }
1176
1177        let source = r"
1178module Example
1179  x = call_func(3, 4)
1180  y = x.call_func 5, 6
1181end
1182";
1183        let result = parse(source.as_ref());
1184        let node = result.node();
1185        let mut visitor = StackingNodeVisitor::default();
1186        visitor.visit(&node);
1187
1188        assert_eq!(0, visitor.stack.len());
1189        assert_eq!(5, visitor.max_depth);
1190    }
1191
1192    #[test]
1193    fn integer_value_test() {
1194        let result = parse("0xA".as_ref());
1195        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1196        let value: i32 = integer.try_into().unwrap();
1197
1198        assert_eq!(value, 10);
1199    }
1200
1201    #[test]
1202    fn integer_small_value_to_u32_digits_test() {
1203        let result = parse("0xA".as_ref());
1204        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1205        let (negative, digits) = integer.to_u32_digits();
1206
1207        assert!(!negative);
1208        assert_eq!(digits, &[10]);
1209    }
1210
1211    #[test]
1212    fn integer_large_value_to_u32_digits_test() {
1213        let result = parse("0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".as_ref());
1214        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1215        let (negative, digits) = integer.to_u32_digits();
1216
1217        assert!(!negative);
1218        assert_eq!(digits, &[4_294_967_295, 4_294_967_295, 4_294_967_295, 2_147_483_647]);
1219    }
1220
1221    #[test]
1222    fn float_value_test() {
1223        let result = parse("1.0".as_ref());
1224        let value: f64 = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_float_node().unwrap().value();
1225
1226        assert!((value - 1.0).abs() < f64::EPSILON);
1227    }
1228
1229    #[test]
1230    fn regex_value_test() {
1231        let result = parse(b"//");
1232        let node = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_regular_expression_node().unwrap();
1233        assert_eq!(node.unescaped(), b"");
1234    }
1235
1236    #[test]
1237    fn node_field_lifetime_test() {
1238        // The code below wouldn't typecheck prior to https://github.com/ruby/prism/pull/2519,
1239        // but we need to stop clippy from complaining about it.
1240        #![allow(clippy::needless_pass_by_value)]
1241
1242        use crate::Node;
1243
1244        #[derive(Default)]
1245        struct Extract<'pr> {
1246            scopes: Vec<crate::ConstantId<'pr>>,
1247        }
1248
1249        impl<'pr> Extract<'pr> {
1250            fn push_scope(&mut self, path: Node<'pr>) {
1251                if let Some(cread) = path.as_constant_read_node() {
1252                    self.scopes.push(cread.name());
1253                } else if let Some(cpath) = path.as_constant_path_node() {
1254                    if let Some(parent) = cpath.parent() {
1255                        self.push_scope(parent);
1256                    }
1257                    self.scopes.push(cpath.name().unwrap());
1258                } else {
1259                    panic!("Wrong node kind!");
1260                }
1261            }
1262        }
1263
1264        let source = "Some::Random::Constant";
1265        let result = parse(source.as_ref());
1266        let node = result.node();
1267        let mut extractor = Extract::default();
1268        extractor.push_scope(node.as_program_node().unwrap().statements().body().iter().next().unwrap());
1269        assert_eq!(3, extractor.scopes.len());
1270    }
1271
1272    #[test]
1273    fn malformed_shebang() {
1274        let source = "#!\x00";
1275        let result = parse(source.as_ref());
1276        assert!(result.errors().next().is_none());
1277        assert!(result.warnings().next().is_none());
1278    }
1279}