ruby_prism/
lib.rs

1//! # ruby-prism
2//!
3//! Rustified version of Ruby's prism parser.
4//!
5#![warn(clippy::all, clippy::nursery, clippy::pedantic, future_incompatible, missing_docs, nonstandard_style, rust_2018_idioms, trivial_casts, trivial_numeric_casts, unreachable_pub, unused_qualifications)]
6
7// Most of the code in this file is generated, so sometimes it generates code
8// that doesn't follow the clippy rules. We don't want to see those warnings.
9#[allow(clippy::too_many_lines, clippy::use_self)]
10mod bindings {
11    // In `build.rs`, we generate bindings based on the config.yml file. Here is
12    // where we pull in those bindings and make them part of our library.
13    include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
14}
15
16use std::ffi::{c_char, CStr};
17use std::marker::PhantomData;
18use std::mem::MaybeUninit;
19use std::ptr::NonNull;
20
21pub use self::bindings::*;
22use ruby_prism_sys::{pm_comment_t, pm_comment_type_t, pm_constant_id_list_t, pm_constant_id_t, pm_diagnostic_t, pm_integer_t, pm_location_t, pm_magic_comment_t, pm_node_destroy, pm_node_list, pm_node_t, pm_parse, pm_parser_free, pm_parser_init, pm_parser_t};
23
24/// A range in the source file.
25pub struct Location<'pr> {
26    parser: NonNull<pm_parser_t>,
27    pub(crate) start: *const u8,
28    pub(crate) end: *const u8,
29    marker: PhantomData<&'pr [u8]>,
30}
31
32impl<'pr> Location<'pr> {
33    /// Returns a byte slice for the range.
34    /// # Panics
35    /// Panics if the end offset is not greater than the start offset.
36    #[must_use]
37    pub fn as_slice(&self) -> &'pr [u8] {
38        unsafe {
39            let len = usize::try_from(self.end.offset_from(self.start)).expect("end should point to memory after start");
40            std::slice::from_raw_parts(self.start, len)
41        }
42    }
43
44    /// Return a Location from the given `pm_location_t`.
45    #[must_use]
46    pub(crate) const fn new(parser: NonNull<pm_parser_t>, loc: &'pr pm_location_t) -> Self {
47        Location {
48            parser,
49            start: loc.start,
50            end: loc.end,
51            marker: PhantomData,
52        }
53    }
54
55    /// Return a Location starting at self and ending at the end of other.
56    /// Returns None if both locations did not originate from the same parser,
57    /// or if self starts after other.
58    #[must_use]
59    pub fn join(&self, other: &Self) -> Option<Self> {
60        if self.parser != other.parser || self.start > other.start {
61            None
62        } else {
63            Some(Location {
64                parser: self.parser,
65                start: self.start,
66                end: other.end,
67                marker: PhantomData,
68            })
69        }
70    }
71
72    /// Return the start offset from the beginning of the parsed source.
73    /// # Panics
74    /// Panics if the start offset is not greater than the parser's start.
75    #[must_use]
76    pub fn start_offset(&self) -> usize {
77        unsafe {
78            let parser_start = (*self.parser.as_ptr()).start;
79            usize::try_from(self.start.offset_from(parser_start)).expect("start should point to memory after the parser's start")
80        }
81    }
82
83    /// Return the end offset from the beginning of the parsed source.
84    /// # Panics
85    /// Panics if the end offset is not greater than the parser's start.
86    #[must_use]
87    pub fn end_offset(&self) -> usize {
88        unsafe {
89            let parser_start = (*self.parser.as_ptr()).start;
90            usize::try_from(self.end.offset_from(parser_start)).expect("end should point to memory after the parser's start")
91        }
92    }
93}
94
95impl std::fmt::Debug for Location<'_> {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        let slice: &[u8] = self.as_slice();
98
99        let mut visible = String::new();
100        visible.push('"');
101
102        for &byte in slice {
103            let part: Vec<u8> = std::ascii::escape_default(byte).collect();
104            visible.push_str(std::str::from_utf8(&part).unwrap());
105        }
106
107        visible.push('"');
108        write!(f, "{visible}")
109    }
110}
111
112/// An iterator over the nodes in a list.
113pub struct NodeListIter<'pr> {
114    parser: NonNull<pm_parser_t>,
115    pointer: NonNull<pm_node_list>,
116    index: usize,
117    marker: PhantomData<&'pr mut pm_node_list>,
118}
119
120impl<'pr> Iterator for NodeListIter<'pr> {
121    type Item = Node<'pr>;
122
123    fn next(&mut self) -> Option<Self::Item> {
124        if self.index >= unsafe { self.pointer.as_ref().size } {
125            None
126        } else {
127            let node: *mut pm_node_t = unsafe { *(self.pointer.as_ref().nodes.add(self.index)) };
128            self.index += 1;
129            Some(Node::new(self.parser, node))
130        }
131    }
132}
133
134/// A list of nodes.
135pub struct NodeList<'pr> {
136    parser: NonNull<pm_parser_t>,
137    pointer: NonNull<pm_node_list>,
138    marker: PhantomData<&'pr mut pm_node_list>,
139}
140
141impl<'pr> NodeList<'pr> {
142    /// Returns an iterator over the nodes.
143    #[must_use]
144    pub const fn iter(&self) -> NodeListIter<'pr> {
145        NodeListIter {
146            parser: self.parser,
147            pointer: self.pointer,
148            index: 0,
149            marker: PhantomData,
150        }
151    }
152
153    /// Returns the length of the list.
154    #[must_use]
155    pub const fn len(&self) -> usize {
156        unsafe { self.pointer.as_ref().size }
157    }
158
159    /// Returns whether the list is empty.
160    #[must_use]
161    pub const fn is_empty(&self) -> bool {
162        self.len() == 0
163    }
164}
165
166impl<'pr> IntoIterator for &NodeList<'pr> {
167    type Item = Node<'pr>;
168    type IntoIter = NodeListIter<'pr>;
169    fn into_iter(self) -> Self::IntoIter {
170        self.iter()
171    }
172}
173
174impl std::fmt::Debug for NodeList<'_> {
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        write!(f, "{:?}", self.iter().collect::<Vec<_>>())
177    }
178}
179
180/// A handle for a constant ID.
181pub struct ConstantId<'pr> {
182    parser: NonNull<pm_parser_t>,
183    id: pm_constant_id_t,
184    marker: PhantomData<&'pr mut pm_constant_id_t>,
185}
186
187impl<'pr> ConstantId<'pr> {
188    const fn new(parser: NonNull<pm_parser_t>, id: pm_constant_id_t) -> Self {
189        ConstantId { parser, id, marker: PhantomData }
190    }
191
192    /// Returns a byte slice for the constant ID.
193    ///
194    /// # Panics
195    ///
196    /// Panics if the constant ID is not found in the constant pool.
197    #[must_use]
198    pub fn as_slice(&self) -> &'pr [u8] {
199        unsafe {
200            let pool = &(*self.parser.as_ptr()).constant_pool;
201            let constant = &(*pool.constants.add((self.id - 1).try_into().unwrap()));
202            std::slice::from_raw_parts(constant.start, constant.length)
203        }
204    }
205}
206
207impl std::fmt::Debug for ConstantId<'_> {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        write!(f, "{:?}", self.id)
210    }
211}
212
213/// An iterator over the constants in a list.
214pub struct ConstantListIter<'pr> {
215    parser: NonNull<pm_parser_t>,
216    pointer: NonNull<pm_constant_id_list_t>,
217    index: usize,
218    marker: PhantomData<&'pr mut pm_constant_id_list_t>,
219}
220
221impl<'pr> Iterator for ConstantListIter<'pr> {
222    type Item = ConstantId<'pr>;
223
224    fn next(&mut self) -> Option<Self::Item> {
225        if self.index >= unsafe { self.pointer.as_ref().size } {
226            None
227        } else {
228            let constant_id: pm_constant_id_t = unsafe { *(self.pointer.as_ref().ids.add(self.index)) };
229            self.index += 1;
230            Some(ConstantId::new(self.parser, constant_id))
231        }
232    }
233}
234
235/// A list of constants.
236pub struct ConstantList<'pr> {
237    /// The raw pointer to the parser where this list came from.
238    parser: NonNull<pm_parser_t>,
239
240    /// The raw pointer to the list allocated by prism.
241    pointer: NonNull<pm_constant_id_list_t>,
242
243    /// The marker to indicate the lifetime of the pointer.
244    marker: PhantomData<&'pr mut pm_constant_id_list_t>,
245}
246
247impl<'pr> ConstantList<'pr> {
248    /// Returns an iterator over the constants in the list.
249    #[must_use]
250    pub const fn iter(&self) -> ConstantListIter<'pr> {
251        ConstantListIter {
252            parser: self.parser,
253            pointer: self.pointer,
254            index: 0,
255            marker: PhantomData,
256        }
257    }
258
259    /// Returns the length of the list.
260    #[must_use]
261    pub const fn len(&self) -> usize {
262        unsafe { self.pointer.as_ref().size }
263    }
264
265    /// Returns whether the list is empty.
266    #[must_use]
267    pub const fn is_empty(&self) -> bool {
268        self.len() == 0
269    }
270}
271
272impl<'pr> IntoIterator for &ConstantList<'pr> {
273    type Item = ConstantId<'pr>;
274    type IntoIter = ConstantListIter<'pr>;
275    fn into_iter(self) -> Self::IntoIter {
276        self.iter()
277    }
278}
279
280impl std::fmt::Debug for ConstantList<'_> {
281    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282        write!(f, "{:?}", self.iter().collect::<Vec<_>>())
283    }
284}
285
286/// A handle for an arbitarily-sized integer.
287pub struct Integer<'pr> {
288    /// The raw pointer to the integer allocated by prism.
289    pointer: *const pm_integer_t,
290
291    /// The marker to indicate the lifetime of the pointer.
292    marker: PhantomData<&'pr mut pm_constant_id_t>,
293}
294
295impl Integer<'_> {
296    const fn new(pointer: *const pm_integer_t) -> Self {
297        Integer { pointer, marker: PhantomData }
298    }
299
300    /// Returns the sign and the u32 digits representation of the integer,
301    /// ordered least significant digit first.
302    #[must_use]
303    pub const fn to_u32_digits(&self) -> (bool, &[u32]) {
304        let negative = unsafe { (*self.pointer).negative };
305        let length = unsafe { (*self.pointer).length };
306        let values = unsafe { (*self.pointer).values };
307
308        if values.is_null() {
309            let value_ptr = unsafe { std::ptr::addr_of!((*self.pointer).value) };
310            let slice = unsafe { std::slice::from_raw_parts(value_ptr, 1) };
311            (negative, slice)
312        } else {
313            let slice = unsafe { std::slice::from_raw_parts(values, length) };
314            (negative, slice)
315        }
316    }
317}
318
319impl std::fmt::Debug for Integer<'_> {
320    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321        write!(f, "{:?}", self.pointer)
322    }
323}
324
325impl TryInto<i32> for Integer<'_> {
326    type Error = ();
327
328    fn try_into(self) -> Result<i32, Self::Error> {
329        let negative = unsafe { (*self.pointer).negative };
330        let length = unsafe { (*self.pointer).length };
331
332        if length == 0 {
333            i32::try_from(unsafe { (*self.pointer).value }).map_or(Err(()), |value| if negative { Ok(-value) } else { Ok(value) })
334        } else {
335            Err(())
336        }
337    }
338}
339
340/// A diagnostic message that came back from the parser.
341#[derive(Debug)]
342pub struct Diagnostic<'pr> {
343    diag: NonNull<pm_diagnostic_t>,
344    parser: NonNull<pm_parser_t>,
345    marker: PhantomData<&'pr pm_diagnostic_t>,
346}
347
348impl<'pr> Diagnostic<'pr> {
349    /// Returns the message associated with the diagnostic.
350    ///
351    /// # Panics
352    ///
353    /// Panics if the message is not able to be converted into a `CStr`.
354    ///
355    #[must_use]
356    pub fn message(&self) -> &str {
357        unsafe {
358            let message: *mut c_char = self.diag.as_ref().message.cast_mut();
359            CStr::from_ptr(message).to_str().expect("prism allows only UTF-8 for diagnostics.")
360        }
361    }
362
363    /// The location of the diagnostic in the source.
364    #[must_use]
365    pub const fn location(&self) -> Location<'pr> {
366        Location::new(self.parser, unsafe { &self.diag.as_ref().location })
367    }
368}
369
370/// A comment that was found during parsing.
371#[derive(Debug)]
372pub struct Comment<'pr> {
373    content: NonNull<pm_comment_t>,
374    parser: NonNull<pm_parser_t>,
375    marker: PhantomData<&'pr pm_comment_t>,
376}
377
378/// The type of the comment
379#[derive(Debug, Clone, Copy, PartialEq, Eq)]
380pub enum CommentType {
381    /// `InlineComment` corresponds to comments that start with #.
382    InlineComment,
383    /// `EmbDocComment` corresponds to comments that are surrounded by =begin and =end.
384    EmbDocComment,
385}
386
387impl<'pr> Comment<'pr> {
388    /// Returns the text of the comment.
389    ///
390    /// # Panics
391    /// Panics if the end offset is not greater than the start offset.
392    #[must_use]
393    pub fn text(&self) -> &[u8] {
394        self.location().as_slice()
395    }
396
397    /// Returns the type of the comment.
398    #[must_use]
399    pub fn type_(&self) -> CommentType {
400        let type_ = unsafe { self.content.as_ref().type_ };
401        if type_ == pm_comment_type_t::PM_COMMENT_EMBDOC {
402            CommentType::EmbDocComment
403        } else {
404            CommentType::InlineComment
405        }
406    }
407
408    /// The location of the comment in the source.
409    #[must_use]
410    pub const fn location(&self) -> Location<'pr> {
411        Location::new(self.parser, unsafe { &self.content.as_ref().location })
412    }
413}
414
415/// A magic comment that was found during parsing.
416#[derive(Debug)]
417pub struct MagicComment<'pr> {
418    comment: NonNull<pm_magic_comment_t>,
419    marker: PhantomData<&'pr pm_magic_comment_t>,
420}
421
422impl MagicComment<'_> {
423    /// Returns the text of the comment's key.
424    #[must_use]
425    pub const fn key(&self) -> &[u8] {
426        unsafe {
427            let start = self.comment.as_ref().key_start;
428            let len = self.comment.as_ref().key_length as usize;
429            std::slice::from_raw_parts(start, len)
430        }
431    }
432
433    /// Returns the text of the comment's value.
434    #[must_use]
435    pub const fn value(&self) -> &[u8] {
436        unsafe {
437            let start = self.comment.as_ref().value_start;
438            let len = self.comment.as_ref().value_length as usize;
439            std::slice::from_raw_parts(start, len)
440        }
441    }
442}
443
444/// A struct created by the `errors` or `warnings` methods on `ParseResult`. It
445/// can be used to iterate over the diagnostics in the parse result.
446pub struct Diagnostics<'pr> {
447    diagnostic: *mut pm_diagnostic_t,
448    parser: NonNull<pm_parser_t>,
449    marker: PhantomData<&'pr pm_diagnostic_t>,
450}
451
452impl<'pr> Iterator for Diagnostics<'pr> {
453    type Item = Diagnostic<'pr>;
454
455    fn next(&mut self) -> Option<Self::Item> {
456        if let Some(diagnostic) = NonNull::new(self.diagnostic) {
457            let current = Diagnostic {
458                diag: diagnostic,
459                parser: self.parser,
460                marker: PhantomData,
461            };
462            self.diagnostic = unsafe { diagnostic.as_ref().node.next.cast::<pm_diagnostic_t>() };
463            Some(current)
464        } else {
465            None
466        }
467    }
468}
469
470/// A struct created by the `comments` method on `ParseResult`. It can be used
471/// to iterate over the comments in the parse result.
472pub struct Comments<'pr> {
473    comment: *mut pm_comment_t,
474    parser: NonNull<pm_parser_t>,
475    marker: PhantomData<&'pr pm_comment_t>,
476}
477
478impl<'pr> Iterator for Comments<'pr> {
479    type Item = Comment<'pr>;
480
481    fn next(&mut self) -> Option<Self::Item> {
482        if let Some(comment) = NonNull::new(self.comment) {
483            let current = Comment {
484                content: comment,
485                parser: self.parser,
486                marker: PhantomData,
487            };
488            self.comment = unsafe { comment.as_ref().node.next.cast::<pm_comment_t>() };
489            Some(current)
490        } else {
491            None
492        }
493    }
494}
495
496/// A struct created by the `magic_comments` method on `ParseResult`. It can be used
497/// to iterate over the magic comments in the parse result.
498pub struct MagicComments<'pr> {
499    comment: *mut pm_magic_comment_t,
500    marker: PhantomData<&'pr pm_magic_comment_t>,
501}
502
503impl<'pr> Iterator for MagicComments<'pr> {
504    type Item = MagicComment<'pr>;
505
506    fn next(&mut self) -> Option<Self::Item> {
507        if let Some(comment) = NonNull::new(self.comment) {
508            let current = MagicComment { comment, marker: PhantomData };
509            self.comment = unsafe { comment.as_ref().node.next.cast::<pm_magic_comment_t>() };
510            Some(current)
511        } else {
512            None
513        }
514    }
515}
516
517/// The result of parsing a source string.
518#[derive(Debug)]
519pub struct ParseResult<'pr> {
520    source: &'pr [u8],
521    parser: NonNull<pm_parser_t>,
522    node: NonNull<pm_node_t>,
523}
524
525impl<'pr> ParseResult<'pr> {
526    /// Returns the source string that was parsed.
527    #[must_use]
528    pub const fn source(&self) -> &'pr [u8] {
529        self.source
530    }
531
532    /// Returns whether we found a `frozen_string_literal` magic comment with a true value.
533    #[must_use]
534    pub fn frozen_string_literals(&self) -> bool {
535        unsafe { (*self.parser.as_ptr()).frozen_string_literal == 1 }
536    }
537
538    /// Returns a slice of the source string that was parsed using the given
539    /// location range.
540    ///
541    /// # Panics
542    /// Panics if start offset or end offset are not valid offsets from the root.
543    #[must_use]
544    pub fn as_slice(&self, location: &Location<'pr>) -> &'pr [u8] {
545        let root = self.source.as_ptr();
546
547        let start = usize::try_from(unsafe { location.start.offset_from(root) }).expect("start should point to memory after root");
548        let end = usize::try_from(unsafe { location.end.offset_from(root) }).expect("end should point to memory after root");
549
550        &self.source[start..end]
551    }
552
553    /// Returns an iterator that can be used to iterate over the errors in the
554    /// parse result.
555    #[must_use]
556    pub fn errors(&self) -> Diagnostics<'_> {
557        unsafe {
558            let list = &mut (*self.parser.as_ptr()).error_list;
559            Diagnostics {
560                diagnostic: list.head.cast::<pm_diagnostic_t>(),
561                parser: self.parser,
562                marker: PhantomData,
563            }
564        }
565    }
566
567    /// Returns an iterator that can be used to iterate over the warnings in the
568    /// parse result.
569    #[must_use]
570    pub fn warnings(&self) -> Diagnostics<'_> {
571        unsafe {
572            let list = &mut (*self.parser.as_ptr()).warning_list;
573            Diagnostics {
574                diagnostic: list.head.cast::<pm_diagnostic_t>(),
575                parser: self.parser,
576                marker: PhantomData,
577            }
578        }
579    }
580
581    /// Returns an iterator that can be used to iterate over the comments in the
582    /// parse result.
583    #[must_use]
584    pub fn comments(&self) -> Comments<'_> {
585        unsafe {
586            let list = &mut (*self.parser.as_ptr()).comment_list;
587            Comments {
588                comment: list.head.cast::<pm_comment_t>(),
589                parser: self.parser,
590                marker: PhantomData,
591            }
592        }
593    }
594
595    /// Returns an iterator that can be used to iterate over the magic comments in the
596    /// parse result.
597    #[must_use]
598    pub fn magic_comments(&self) -> MagicComments<'_> {
599        unsafe {
600            let list = &mut (*self.parser.as_ptr()).magic_comment_list;
601            MagicComments {
602                comment: list.head.cast::<pm_magic_comment_t>(),
603                marker: PhantomData,
604            }
605        }
606    }
607
608    /// Returns an optional location of the __END__ marker and the rest of the content of the file.
609    #[must_use]
610    pub fn data_loc(&self) -> Option<Location<'_>> {
611        let location = unsafe { &(*self.parser.as_ptr()).data_loc };
612        if location.start.is_null() {
613            None
614        } else {
615            Some(Location::new(self.parser, location))
616        }
617    }
618
619    /// Returns the root node of the parse result.
620    #[must_use]
621    pub fn node(&self) -> Node<'_> {
622        Node::new(self.parser, self.node.as_ptr())
623    }
624}
625
626impl Drop for ParseResult<'_> {
627    fn drop(&mut self) {
628        unsafe {
629            pm_node_destroy(self.parser.as_ptr(), self.node.as_ptr());
630            pm_parser_free(self.parser.as_ptr());
631            drop(Box::from_raw(self.parser.as_ptr()));
632        }
633    }
634}
635
636/// Parses the given source string and returns a parse result.
637///
638/// # Panics
639///
640/// Panics if the parser fails to initialize.
641///
642#[must_use]
643pub fn parse(source: &[u8]) -> ParseResult<'_> {
644    unsafe {
645        let uninit = Box::new(MaybeUninit::<pm_parser_t>::uninit());
646        let uninit = Box::into_raw(uninit);
647
648        pm_parser_init((*uninit).as_mut_ptr(), source.as_ptr(), source.len(), std::ptr::null());
649
650        let parser = (*uninit).assume_init_mut();
651        let parser = NonNull::new_unchecked(parser);
652
653        let node = pm_parse(parser.as_ptr());
654        let node = NonNull::new_unchecked(node);
655
656        ParseResult { source, parser, node }
657    }
658}
659
660#[cfg(test)]
661mod tests {
662    use super::parse;
663
664    #[test]
665    fn comments_test() {
666        let source = "# comment 1\n# comment 2\n# comment 3\n";
667        let result = parse(source.as_ref());
668
669        for comment in result.comments() {
670            assert_eq!(super::CommentType::InlineComment, comment.type_());
671            let text = std::str::from_utf8(comment.text()).unwrap();
672            assert!(text.starts_with("# comment"));
673        }
674    }
675
676    #[test]
677    fn magic_comments_test() {
678        use crate::MagicComment;
679
680        let source = "# typed: ignore\n# typed:true\n#typed: strict\n";
681        let result = parse(source.as_ref());
682
683        let comments: Vec<MagicComment<'_>> = result.magic_comments().collect();
684        assert_eq!(3, comments.len());
685
686        assert_eq!(b"typed", comments[0].key());
687        assert_eq!(b"ignore", comments[0].value());
688
689        assert_eq!(b"typed", comments[1].key());
690        assert_eq!(b"true", comments[1].value());
691
692        assert_eq!(b"typed", comments[2].key());
693        assert_eq!(b"strict", comments[2].value());
694    }
695
696    #[test]
697    fn data_loc_test() {
698        let source = "1";
699        let result = parse(source.as_ref());
700        let data_loc = result.data_loc();
701        assert!(data_loc.is_none());
702
703        let source = "__END__\nabc\n";
704        let result = parse(source.as_ref());
705        let data_loc = result.data_loc().unwrap();
706        let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
707        assert_eq!(slice, "__END__\nabc\n");
708
709        let source = "1\n2\n3\n__END__\nabc\ndef\n";
710        let result = parse(source.as_ref());
711        let data_loc = result.data_loc().unwrap();
712        let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
713        assert_eq!(slice, "__END__\nabc\ndef\n");
714    }
715
716    #[test]
717    fn location_test() {
718        let source = "111 + 222 + 333";
719        let result = parse(source.as_ref());
720
721        let node = result.node();
722        let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
723        let node = node.as_call_node().unwrap().receiver().unwrap();
724        let plus = node.as_call_node().unwrap();
725        let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
726
727        let location = node.as_integer_node().unwrap().location();
728        let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
729
730        assert_eq!(slice, "222");
731        assert_eq!(6, location.start_offset());
732        assert_eq!(9, location.end_offset());
733
734        let recv_loc = plus.receiver().unwrap().location();
735        assert_eq!(recv_loc.as_slice(), b"111");
736        assert_eq!(0, recv_loc.start_offset());
737        assert_eq!(3, recv_loc.end_offset());
738
739        let joined = recv_loc.join(&location).unwrap();
740        assert_eq!(joined.as_slice(), b"111 + 222");
741
742        let not_joined = location.join(&recv_loc);
743        assert!(not_joined.is_none());
744
745        {
746            let result = parse(source.as_ref());
747            let node = result.node();
748            let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
749            let node = node.as_call_node().unwrap().receiver().unwrap();
750            let plus = node.as_call_node().unwrap();
751            let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
752
753            let location = node.as_integer_node().unwrap().location();
754            let not_joined = recv_loc.join(&location);
755            assert!(not_joined.is_none());
756
757            let not_joined = location.join(&recv_loc);
758            assert!(not_joined.is_none());
759        }
760
761        let location = node.location();
762        let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
763
764        assert_eq!(slice, "222");
765
766        let slice = std::str::from_utf8(location.as_slice()).unwrap();
767
768        assert_eq!(slice, "222");
769    }
770
771    #[test]
772    fn visitor_test() {
773        use super::{visit_interpolated_regular_expression_node, visit_regular_expression_node, InterpolatedRegularExpressionNode, RegularExpressionNode, Visit};
774
775        struct RegularExpressionVisitor {
776            count: usize,
777        }
778
779        impl Visit<'_> for RegularExpressionVisitor {
780            fn visit_interpolated_regular_expression_node(&mut self, node: &InterpolatedRegularExpressionNode<'_>) {
781                self.count += 1;
782                visit_interpolated_regular_expression_node(self, node);
783            }
784
785            fn visit_regular_expression_node(&mut self, node: &RegularExpressionNode<'_>) {
786                self.count += 1;
787                visit_regular_expression_node(self, node);
788            }
789        }
790
791        let source = "# comment 1\n# comment 2\nmodule Foo; class Bar; /abc #{/def/}/; end; end";
792        let result = parse(source.as_ref());
793
794        let mut visitor = RegularExpressionVisitor { count: 0 };
795        visitor.visit(&result.node());
796
797        assert_eq!(visitor.count, 2);
798    }
799
800    #[test]
801    fn node_upcast_test() {
802        use super::Node;
803
804        let source = "module Foo; end";
805        let result = parse(source.as_ref());
806
807        let node = result.node();
808        let upcast_node = node.as_program_node().unwrap().as_node();
809        assert!(matches!(upcast_node, Node::ProgramNode { .. }));
810
811        let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
812        let upcast_node = node.as_module_node().unwrap().as_node();
813        assert!(matches!(upcast_node, Node::ModuleNode { .. }));
814    }
815
816    #[test]
817    fn constant_id_test() {
818        let source = "module Foo; x = 1; end";
819        let result = parse(source.as_ref());
820
821        let node = result.node();
822        assert_eq!(node.as_program_node().unwrap().statements().body().len(), 1);
823        assert!(!node.as_program_node().unwrap().statements().body().is_empty());
824        let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
825        let module = module.as_module_node().unwrap();
826
827        assert_eq!(module.locals().len(), 1);
828        assert!(!module.locals().is_empty());
829
830        let locals = module.locals().iter().collect::<Vec<_>>();
831
832        assert_eq!(locals.len(), 1);
833
834        assert_eq!(locals[0].as_slice(), b"x");
835
836        let source = "module Foo; end";
837        let result = parse(source.as_ref());
838
839        let node = result.node();
840        assert_eq!(node.as_program_node().unwrap().statements().body().len(), 1);
841        assert!(!node.as_program_node().unwrap().statements().body().is_empty());
842        let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
843        let module = module.as_module_node().unwrap();
844
845        assert_eq!(module.locals().len(), 0);
846        assert!(module.locals().is_empty());
847    }
848
849    #[test]
850    fn optional_loc_test() {
851        let source = r"
852module Example
853  x = call_func(3, 4)
854  y = x.call_func 5, 6
855end
856";
857        let result = parse(source.as_ref());
858
859        let node = result.node();
860        let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
861        let module = module.as_module_node().unwrap();
862        let body = module.body();
863        let writes = body.iter().next().unwrap().as_statements_node().unwrap().body().iter().collect::<Vec<_>>();
864        assert_eq!(writes.len(), 2);
865
866        let asgn = &writes[0];
867        let call = asgn.as_local_variable_write_node().unwrap().value();
868        let call = call.as_call_node().unwrap();
869
870        let call_operator_loc = call.call_operator_loc();
871        assert!(call_operator_loc.is_none());
872        let closing_loc = call.closing_loc();
873        assert!(closing_loc.is_some());
874
875        let asgn = &writes[1];
876        let call = asgn.as_local_variable_write_node().unwrap().value();
877        let call = call.as_call_node().unwrap();
878
879        let call_operator_loc = call.call_operator_loc();
880        assert!(call_operator_loc.is_some());
881        let closing_loc = call.closing_loc();
882        assert!(closing_loc.is_none());
883    }
884
885    #[test]
886    fn frozen_strings_test() {
887        let source = r#"
888# frozen_string_literal: true
889"foo"
890"#;
891        let result = parse(source.as_ref());
892        assert!(result.frozen_string_literals());
893
894        let source = "3";
895        let result = parse(source.as_ref());
896        assert!(!result.frozen_string_literals());
897    }
898
899    #[test]
900    fn string_flags_test() {
901        let source = r#"
902# frozen_string_literal: true
903"foo"
904"#;
905        let result = parse(source.as_ref());
906
907        let node = result.node();
908        let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
909        let string = string.as_string_node().unwrap();
910        assert!(string.is_frozen());
911
912        let source = r#"
913"foo"
914"#;
915        let result = parse(source.as_ref());
916
917        let node = result.node();
918        let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
919        let string = string.as_string_node().unwrap();
920        assert!(!string.is_frozen());
921    }
922
923    #[test]
924    fn call_flags_test() {
925        let source = r"
926x
927";
928        let result = parse(source.as_ref());
929
930        let node = result.node();
931        let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
932        let call = call.as_call_node().unwrap();
933        assert!(call.is_variable_call());
934
935        let source = r"
936x&.foo
937";
938        let result = parse(source.as_ref());
939
940        let node = result.node();
941        let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
942        let call = call.as_call_node().unwrap();
943        assert!(call.is_safe_navigation());
944    }
945
946    #[test]
947    fn integer_flags_test() {
948        let source = r"
9490b1
950";
951        let result = parse(source.as_ref());
952
953        let node = result.node();
954        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
955        let i = i.as_integer_node().unwrap();
956        assert!(i.is_binary());
957        assert!(!i.is_decimal());
958        assert!(!i.is_octal());
959        assert!(!i.is_hexadecimal());
960
961        let source = r"
9621
963";
964        let result = parse(source.as_ref());
965
966        let node = result.node();
967        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
968        let i = i.as_integer_node().unwrap();
969        assert!(!i.is_binary());
970        assert!(i.is_decimal());
971        assert!(!i.is_octal());
972        assert!(!i.is_hexadecimal());
973
974        let source = r"
9750o1
976";
977        let result = parse(source.as_ref());
978
979        let node = result.node();
980        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
981        let i = i.as_integer_node().unwrap();
982        assert!(!i.is_binary());
983        assert!(!i.is_decimal());
984        assert!(i.is_octal());
985        assert!(!i.is_hexadecimal());
986
987        let source = r"
9880x1
989";
990        let result = parse(source.as_ref());
991
992        let node = result.node();
993        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
994        let i = i.as_integer_node().unwrap();
995        assert!(!i.is_binary());
996        assert!(!i.is_decimal());
997        assert!(!i.is_octal());
998        assert!(i.is_hexadecimal());
999    }
1000
1001    #[test]
1002    fn range_flags_test() {
1003        let source = r"
10040..1
1005";
1006        let result = parse(source.as_ref());
1007
1008        let node = result.node();
1009        let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1010        let range = range.as_range_node().unwrap();
1011        assert!(!range.is_exclude_end());
1012
1013        let source = r"
10140...1
1015";
1016        let result = parse(source.as_ref());
1017
1018        let node = result.node();
1019        let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1020        let range = range.as_range_node().unwrap();
1021        assert!(range.is_exclude_end());
1022    }
1023
1024    #[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
1025    #[test]
1026    fn regex_flags_test() {
1027        let source = r"
1028/a/i
1029";
1030        let result = parse(source.as_ref());
1031
1032        let node = result.node();
1033        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1034        let regex = regex.as_regular_expression_node().unwrap();
1035        assert!(regex.is_ignore_case());
1036        assert!(!regex.is_extended());
1037        assert!(!regex.is_multi_line());
1038        assert!(!regex.is_euc_jp());
1039        assert!(!regex.is_ascii_8bit());
1040        assert!(!regex.is_windows_31j());
1041        assert!(!regex.is_utf_8());
1042        assert!(!regex.is_once());
1043
1044        let source = r"
1045/a/x
1046";
1047        let result = parse(source.as_ref());
1048
1049        let node = result.node();
1050        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1051        let regex = regex.as_regular_expression_node().unwrap();
1052        assert!(!regex.is_ignore_case());
1053        assert!(regex.is_extended());
1054        assert!(!regex.is_multi_line());
1055        assert!(!regex.is_euc_jp());
1056        assert!(!regex.is_ascii_8bit());
1057        assert!(!regex.is_windows_31j());
1058        assert!(!regex.is_utf_8());
1059        assert!(!regex.is_once());
1060
1061        let source = r"
1062/a/m
1063";
1064        let result = parse(source.as_ref());
1065
1066        let node = result.node();
1067        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1068        let regex = regex.as_regular_expression_node().unwrap();
1069        assert!(!regex.is_ignore_case());
1070        assert!(!regex.is_extended());
1071        assert!(regex.is_multi_line());
1072        assert!(!regex.is_euc_jp());
1073        assert!(!regex.is_ascii_8bit());
1074        assert!(!regex.is_windows_31j());
1075        assert!(!regex.is_utf_8());
1076        assert!(!regex.is_once());
1077
1078        let source = r"
1079/a/e
1080";
1081        let result = parse(source.as_ref());
1082
1083        let node = result.node();
1084        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1085        let regex = regex.as_regular_expression_node().unwrap();
1086        assert!(!regex.is_ignore_case());
1087        assert!(!regex.is_extended());
1088        assert!(!regex.is_multi_line());
1089        assert!(regex.is_euc_jp());
1090        assert!(!regex.is_ascii_8bit());
1091        assert!(!regex.is_windows_31j());
1092        assert!(!regex.is_utf_8());
1093        assert!(!regex.is_once());
1094
1095        let source = r"
1096/a/n
1097";
1098        let result = parse(source.as_ref());
1099
1100        let node = result.node();
1101        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1102        let regex = regex.as_regular_expression_node().unwrap();
1103        assert!(!regex.is_ignore_case());
1104        assert!(!regex.is_extended());
1105        assert!(!regex.is_multi_line());
1106        assert!(!regex.is_euc_jp());
1107        assert!(regex.is_ascii_8bit());
1108        assert!(!regex.is_windows_31j());
1109        assert!(!regex.is_utf_8());
1110        assert!(!regex.is_once());
1111
1112        let source = r"
1113/a/s
1114";
1115        let result = parse(source.as_ref());
1116
1117        let node = result.node();
1118        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1119        let regex = regex.as_regular_expression_node().unwrap();
1120        assert!(!regex.is_ignore_case());
1121        assert!(!regex.is_extended());
1122        assert!(!regex.is_multi_line());
1123        assert!(!regex.is_euc_jp());
1124        assert!(!regex.is_ascii_8bit());
1125        assert!(regex.is_windows_31j());
1126        assert!(!regex.is_utf_8());
1127        assert!(!regex.is_once());
1128
1129        let source = r"
1130/a/u
1131";
1132        let result = parse(source.as_ref());
1133
1134        let node = result.node();
1135        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1136        let regex = regex.as_regular_expression_node().unwrap();
1137        assert!(!regex.is_ignore_case());
1138        assert!(!regex.is_extended());
1139        assert!(!regex.is_multi_line());
1140        assert!(!regex.is_euc_jp());
1141        assert!(!regex.is_ascii_8bit());
1142        assert!(!regex.is_windows_31j());
1143        assert!(regex.is_utf_8());
1144        assert!(!regex.is_once());
1145
1146        let source = r"
1147/a/o
1148";
1149        let result = parse(source.as_ref());
1150
1151        let node = result.node();
1152        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1153        let regex = regex.as_regular_expression_node().unwrap();
1154        assert!(!regex.is_ignore_case());
1155        assert!(!regex.is_extended());
1156        assert!(!regex.is_multi_line());
1157        assert!(!regex.is_euc_jp());
1158        assert!(!regex.is_ascii_8bit());
1159        assert!(!regex.is_windows_31j());
1160        assert!(!regex.is_utf_8());
1161        assert!(regex.is_once());
1162    }
1163
1164    #[test]
1165    fn visitor_traversal_test() {
1166        use crate::{Node, Visit};
1167
1168        #[derive(Default)]
1169        struct NodeCounts {
1170            pre_parent: usize,
1171            post_parent: usize,
1172            pre_leaf: usize,
1173            post_leaf: usize,
1174        }
1175
1176        #[derive(Default)]
1177        struct CountingVisitor {
1178            counts: NodeCounts,
1179        }
1180
1181        impl Visit<'_> for CountingVisitor {
1182            fn visit_branch_node_enter(&mut self, _node: Node<'_>) {
1183                self.counts.pre_parent += 1;
1184            }
1185
1186            fn visit_branch_node_leave(&mut self) {
1187                self.counts.post_parent += 1;
1188            }
1189
1190            fn visit_leaf_node_enter(&mut self, _node: Node<'_>) {
1191                self.counts.pre_leaf += 1;
1192            }
1193
1194            fn visit_leaf_node_leave(&mut self) {
1195                self.counts.post_leaf += 1;
1196            }
1197        }
1198
1199        let source = r"
1200module Example
1201  x = call_func(3, 4)
1202  y = x.call_func 5, 6
1203end
1204";
1205        let result = parse(source.as_ref());
1206        let node = result.node();
1207        let mut visitor = CountingVisitor::default();
1208        visitor.visit(&node);
1209
1210        assert_eq!(7, visitor.counts.pre_parent);
1211        assert_eq!(7, visitor.counts.post_parent);
1212        assert_eq!(6, visitor.counts.pre_leaf);
1213        assert_eq!(6, visitor.counts.post_leaf);
1214    }
1215
1216    #[test]
1217    fn visitor_lifetime_test() {
1218        use crate::{Node, Visit};
1219
1220        #[derive(Default)]
1221        struct StackingNodeVisitor<'a> {
1222            stack: Vec<Node<'a>>,
1223            max_depth: usize,
1224        }
1225
1226        impl<'pr> Visit<'pr> for StackingNodeVisitor<'pr> {
1227            fn visit_branch_node_enter(&mut self, node: Node<'pr>) {
1228                self.stack.push(node);
1229            }
1230
1231            fn visit_branch_node_leave(&mut self) {
1232                self.stack.pop();
1233            }
1234
1235            fn visit_leaf_node_leave(&mut self) {
1236                self.max_depth = self.max_depth.max(self.stack.len());
1237            }
1238        }
1239
1240        let source = r"
1241module Example
1242  x = call_func(3, 4)
1243  y = x.call_func 5, 6
1244end
1245";
1246        let result = parse(source.as_ref());
1247        let node = result.node();
1248        let mut visitor = StackingNodeVisitor::default();
1249        visitor.visit(&node);
1250
1251        assert_eq!(0, visitor.stack.len());
1252        assert_eq!(5, visitor.max_depth);
1253    }
1254
1255    #[test]
1256    fn integer_value_test() {
1257        let result = parse("0xA".as_ref());
1258        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1259        let value: i32 = integer.try_into().unwrap();
1260
1261        assert_eq!(value, 10);
1262    }
1263
1264    #[test]
1265    fn integer_small_value_to_u32_digits_test() {
1266        let result = parse("0xA".as_ref());
1267        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1268        let (negative, digits) = integer.to_u32_digits();
1269
1270        assert!(!negative);
1271        assert_eq!(digits, &[10]);
1272    }
1273
1274    #[test]
1275    fn integer_large_value_to_u32_digits_test() {
1276        let result = parse("0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".as_ref());
1277        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1278        let (negative, digits) = integer.to_u32_digits();
1279
1280        assert!(!negative);
1281        assert_eq!(digits, &[4_294_967_295, 4_294_967_295, 4_294_967_295, 2_147_483_647]);
1282    }
1283
1284    #[test]
1285    fn float_value_test() {
1286        let result = parse("1.0".as_ref());
1287        let value: f64 = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_float_node().unwrap().value();
1288
1289        assert!((value - 1.0).abs() < f64::EPSILON);
1290    }
1291
1292    #[test]
1293    fn regex_value_test() {
1294        let result = parse(b"//");
1295        let node = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_regular_expression_node().unwrap();
1296        assert_eq!(node.unescaped(), b"");
1297    }
1298
1299    #[test]
1300    fn node_field_lifetime_test() {
1301        // The code below wouldn't typecheck prior to https://github.com/ruby/prism/pull/2519,
1302        // but we need to stop clippy from complaining about it.
1303        #![allow(clippy::needless_pass_by_value)]
1304
1305        use crate::Node;
1306
1307        #[derive(Default)]
1308        struct Extract<'pr> {
1309            scopes: Vec<crate::ConstantId<'pr>>,
1310        }
1311
1312        impl<'pr> Extract<'pr> {
1313            fn push_scope(&mut self, path: Node<'pr>) {
1314                if let Some(cread) = path.as_constant_read_node() {
1315                    self.scopes.push(cread.name());
1316                } else if let Some(cpath) = path.as_constant_path_node() {
1317                    if let Some(parent) = cpath.parent() {
1318                        self.push_scope(parent);
1319                    }
1320                    self.scopes.push(cpath.name().unwrap());
1321                } else {
1322                    panic!("Wrong node kind!");
1323                }
1324            }
1325        }
1326
1327        let source = "Some::Random::Constant";
1328        let result = parse(source.as_ref());
1329        let node = result.node();
1330        let mut extractor = Extract::default();
1331        extractor.push_scope(node.as_program_node().unwrap().statements().body().iter().next().unwrap());
1332        assert_eq!(3, extractor.scopes.len());
1333    }
1334
1335    #[test]
1336    fn malformed_shebang() {
1337        let source = "#!\x00";
1338        let result = parse(source.as_ref());
1339        assert!(result.errors().next().is_none());
1340        assert!(result.warnings().next().is_none());
1341    }
1342}