ruby_prism/
lib.rs

1//! # ruby-prism
2//!
3//! Rustified version of Ruby's prism parser.
4//!
5#![warn(clippy::all, clippy::nursery, clippy::pedantic, future_incompatible, missing_docs, nonstandard_style, rust_2018_idioms, trivial_casts, trivial_numeric_casts, unreachable_pub, unused_qualifications)]
6
7// Most of the code in this file is generated, so sometimes it generates code
8// that doesn't follow the clippy rules. We don't want to see those warnings.
9#[allow(clippy::too_many_lines, clippy::use_self)]
10mod bindings {
11    // In `build.rs`, we generate bindings based on the config.yml file. Here is
12    // where we pull in those bindings and make them part of our library.
13    include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
14}
15
16use std::ffi::{c_char, CStr};
17use std::marker::PhantomData;
18use std::mem::MaybeUninit;
19use std::ptr::NonNull;
20
21pub use self::bindings::*;
22use ruby_prism_sys::{pm_comment_t, pm_comment_type_t, pm_constant_id_list_t, pm_constant_id_t, pm_diagnostic_t, pm_integer_t, pm_location_t, pm_magic_comment_t, pm_node_destroy, pm_node_list, pm_node_t, pm_parse, pm_parser_free, pm_parser_init, pm_parser_t};
23
24/// A range in the source file.
25pub struct Location<'pr> {
26    parser: NonNull<pm_parser_t>,
27    pub(crate) start: *const u8,
28    pub(crate) end: *const u8,
29    marker: PhantomData<&'pr [u8]>,
30}
31
32impl<'pr> Location<'pr> {
33    /// Returns a byte slice for the range.
34    /// # Panics
35    /// Panics if the end offset is not greater than the start offset.
36    #[must_use]
37    pub fn as_slice(&self) -> &'pr [u8] {
38        unsafe {
39            let len = usize::try_from(self.end.offset_from(self.start)).expect("end should point to memory after start");
40            std::slice::from_raw_parts(self.start, len)
41        }
42    }
43
44    /// Return a Location from the given `pm_location_t`.
45    #[must_use]
46    pub(crate) const fn new(parser: NonNull<pm_parser_t>, loc: &'pr pm_location_t) -> Self {
47        Location {
48            parser,
49            start: loc.start,
50            end: loc.end,
51            marker: PhantomData,
52        }
53    }
54
55    /// Return a Location starting at self and ending at the end of other.
56    /// Returns None if both locations did not originate from the same parser,
57    /// or if self starts after other.
58    #[must_use]
59    pub fn join(&self, other: &Self) -> Option<Self> {
60        if self.parser != other.parser || self.start > other.start {
61            None
62        } else {
63            Some(Location {
64                parser: self.parser,
65                start: self.start,
66                end: other.end,
67                marker: PhantomData,
68            })
69        }
70    }
71
72    /// Return the start offset from the beginning of the parsed source.
73    /// # Panics
74    /// Panics if the start offset is not greater than the parser's start.
75    #[must_use]
76    pub fn start_offset(&self) -> usize {
77        unsafe {
78            let parser_start = (*self.parser.as_ptr()).start;
79            usize::try_from(self.start.offset_from(parser_start)).expect("start should point to memory after the parser's start")
80        }
81    }
82
83    /// Return the end offset from the beginning of the parsed source.
84    /// # Panics
85    /// Panics if the end offset is not greater than the parser's start.
86    #[must_use]
87    pub fn end_offset(&self) -> usize {
88        unsafe {
89            let parser_start = (*self.parser.as_ptr()).start;
90            usize::try_from(self.end.offset_from(parser_start)).expect("end should point to memory after the parser's start")
91        }
92    }
93}
94
95impl std::fmt::Debug for Location<'_> {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        let slice: &[u8] = self.as_slice();
98
99        let mut visible = String::new();
100        visible.push('"');
101
102        for &byte in slice {
103            let part: Vec<u8> = std::ascii::escape_default(byte).collect();
104            visible.push_str(std::str::from_utf8(&part).unwrap());
105        }
106
107        visible.push('"');
108        write!(f, "{visible}")
109    }
110}
111
112/// An iterator over the nodes in a list.
113pub struct NodeListIter<'pr> {
114    parser: NonNull<pm_parser_t>,
115    pointer: NonNull<pm_node_list>,
116    index: usize,
117    marker: PhantomData<&'pr mut pm_node_list>,
118}
119
120impl<'pr> Iterator for NodeListIter<'pr> {
121    type Item = Node<'pr>;
122
123    fn next(&mut self) -> Option<Self::Item> {
124        if self.index >= unsafe { self.pointer.as_ref().size } {
125            None
126        } else {
127            let node: *mut pm_node_t = unsafe { *(self.pointer.as_ref().nodes.add(self.index)) };
128            self.index += 1;
129            Some(Node::new(self.parser, node))
130        }
131    }
132}
133
134/// A list of nodes.
135pub struct NodeList<'pr> {
136    parser: NonNull<pm_parser_t>,
137    pointer: NonNull<pm_node_list>,
138    marker: PhantomData<&'pr mut pm_node_list>,
139}
140
141impl<'pr> NodeList<'pr> {
142    /// Returns an iterator over the nodes.
143    #[must_use]
144    pub const fn iter(&self) -> NodeListIter<'pr> {
145        NodeListIter {
146            parser: self.parser,
147            pointer: self.pointer,
148            index: 0,
149            marker: PhantomData,
150        }
151    }
152
153    /// Returns the length of the list.
154    #[must_use]
155    pub const fn len(&self) -> usize {
156        unsafe { self.pointer.as_ref().size }
157    }
158
159    /// Returns whether the list is empty.
160    #[must_use]
161    pub const fn is_empty(&self) -> bool {
162        self.len() == 0
163    }
164}
165
166impl<'pr> IntoIterator for &NodeList<'pr> {
167    type Item = Node<'pr>;
168    type IntoIter = NodeListIter<'pr>;
169    fn into_iter(self) -> Self::IntoIter {
170        self.iter()
171    }
172}
173
174impl std::fmt::Debug for NodeList<'_> {
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        write!(f, "{:?}", self.iter().collect::<Vec<_>>())
177    }
178}
179
180/// A handle for a constant ID.
181pub struct ConstantId<'pr> {
182    parser: NonNull<pm_parser_t>,
183    id: pm_constant_id_t,
184    marker: PhantomData<&'pr mut pm_constant_id_t>,
185}
186
187impl<'pr> ConstantId<'pr> {
188    const fn new(parser: NonNull<pm_parser_t>, id: pm_constant_id_t) -> Self {
189        ConstantId { parser, id, marker: PhantomData }
190    }
191
192    /// Returns a byte slice for the constant ID.
193    ///
194    /// # Panics
195    ///
196    /// Panics if the constant ID is not found in the constant pool.
197    #[must_use]
198    pub fn as_slice(&self) -> &'pr [u8] {
199        unsafe {
200            let pool = &(*self.parser.as_ptr()).constant_pool;
201            let constant = &(*pool.constants.add((self.id - 1).try_into().unwrap()));
202            std::slice::from_raw_parts(constant.start, constant.length)
203        }
204    }
205}
206
207impl std::fmt::Debug for ConstantId<'_> {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        write!(f, "{:?}", self.id)
210    }
211}
212
213/// An iterator over the constants in a list.
214pub struct ConstantListIter<'pr> {
215    parser: NonNull<pm_parser_t>,
216    pointer: NonNull<pm_constant_id_list_t>,
217    index: usize,
218    marker: PhantomData<&'pr mut pm_constant_id_list_t>,
219}
220
221impl<'pr> Iterator for ConstantListIter<'pr> {
222    type Item = ConstantId<'pr>;
223
224    fn next(&mut self) -> Option<Self::Item> {
225        if self.index >= unsafe { self.pointer.as_ref().size } {
226            None
227        } else {
228            let constant_id: pm_constant_id_t = unsafe { *(self.pointer.as_ref().ids.add(self.index)) };
229            self.index += 1;
230            Some(ConstantId::new(self.parser, constant_id))
231        }
232    }
233}
234
235/// A list of constants.
236pub struct ConstantList<'pr> {
237    /// The raw pointer to the parser where this list came from.
238    parser: NonNull<pm_parser_t>,
239
240    /// The raw pointer to the list allocated by prism.
241    pointer: NonNull<pm_constant_id_list_t>,
242
243    /// The marker to indicate the lifetime of the pointer.
244    marker: PhantomData<&'pr mut pm_constant_id_list_t>,
245}
246
247impl<'pr> ConstantList<'pr> {
248    /// Returns an iterator over the constants in the list.
249    #[must_use]
250    pub const fn iter(&self) -> ConstantListIter<'pr> {
251        ConstantListIter {
252            parser: self.parser,
253            pointer: self.pointer,
254            index: 0,
255            marker: PhantomData,
256        }
257    }
258}
259
260impl<'pr> IntoIterator for &ConstantList<'pr> {
261    type Item = ConstantId<'pr>;
262    type IntoIter = ConstantListIter<'pr>;
263    fn into_iter(self) -> Self::IntoIter {
264        self.iter()
265    }
266}
267
268impl std::fmt::Debug for ConstantList<'_> {
269    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270        write!(f, "{:?}", self.iter().collect::<Vec<_>>())
271    }
272}
273
274/// A handle for an arbitarily-sized integer.
275pub struct Integer<'pr> {
276    /// The raw pointer to the integer allocated by prism.
277    pointer: *const pm_integer_t,
278
279    /// The marker to indicate the lifetime of the pointer.
280    marker: PhantomData<&'pr mut pm_constant_id_t>,
281}
282
283impl Integer<'_> {
284    const fn new(pointer: *const pm_integer_t) -> Self {
285        Integer { pointer, marker: PhantomData }
286    }
287
288    /// Returns the sign and the u32 digits representation of the integer,
289    /// ordered least significant digit first.
290    #[must_use]
291    pub const fn to_u32_digits(&self) -> (bool, &[u32]) {
292        let negative = unsafe { (*self.pointer).negative };
293        let length = unsafe { (*self.pointer).length };
294        let values = unsafe { (*self.pointer).values };
295
296        if values.is_null() {
297            let value_ptr = unsafe { std::ptr::addr_of!((*self.pointer).value) };
298            let slice = unsafe { std::slice::from_raw_parts(value_ptr, 1) };
299            (negative, slice)
300        } else {
301            let slice = unsafe { std::slice::from_raw_parts(values, length) };
302            (negative, slice)
303        }
304    }
305}
306
307impl std::fmt::Debug for Integer<'_> {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        write!(f, "{:?}", self.pointer)
310    }
311}
312
313impl TryInto<i32> for Integer<'_> {
314    type Error = ();
315
316    fn try_into(self) -> Result<i32, Self::Error> {
317        let negative = unsafe { (*self.pointer).negative };
318        let length = unsafe { (*self.pointer).length };
319
320        if length == 0 {
321            i32::try_from(unsafe { (*self.pointer).value }).map_or(Err(()), |value| if negative { Ok(-value) } else { Ok(value) })
322        } else {
323            Err(())
324        }
325    }
326}
327
328/// A diagnostic message that came back from the parser.
329#[derive(Debug)]
330pub struct Diagnostic<'pr> {
331    diag: NonNull<pm_diagnostic_t>,
332    parser: NonNull<pm_parser_t>,
333    marker: PhantomData<&'pr pm_diagnostic_t>,
334}
335
336impl<'pr> Diagnostic<'pr> {
337    /// Returns the message associated with the diagnostic.
338    ///
339    /// # Panics
340    ///
341    /// Panics if the message is not able to be converted into a `CStr`.
342    ///
343    #[must_use]
344    pub fn message(&self) -> &str {
345        unsafe {
346            let message: *mut c_char = self.diag.as_ref().message.cast_mut();
347            CStr::from_ptr(message).to_str().expect("prism allows only UTF-8 for diagnostics.")
348        }
349    }
350
351    /// The location of the diagnostic in the source.
352    #[must_use]
353    pub const fn location(&self) -> Location<'pr> {
354        Location::new(self.parser, unsafe { &self.diag.as_ref().location })
355    }
356}
357
358/// A comment that was found during parsing.
359#[derive(Debug)]
360pub struct Comment<'pr> {
361    content: NonNull<pm_comment_t>,
362    parser: NonNull<pm_parser_t>,
363    marker: PhantomData<&'pr pm_comment_t>,
364}
365
366/// The type of the comment
367#[derive(Debug, Clone, Copy, PartialEq, Eq)]
368pub enum CommentType {
369    /// `InlineComment` corresponds to comments that start with #.
370    InlineComment,
371    /// `EmbDocComment` corresponds to comments that are surrounded by =begin and =end.
372    EmbDocComment,
373}
374
375impl<'pr> Comment<'pr> {
376    /// Returns the text of the comment.
377    ///
378    /// # Panics
379    /// Panics if the end offset is not greater than the start offset.
380    #[must_use]
381    pub fn text(&self) -> &[u8] {
382        self.location().as_slice()
383    }
384
385    /// Returns the type of the comment.
386    #[must_use]
387    pub fn type_(&self) -> CommentType {
388        let type_ = unsafe { self.content.as_ref().type_ };
389        if type_ == pm_comment_type_t::PM_COMMENT_EMBDOC {
390            CommentType::EmbDocComment
391        } else {
392            CommentType::InlineComment
393        }
394    }
395
396    /// The location of the comment in the source.
397    #[must_use]
398    pub const fn location(&self) -> Location<'pr> {
399        Location::new(self.parser, unsafe { &self.content.as_ref().location })
400    }
401}
402
403/// A magic comment that was found during parsing.
404#[derive(Debug)]
405pub struct MagicComment<'pr> {
406    comment: NonNull<pm_magic_comment_t>,
407    marker: PhantomData<&'pr pm_magic_comment_t>,
408}
409
410impl MagicComment<'_> {
411    /// Returns the text of the comment's key.
412    #[must_use]
413    pub const fn key(&self) -> &[u8] {
414        unsafe {
415            let start = self.comment.as_ref().key_start;
416            let len = self.comment.as_ref().key_length as usize;
417            std::slice::from_raw_parts(start, len)
418        }
419    }
420
421    /// Returns the text of the comment's value.
422    #[must_use]
423    pub const fn value(&self) -> &[u8] {
424        unsafe {
425            let start = self.comment.as_ref().value_start;
426            let len = self.comment.as_ref().value_length as usize;
427            std::slice::from_raw_parts(start, len)
428        }
429    }
430}
431
432/// A struct created by the `errors` or `warnings` methods on `ParseResult`. It
433/// can be used to iterate over the diagnostics in the parse result.
434pub struct Diagnostics<'pr> {
435    diagnostic: *mut pm_diagnostic_t,
436    parser: NonNull<pm_parser_t>,
437    marker: PhantomData<&'pr pm_diagnostic_t>,
438}
439
440impl<'pr> Iterator for Diagnostics<'pr> {
441    type Item = Diagnostic<'pr>;
442
443    fn next(&mut self) -> Option<Self::Item> {
444        if let Some(diagnostic) = NonNull::new(self.diagnostic) {
445            let current = Diagnostic {
446                diag: diagnostic,
447                parser: self.parser,
448                marker: PhantomData,
449            };
450            self.diagnostic = unsafe { diagnostic.as_ref().node.next.cast::<pm_diagnostic_t>() };
451            Some(current)
452        } else {
453            None
454        }
455    }
456}
457
458/// A struct created by the `comments` method on `ParseResult`. It can be used
459/// to iterate over the comments in the parse result.
460pub struct Comments<'pr> {
461    comment: *mut pm_comment_t,
462    parser: NonNull<pm_parser_t>,
463    marker: PhantomData<&'pr pm_comment_t>,
464}
465
466impl<'pr> Iterator for Comments<'pr> {
467    type Item = Comment<'pr>;
468
469    fn next(&mut self) -> Option<Self::Item> {
470        if let Some(comment) = NonNull::new(self.comment) {
471            let current = Comment {
472                content: comment,
473                parser: self.parser,
474                marker: PhantomData,
475            };
476            self.comment = unsafe { comment.as_ref().node.next.cast::<pm_comment_t>() };
477            Some(current)
478        } else {
479            None
480        }
481    }
482}
483
484/// A struct created by the `magic_comments` method on `ParseResult`. It can be used
485/// to iterate over the magic comments in the parse result.
486pub struct MagicComments<'pr> {
487    comment: *mut pm_magic_comment_t,
488    marker: PhantomData<&'pr pm_magic_comment_t>,
489}
490
491impl<'pr> Iterator for MagicComments<'pr> {
492    type Item = MagicComment<'pr>;
493
494    fn next(&mut self) -> Option<Self::Item> {
495        if let Some(comment) = NonNull::new(self.comment) {
496            let current = MagicComment { comment, marker: PhantomData };
497            self.comment = unsafe { comment.as_ref().node.next.cast::<pm_magic_comment_t>() };
498            Some(current)
499        } else {
500            None
501        }
502    }
503}
504
505/// The result of parsing a source string.
506#[derive(Debug)]
507pub struct ParseResult<'pr> {
508    source: &'pr [u8],
509    parser: NonNull<pm_parser_t>,
510    node: NonNull<pm_node_t>,
511}
512
513impl<'pr> ParseResult<'pr> {
514    /// Returns the source string that was parsed.
515    #[must_use]
516    pub const fn source(&self) -> &'pr [u8] {
517        self.source
518    }
519
520    /// Returns whether we found a `frozen_string_literal` magic comment with a true value.
521    #[must_use]
522    pub fn frozen_string_literals(&self) -> bool {
523        unsafe { (*self.parser.as_ptr()).frozen_string_literal == 1 }
524    }
525
526    /// Returns a slice of the source string that was parsed using the given
527    /// location range.
528    ///
529    /// # Panics
530    /// Panics if start offset or end offset are not valid offsets from the root.
531    #[must_use]
532    pub fn as_slice(&self, location: &Location<'pr>) -> &'pr [u8] {
533        let root = self.source.as_ptr();
534
535        let start = usize::try_from(unsafe { location.start.offset_from(root) }).expect("start should point to memory after root");
536        let end = usize::try_from(unsafe { location.end.offset_from(root) }).expect("end should point to memory after root");
537
538        &self.source[start..end]
539    }
540
541    /// Returns an iterator that can be used to iterate over the errors in the
542    /// parse result.
543    #[must_use]
544    pub fn errors(&self) -> Diagnostics<'_> {
545        unsafe {
546            let list = &mut (*self.parser.as_ptr()).error_list;
547            Diagnostics {
548                diagnostic: list.head.cast::<pm_diagnostic_t>(),
549                parser: self.parser,
550                marker: PhantomData,
551            }
552        }
553    }
554
555    /// Returns an iterator that can be used to iterate over the warnings in the
556    /// parse result.
557    #[must_use]
558    pub fn warnings(&self) -> Diagnostics<'_> {
559        unsafe {
560            let list = &mut (*self.parser.as_ptr()).warning_list;
561            Diagnostics {
562                diagnostic: list.head.cast::<pm_diagnostic_t>(),
563                parser: self.parser,
564                marker: PhantomData,
565            }
566        }
567    }
568
569    /// Returns an iterator that can be used to iterate over the comments in the
570    /// parse result.
571    #[must_use]
572    pub fn comments(&self) -> Comments<'_> {
573        unsafe {
574            let list = &mut (*self.parser.as_ptr()).comment_list;
575            Comments {
576                comment: list.head.cast::<pm_comment_t>(),
577                parser: self.parser,
578                marker: PhantomData,
579            }
580        }
581    }
582
583    /// Returns an iterator that can be used to iterate over the magic comments in the
584    /// parse result.
585    #[must_use]
586    pub fn magic_comments(&self) -> MagicComments<'_> {
587        unsafe {
588            let list = &mut (*self.parser.as_ptr()).magic_comment_list;
589            MagicComments {
590                comment: list.head.cast::<pm_magic_comment_t>(),
591                marker: PhantomData,
592            }
593        }
594    }
595
596    /// Returns an optional location of the __END__ marker and the rest of the content of the file.
597    #[must_use]
598    pub fn data_loc(&self) -> Option<Location<'_>> {
599        let location = unsafe { &(*self.parser.as_ptr()).data_loc };
600        if location.start.is_null() {
601            None
602        } else {
603            Some(Location::new(self.parser, location))
604        }
605    }
606
607    /// Returns the root node of the parse result.
608    #[must_use]
609    pub fn node(&self) -> Node<'_> {
610        Node::new(self.parser, self.node.as_ptr())
611    }
612}
613
614impl Drop for ParseResult<'_> {
615    fn drop(&mut self) {
616        unsafe {
617            pm_node_destroy(self.parser.as_ptr(), self.node.as_ptr());
618            pm_parser_free(self.parser.as_ptr());
619            drop(Box::from_raw(self.parser.as_ptr()));
620        }
621    }
622}
623
624/// Parses the given source string and returns a parse result.
625///
626/// # Panics
627///
628/// Panics if the parser fails to initialize.
629///
630#[must_use]
631pub fn parse(source: &[u8]) -> ParseResult<'_> {
632    unsafe {
633        let uninit = Box::new(MaybeUninit::<pm_parser_t>::uninit());
634        let uninit = Box::into_raw(uninit);
635
636        pm_parser_init((*uninit).as_mut_ptr(), source.as_ptr(), source.len(), std::ptr::null());
637
638        let parser = (*uninit).assume_init_mut();
639        let parser = NonNull::new_unchecked(parser);
640
641        let node = pm_parse(parser.as_ptr());
642        let node = NonNull::new_unchecked(node);
643
644        ParseResult { source, parser, node }
645    }
646}
647
648#[cfg(test)]
649mod tests {
650    use super::parse;
651
652    #[test]
653    fn comments_test() {
654        let source = "# comment 1\n# comment 2\n# comment 3\n";
655        let result = parse(source.as_ref());
656
657        for comment in result.comments() {
658            assert_eq!(super::CommentType::InlineComment, comment.type_());
659            let text = std::str::from_utf8(comment.text()).unwrap();
660            assert!(text.starts_with("# comment"));
661        }
662    }
663
664    #[test]
665    fn magic_comments_test() {
666        use crate::MagicComment;
667
668        let source = "# typed: ignore\n# typed:true\n#typed: strict\n";
669        let result = parse(source.as_ref());
670
671        let comments: Vec<MagicComment<'_>> = result.magic_comments().collect();
672        assert_eq!(3, comments.len());
673
674        assert_eq!(b"typed", comments[0].key());
675        assert_eq!(b"ignore", comments[0].value());
676
677        assert_eq!(b"typed", comments[1].key());
678        assert_eq!(b"true", comments[1].value());
679
680        assert_eq!(b"typed", comments[2].key());
681        assert_eq!(b"strict", comments[2].value());
682    }
683
684    #[test]
685    fn data_loc_test() {
686        let source = "1";
687        let result = parse(source.as_ref());
688        let data_loc = result.data_loc();
689        assert!(data_loc.is_none());
690
691        let source = "__END__\nabc\n";
692        let result = parse(source.as_ref());
693        let data_loc = result.data_loc().unwrap();
694        let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
695        assert_eq!(slice, "__END__\nabc\n");
696
697        let source = "1\n2\n3\n__END__\nabc\ndef\n";
698        let result = parse(source.as_ref());
699        let data_loc = result.data_loc().unwrap();
700        let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
701        assert_eq!(slice, "__END__\nabc\ndef\n");
702    }
703
704    #[test]
705    fn location_test() {
706        let source = "111 + 222 + 333";
707        let result = parse(source.as_ref());
708
709        let node = result.node();
710        let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
711        let node = node.as_call_node().unwrap().receiver().unwrap();
712        let plus = node.as_call_node().unwrap();
713        let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
714
715        let location = node.as_integer_node().unwrap().location();
716        let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
717
718        assert_eq!(slice, "222");
719        assert_eq!(6, location.start_offset());
720        assert_eq!(9, location.end_offset());
721
722        let recv_loc = plus.receiver().unwrap().location();
723        assert_eq!(recv_loc.as_slice(), b"111");
724        assert_eq!(0, recv_loc.start_offset());
725        assert_eq!(3, recv_loc.end_offset());
726
727        let joined = recv_loc.join(&location).unwrap();
728        assert_eq!(joined.as_slice(), b"111 + 222");
729
730        let not_joined = location.join(&recv_loc);
731        assert!(not_joined.is_none());
732
733        {
734            let result = parse(source.as_ref());
735            let node = result.node();
736            let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
737            let node = node.as_call_node().unwrap().receiver().unwrap();
738            let plus = node.as_call_node().unwrap();
739            let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
740
741            let location = node.as_integer_node().unwrap().location();
742            let not_joined = recv_loc.join(&location);
743            assert!(not_joined.is_none());
744
745            let not_joined = location.join(&recv_loc);
746            assert!(not_joined.is_none());
747        }
748
749        let location = node.location();
750        let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
751
752        assert_eq!(slice, "222");
753
754        let slice = std::str::from_utf8(location.as_slice()).unwrap();
755
756        assert_eq!(slice, "222");
757    }
758
759    #[test]
760    fn visitor_test() {
761        use super::{visit_interpolated_regular_expression_node, visit_regular_expression_node, InterpolatedRegularExpressionNode, RegularExpressionNode, Visit};
762
763        struct RegularExpressionVisitor {
764            count: usize,
765        }
766
767        impl Visit<'_> for RegularExpressionVisitor {
768            fn visit_interpolated_regular_expression_node(&mut self, node: &InterpolatedRegularExpressionNode<'_>) {
769                self.count += 1;
770                visit_interpolated_regular_expression_node(self, node);
771            }
772
773            fn visit_regular_expression_node(&mut self, node: &RegularExpressionNode<'_>) {
774                self.count += 1;
775                visit_regular_expression_node(self, node);
776            }
777        }
778
779        let source = "# comment 1\n# comment 2\nmodule Foo; class Bar; /abc #{/def/}/; end; end";
780        let result = parse(source.as_ref());
781
782        let mut visitor = RegularExpressionVisitor { count: 0 };
783        visitor.visit(&result.node());
784
785        assert_eq!(visitor.count, 2);
786    }
787
788    #[test]
789    fn node_upcast_test() {
790        use super::Node;
791
792        let source = "module Foo; end";
793        let result = parse(source.as_ref());
794
795        let node = result.node();
796        let upcast_node = node.as_program_node().unwrap().as_node();
797        assert!(matches!(upcast_node, Node::ProgramNode { .. }));
798
799        let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
800        let upcast_node = node.as_module_node().unwrap().as_node();
801        assert!(matches!(upcast_node, Node::ModuleNode { .. }));
802    }
803
804    #[test]
805    fn constant_id_test() {
806        let source = "module Foo; x = 1; end";
807        let result = parse(source.as_ref());
808
809        let node = result.node();
810        assert_eq!(node.as_program_node().unwrap().statements().body().len(), 1);
811        assert!(!node.as_program_node().unwrap().statements().body().is_empty());
812        let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
813        let module = module.as_module_node().unwrap();
814        let locals = module.locals().iter().collect::<Vec<_>>();
815
816        assert_eq!(locals.len(), 1);
817
818        assert_eq!(locals[0].as_slice(), b"x");
819    }
820
821    #[test]
822    fn optional_loc_test() {
823        let source = r"
824module Example
825  x = call_func(3, 4)
826  y = x.call_func 5, 6
827end
828";
829        let result = parse(source.as_ref());
830
831        let node = result.node();
832        let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
833        let module = module.as_module_node().unwrap();
834        let body = module.body();
835        let writes = body.iter().next().unwrap().as_statements_node().unwrap().body().iter().collect::<Vec<_>>();
836        assert_eq!(writes.len(), 2);
837
838        let asgn = &writes[0];
839        let call = asgn.as_local_variable_write_node().unwrap().value();
840        let call = call.as_call_node().unwrap();
841
842        let call_operator_loc = call.call_operator_loc();
843        assert!(call_operator_loc.is_none());
844        let closing_loc = call.closing_loc();
845        assert!(closing_loc.is_some());
846
847        let asgn = &writes[1];
848        let call = asgn.as_local_variable_write_node().unwrap().value();
849        let call = call.as_call_node().unwrap();
850
851        let call_operator_loc = call.call_operator_loc();
852        assert!(call_operator_loc.is_some());
853        let closing_loc = call.closing_loc();
854        assert!(closing_loc.is_none());
855    }
856
857    #[test]
858    fn frozen_strings_test() {
859        let source = r#"
860# frozen_string_literal: true
861"foo"
862"#;
863        let result = parse(source.as_ref());
864        assert!(result.frozen_string_literals());
865
866        let source = "3";
867        let result = parse(source.as_ref());
868        assert!(!result.frozen_string_literals());
869    }
870
871    #[test]
872    fn string_flags_test() {
873        let source = r#"
874# frozen_string_literal: true
875"foo"
876"#;
877        let result = parse(source.as_ref());
878
879        let node = result.node();
880        let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
881        let string = string.as_string_node().unwrap();
882        assert!(string.is_frozen());
883
884        let source = r#"
885"foo"
886"#;
887        let result = parse(source.as_ref());
888
889        let node = result.node();
890        let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
891        let string = string.as_string_node().unwrap();
892        assert!(!string.is_frozen());
893    }
894
895    #[test]
896    fn call_flags_test() {
897        let source = r"
898x
899";
900        let result = parse(source.as_ref());
901
902        let node = result.node();
903        let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
904        let call = call.as_call_node().unwrap();
905        assert!(call.is_variable_call());
906
907        let source = r"
908x&.foo
909";
910        let result = parse(source.as_ref());
911
912        let node = result.node();
913        let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
914        let call = call.as_call_node().unwrap();
915        assert!(call.is_safe_navigation());
916    }
917
918    #[test]
919    fn integer_flags_test() {
920        let source = r"
9210b1
922";
923        let result = parse(source.as_ref());
924
925        let node = result.node();
926        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
927        let i = i.as_integer_node().unwrap();
928        assert!(i.is_binary());
929        assert!(!i.is_decimal());
930        assert!(!i.is_octal());
931        assert!(!i.is_hexadecimal());
932
933        let source = r"
9341
935";
936        let result = parse(source.as_ref());
937
938        let node = result.node();
939        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
940        let i = i.as_integer_node().unwrap();
941        assert!(!i.is_binary());
942        assert!(i.is_decimal());
943        assert!(!i.is_octal());
944        assert!(!i.is_hexadecimal());
945
946        let source = r"
9470o1
948";
949        let result = parse(source.as_ref());
950
951        let node = result.node();
952        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
953        let i = i.as_integer_node().unwrap();
954        assert!(!i.is_binary());
955        assert!(!i.is_decimal());
956        assert!(i.is_octal());
957        assert!(!i.is_hexadecimal());
958
959        let source = r"
9600x1
961";
962        let result = parse(source.as_ref());
963
964        let node = result.node();
965        let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
966        let i = i.as_integer_node().unwrap();
967        assert!(!i.is_binary());
968        assert!(!i.is_decimal());
969        assert!(!i.is_octal());
970        assert!(i.is_hexadecimal());
971    }
972
973    #[test]
974    fn range_flags_test() {
975        let source = r"
9760..1
977";
978        let result = parse(source.as_ref());
979
980        let node = result.node();
981        let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
982        let range = range.as_range_node().unwrap();
983        assert!(!range.is_exclude_end());
984
985        let source = r"
9860...1
987";
988        let result = parse(source.as_ref());
989
990        let node = result.node();
991        let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
992        let range = range.as_range_node().unwrap();
993        assert!(range.is_exclude_end());
994    }
995
996    #[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
997    #[test]
998    fn regex_flags_test() {
999        let source = r"
1000/a/i
1001";
1002        let result = parse(source.as_ref());
1003
1004        let node = result.node();
1005        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1006        let regex = regex.as_regular_expression_node().unwrap();
1007        assert!(regex.is_ignore_case());
1008        assert!(!regex.is_extended());
1009        assert!(!regex.is_multi_line());
1010        assert!(!regex.is_euc_jp());
1011        assert!(!regex.is_ascii_8bit());
1012        assert!(!regex.is_windows_31j());
1013        assert!(!regex.is_utf_8());
1014        assert!(!regex.is_once());
1015
1016        let source = r"
1017/a/x
1018";
1019        let result = parse(source.as_ref());
1020
1021        let node = result.node();
1022        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1023        let regex = regex.as_regular_expression_node().unwrap();
1024        assert!(!regex.is_ignore_case());
1025        assert!(regex.is_extended());
1026        assert!(!regex.is_multi_line());
1027        assert!(!regex.is_euc_jp());
1028        assert!(!regex.is_ascii_8bit());
1029        assert!(!regex.is_windows_31j());
1030        assert!(!regex.is_utf_8());
1031        assert!(!regex.is_once());
1032
1033        let source = r"
1034/a/m
1035";
1036        let result = parse(source.as_ref());
1037
1038        let node = result.node();
1039        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1040        let regex = regex.as_regular_expression_node().unwrap();
1041        assert!(!regex.is_ignore_case());
1042        assert!(!regex.is_extended());
1043        assert!(regex.is_multi_line());
1044        assert!(!regex.is_euc_jp());
1045        assert!(!regex.is_ascii_8bit());
1046        assert!(!regex.is_windows_31j());
1047        assert!(!regex.is_utf_8());
1048        assert!(!regex.is_once());
1049
1050        let source = r"
1051/a/e
1052";
1053        let result = parse(source.as_ref());
1054
1055        let node = result.node();
1056        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1057        let regex = regex.as_regular_expression_node().unwrap();
1058        assert!(!regex.is_ignore_case());
1059        assert!(!regex.is_extended());
1060        assert!(!regex.is_multi_line());
1061        assert!(regex.is_euc_jp());
1062        assert!(!regex.is_ascii_8bit());
1063        assert!(!regex.is_windows_31j());
1064        assert!(!regex.is_utf_8());
1065        assert!(!regex.is_once());
1066
1067        let source = r"
1068/a/n
1069";
1070        let result = parse(source.as_ref());
1071
1072        let node = result.node();
1073        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1074        let regex = regex.as_regular_expression_node().unwrap();
1075        assert!(!regex.is_ignore_case());
1076        assert!(!regex.is_extended());
1077        assert!(!regex.is_multi_line());
1078        assert!(!regex.is_euc_jp());
1079        assert!(regex.is_ascii_8bit());
1080        assert!(!regex.is_windows_31j());
1081        assert!(!regex.is_utf_8());
1082        assert!(!regex.is_once());
1083
1084        let source = r"
1085/a/s
1086";
1087        let result = parse(source.as_ref());
1088
1089        let node = result.node();
1090        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1091        let regex = regex.as_regular_expression_node().unwrap();
1092        assert!(!regex.is_ignore_case());
1093        assert!(!regex.is_extended());
1094        assert!(!regex.is_multi_line());
1095        assert!(!regex.is_euc_jp());
1096        assert!(!regex.is_ascii_8bit());
1097        assert!(regex.is_windows_31j());
1098        assert!(!regex.is_utf_8());
1099        assert!(!regex.is_once());
1100
1101        let source = r"
1102/a/u
1103";
1104        let result = parse(source.as_ref());
1105
1106        let node = result.node();
1107        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1108        let regex = regex.as_regular_expression_node().unwrap();
1109        assert!(!regex.is_ignore_case());
1110        assert!(!regex.is_extended());
1111        assert!(!regex.is_multi_line());
1112        assert!(!regex.is_euc_jp());
1113        assert!(!regex.is_ascii_8bit());
1114        assert!(!regex.is_windows_31j());
1115        assert!(regex.is_utf_8());
1116        assert!(!regex.is_once());
1117
1118        let source = r"
1119/a/o
1120";
1121        let result = parse(source.as_ref());
1122
1123        let node = result.node();
1124        let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1125        let regex = regex.as_regular_expression_node().unwrap();
1126        assert!(!regex.is_ignore_case());
1127        assert!(!regex.is_extended());
1128        assert!(!regex.is_multi_line());
1129        assert!(!regex.is_euc_jp());
1130        assert!(!regex.is_ascii_8bit());
1131        assert!(!regex.is_windows_31j());
1132        assert!(!regex.is_utf_8());
1133        assert!(regex.is_once());
1134    }
1135
1136    #[test]
1137    fn visitor_traversal_test() {
1138        use crate::{Node, Visit};
1139
1140        #[derive(Default)]
1141        struct NodeCounts {
1142            pre_parent: usize,
1143            post_parent: usize,
1144            pre_leaf: usize,
1145            post_leaf: usize,
1146        }
1147
1148        #[derive(Default)]
1149        struct CountingVisitor {
1150            counts: NodeCounts,
1151        }
1152
1153        impl Visit<'_> for CountingVisitor {
1154            fn visit_branch_node_enter(&mut self, _node: Node<'_>) {
1155                self.counts.pre_parent += 1;
1156            }
1157
1158            fn visit_branch_node_leave(&mut self) {
1159                self.counts.post_parent += 1;
1160            }
1161
1162            fn visit_leaf_node_enter(&mut self, _node: Node<'_>) {
1163                self.counts.pre_leaf += 1;
1164            }
1165
1166            fn visit_leaf_node_leave(&mut self) {
1167                self.counts.post_leaf += 1;
1168            }
1169        }
1170
1171        let source = r"
1172module Example
1173  x = call_func(3, 4)
1174  y = x.call_func 5, 6
1175end
1176";
1177        let result = parse(source.as_ref());
1178        let node = result.node();
1179        let mut visitor = CountingVisitor::default();
1180        visitor.visit(&node);
1181
1182        assert_eq!(7, visitor.counts.pre_parent);
1183        assert_eq!(7, visitor.counts.post_parent);
1184        assert_eq!(6, visitor.counts.pre_leaf);
1185        assert_eq!(6, visitor.counts.post_leaf);
1186    }
1187
1188    #[test]
1189    fn visitor_lifetime_test() {
1190        use crate::{Node, Visit};
1191
1192        #[derive(Default)]
1193        struct StackingNodeVisitor<'a> {
1194            stack: Vec<Node<'a>>,
1195            max_depth: usize,
1196        }
1197
1198        impl<'pr> Visit<'pr> for StackingNodeVisitor<'pr> {
1199            fn visit_branch_node_enter(&mut self, node: Node<'pr>) {
1200                self.stack.push(node);
1201            }
1202
1203            fn visit_branch_node_leave(&mut self) {
1204                self.stack.pop();
1205            }
1206
1207            fn visit_leaf_node_leave(&mut self) {
1208                self.max_depth = self.max_depth.max(self.stack.len());
1209            }
1210        }
1211
1212        let source = r"
1213module Example
1214  x = call_func(3, 4)
1215  y = x.call_func 5, 6
1216end
1217";
1218        let result = parse(source.as_ref());
1219        let node = result.node();
1220        let mut visitor = StackingNodeVisitor::default();
1221        visitor.visit(&node);
1222
1223        assert_eq!(0, visitor.stack.len());
1224        assert_eq!(5, visitor.max_depth);
1225    }
1226
1227    #[test]
1228    fn integer_value_test() {
1229        let result = parse("0xA".as_ref());
1230        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1231        let value: i32 = integer.try_into().unwrap();
1232
1233        assert_eq!(value, 10);
1234    }
1235
1236    #[test]
1237    fn integer_small_value_to_u32_digits_test() {
1238        let result = parse("0xA".as_ref());
1239        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1240        let (negative, digits) = integer.to_u32_digits();
1241
1242        assert!(!negative);
1243        assert_eq!(digits, &[10]);
1244    }
1245
1246    #[test]
1247    fn integer_large_value_to_u32_digits_test() {
1248        let result = parse("0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".as_ref());
1249        let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1250        let (negative, digits) = integer.to_u32_digits();
1251
1252        assert!(!negative);
1253        assert_eq!(digits, &[4_294_967_295, 4_294_967_295, 4_294_967_295, 2_147_483_647]);
1254    }
1255
1256    #[test]
1257    fn float_value_test() {
1258        let result = parse("1.0".as_ref());
1259        let value: f64 = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_float_node().unwrap().value();
1260
1261        assert!((value - 1.0).abs() < f64::EPSILON);
1262    }
1263
1264    #[test]
1265    fn regex_value_test() {
1266        let result = parse(b"//");
1267        let node = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_regular_expression_node().unwrap();
1268        assert_eq!(node.unescaped(), b"");
1269    }
1270
1271    #[test]
1272    fn node_field_lifetime_test() {
1273        // The code below wouldn't typecheck prior to https://github.com/ruby/prism/pull/2519,
1274        // but we need to stop clippy from complaining about it.
1275        #![allow(clippy::needless_pass_by_value)]
1276
1277        use crate::Node;
1278
1279        #[derive(Default)]
1280        struct Extract<'pr> {
1281            scopes: Vec<crate::ConstantId<'pr>>,
1282        }
1283
1284        impl<'pr> Extract<'pr> {
1285            fn push_scope(&mut self, path: Node<'pr>) {
1286                if let Some(cread) = path.as_constant_read_node() {
1287                    self.scopes.push(cread.name());
1288                } else if let Some(cpath) = path.as_constant_path_node() {
1289                    if let Some(parent) = cpath.parent() {
1290                        self.push_scope(parent);
1291                    }
1292                    self.scopes.push(cpath.name().unwrap());
1293                } else {
1294                    panic!("Wrong node kind!");
1295                }
1296            }
1297        }
1298
1299        let source = "Some::Random::Constant";
1300        let result = parse(source.as_ref());
1301        let node = result.node();
1302        let mut extractor = Extract::default();
1303        extractor.push_scope(node.as_program_node().unwrap().statements().body().iter().next().unwrap());
1304        assert_eq!(3, extractor.scopes.len());
1305    }
1306
1307    #[test]
1308    fn malformed_shebang() {
1309        let source = "#!\x00";
1310        let result = parse(source.as_ref());
1311        assert!(result.errors().next().is_none());
1312        assert!(result.warnings().next().is_none());
1313    }
1314}