ruby_prism/
lib.rs

1//! # ruby-prism
2//!
3//! Rustified version of Ruby's prism parser.
4//!
5#![warn(clippy::all, clippy::nursery, clippy::pedantic, future_incompatible, missing_docs, nonstandard_style, rust_2018_idioms, trivial_casts, trivial_numeric_casts, unreachable_pub, unused_qualifications)]
6
7// Most of the code in this file is generated, so sometimes it generates code
8// that doesn't follow the clippy rules. We don't want to see those warnings.
9#[allow(clippy::too_many_lines, clippy::use_self)]
10mod bindings {
11    // In `build.rs`, we generate bindings based on the config.yml file. Here is
12    // where we pull in those bindings and make them part of our library.
13    include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
14}
15
16use std::ffi::{c_char, CStr};
17use std::marker::PhantomData;
18use std::mem::MaybeUninit;
19use std::ptr::NonNull;
20
21pub use self::bindings::*;
22use ruby_prism_sys::{pm_comment_t, pm_comment_type_t, pm_constant_id_list_t, pm_constant_id_t, pm_diagnostic_t, pm_integer_t, pm_location_t, pm_magic_comment_t, pm_node_destroy, pm_node_list, pm_node_t, pm_parse, pm_parser_free, pm_parser_init, pm_parser_t};
23
24/// A range in the source file.
25pub struct Location<'pr> {
26    parser: NonNull<pm_parser_t>,
27    pub(crate) start: *const u8,
28    pub(crate) end: *const u8,
29    marker: PhantomData<&'pr [u8]>,
30}
31
32impl<'pr> Location<'pr> {
33    /// Returns a byte slice for the range.
34    /// # Panics
35    /// Panics if the end offset is not greater than the start offset.
36    #[must_use]
37    pub fn as_slice(&self) -> &'pr [u8] {
38        unsafe {
39            let len = usize::try_from(self.end.offset_from(self.start)).expect("end should point to memory after start");
40            std::slice::from_raw_parts(self.start, len)
41        }
42    }
43
44    /// Return a Location from the given `pm_location_t`.
45    #[must_use]
46    pub(crate) const fn new(parser: NonNull<pm_parser_t>, loc: &'pr pm_location_t) -> Self {
47        Location {
48            parser,
49            start: loc.start,
50            end: loc.end,
51            marker: PhantomData,
52        }
53    }
54
55    /// Return a Location starting at self and ending at the end of other.
56    /// Returns None if both locations did not originate from the same parser,
57    /// or if self starts after other.
58    #[must_use]
59    pub fn join(&self, other: &Self) -> Option<Self> {
60        if self.parser != other.parser || self.start > other.start {
61            None
62        } else {
63            Some(Location {
64                parser: self.parser,
65                start: self.start,
66                end: other.end,
67                marker: PhantomData,
68            })
69        }
70    }
71
72    /// Return the start offset from the beginning of the parsed source.
73    /// # Panics
74    /// Panics if the start offset is not greater than the parser's start.
75    #[must_use]
76    pub fn start_offset(&self) -> usize {
77        unsafe {
78            let parser_start = (*self.parser.as_ptr()).start;
79            usize::try_from(self.start.offset_from(parser_start)).expect("start should point to memory after the parser's start")
80        }
81    }
82
83    /// Return the end offset from the beginning of the parsed source.
84    /// # Panics
85    /// Panics if the end offset is not greater than the parser's start.
86    #[must_use]
87    pub fn end_offset(&self) -> usize {
88        unsafe {
89            let parser_start = (*self.parser.as_ptr()).start;
90            usize::try_from(self.end.offset_from(parser_start)).expect("end should point to memory after the parser's start")
91        }
92    }
93}
94
95impl std::fmt::Debug for Location<'_> {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        let slice: &[u8] = self.as_slice();
98
99        let mut visible = String::new();
100        visible.push('"');
101
102        for &byte in slice {
103            let part: Vec<u8> = std::ascii::escape_default(byte).collect();
104            visible.push_str(std::str::from_utf8(&part).unwrap());
105        }
106
107        visible.push('"');
108        write!(f, "{visible}")
109    }
110}
111
112/// An iterator over the nodes in a list.
113pub struct NodeListIter<'pr> {
114    parser: NonNull<pm_parser_t>,
115    pointer: NonNull<pm_node_list>,
116    index: usize,
117    marker: PhantomData<&'pr mut pm_node_list>,
118}
119
120impl<'pr> Iterator for NodeListIter<'pr> {
121    type Item = Node<'pr>;
122
123    fn next(&mut self) -> Option<Self::Item> {
124        if self.index >= unsafe { self.pointer.as_ref().size } {
125            None
126        } else {
127            let node: *mut pm_node_t = unsafe { *(self.pointer.as_ref().nodes.add(self.index)) };
128            self.index += 1;
129            Some(Node::new(self.parser, node))
130        }
131    }
132}
133
134/// A list of nodes.
135pub struct NodeList<'pr> {
136    parser: NonNull<pm_parser_t>,
137    pointer: NonNull<pm_node_list>,
138    marker: PhantomData<&'pr mut pm_node_list>,
139}
140
141impl<'pr> NodeList<'pr> {
142    unsafe fn at(&self, index: usize) -> Node<'pr> {
143        let node: *mut pm_node_t = *(self.pointer.as_ref().nodes.add(index));
144        Node::new(self.parser, node)
145    }
146
147    /// Returns an iterator over the nodes.
148    #[must_use]
149    pub const fn iter(&self) -> NodeListIter<'pr> {
150        NodeListIter {
151            parser: self.parser,
152            pointer: self.pointer,
153            index: 0,
154            marker: PhantomData,
155        }
156    }
157
158    /// Returns the length of the list.
159    #[must_use]
160    pub const fn len(&self) -> usize {
161        unsafe { self.pointer.as_ref().size }
162    }
163
164    /// Returns whether the list is empty.
165    #[must_use]
166    pub const fn is_empty(&self) -> bool {
167        self.len() == 0
168    }
169
170    /// Returns the first element of the list, or `None` if it is empty.
171    #[must_use]
172    pub fn first(&self) -> Option<Node<'pr>> {
173        if self.is_empty() {
174            None
175        } else {
176            Some(unsafe { self.at(0) })
177        }
178    }
179
180    /// Returns the last element of the list, or `None` if it is empty.
181    #[must_use]
182    pub fn last(&self) -> Option<Node<'pr>> {
183        if self.is_empty() {
184            None
185        } else {
186            Some(unsafe { self.at(self.len() - 1) })
187        }
188    }
189}
190
191impl<'pr> IntoIterator for &NodeList<'pr> {
192    type Item = Node<'pr>;
193    type IntoIter = NodeListIter<'pr>;
194    fn into_iter(self) -> Self::IntoIter {
195        self.iter()
196    }
197}
198
199impl std::fmt::Debug for NodeList<'_> {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        write!(f, "{:?}", self.iter().collect::<Vec<_>>())
202    }
203}
204
205/// A handle for a constant ID.
206pub struct ConstantId<'pr> {
207    parser: NonNull<pm_parser_t>,
208    id: pm_constant_id_t,
209    marker: PhantomData<&'pr mut pm_constant_id_t>,
210}
211
212impl<'pr> ConstantId<'pr> {
213    const fn new(parser: NonNull<pm_parser_t>, id: pm_constant_id_t) -> Self {
214        ConstantId { parser, id, marker: PhantomData }
215    }
216
217    /// Returns a byte slice for the constant ID.
218    ///
219    /// # Panics
220    ///
221    /// Panics if the constant ID is not found in the constant pool.
222    #[must_use]
223    pub fn as_slice(&self) -> &'pr [u8] {
224        unsafe {
225            let pool = &(*self.parser.as_ptr()).constant_pool;
226            let constant = &(*pool.constants.add((self.id - 1).try_into().unwrap()));
227            std::slice::from_raw_parts(constant.start, constant.length)
228        }
229    }
230}
231
232impl std::fmt::Debug for ConstantId<'_> {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        write!(f, "{:?}", self.id)
235    }
236}
237
238/// An iterator over the constants in a list.
239pub struct ConstantListIter<'pr> {
240    parser: NonNull<pm_parser_t>,
241    pointer: NonNull<pm_constant_id_list_t>,
242    index: usize,
243    marker: PhantomData<&'pr mut pm_constant_id_list_t>,
244}
245
246impl<'pr> Iterator for ConstantListIter<'pr> {
247    type Item = ConstantId<'pr>;
248
249    fn next(&mut self) -> Option<Self::Item> {
250        if self.index >= unsafe { self.pointer.as_ref().size } {
251            None
252        } else {
253            let constant_id: pm_constant_id_t = unsafe { *(self.pointer.as_ref().ids.add(self.index)) };
254            self.index += 1;
255            Some(ConstantId::new(self.parser, constant_id))
256        }
257    }
258}
259
260/// A list of constants.
261pub struct ConstantList<'pr> {
262    /// The raw pointer to the parser where this list came from.
263    parser: NonNull<pm_parser_t>,
264
265    /// The raw pointer to the list allocated by prism.
266    pointer: NonNull<pm_constant_id_list_t>,
267
268    /// The marker to indicate the lifetime of the pointer.
269    marker: PhantomData<&'pr mut pm_constant_id_list_t>,
270}
271
272impl<'pr> ConstantList<'pr> {
273    const unsafe fn at(&self, index: usize) -> ConstantId<'pr> {
274        let constant_id: pm_constant_id_t = *(self.pointer.as_ref().ids.add(index));
275        ConstantId::new(self.parser, constant_id)
276    }
277
278    /// Returns an iterator over the constants in the list.
279    #[must_use]
280    pub const fn iter(&self) -> ConstantListIter<'pr> {
281        ConstantListIter {
282            parser: self.parser,
283            pointer: self.pointer,
284            index: 0,
285            marker: PhantomData,
286        }
287    }
288
289    /// Returns the length of the list.
290    #[must_use]
291    pub const fn len(&self) -> usize {
292        unsafe { self.pointer.as_ref().size }
293    }
294
295    /// Returns whether the list is empty.
296    #[must_use]
297    pub const fn is_empty(&self) -> bool {
298        self.len() == 0
299    }
300
301    /// Returns the first element of the list, or `None` if it is empty.
302    #[must_use]
303    pub const fn first(&self) -> Option<ConstantId<'pr>> {
304        if self.is_empty() {
305            None
306        } else {
307            Some(unsafe { self.at(0) })
308        }
309    }
310
311    /// Returns the last element of the list, or `None` if it is empty.
312    #[must_use]
313    pub const fn last(&self) -> Option<ConstantId<'pr>> {
314        if self.is_empty() {
315            None
316        } else {
317            Some(unsafe { self.at(self.len() - 1) })
318        }
319    }
320}
321
322impl<'pr> IntoIterator for &ConstantList<'pr> {
323    type Item = ConstantId<'pr>;
324    type IntoIter = ConstantListIter<'pr>;
325    fn into_iter(self) -> Self::IntoIter {
326        self.iter()
327    }
328}
329
330impl std::fmt::Debug for ConstantList<'_> {
331    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332        write!(f, "{:?}", self.iter().collect::<Vec<_>>())
333    }
334}
335
336/// A handle for an arbitarily-sized integer.
337pub struct Integer<'pr> {
338    /// The raw pointer to the integer allocated by prism.
339    pointer: *const pm_integer_t,
340
341    /// The marker to indicate the lifetime of the pointer.
342    marker: PhantomData<&'pr mut pm_constant_id_t>,
343}
344
345impl Integer<'_> {
346    const fn new(pointer: *const pm_integer_t) -> Self {
347        Integer { pointer, marker: PhantomData }
348    }
349
350    /// Returns the sign and the u32 digits representation of the integer,
351    /// ordered least significant digit first.
352    #[must_use]
353    pub const fn to_u32_digits(&self) -> (bool, &[u32]) {
354        let negative = unsafe { (*self.pointer).negative };
355        let length = unsafe { (*self.pointer).length };
356        let values = unsafe { (*self.pointer).values };
357
358        if values.is_null() {
359            let value_ptr = unsafe { std::ptr::addr_of!((*self.pointer).value) };
360            let slice = unsafe { std::slice::from_raw_parts(value_ptr, 1) };
361            (negative, slice)
362        } else {
363            let slice = unsafe { std::slice::from_raw_parts(values, length) };
364            (negative, slice)
365        }
366    }
367}
368
369impl std::fmt::Debug for Integer<'_> {
370    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371        write!(f, "{:?}", self.pointer)
372    }
373}
374
375impl TryInto<i32> for Integer<'_> {
376    type Error = ();
377
378    fn try_into(self) -> Result<i32, Self::Error> {
379        let negative = unsafe { (*self.pointer).negative };
380        let length = unsafe { (*self.pointer).length };
381
382        if length == 0 {
383            i32::try_from(unsafe { (*self.pointer).value }).map_or(Err(()), |value| if negative { Ok(-value) } else { Ok(value) })
384        } else {
385            Err(())
386        }
387    }
388}
389
390/// A diagnostic message that came back from the parser.
391#[derive(Debug)]
392pub struct Diagnostic<'pr> {
393    diag: NonNull<pm_diagnostic_t>,
394    parser: NonNull<pm_parser_t>,
395    marker: PhantomData<&'pr pm_diagnostic_t>,
396}
397
398impl<'pr> Diagnostic<'pr> {
399    /// Returns the message associated with the diagnostic.
400    ///
401    /// # Panics
402    ///
403    /// Panics if the message is not able to be converted into a `CStr`.
404    ///
405    #[must_use]
406    pub fn message(&self) -> &str {
407        unsafe {
408            let message: *mut c_char = self.diag.as_ref().message.cast_mut();
409            CStr::from_ptr(message).to_str().expect("prism allows only UTF-8 for diagnostics.")
410        }
411    }
412
413    /// The location of the diagnostic in the source.
414    #[must_use]
415    pub const fn location(&self) -> Location<'pr> {
416        Location::new(self.parser, unsafe { &self.diag.as_ref().location })
417    }
418}
419
420/// A comment that was found during parsing.
421#[derive(Debug)]
422pub struct Comment<'pr> {
423    content: NonNull<pm_comment_t>,
424    parser: NonNull<pm_parser_t>,
425    marker: PhantomData<&'pr pm_comment_t>,
426}
427
428/// The type of the comment
429#[derive(Debug, Clone, Copy, PartialEq, Eq)]
430pub enum CommentType {
431    /// `InlineComment` corresponds to comments that start with #.
432    InlineComment,
433    /// `EmbDocComment` corresponds to comments that are surrounded by =begin and =end.
434    EmbDocComment,
435}
436
437impl<'pr> Comment<'pr> {
438    /// Returns the text of the comment.
439    ///
440    /// # Panics
441    /// Panics if the end offset is not greater than the start offset.
442    #[must_use]
443    pub fn text(&self) -> &[u8] {
444        self.location().as_slice()
445    }
446
447    /// Returns the type of the comment.
448    #[must_use]
449    pub fn type_(&self) -> CommentType {
450        let type_ = unsafe { self.content.as_ref().type_ };
451        if type_ == pm_comment_type_t::PM_COMMENT_EMBDOC {
452            CommentType::EmbDocComment
453        } else {
454            CommentType::InlineComment
455        }
456    }
457
458    /// The location of the comment in the source.
459    #[must_use]
460    pub const fn location(&self) -> Location<'pr> {
461        Location::new(self.parser, unsafe { &self.content.as_ref().location })
462    }
463}
464
465/// A magic comment that was found during parsing.
466#[derive(Debug)]
467pub struct MagicComment<'pr> {
468    comment: NonNull<pm_magic_comment_t>,
469    marker: PhantomData<&'pr pm_magic_comment_t>,
470}
471
472impl MagicComment<'_> {
473    /// Returns the text of the comment's key.
474    #[must_use]
475    pub const fn key(&self) -> &[u8] {
476        unsafe {
477            let start = self.comment.as_ref().key_start;
478            let len = self.comment.as_ref().key_length as usize;
479            std::slice::from_raw_parts(start, len)
480        }
481    }
482
483    /// Returns the text of the comment's value.
484    #[must_use]
485    pub const fn value(&self) -> &[u8] {
486        unsafe {
487            let start = self.comment.as_ref().value_start;
488            let len = self.comment.as_ref().value_length as usize;
489            std::slice::from_raw_parts(start, len)
490        }
491    }
492}
493
494/// A struct created by the `errors` or `warnings` methods on `ParseResult`. It
495/// can be used to iterate over the diagnostics in the parse result.
496pub struct Diagnostics<'pr> {
497    diagnostic: *mut pm_diagnostic_t,
498    parser: NonNull<pm_parser_t>,
499    marker: PhantomData<&'pr pm_diagnostic_t>,
500}
501
502impl<'pr> Iterator for Diagnostics<'pr> {
503    type Item = Diagnostic<'pr>;
504
505    fn next(&mut self) -> Option<Self::Item> {
506        if let Some(diagnostic) = NonNull::new(self.diagnostic) {
507            let current = Diagnostic {
508                diag: diagnostic,
509                parser: self.parser,
510                marker: PhantomData,
511            };
512            self.diagnostic = unsafe { diagnostic.as_ref().node.next.cast::<pm_diagnostic_t>() };
513            Some(current)
514        } else {
515            None
516        }
517    }
518}
519
520/// A struct created by the `comments` method on `ParseResult`. It can be used
521/// to iterate over the comments in the parse result.
522pub struct Comments<'pr> {
523    comment: *mut pm_comment_t,
524    parser: NonNull<pm_parser_t>,
525    marker: PhantomData<&'pr pm_comment_t>,
526}
527
528impl<'pr> Iterator for Comments<'pr> {
529    type Item = Comment<'pr>;
530
531    fn next(&mut self) -> Option<Self::Item> {
532        if let Some(comment) = NonNull::new(self.comment) {
533            let current = Comment {
534                content: comment,
535                parser: self.parser,
536                marker: PhantomData,
537            };
538            self.comment = unsafe { comment.as_ref().node.next.cast::<pm_comment_t>() };
539            Some(current)
540        } else {
541            None
542        }
543    }
544}
545
546/// A struct created by the `magic_comments` method on `ParseResult`. It can be used
547/// to iterate over the magic comments in the parse result.
548pub struct MagicComments<'pr> {
549    comment: *mut pm_magic_comment_t,
550    marker: PhantomData<&'pr pm_magic_comment_t>,
551}
552
553impl<'pr> Iterator for MagicComments<'pr> {
554    type Item = MagicComment<'pr>;
555
556    fn next(&mut self) -> Option<Self::Item> {
557        if let Some(comment) = NonNull::new(self.comment) {
558            let current = MagicComment { comment, marker: PhantomData };
559            self.comment = unsafe { comment.as_ref().node.next.cast::<pm_magic_comment_t>() };
560            Some(current)
561        } else {
562            None
563        }
564    }
565}
566
567/// The result of parsing a source string.
568#[derive(Debug)]
569pub struct ParseResult<'pr> {
570    source: &'pr [u8],
571    parser: NonNull<pm_parser_t>,
572    node: NonNull<pm_node_t>,
573}
574
575impl<'pr> ParseResult<'pr> {
576    /// Returns the source string that was parsed.
577    #[must_use]
578    pub const fn source(&self) -> &'pr [u8] {
579        self.source
580    }
581
582    /// Returns whether we found a `frozen_string_literal` magic comment with a true value.
583    #[must_use]
584    pub fn frozen_string_literals(&self) -> bool {
585        unsafe { (*self.parser.as_ptr()).frozen_string_literal == 1 }
586    }
587
588    /// Returns a slice of the source string that was parsed using the given
589    /// location range.
590    ///
591    /// # Panics
592    /// Panics if start offset or end offset are not valid offsets from the root.
593    #[must_use]
594    pub fn as_slice(&self, location: &Location<'pr>) -> &'pr [u8] {
595        let root = self.source.as_ptr();
596
597        let start = usize::try_from(unsafe { location.start.offset_from(root) }).expect("start should point to memory after root");
598        let end = usize::try_from(unsafe { location.end.offset_from(root) }).expect("end should point to memory after root");
599
600        &self.source[start..end]
601    }
602
603    /// Returns an iterator that can be used to iterate over the errors in the
604    /// parse result.
605    #[must_use]
606    pub fn errors(&self) -> Diagnostics<'_> {
607        unsafe {
608            let list = &mut (*self.parser.as_ptr()).error_list;
609            Diagnostics {
610                diagnostic: list.head.cast::<pm_diagnostic_t>(),
611                parser: self.parser,
612                marker: PhantomData,
613            }
614        }
615    }
616
617    /// Returns an iterator that can be used to iterate over the warnings in the
618    /// parse result.
619    #[must_use]
620    pub fn warnings(&self) -> Diagnostics<'_> {
621        unsafe {
622            let list = &mut (*self.parser.as_ptr()).warning_list;
623            Diagnostics {
624                diagnostic: list.head.cast::<pm_diagnostic_t>(),
625                parser: self.parser,
626                marker: PhantomData,
627            }
628        }
629    }
630
631    /// Returns an iterator that can be used to iterate over the comments in the
632    /// parse result.
633    #[must_use]
634    pub fn comments(&self) -> Comments<'_> {
635        unsafe {
636            let list = &mut (*self.parser.as_ptr()).comment_list;
637            Comments {
638                comment: list.head.cast::<pm_comment_t>(),
639                parser: self.parser,
640                marker: PhantomData,
641            }
642        }
643    }
644
645    /// Returns an iterator that can be used to iterate over the magic comments in the
646    /// parse result.
647    #[must_use]
648    pub fn magic_comments(&self) -> MagicComments<'_> {
649        unsafe {
650            let list = &mut (*self.parser.as_ptr()).magic_comment_list;
651            MagicComments {
652                comment: list.head.cast::<pm_magic_comment_t>(),
653                marker: PhantomData,
654            }
655        }
656    }
657
658    /// Returns an optional location of the __END__ marker and the rest of the content of the file.
659    #[must_use]
660    pub fn data_loc(&self) -> Option<Location<'_>> {
661        let location = unsafe { &(*self.parser.as_ptr()).data_loc };
662        if location.start.is_null() {
663            None
664        } else {
665            Some(Location::new(self.parser, location))
666        }
667    }
668
669    /// Returns the root node of the parse result.
670    #[must_use]
671    pub fn node(&self) -> Node<'_> {
672        Node::new(self.parser, self.node.as_ptr())
673    }
674}
675
676impl Drop for ParseResult<'_> {
677    fn drop(&mut self) {
678        unsafe {
679            pm_node_destroy(self.parser.as_ptr(), self.node.as_ptr());
680            pm_parser_free(self.parser.as_ptr());
681            drop(Box::from_raw(self.parser.as_ptr()));
682        }
683    }
684}
685
686/// Parses the given source string and returns a parse result.
687///
688/// # Panics
689///
690/// Panics if the parser fails to initialize.
691///
692#[must_use]
693pub fn parse(source: &[u8]) -> ParseResult<'_> {
694    unsafe {
695        let uninit = Box::new(MaybeUninit::<pm_parser_t>::uninit());
696        let uninit = Box::into_raw(uninit);
697
698        pm_parser_init((*uninit).as_mut_ptr(), source.as_ptr(), source.len(), std::ptr::null());
699
700        let parser = (*uninit).assume_init_mut();
701        let parser = NonNull::new_unchecked(parser);
702
703        let node = pm_parse(parser.as_ptr());
704        let node = NonNull::new_unchecked(node);
705
706        ParseResult { source, parser, node }
707    }
708}
709
710#[cfg(test)]
711mod tests {
712    use super::parse;
713
714    #[test]
715    fn comments_test() {
716        let source = "# comment 1\n# comment 2\n# comment 3\n";
717        let result = parse(source.as_ref());
718
719        for comment in result.comments() {
720            assert_eq!(super::CommentType::InlineComment, comment.type_());
721            let text = std::str::from_utf8(comment.text()).unwrap();
722            assert!(text.starts_with("# comment"));
723        }
724    }
725
726    #[test]
727    fn magic_comments_test() {
728        use crate::MagicComment;
729
730        let source = "# typed: ignore\n# typed:true\n#typed: strict\n";
731        let result = parse(source.as_ref());
732
733        let comments: Vec<MagicComment<'_>> = result.magic_comments().collect();
734        assert_eq!(3, comments.len());
735
736        assert_eq!(b"typed", comments[0].key());
737        assert_eq!(b"ignore", comments[0].value());
738
739        assert_eq!(b"typed", comments[1].key());
740        assert_eq!(b"true", comments[1].value());
741
742        assert_eq!(b"typed", comments[2].key());
743        assert_eq!(b"strict", comments[2].value());
744    }
745
746    #[test]
747    fn data_loc_test() {
748        let source = "1";
749        let result = parse(source.as_ref());
750        let data_loc = result.data_loc();
751        assert!(data_loc.is_none());
752
753        let source = "__END__\nabc\n";
754        let result = parse(source.as_ref());
755        let data_loc = result.data_loc().unwrap();
756        let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
757        assert_eq!(slice, "__END__\nabc\n");
758
759        let source = "1\n2\n3\n__END__\nabc\ndef\n";
760        let result = parse(source.as_ref());
761        let data_loc = result.data_loc().unwrap();
762        let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
763        assert_eq!(slice, "__END__\nabc\ndef\n");
764    }
765
766    #[test]
767    fn location_test() {
768        let source = "111 + 222 + 333";
769        let result = parse(source.as_ref());
770
771        let node = result.node();
772        let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
773        let node = node.as_call_node().unwrap().receiver().unwrap();
774        let plus = node.as_call_node().unwrap();
775        let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
776
777        let location = node.as_integer_node().unwrap().location();
778        let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
779
780        assert_eq!(slice, "222");
781        assert_eq!(6, location.start_offset());
782        assert_eq!(9, location.end_offset());
783
784        let recv_loc = plus.receiver().unwrap().location();
785        assert_eq!(recv_loc.as_slice(), b"111");
786        assert_eq!(0, recv_loc.start_offset());
787        assert_eq!(3, recv_loc.end_offset());
788
789        let joined = recv_loc.join(&location).unwrap();
790        assert_eq!(joined.as_slice(), b"111 + 222");
791
792        let not_joined = location.join(&recv_loc);
793        assert!(not_joined.is_none());
794
795        {
796            let result = parse(source.as_ref());
797            let node = result.node();
798            let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
799            let node = node.as_call_node().unwrap().receiver().unwrap();
800            let plus = node.as_call_node().unwrap();
801            let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
802
803            let location = node.as_integer_node().unwrap().location();
804            let not_joined = recv_loc.join(&location);
805            assert!(not_joined.is_none());
806
807            let not_joined = location.join(&recv_loc);
808            assert!(not_joined.is_none());
809        }
810
811        let location = node.location();
812        let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
813
814        assert_eq!(slice, "222");
815
816        let slice = std::str::from_utf8(location.as_slice()).unwrap();
817
818        assert_eq!(slice, "222");
819    }
820
821    #[test]
822    fn visitor_test() {
823        use super::{visit_interpolated_regular_expression_node, visit_regular_expression_node, InterpolatedRegularExpressionNode, RegularExpressionNode, Visit};
824
825        struct RegularExpressionVisitor {
826            count: usize,
827        }
828
829        impl Visit<'_> for RegularExpressionVisitor {
830            fn visit_interpolated_regular_expression_node(&mut self, node: &InterpolatedRegularExpressionNode<'_>) {
831                self.count += 1;
832                visit_interpolated_regular_expression_node(self, node);
833            }
834
835            fn visit_regular_expression_node(&mut self, node: &RegularExpressionNode<'_>) {
836                self.count += 1;
837                visit_regular_expression_node(self, node);
838            }
839        }
840
841        let source = "# comment 1\n# comment 2\nmodule Foo; class Bar; /abc #{/def/}/; end; end";
842        let result = parse(source.as_ref());
843
844        let mut visitor = RegularExpressionVisitor { count: 0 };
845        visitor.visit(&result.node());
846
847        assert_eq!(visitor.count, 2);
848    }
849
850    #[test]
851    fn node_upcast_test() {
852        use super::Node;
853
854        let source = "module Foo; end";
855        let result = parse(source.as_ref());
856
857        let node = result.node();
858        let upcast_node = node.as_program_node().unwrap().as_node();
859        assert!(matches!(upcast_node, Node::ProgramNode { .. }));
860
861        let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
862        let upcast_node = node.as_module_node().unwrap().as_node();
863        assert!(matches!(upcast_node, Node::ModuleNode { .. }));
864    }
865
866    #[test]
867    fn constant_id_test() {
868        let source = "module Foo; x = 1; y = 2; end";
869        let result = parse(source.as_ref());
870
871        let node = result.node();
872        assert_eq!(node.as_program_node().unwrap().statements().body().len(), 1);
873        assert!(!node.as_program_node().unwrap().statements().body().is_empty());
874        let module = node.as_program_node().and_then(|pn| pn.statements().body().first()).unwrap();
875        let module = module.as_module_node().unwrap();
876
877        assert_eq!(module.locals().len(), 2);
878        assert!(!module.locals().is_empty());
879
880        assert_eq!(module.locals().first().unwrap().as_slice(), b"x");
881        assert_eq!(module.locals().last().unwrap().as_slice(), b"y");
882
883        let source = "module Foo; end";
884        let result = parse(source.as_ref());
885
886        let node = result.node();
887        assert_eq!(node.as_program_node().unwrap().statements().body().len(), 1);
888        assert!(!node.as_program_node().unwrap().statements().body().is_empty());
889        let module = node.as_program_node().and_then(|pn| pn.statements().body().first()).unwrap();
890        let module = module.as_module_node().unwrap();
891
892        assert_eq!(module.locals().len(), 0);
893        assert!(module.locals().is_empty());
894    }
895
896    #[test]
897    fn optional_loc_test() {
898        let source = r"
899module Example
900  x = call_func(3, 4)
901  y = x.call_func 5, 6
902end
903";
904        let result = parse(source.as_ref());
905
906        let node = result.node();
907        let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
908        let module = module.as_module_node().unwrap();
909        let body = module.body();
910        let writes = body.iter().next().unwrap().as_statements_node().unwrap().body().iter().collect::<Vec<_>>();
911        assert_eq!(writes.len(), 2);
912
913        let asgn = &writes[0];
914        let call = asgn.as_local_variable_write_node().unwrap().value();
915        let call = call.as_call_node().unwrap();
916
917        let call_operator_loc = call.call_operator_loc();
918        assert!(call_operator_loc.is_none());
919        let closing_loc = call.closing_loc();
920        assert!(closing_loc.is_some());
921
922        let asgn = &writes[1];
923        let call = asgn.as_local_variable_write_node().unwrap().value();
924        let call = call.as_call_node().unwrap();
925
926        let call_operator_loc = call.call_operator_loc();
927        assert!(call_operator_loc.is_some());
928        let closing_loc = call.closing_loc();
929        assert!(closing_loc.is_none());
930    }
931
932    #[test]
933    fn frozen_strings_test() {
934        let source = r#"
935# frozen_string_literal: true
936"foo"
937"#;
938        let result = parse(source.as_ref());
939        assert!(result.frozen_string_literals());
940
941        let source = "3";
942        let result = parse(source.as_ref());
943        assert!(!result.frozen_string_literals());
944    }
945
946    #[test]
947    fn string_flags_test() {
948        let source = r#"
949# frozen_string_literal: true
950"foo"
951"#;
952        let result = parse(source.as_ref());
953
954        let node = result.node();
955        let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
956        let string = string.as_string_node().unwrap();
957        assert!(string.is_frozen());
958
959        let source = r#"
960"foo"
961"#;
962        let result = parse(source.as_ref());
963
964        let node = result.node();
965        let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
966        let string = string.as_string_node().unwrap();
967        assert!(!string.is_frozen());
968    }
969
970    #[test]
971    fn call_flags_test() {
972        let source = r"
973x
974";
975        let result = parse(source.as_ref());
976
977        let node = result.node();
978        let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
979        let call = call.as_call_node().unwrap();
980        assert!(call.is_variable_call());
981
982        let source = r"
983x&.foo
984";
985        let result = parse(source.as_ref());
986
987        let node = result.node();
988        let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
989        let call = call.as_call_node().unwrap();
990        assert!(call.is_safe_navigation());
991    }
992
993    #[test]
994    fn integer_flags_test() {
995        let source = r"
9960b1
997";
998        let result = parse(source.as_ref());
999
1000        let node = result.node();
1001        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1002        let i = i.as_integer_node().unwrap();
1003        assert!(i.is_binary());
1004        assert!(!i.is_decimal());
1005        assert!(!i.is_octal());
1006        assert!(!i.is_hexadecimal());
1007
1008        let source = r"
10091
1010";
1011        let result = parse(source.as_ref());
1012
1013        let node = result.node();
1014        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1015        let i = i.as_integer_node().unwrap();
1016        assert!(!i.is_binary());
1017        assert!(i.is_decimal());
1018        assert!(!i.is_octal());
1019        assert!(!i.is_hexadecimal());
1020
1021        let source = r"
10220o1
1023";
1024        let result = parse(source.as_ref());
1025
1026        let node = result.node();
1027        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1028        let i = i.as_integer_node().unwrap();
1029        assert!(!i.is_binary());
1030        assert!(!i.is_decimal());
1031        assert!(i.is_octal());
1032        assert!(!i.is_hexadecimal());
1033
1034        let source = r"
10350x1
1036";
1037        let result = parse(source.as_ref());
1038
1039        let node = result.node();
1040        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1041        let i = i.as_integer_node().unwrap();
1042        assert!(!i.is_binary());
1043        assert!(!i.is_decimal());
1044        assert!(!i.is_octal());
1045        assert!(i.is_hexadecimal());
1046    }
1047
1048    #[test]
1049    fn range_flags_test() {
1050        let source = r"
10510..1
1052";
1053        let result = parse(source.as_ref());
1054
1055        let node = result.node();
1056        let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1057        let range = range.as_range_node().unwrap();
1058        assert!(!range.is_exclude_end());
1059
1060        let source = r"
10610...1
1062";
1063        let result = parse(source.as_ref());
1064
1065        let node = result.node();
1066        let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1067        let range = range.as_range_node().unwrap();
1068        assert!(range.is_exclude_end());
1069    }
1070
1071    #[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
1072    #[test]
1073    fn regex_flags_test() {
1074        let source = r"
1075/a/i
1076";
1077        let result = parse(source.as_ref());
1078
1079        let node = result.node();
1080        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1081        let regex = regex.as_regular_expression_node().unwrap();
1082        assert!(regex.is_ignore_case());
1083        assert!(!regex.is_extended());
1084        assert!(!regex.is_multi_line());
1085        assert!(!regex.is_euc_jp());
1086        assert!(!regex.is_ascii_8bit());
1087        assert!(!regex.is_windows_31j());
1088        assert!(!regex.is_utf_8());
1089        assert!(!regex.is_once());
1090
1091        let source = r"
1092/a/x
1093";
1094        let result = parse(source.as_ref());
1095
1096        let node = result.node();
1097        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1098        let regex = regex.as_regular_expression_node().unwrap();
1099        assert!(!regex.is_ignore_case());
1100        assert!(regex.is_extended());
1101        assert!(!regex.is_multi_line());
1102        assert!(!regex.is_euc_jp());
1103        assert!(!regex.is_ascii_8bit());
1104        assert!(!regex.is_windows_31j());
1105        assert!(!regex.is_utf_8());
1106        assert!(!regex.is_once());
1107
1108        let source = r"
1109/a/m
1110";
1111        let result = parse(source.as_ref());
1112
1113        let node = result.node();
1114        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1115        let regex = regex.as_regular_expression_node().unwrap();
1116        assert!(!regex.is_ignore_case());
1117        assert!(!regex.is_extended());
1118        assert!(regex.is_multi_line());
1119        assert!(!regex.is_euc_jp());
1120        assert!(!regex.is_ascii_8bit());
1121        assert!(!regex.is_windows_31j());
1122        assert!(!regex.is_utf_8());
1123        assert!(!regex.is_once());
1124
1125        let source = r"
1126/a/e
1127";
1128        let result = parse(source.as_ref());
1129
1130        let node = result.node();
1131        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1132        let regex = regex.as_regular_expression_node().unwrap();
1133        assert!(!regex.is_ignore_case());
1134        assert!(!regex.is_extended());
1135        assert!(!regex.is_multi_line());
1136        assert!(regex.is_euc_jp());
1137        assert!(!regex.is_ascii_8bit());
1138        assert!(!regex.is_windows_31j());
1139        assert!(!regex.is_utf_8());
1140        assert!(!regex.is_once());
1141
1142        let source = r"
1143/a/n
1144";
1145        let result = parse(source.as_ref());
1146
1147        let node = result.node();
1148        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1149        let regex = regex.as_regular_expression_node().unwrap();
1150        assert!(!regex.is_ignore_case());
1151        assert!(!regex.is_extended());
1152        assert!(!regex.is_multi_line());
1153        assert!(!regex.is_euc_jp());
1154        assert!(regex.is_ascii_8bit());
1155        assert!(!regex.is_windows_31j());
1156        assert!(!regex.is_utf_8());
1157        assert!(!regex.is_once());
1158
1159        let source = r"
1160/a/s
1161";
1162        let result = parse(source.as_ref());
1163
1164        let node = result.node();
1165        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1166        let regex = regex.as_regular_expression_node().unwrap();
1167        assert!(!regex.is_ignore_case());
1168        assert!(!regex.is_extended());
1169        assert!(!regex.is_multi_line());
1170        assert!(!regex.is_euc_jp());
1171        assert!(!regex.is_ascii_8bit());
1172        assert!(regex.is_windows_31j());
1173        assert!(!regex.is_utf_8());
1174        assert!(!regex.is_once());
1175
1176        let source = r"
1177/a/u
1178";
1179        let result = parse(source.as_ref());
1180
1181        let node = result.node();
1182        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1183        let regex = regex.as_regular_expression_node().unwrap();
1184        assert!(!regex.is_ignore_case());
1185        assert!(!regex.is_extended());
1186        assert!(!regex.is_multi_line());
1187        assert!(!regex.is_euc_jp());
1188        assert!(!regex.is_ascii_8bit());
1189        assert!(!regex.is_windows_31j());
1190        assert!(regex.is_utf_8());
1191        assert!(!regex.is_once());
1192
1193        let source = r"
1194/a/o
1195";
1196        let result = parse(source.as_ref());
1197
1198        let node = result.node();
1199        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1200        let regex = regex.as_regular_expression_node().unwrap();
1201        assert!(!regex.is_ignore_case());
1202        assert!(!regex.is_extended());
1203        assert!(!regex.is_multi_line());
1204        assert!(!regex.is_euc_jp());
1205        assert!(!regex.is_ascii_8bit());
1206        assert!(!regex.is_windows_31j());
1207        assert!(!regex.is_utf_8());
1208        assert!(regex.is_once());
1209    }
1210
1211    #[test]
1212    fn visitor_traversal_test() {
1213        use crate::{Node, Visit};
1214
1215        #[derive(Default)]
1216        struct NodeCounts {
1217            pre_parent: usize,
1218            post_parent: usize,
1219            pre_leaf: usize,
1220            post_leaf: usize,
1221        }
1222
1223        #[derive(Default)]
1224        struct CountingVisitor {
1225            counts: NodeCounts,
1226        }
1227
1228        impl Visit<'_> for CountingVisitor {
1229            fn visit_branch_node_enter(&mut self, _node: Node<'_>) {
1230                self.counts.pre_parent += 1;
1231            }
1232
1233            fn visit_branch_node_leave(&mut self) {
1234                self.counts.post_parent += 1;
1235            }
1236
1237            fn visit_leaf_node_enter(&mut self, _node: Node<'_>) {
1238                self.counts.pre_leaf += 1;
1239            }
1240
1241            fn visit_leaf_node_leave(&mut self) {
1242                self.counts.post_leaf += 1;
1243            }
1244        }
1245
1246        let source = r"
1247module Example
1248  x = call_func(3, 4)
1249  y = x.call_func 5, 6
1250end
1251";
1252        let result = parse(source.as_ref());
1253        let node = result.node();
1254        let mut visitor = CountingVisitor::default();
1255        visitor.visit(&node);
1256
1257        assert_eq!(7, visitor.counts.pre_parent);
1258        assert_eq!(7, visitor.counts.post_parent);
1259        assert_eq!(6, visitor.counts.pre_leaf);
1260        assert_eq!(6, visitor.counts.post_leaf);
1261    }
1262
1263    #[test]
1264    fn visitor_lifetime_test() {
1265        use crate::{Node, Visit};
1266
1267        #[derive(Default)]
1268        struct StackingNodeVisitor<'a> {
1269            stack: Vec<Node<'a>>,
1270            max_depth: usize,
1271        }
1272
1273        impl<'pr> Visit<'pr> for StackingNodeVisitor<'pr> {
1274            fn visit_branch_node_enter(&mut self, node: Node<'pr>) {
1275                self.stack.push(node);
1276            }
1277
1278            fn visit_branch_node_leave(&mut self) {
1279                self.stack.pop();
1280            }
1281
1282            fn visit_leaf_node_leave(&mut self) {
1283                self.max_depth = self.max_depth.max(self.stack.len());
1284            }
1285        }
1286
1287        let source = r"
1288module Example
1289  x = call_func(3, 4)
1290  y = x.call_func 5, 6
1291end
1292";
1293        let result = parse(source.as_ref());
1294        let node = result.node();
1295        let mut visitor = StackingNodeVisitor::default();
1296        visitor.visit(&node);
1297
1298        assert_eq!(0, visitor.stack.len());
1299        assert_eq!(5, visitor.max_depth);
1300    }
1301
1302    #[test]
1303    fn integer_value_test() {
1304        let result = parse("0xA".as_ref());
1305        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1306        let value: i32 = integer.try_into().unwrap();
1307
1308        assert_eq!(value, 10);
1309    }
1310
1311    #[test]
1312    fn integer_small_value_to_u32_digits_test() {
1313        let result = parse("0xA".as_ref());
1314        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1315        let (negative, digits) = integer.to_u32_digits();
1316
1317        assert!(!negative);
1318        assert_eq!(digits, &[10]);
1319    }
1320
1321    #[test]
1322    fn integer_large_value_to_u32_digits_test() {
1323        let result = parse("0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".as_ref());
1324        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1325        let (negative, digits) = integer.to_u32_digits();
1326
1327        assert!(!negative);
1328        assert_eq!(digits, &[4_294_967_295, 4_294_967_295, 4_294_967_295, 2_147_483_647]);
1329    }
1330
1331    #[test]
1332    fn float_value_test() {
1333        let result = parse("1.0".as_ref());
1334        let value: f64 = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_float_node().unwrap().value();
1335
1336        assert!((value - 1.0).abs() < f64::EPSILON);
1337    }
1338
1339    #[test]
1340    fn regex_value_test() {
1341        let result = parse(b"//");
1342        let node = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_regular_expression_node().unwrap();
1343        assert_eq!(node.unescaped(), b"");
1344    }
1345
1346    #[test]
1347    fn node_field_lifetime_test() {
1348        // The code below wouldn't typecheck prior to https://github.com/ruby/prism/pull/2519,
1349        // but we need to stop clippy from complaining about it.
1350        #![allow(clippy::needless_pass_by_value)]
1351
1352        use crate::Node;
1353
1354        #[derive(Default)]
1355        struct Extract<'pr> {
1356            scopes: Vec<crate::ConstantId<'pr>>,
1357        }
1358
1359        impl<'pr> Extract<'pr> {
1360            fn push_scope(&mut self, path: Node<'pr>) {
1361                if let Some(cread) = path.as_constant_read_node() {
1362                    self.scopes.push(cread.name());
1363                } else if let Some(cpath) = path.as_constant_path_node() {
1364                    if let Some(parent) = cpath.parent() {
1365                        self.push_scope(parent);
1366                    }
1367                    self.scopes.push(cpath.name().unwrap());
1368                } else {
1369                    panic!("Wrong node kind!");
1370                }
1371            }
1372        }
1373
1374        let source = "Some::Random::Constant";
1375        let result = parse(source.as_ref());
1376        let node = result.node();
1377        let mut extractor = Extract::default();
1378        extractor.push_scope(node.as_program_node().unwrap().statements().body().iter().next().unwrap());
1379        assert_eq!(3, extractor.scopes.len());
1380    }
1381
1382    #[test]
1383    fn malformed_shebang() {
1384        let source = "#!\x00";
1385        let result = parse(source.as_ref());
1386        assert!(result.errors().next().is_none());
1387        assert!(result.warnings().next().is_none());
1388    }
1389}