ruby_prism/
lib.rs

1//! # ruby-prism
2//!
3//! Rustified version of Ruby's prism parser.
4//!
5#![warn(clippy::all, clippy::nursery, clippy::pedantic, future_incompatible, missing_docs, nonstandard_style, rust_2018_idioms, trivial_casts, trivial_numeric_casts, unreachable_pub, unused_qualifications)]
6
7// Most of the code in this file is generated, so sometimes it generates code
8// that doesn't follow the clippy rules. We don't want to see those warnings.
9#[allow(clippy::too_many_lines, clippy::use_self)]
10mod bindings {
11    // In `build.rs`, we generate bindings based on the config.yml file. Here is
12    // where we pull in those bindings and make them part of our library.
13    include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
14}
15
16use std::ffi::{c_char, CStr};
17use std::marker::PhantomData;
18use std::mem::MaybeUninit;
19use std::ptr::NonNull;
20
21pub use self::bindings::*;
22use ruby_prism_sys::{pm_comment_t, pm_comment_type_t, pm_constant_id_list_t, pm_constant_id_t, pm_diagnostic_t, pm_integer_t, pm_location_t, pm_magic_comment_t, pm_node_destroy, pm_node_list, pm_node_t, pm_parse, pm_parser_free, pm_parser_init, pm_parser_t};
23
24/// A range in the source file.
25pub struct Location<'pr> {
26    parser: NonNull<pm_parser_t>,
27    pub(crate) start: *const u8,
28    pub(crate) end: *const u8,
29    marker: PhantomData<&'pr [u8]>,
30}
31
32impl<'pr> Location<'pr> {
33    /// Returns a byte slice for the range.
34    /// # Panics
35    /// Panics if the end offset is not greater than the start offset.
36    #[must_use]
37    pub fn as_slice(&self) -> &'pr [u8] {
38        unsafe {
39            let len = usize::try_from(self.end.offset_from(self.start)).expect("end should point to memory after start");
40            std::slice::from_raw_parts(self.start, len)
41        }
42    }
43
44    /// Return a Location from the given `pm_location_t`.
45    #[must_use]
46    pub(crate) const fn new(parser: NonNull<pm_parser_t>, loc: &'pr pm_location_t) -> Self {
47        Location {
48            parser,
49            start: loc.start,
50            end: loc.end,
51            marker: PhantomData,
52        }
53    }
54
55    /// Return a Location starting at self and ending at the end of other.
56    /// Returns None if both locations did not originate from the same parser,
57    /// or if self starts after other.
58    #[must_use]
59    pub fn join(&self, other: &Self) -> Option<Self> {
60        if self.parser != other.parser || self.start > other.start {
61            None
62        } else {
63            Some(Location {
64                parser: self.parser,
65                start: self.start,
66                end: other.end,
67                marker: PhantomData,
68            })
69        }
70    }
71
72    /// Return the start offset from the beginning of the parsed source.
73    /// # Panics
74    /// Panics if the start offset is not greater than the parser's start.
75    #[must_use]
76    pub fn start_offset(&self) -> usize {
77        unsafe {
78            let parser_start = (*self.parser.as_ptr()).start;
79            usize::try_from(self.start.offset_from(parser_start)).expect("start should point to memory after the parser's start")
80        }
81    }
82
83    /// Return the end offset from the beginning of the parsed source.
84    /// # Panics
85    /// Panics if the end offset is not greater than the parser's start.
86    #[must_use]
87    pub fn end_offset(&self) -> usize {
88        unsafe {
89            let parser_start = (*self.parser.as_ptr()).start;
90            usize::try_from(self.end.offset_from(parser_start)).expect("end should point to memory after the parser's start")
91        }
92    }
93}
94
95impl std::fmt::Debug for Location<'_> {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        let slice: &[u8] = self.as_slice();
98
99        let mut visible = String::new();
100        visible.push('"');
101
102        for &byte in slice {
103            let part: Vec<u8> = std::ascii::escape_default(byte).collect();
104            visible.push_str(std::str::from_utf8(&part).unwrap());
105        }
106
107        visible.push('"');
108        write!(f, "{visible}")
109    }
110}
111
112/// An iterator over the nodes in a list.
113pub struct NodeListIter<'pr> {
114    parser: NonNull<pm_parser_t>,
115    pointer: NonNull<pm_node_list>,
116    index: usize,
117    marker: PhantomData<&'pr mut pm_node_list>,
118}
119
120impl<'pr> Iterator for NodeListIter<'pr> {
121    type Item = Node<'pr>;
122
123    fn next(&mut self) -> Option<Self::Item> {
124        if self.index >= unsafe { self.pointer.as_ref().size } {
125            None
126        } else {
127            let node: *mut pm_node_t = unsafe { *(self.pointer.as_ref().nodes.add(self.index)) };
128            self.index += 1;
129            Some(Node::new(self.parser, node))
130        }
131    }
132}
133
134/// A list of nodes.
135pub struct NodeList<'pr> {
136    parser: NonNull<pm_parser_t>,
137    pointer: NonNull<pm_node_list>,
138    marker: PhantomData<&'pr mut pm_node_list>,
139}
140
141impl<'pr> NodeList<'pr> {
142    /// Returns an iterator over the nodes.
143    #[must_use]
144    pub const fn iter(&self) -> NodeListIter<'pr> {
145        NodeListIter {
146            parser: self.parser,
147            pointer: self.pointer,
148            index: 0,
149            marker: PhantomData,
150        }
151    }
152}
153
154impl<'pr> IntoIterator for &NodeList<'pr> {
155    type Item = Node<'pr>;
156    type IntoIter = NodeListIter<'pr>;
157    fn into_iter(self) -> Self::IntoIter {
158        self.iter()
159    }
160}
161
162impl std::fmt::Debug for NodeList<'_> {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        write!(f, "{:?}", self.iter().collect::<Vec<_>>())
165    }
166}
167
168/// A handle for a constant ID.
169pub struct ConstantId<'pr> {
170    parser: NonNull<pm_parser_t>,
171    id: pm_constant_id_t,
172    marker: PhantomData<&'pr mut pm_constant_id_t>,
173}
174
175impl<'pr> ConstantId<'pr> {
176    const fn new(parser: NonNull<pm_parser_t>, id: pm_constant_id_t) -> Self {
177        ConstantId { parser, id, marker: PhantomData }
178    }
179
180    /// Returns a byte slice for the constant ID.
181    ///
182    /// # Panics
183    ///
184    /// Panics if the constant ID is not found in the constant pool.
185    #[must_use]
186    pub fn as_slice(&self) -> &'pr [u8] {
187        unsafe {
188            let pool = &(*self.parser.as_ptr()).constant_pool;
189            let constant = &(*pool.constants.add((self.id - 1).try_into().unwrap()));
190            std::slice::from_raw_parts(constant.start, constant.length)
191        }
192    }
193}
194
195impl std::fmt::Debug for ConstantId<'_> {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        write!(f, "{:?}", self.id)
198    }
199}
200
201/// An iterator over the constants in a list.
202pub struct ConstantListIter<'pr> {
203    parser: NonNull<pm_parser_t>,
204    pointer: NonNull<pm_constant_id_list_t>,
205    index: usize,
206    marker: PhantomData<&'pr mut pm_constant_id_list_t>,
207}
208
209impl<'pr> Iterator for ConstantListIter<'pr> {
210    type Item = ConstantId<'pr>;
211
212    fn next(&mut self) -> Option<Self::Item> {
213        if self.index >= unsafe { self.pointer.as_ref().size } {
214            None
215        } else {
216            let constant_id: pm_constant_id_t = unsafe { *(self.pointer.as_ref().ids.add(self.index)) };
217            self.index += 1;
218            Some(ConstantId::new(self.parser, constant_id))
219        }
220    }
221}
222
223/// A list of constants.
224pub struct ConstantList<'pr> {
225    /// The raw pointer to the parser where this list came from.
226    parser: NonNull<pm_parser_t>,
227
228    /// The raw pointer to the list allocated by prism.
229    pointer: NonNull<pm_constant_id_list_t>,
230
231    /// The marker to indicate the lifetime of the pointer.
232    marker: PhantomData<&'pr mut pm_constant_id_list_t>,
233}
234
235impl<'pr> ConstantList<'pr> {
236    /// Returns an iterator over the constants in the list.
237    #[must_use]
238    pub const fn iter(&self) -> ConstantListIter<'pr> {
239        ConstantListIter {
240            parser: self.parser,
241            pointer: self.pointer,
242            index: 0,
243            marker: PhantomData,
244        }
245    }
246}
247
248impl<'pr> IntoIterator for &ConstantList<'pr> {
249    type Item = ConstantId<'pr>;
250    type IntoIter = ConstantListIter<'pr>;
251    fn into_iter(self) -> Self::IntoIter {
252        self.iter()
253    }
254}
255
256impl std::fmt::Debug for ConstantList<'_> {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        write!(f, "{:?}", self.iter().collect::<Vec<_>>())
259    }
260}
261
262/// A handle for an arbitarily-sized integer.
263pub struct Integer<'pr> {
264    /// The raw pointer to the integer allocated by prism.
265    pointer: *const pm_integer_t,
266
267    /// The marker to indicate the lifetime of the pointer.
268    marker: PhantomData<&'pr mut pm_constant_id_t>,
269}
270
271impl Integer<'_> {
272    const fn new(pointer: *const pm_integer_t) -> Self {
273        Integer { pointer, marker: PhantomData }
274    }
275
276    /// Returns the sign and the u32 digits representation of the integer,
277    /// ordered least significant digit first.
278    #[must_use]
279    pub const fn to_u32_digits(&self) -> (bool, &[u32]) {
280        let negative = unsafe { (*self.pointer).negative };
281        let length = unsafe { (*self.pointer).length };
282        let values = unsafe { (*self.pointer).values };
283
284        if values.is_null() {
285            let value_ptr = unsafe { std::ptr::addr_of!((*self.pointer).value) };
286            let slice = unsafe { std::slice::from_raw_parts(value_ptr, 1) };
287            (negative, slice)
288        } else {
289            let slice = unsafe { std::slice::from_raw_parts(values, length) };
290            (negative, slice)
291        }
292    }
293}
294
295impl std::fmt::Debug for Integer<'_> {
296    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        write!(f, "{:?}", self.pointer)
298    }
299}
300
301impl TryInto<i32> for Integer<'_> {
302    type Error = ();
303
304    fn try_into(self) -> Result<i32, Self::Error> {
305        let negative = unsafe { (*self.pointer).negative };
306        let length = unsafe { (*self.pointer).length };
307
308        if length == 0 {
309            i32::try_from(unsafe { (*self.pointer).value }).map_or(Err(()), |value| if negative { Ok(-value) } else { Ok(value) })
310        } else {
311            Err(())
312        }
313    }
314}
315
316/// A diagnostic message that came back from the parser.
317#[derive(Debug)]
318pub struct Diagnostic<'pr> {
319    diag: NonNull<pm_diagnostic_t>,
320    parser: NonNull<pm_parser_t>,
321    marker: PhantomData<&'pr pm_diagnostic_t>,
322}
323
324impl<'pr> Diagnostic<'pr> {
325    /// Returns the message associated with the diagnostic.
326    ///
327    /// # Panics
328    ///
329    /// Panics if the message is not able to be converted into a `CStr`.
330    ///
331    #[must_use]
332    pub fn message(&self) -> &str {
333        unsafe {
334            let message: *mut c_char = self.diag.as_ref().message.cast_mut();
335            CStr::from_ptr(message).to_str().expect("prism allows only UTF-8 for diagnostics.")
336        }
337    }
338
339    /// The location of the diagnostic in the source.
340    #[must_use]
341    pub const fn location(&self) -> Location<'pr> {
342        Location::new(self.parser, unsafe { &self.diag.as_ref().location })
343    }
344}
345
346/// A comment that was found during parsing.
347#[derive(Debug)]
348pub struct Comment<'pr> {
349    content: NonNull<pm_comment_t>,
350    parser: NonNull<pm_parser_t>,
351    marker: PhantomData<&'pr pm_comment_t>,
352}
353
354/// The type of the comment
355#[derive(Debug, Clone, Copy, PartialEq, Eq)]
356pub enum CommentType {
357    /// `InlineComment` corresponds to comments that start with #.
358    InlineComment,
359    /// `EmbDocComment` corresponds to comments that are surrounded by =begin and =end.
360    EmbDocComment,
361}
362
363impl<'pr> Comment<'pr> {
364    /// Returns the text of the comment.
365    ///
366    /// # Panics
367    /// Panics if the end offset is not greater than the start offset.
368    #[must_use]
369    pub fn text(&self) -> &[u8] {
370        self.location().as_slice()
371    }
372
373    /// Returns the type of the comment.
374    #[must_use]
375    pub fn type_(&self) -> CommentType {
376        let type_ = unsafe { self.content.as_ref().type_ };
377        if type_ == pm_comment_type_t::PM_COMMENT_EMBDOC {
378            CommentType::EmbDocComment
379        } else {
380            CommentType::InlineComment
381        }
382    }
383
384    /// The location of the comment in the source.
385    #[must_use]
386    pub const fn location(&self) -> Location<'pr> {
387        Location::new(self.parser, unsafe { &self.content.as_ref().location })
388    }
389}
390
391/// A magic comment that was found during parsing.
392#[derive(Debug)]
393pub struct MagicComment<'pr> {
394    comment: NonNull<pm_magic_comment_t>,
395    marker: PhantomData<&'pr pm_magic_comment_t>,
396}
397
398impl MagicComment<'_> {
399    /// Returns the text of the comment's key.
400    #[must_use]
401    pub const fn key(&self) -> &[u8] {
402        unsafe {
403            let start = self.comment.as_ref().key_start;
404            let len = self.comment.as_ref().key_length as usize;
405            std::slice::from_raw_parts(start, len)
406        }
407    }
408
409    /// Returns the text of the comment's value.
410    #[must_use]
411    pub const fn value(&self) -> &[u8] {
412        unsafe {
413            let start = self.comment.as_ref().value_start;
414            let len = self.comment.as_ref().value_length as usize;
415            std::slice::from_raw_parts(start, len)
416        }
417    }
418}
419
420/// A struct created by the `errors` or `warnings` methods on `ParseResult`. It
421/// can be used to iterate over the diagnostics in the parse result.
422pub struct Diagnostics<'pr> {
423    diagnostic: *mut pm_diagnostic_t,
424    parser: NonNull<pm_parser_t>,
425    marker: PhantomData<&'pr pm_diagnostic_t>,
426}
427
428impl<'pr> Iterator for Diagnostics<'pr> {
429    type Item = Diagnostic<'pr>;
430
431    fn next(&mut self) -> Option<Self::Item> {
432        if let Some(diagnostic) = NonNull::new(self.diagnostic) {
433            let current = Diagnostic {
434                diag: diagnostic,
435                parser: self.parser,
436                marker: PhantomData,
437            };
438            self.diagnostic = unsafe { diagnostic.as_ref().node.next.cast::<pm_diagnostic_t>() };
439            Some(current)
440        } else {
441            None
442        }
443    }
444}
445
446/// A struct created by the `comments` method on `ParseResult`. It can be used
447/// to iterate over the comments in the parse result.
448pub struct Comments<'pr> {
449    comment: *mut pm_comment_t,
450    parser: NonNull<pm_parser_t>,
451    marker: PhantomData<&'pr pm_comment_t>,
452}
453
454impl<'pr> Iterator for Comments<'pr> {
455    type Item = Comment<'pr>;
456
457    fn next(&mut self) -> Option<Self::Item> {
458        if let Some(comment) = NonNull::new(self.comment) {
459            let current = Comment {
460                content: comment,
461                parser: self.parser,
462                marker: PhantomData,
463            };
464            self.comment = unsafe { comment.as_ref().node.next.cast::<pm_comment_t>() };
465            Some(current)
466        } else {
467            None
468        }
469    }
470}
471
472/// A struct created by the `magic_comments` method on `ParseResult`. It can be used
473/// to iterate over the magic comments in the parse result.
474pub struct MagicComments<'pr> {
475    comment: *mut pm_magic_comment_t,
476    marker: PhantomData<&'pr pm_magic_comment_t>,
477}
478
479impl<'pr> Iterator for MagicComments<'pr> {
480    type Item = MagicComment<'pr>;
481
482    fn next(&mut self) -> Option<Self::Item> {
483        if let Some(comment) = NonNull::new(self.comment) {
484            let current = MagicComment { comment, marker: PhantomData };
485            self.comment = unsafe { comment.as_ref().node.next.cast::<pm_magic_comment_t>() };
486            Some(current)
487        } else {
488            None
489        }
490    }
491}
492
493/// The result of parsing a source string.
494#[derive(Debug)]
495pub struct ParseResult<'pr> {
496    source: &'pr [u8],
497    parser: NonNull<pm_parser_t>,
498    node: NonNull<pm_node_t>,
499}
500
501impl<'pr> ParseResult<'pr> {
502    /// Returns the source string that was parsed.
503    #[must_use]
504    pub const fn source(&self) -> &'pr [u8] {
505        self.source
506    }
507
508    /// Returns whether we found a `frozen_string_literal` magic comment with a true value.
509    #[must_use]
510    pub fn frozen_string_literals(&self) -> bool {
511        unsafe { (*self.parser.as_ptr()).frozen_string_literal == 1 }
512    }
513
514    /// Returns a slice of the source string that was parsed using the given
515    /// location range.
516    ///
517    /// # Panics
518    /// Panics if start offset or end offset are not valid offsets from the root.
519    #[must_use]
520    pub fn as_slice(&self, location: &Location<'pr>) -> &'pr [u8] {
521        let root = self.source.as_ptr();
522
523        let start = usize::try_from(unsafe { location.start.offset_from(root) }).expect("start should point to memory after root");
524        let end = usize::try_from(unsafe { location.end.offset_from(root) }).expect("end should point to memory after root");
525
526        &self.source[start..end]
527    }
528
529    /// Returns an iterator that can be used to iterate over the errors in the
530    /// parse result.
531    #[must_use]
532    pub fn errors(&self) -> Diagnostics<'_> {
533        unsafe {
534            let list = &mut (*self.parser.as_ptr()).error_list;
535            Diagnostics {
536                diagnostic: list.head.cast::<pm_diagnostic_t>(),
537                parser: self.parser,
538                marker: PhantomData,
539            }
540        }
541    }
542
543    /// Returns an iterator that can be used to iterate over the warnings in the
544    /// parse result.
545    #[must_use]
546    pub fn warnings(&self) -> Diagnostics<'_> {
547        unsafe {
548            let list = &mut (*self.parser.as_ptr()).warning_list;
549            Diagnostics {
550                diagnostic: list.head.cast::<pm_diagnostic_t>(),
551                parser: self.parser,
552                marker: PhantomData,
553            }
554        }
555    }
556
557    /// Returns an iterator that can be used to iterate over the comments in the
558    /// parse result.
559    #[must_use]
560    pub fn comments(&self) -> Comments<'_> {
561        unsafe {
562            let list = &mut (*self.parser.as_ptr()).comment_list;
563            Comments {
564                comment: list.head.cast::<pm_comment_t>(),
565                parser: self.parser,
566                marker: PhantomData,
567            }
568        }
569    }
570
571    /// Returns an iterator that can be used to iterate over the magic comments in the
572    /// parse result.
573    #[must_use]
574    pub fn magic_comments(&self) -> MagicComments<'_> {
575        unsafe {
576            let list = &mut (*self.parser.as_ptr()).magic_comment_list;
577            MagicComments {
578                comment: list.head.cast::<pm_magic_comment_t>(),
579                marker: PhantomData,
580            }
581        }
582    }
583
584    /// Returns an optional location of the __END__ marker and the rest of the content of the file.
585    #[must_use]
586    pub fn data_loc(&self) -> Option<Location<'_>> {
587        let location = unsafe { &(*self.parser.as_ptr()).data_loc };
588        if location.start.is_null() {
589            None
590        } else {
591            Some(Location::new(self.parser, location))
592        }
593    }
594
595    /// Returns the root node of the parse result.
596    #[must_use]
597    pub fn node(&self) -> Node<'_> {
598        Node::new(self.parser, self.node.as_ptr())
599    }
600}
601
602impl Drop for ParseResult<'_> {
603    fn drop(&mut self) {
604        unsafe {
605            pm_node_destroy(self.parser.as_ptr(), self.node.as_ptr());
606            pm_parser_free(self.parser.as_ptr());
607            drop(Box::from_raw(self.parser.as_ptr()));
608        }
609    }
610}
611
612/// Parses the given source string and returns a parse result.
613///
614/// # Panics
615///
616/// Panics if the parser fails to initialize.
617///
618#[must_use]
619pub fn parse(source: &[u8]) -> ParseResult<'_> {
620    unsafe {
621        let uninit = Box::new(MaybeUninit::<pm_parser_t>::uninit());
622        let uninit = Box::into_raw(uninit);
623
624        pm_parser_init((*uninit).as_mut_ptr(), source.as_ptr(), source.len(), std::ptr::null());
625
626        let parser = (*uninit).assume_init_mut();
627        let parser = NonNull::new_unchecked(parser);
628
629        let node = pm_parse(parser.as_ptr());
630        let node = NonNull::new_unchecked(node);
631
632        ParseResult { source, parser, node }
633    }
634}
635
636#[cfg(test)]
637mod tests {
638    use super::parse;
639
640    #[test]
641    fn comments_test() {
642        let source = "# comment 1\n# comment 2\n# comment 3\n";
643        let result = parse(source.as_ref());
644
645        for comment in result.comments() {
646            assert_eq!(super::CommentType::InlineComment, comment.type_());
647            let text = std::str::from_utf8(comment.text()).unwrap();
648            assert!(text.starts_with("# comment"));
649        }
650    }
651
652    #[test]
653    fn magic_comments_test() {
654        use crate::MagicComment;
655
656        let source = "# typed: ignore\n# typed:true\n#typed: strict\n";
657        let result = parse(source.as_ref());
658
659        let comments: Vec<MagicComment<'_>> = result.magic_comments().collect();
660        assert_eq!(3, comments.len());
661
662        assert_eq!(b"typed", comments[0].key());
663        assert_eq!(b"ignore", comments[0].value());
664
665        assert_eq!(b"typed", comments[1].key());
666        assert_eq!(b"true", comments[1].value());
667
668        assert_eq!(b"typed", comments[2].key());
669        assert_eq!(b"strict", comments[2].value());
670    }
671
672    #[test]
673    fn data_loc_test() {
674        let source = "1";
675        let result = parse(source.as_ref());
676        let data_loc = result.data_loc();
677        assert!(data_loc.is_none());
678
679        let source = "__END__\nabc\n";
680        let result = parse(source.as_ref());
681        let data_loc = result.data_loc().unwrap();
682        let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
683        assert_eq!(slice, "__END__\nabc\n");
684
685        let source = "1\n2\n3\n__END__\nabc\ndef\n";
686        let result = parse(source.as_ref());
687        let data_loc = result.data_loc().unwrap();
688        let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
689        assert_eq!(slice, "__END__\nabc\ndef\n");
690    }
691
692    #[test]
693    fn location_test() {
694        let source = "111 + 222 + 333";
695        let result = parse(source.as_ref());
696
697        let node = result.node();
698        let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
699        let node = node.as_call_node().unwrap().receiver().unwrap();
700        let plus = node.as_call_node().unwrap();
701        let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
702
703        let location = node.as_integer_node().unwrap().location();
704        let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
705
706        assert_eq!(slice, "222");
707        assert_eq!(6, location.start_offset());
708        assert_eq!(9, location.end_offset());
709
710        let recv_loc = plus.receiver().unwrap().location();
711        assert_eq!(recv_loc.as_slice(), b"111");
712        assert_eq!(0, recv_loc.start_offset());
713        assert_eq!(3, recv_loc.end_offset());
714
715        let joined = recv_loc.join(&location).unwrap();
716        assert_eq!(joined.as_slice(), b"111 + 222");
717
718        let not_joined = location.join(&recv_loc);
719        assert!(not_joined.is_none());
720
721        {
722            let result = parse(source.as_ref());
723            let node = result.node();
724            let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
725            let node = node.as_call_node().unwrap().receiver().unwrap();
726            let plus = node.as_call_node().unwrap();
727            let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
728
729            let location = node.as_integer_node().unwrap().location();
730            let not_joined = recv_loc.join(&location);
731            assert!(not_joined.is_none());
732
733            let not_joined = location.join(&recv_loc);
734            assert!(not_joined.is_none());
735        }
736
737        let location = node.location();
738        let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
739
740        assert_eq!(slice, "222");
741
742        let slice = std::str::from_utf8(location.as_slice()).unwrap();
743
744        assert_eq!(slice, "222");
745    }
746
747    #[test]
748    fn visitor_test() {
749        use super::{visit_interpolated_regular_expression_node, visit_regular_expression_node, InterpolatedRegularExpressionNode, RegularExpressionNode, Visit};
750
751        struct RegularExpressionVisitor {
752            count: usize,
753        }
754
755        impl Visit<'_> for RegularExpressionVisitor {
756            fn visit_interpolated_regular_expression_node(&mut self, node: &InterpolatedRegularExpressionNode<'_>) {
757                self.count += 1;
758                visit_interpolated_regular_expression_node(self, node);
759            }
760
761            fn visit_regular_expression_node(&mut self, node: &RegularExpressionNode<'_>) {
762                self.count += 1;
763                visit_regular_expression_node(self, node);
764            }
765        }
766
767        let source = "# comment 1\n# comment 2\nmodule Foo; class Bar; /abc #{/def/}/; end; end";
768        let result = parse(source.as_ref());
769
770        let mut visitor = RegularExpressionVisitor { count: 0 };
771        visitor.visit(&result.node());
772
773        assert_eq!(visitor.count, 2);
774    }
775
776    #[test]
777    fn node_upcast_test() {
778        use super::Node;
779
780        let source = "module Foo; end";
781        let result = parse(source.as_ref());
782
783        let node = result.node();
784        let upcast_node = node.as_program_node().unwrap().as_node();
785        assert!(matches!(upcast_node, Node::ProgramNode { .. }));
786
787        let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
788        let upcast_node = node.as_module_node().unwrap().as_node();
789        assert!(matches!(upcast_node, Node::ModuleNode { .. }));
790    }
791
792    #[test]
793    fn constant_id_test() {
794        let source = "module Foo; x = 1; end";
795        let result = parse(source.as_ref());
796
797        let node = result.node();
798        let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
799        let module = module.as_module_node().unwrap();
800        let locals = module.locals().iter().collect::<Vec<_>>();
801
802        assert_eq!(locals.len(), 1);
803
804        assert_eq!(locals[0].as_slice(), b"x");
805    }
806
807    #[test]
808    fn optional_loc_test() {
809        let source = r"
810module Example
811  x = call_func(3, 4)
812  y = x.call_func 5, 6
813end
814";
815        let result = parse(source.as_ref());
816
817        let node = result.node();
818        let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
819        let module = module.as_module_node().unwrap();
820        let body = module.body();
821        let writes = body.iter().next().unwrap().as_statements_node().unwrap().body().iter().collect::<Vec<_>>();
822        assert_eq!(writes.len(), 2);
823
824        let asgn = &writes[0];
825        let call = asgn.as_local_variable_write_node().unwrap().value();
826        let call = call.as_call_node().unwrap();
827
828        let call_operator_loc = call.call_operator_loc();
829        assert!(call_operator_loc.is_none());
830        let closing_loc = call.closing_loc();
831        assert!(closing_loc.is_some());
832
833        let asgn = &writes[1];
834        let call = asgn.as_local_variable_write_node().unwrap().value();
835        let call = call.as_call_node().unwrap();
836
837        let call_operator_loc = call.call_operator_loc();
838        assert!(call_operator_loc.is_some());
839        let closing_loc = call.closing_loc();
840        assert!(closing_loc.is_none());
841    }
842
843    #[test]
844    fn frozen_strings_test() {
845        let source = r#"
846# frozen_string_literal: true
847"foo"
848"#;
849        let result = parse(source.as_ref());
850        assert!(result.frozen_string_literals());
851
852        let source = "3";
853        let result = parse(source.as_ref());
854        assert!(!result.frozen_string_literals());
855    }
856
857    #[test]
858    fn string_flags_test() {
859        let source = r#"
860# frozen_string_literal: true
861"foo"
862"#;
863        let result = parse(source.as_ref());
864
865        let node = result.node();
866        let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
867        let string = string.as_string_node().unwrap();
868        assert!(string.is_frozen());
869
870        let source = r#"
871"foo"
872"#;
873        let result = parse(source.as_ref());
874
875        let node = result.node();
876        let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
877        let string = string.as_string_node().unwrap();
878        assert!(!string.is_frozen());
879    }
880
881    #[test]
882    fn call_flags_test() {
883        let source = r"
884x
885";
886        let result = parse(source.as_ref());
887
888        let node = result.node();
889        let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
890        let call = call.as_call_node().unwrap();
891        assert!(call.is_variable_call());
892
893        let source = r"
894x&.foo
895";
896        let result = parse(source.as_ref());
897
898        let node = result.node();
899        let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
900        let call = call.as_call_node().unwrap();
901        assert!(call.is_safe_navigation());
902    }
903
904    #[test]
905    fn integer_flags_test() {
906        let source = r"
9070b1
908";
909        let result = parse(source.as_ref());
910
911        let node = result.node();
912        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
913        let i = i.as_integer_node().unwrap();
914        assert!(i.is_binary());
915        assert!(!i.is_decimal());
916        assert!(!i.is_octal());
917        assert!(!i.is_hexadecimal());
918
919        let source = r"
9201
921";
922        let result = parse(source.as_ref());
923
924        let node = result.node();
925        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
926        let i = i.as_integer_node().unwrap();
927        assert!(!i.is_binary());
928        assert!(i.is_decimal());
929        assert!(!i.is_octal());
930        assert!(!i.is_hexadecimal());
931
932        let source = r"
9330o1
934";
935        let result = parse(source.as_ref());
936
937        let node = result.node();
938        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
939        let i = i.as_integer_node().unwrap();
940        assert!(!i.is_binary());
941        assert!(!i.is_decimal());
942        assert!(i.is_octal());
943        assert!(!i.is_hexadecimal());
944
945        let source = r"
9460x1
947";
948        let result = parse(source.as_ref());
949
950        let node = result.node();
951        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
952        let i = i.as_integer_node().unwrap();
953        assert!(!i.is_binary());
954        assert!(!i.is_decimal());
955        assert!(!i.is_octal());
956        assert!(i.is_hexadecimal());
957    }
958
959    #[test]
960    fn range_flags_test() {
961        let source = r"
9620..1
963";
964        let result = parse(source.as_ref());
965
966        let node = result.node();
967        let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
968        let range = range.as_range_node().unwrap();
969        assert!(!range.is_exclude_end());
970
971        let source = r"
9720...1
973";
974        let result = parse(source.as_ref());
975
976        let node = result.node();
977        let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
978        let range = range.as_range_node().unwrap();
979        assert!(range.is_exclude_end());
980    }
981
982    #[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
983    #[test]
984    fn regex_flags_test() {
985        let source = r"
986/a/i
987";
988        let result = parse(source.as_ref());
989
990        let node = result.node();
991        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
992        let regex = regex.as_regular_expression_node().unwrap();
993        assert!(regex.is_ignore_case());
994        assert!(!regex.is_extended());
995        assert!(!regex.is_multi_line());
996        assert!(!regex.is_euc_jp());
997        assert!(!regex.is_ascii_8bit());
998        assert!(!regex.is_windows_31j());
999        assert!(!regex.is_utf_8());
1000        assert!(!regex.is_once());
1001
1002        let source = r"
1003/a/x
1004";
1005        let result = parse(source.as_ref());
1006
1007        let node = result.node();
1008        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1009        let regex = regex.as_regular_expression_node().unwrap();
1010        assert!(!regex.is_ignore_case());
1011        assert!(regex.is_extended());
1012        assert!(!regex.is_multi_line());
1013        assert!(!regex.is_euc_jp());
1014        assert!(!regex.is_ascii_8bit());
1015        assert!(!regex.is_windows_31j());
1016        assert!(!regex.is_utf_8());
1017        assert!(!regex.is_once());
1018
1019        let source = r"
1020/a/m
1021";
1022        let result = parse(source.as_ref());
1023
1024        let node = result.node();
1025        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1026        let regex = regex.as_regular_expression_node().unwrap();
1027        assert!(!regex.is_ignore_case());
1028        assert!(!regex.is_extended());
1029        assert!(regex.is_multi_line());
1030        assert!(!regex.is_euc_jp());
1031        assert!(!regex.is_ascii_8bit());
1032        assert!(!regex.is_windows_31j());
1033        assert!(!regex.is_utf_8());
1034        assert!(!regex.is_once());
1035
1036        let source = r"
1037/a/e
1038";
1039        let result = parse(source.as_ref());
1040
1041        let node = result.node();
1042        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1043        let regex = regex.as_regular_expression_node().unwrap();
1044        assert!(!regex.is_ignore_case());
1045        assert!(!regex.is_extended());
1046        assert!(!regex.is_multi_line());
1047        assert!(regex.is_euc_jp());
1048        assert!(!regex.is_ascii_8bit());
1049        assert!(!regex.is_windows_31j());
1050        assert!(!regex.is_utf_8());
1051        assert!(!regex.is_once());
1052
1053        let source = r"
1054/a/n
1055";
1056        let result = parse(source.as_ref());
1057
1058        let node = result.node();
1059        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1060        let regex = regex.as_regular_expression_node().unwrap();
1061        assert!(!regex.is_ignore_case());
1062        assert!(!regex.is_extended());
1063        assert!(!regex.is_multi_line());
1064        assert!(!regex.is_euc_jp());
1065        assert!(regex.is_ascii_8bit());
1066        assert!(!regex.is_windows_31j());
1067        assert!(!regex.is_utf_8());
1068        assert!(!regex.is_once());
1069
1070        let source = r"
1071/a/s
1072";
1073        let result = parse(source.as_ref());
1074
1075        let node = result.node();
1076        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1077        let regex = regex.as_regular_expression_node().unwrap();
1078        assert!(!regex.is_ignore_case());
1079        assert!(!regex.is_extended());
1080        assert!(!regex.is_multi_line());
1081        assert!(!regex.is_euc_jp());
1082        assert!(!regex.is_ascii_8bit());
1083        assert!(regex.is_windows_31j());
1084        assert!(!regex.is_utf_8());
1085        assert!(!regex.is_once());
1086
1087        let source = r"
1088/a/u
1089";
1090        let result = parse(source.as_ref());
1091
1092        let node = result.node();
1093        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1094        let regex = regex.as_regular_expression_node().unwrap();
1095        assert!(!regex.is_ignore_case());
1096        assert!(!regex.is_extended());
1097        assert!(!regex.is_multi_line());
1098        assert!(!regex.is_euc_jp());
1099        assert!(!regex.is_ascii_8bit());
1100        assert!(!regex.is_windows_31j());
1101        assert!(regex.is_utf_8());
1102        assert!(!regex.is_once());
1103
1104        let source = r"
1105/a/o
1106";
1107        let result = parse(source.as_ref());
1108
1109        let node = result.node();
1110        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1111        let regex = regex.as_regular_expression_node().unwrap();
1112        assert!(!regex.is_ignore_case());
1113        assert!(!regex.is_extended());
1114        assert!(!regex.is_multi_line());
1115        assert!(!regex.is_euc_jp());
1116        assert!(!regex.is_ascii_8bit());
1117        assert!(!regex.is_windows_31j());
1118        assert!(!regex.is_utf_8());
1119        assert!(regex.is_once());
1120    }
1121
1122    #[test]
1123    fn visitor_traversal_test() {
1124        use crate::{Node, Visit};
1125
1126        #[derive(Default)]
1127        struct NodeCounts {
1128            pre_parent: usize,
1129            post_parent: usize,
1130            pre_leaf: usize,
1131            post_leaf: usize,
1132        }
1133
1134        #[derive(Default)]
1135        struct CountingVisitor {
1136            counts: NodeCounts,
1137        }
1138
1139        impl Visit<'_> for CountingVisitor {
1140            fn visit_branch_node_enter(&mut self, _node: Node<'_>) {
1141                self.counts.pre_parent += 1;
1142            }
1143
1144            fn visit_branch_node_leave(&mut self) {
1145                self.counts.post_parent += 1;
1146            }
1147
1148            fn visit_leaf_node_enter(&mut self, _node: Node<'_>) {
1149                self.counts.pre_leaf += 1;
1150            }
1151
1152            fn visit_leaf_node_leave(&mut self) {
1153                self.counts.post_leaf += 1;
1154            }
1155        }
1156
1157        let source = r"
1158module Example
1159  x = call_func(3, 4)
1160  y = x.call_func 5, 6
1161end
1162";
1163        let result = parse(source.as_ref());
1164        let node = result.node();
1165        let mut visitor = CountingVisitor::default();
1166        visitor.visit(&node);
1167
1168        assert_eq!(7, visitor.counts.pre_parent);
1169        assert_eq!(7, visitor.counts.post_parent);
1170        assert_eq!(6, visitor.counts.pre_leaf);
1171        assert_eq!(6, visitor.counts.post_leaf);
1172    }
1173
1174    #[test]
1175    fn visitor_lifetime_test() {
1176        use crate::{Node, Visit};
1177
1178        #[derive(Default)]
1179        struct StackingNodeVisitor<'a> {
1180            stack: Vec<Node<'a>>,
1181            max_depth: usize,
1182        }
1183
1184        impl<'pr> Visit<'pr> for StackingNodeVisitor<'pr> {
1185            fn visit_branch_node_enter(&mut self, node: Node<'pr>) {
1186                self.stack.push(node);
1187            }
1188
1189            fn visit_branch_node_leave(&mut self) {
1190                self.stack.pop();
1191            }
1192
1193            fn visit_leaf_node_leave(&mut self) {
1194                self.max_depth = self.max_depth.max(self.stack.len());
1195            }
1196        }
1197
1198        let source = r"
1199module Example
1200  x = call_func(3, 4)
1201  y = x.call_func 5, 6
1202end
1203";
1204        let result = parse(source.as_ref());
1205        let node = result.node();
1206        let mut visitor = StackingNodeVisitor::default();
1207        visitor.visit(&node);
1208
1209        assert_eq!(0, visitor.stack.len());
1210        assert_eq!(5, visitor.max_depth);
1211    }
1212
1213    #[test]
1214    fn integer_value_test() {
1215        let result = parse("0xA".as_ref());
1216        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1217        let value: i32 = integer.try_into().unwrap();
1218
1219        assert_eq!(value, 10);
1220    }
1221
1222    #[test]
1223    fn integer_small_value_to_u32_digits_test() {
1224        let result = parse("0xA".as_ref());
1225        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1226        let (negative, digits) = integer.to_u32_digits();
1227
1228        assert!(!negative);
1229        assert_eq!(digits, &[10]);
1230    }
1231
1232    #[test]
1233    fn integer_large_value_to_u32_digits_test() {
1234        let result = parse("0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".as_ref());
1235        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1236        let (negative, digits) = integer.to_u32_digits();
1237
1238        assert!(!negative);
1239        assert_eq!(digits, &[4_294_967_295, 4_294_967_295, 4_294_967_295, 2_147_483_647]);
1240    }
1241
1242    #[test]
1243    fn float_value_test() {
1244        let result = parse("1.0".as_ref());
1245        let value: f64 = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_float_node().unwrap().value();
1246
1247        assert!((value - 1.0).abs() < f64::EPSILON);
1248    }
1249
1250    #[test]
1251    fn regex_value_test() {
1252        let result = parse(b"//");
1253        let node = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_regular_expression_node().unwrap();
1254        assert_eq!(node.unescaped(), b"");
1255    }
1256
1257    #[test]
1258    fn node_field_lifetime_test() {
1259        // The code below wouldn't typecheck prior to https://github.com/ruby/prism/pull/2519,
1260        // but we need to stop clippy from complaining about it.
1261        #![allow(clippy::needless_pass_by_value)]
1262
1263        use crate::Node;
1264
1265        #[derive(Default)]
1266        struct Extract<'pr> {
1267            scopes: Vec<crate::ConstantId<'pr>>,
1268        }
1269
1270        impl<'pr> Extract<'pr> {
1271            fn push_scope(&mut self, path: Node<'pr>) {
1272                if let Some(cread) = path.as_constant_read_node() {
1273                    self.scopes.push(cread.name());
1274                } else if let Some(cpath) = path.as_constant_path_node() {
1275                    if let Some(parent) = cpath.parent() {
1276                        self.push_scope(parent);
1277                    }
1278                    self.scopes.push(cpath.name().unwrap());
1279                } else {
1280                    panic!("Wrong node kind!");
1281                }
1282            }
1283        }
1284
1285        let source = "Some::Random::Constant";
1286        let result = parse(source.as_ref());
1287        let node = result.node();
1288        let mut extractor = Extract::default();
1289        extractor.push_scope(node.as_program_node().unwrap().statements().body().iter().next().unwrap());
1290        assert_eq!(3, extractor.scopes.len());
1291    }
1292
1293    #[test]
1294    fn malformed_shebang() {
1295        let source = "#!\x00";
1296        let result = parse(source.as_ref());
1297        assert!(result.errors().next().is_none());
1298        assert!(result.warnings().next().is_none());
1299    }
1300}