1#![warn(clippy::all, clippy::nursery, clippy::pedantic, future_incompatible, missing_docs, nonstandard_style, rust_2018_idioms, trivial_casts, trivial_numeric_casts, unreachable_pub, unused_qualifications)]
6
7#[allow(clippy::too_many_lines, clippy::use_self)]
10mod bindings {
11 include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
14}
15
16use std::ffi::{c_char, CStr};
17use std::marker::PhantomData;
18use std::mem::MaybeUninit;
19use std::ptr::NonNull;
20
21pub use self::bindings::*;
22use ruby_prism_sys::{pm_comment_t, pm_comment_type_t, pm_constant_id_list_t, pm_constant_id_t, pm_diagnostic_t, pm_integer_t, pm_location_t, pm_magic_comment_t, pm_node_destroy, pm_node_list, pm_node_t, pm_parse, pm_parser_free, pm_parser_init, pm_parser_t};
23
24pub struct Location<'pr> {
26 parser: NonNull<pm_parser_t>,
27 pub(crate) start: *const u8,
28 pub(crate) end: *const u8,
29 marker: PhantomData<&'pr [u8]>,
30}
31
32impl<'pr> Location<'pr> {
33 #[must_use]
37 pub fn as_slice(&self) -> &'pr [u8] {
38 unsafe {
39 let len = usize::try_from(self.end.offset_from(self.start)).expect("end should point to memory after start");
40 std::slice::from_raw_parts(self.start, len)
41 }
42 }
43
44 #[must_use]
46 pub(crate) const fn new(parser: NonNull<pm_parser_t>, loc: &'pr pm_location_t) -> Self {
47 Location {
48 parser,
49 start: loc.start,
50 end: loc.end,
51 marker: PhantomData,
52 }
53 }
54
55 #[must_use]
59 pub fn join(&self, other: &Self) -> Option<Self> {
60 if self.parser != other.parser || self.start > other.start {
61 None
62 } else {
63 Some(Location {
64 parser: self.parser,
65 start: self.start,
66 end: other.end,
67 marker: PhantomData,
68 })
69 }
70 }
71
72 #[must_use]
76 pub fn start_offset(&self) -> usize {
77 unsafe {
78 let parser_start = (*self.parser.as_ptr()).start;
79 usize::try_from(self.start.offset_from(parser_start)).expect("start should point to memory after the parser's start")
80 }
81 }
82
83 #[must_use]
87 pub fn end_offset(&self) -> usize {
88 unsafe {
89 let parser_start = (*self.parser.as_ptr()).start;
90 usize::try_from(self.end.offset_from(parser_start)).expect("end should point to memory after the parser's start")
91 }
92 }
93}
94
95impl std::fmt::Debug for Location<'_> {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 let slice: &[u8] = self.as_slice();
98
99 let mut visible = String::new();
100 visible.push('"');
101
102 for &byte in slice {
103 let part: Vec<u8> = std::ascii::escape_default(byte).collect();
104 visible.push_str(std::str::from_utf8(&part).unwrap());
105 }
106
107 visible.push('"');
108 write!(f, "{visible}")
109 }
110}
111
112pub struct NodeListIter<'pr> {
114 parser: NonNull<pm_parser_t>,
115 pointer: NonNull<pm_node_list>,
116 index: usize,
117 marker: PhantomData<&'pr mut pm_node_list>,
118}
119
120impl<'pr> Iterator for NodeListIter<'pr> {
121 type Item = Node<'pr>;
122
123 fn next(&mut self) -> Option<Self::Item> {
124 if self.index >= unsafe { self.pointer.as_ref().size } {
125 None
126 } else {
127 let node: *mut pm_node_t = unsafe { *(self.pointer.as_ref().nodes.add(self.index)) };
128 self.index += 1;
129 Some(Node::new(self.parser, node))
130 }
131 }
132}
133
134pub struct NodeList<'pr> {
136 parser: NonNull<pm_parser_t>,
137 pointer: NonNull<pm_node_list>,
138 marker: PhantomData<&'pr mut pm_node_list>,
139}
140
141impl<'pr> NodeList<'pr> {
142 unsafe fn at(&self, index: usize) -> Node<'pr> {
143 let node: *mut pm_node_t = *(self.pointer.as_ref().nodes.add(index));
144 Node::new(self.parser, node)
145 }
146
147 #[must_use]
149 pub const fn iter(&self) -> NodeListIter<'pr> {
150 NodeListIter {
151 parser: self.parser,
152 pointer: self.pointer,
153 index: 0,
154 marker: PhantomData,
155 }
156 }
157
158 #[must_use]
160 pub const fn len(&self) -> usize {
161 unsafe { self.pointer.as_ref().size }
162 }
163
164 #[must_use]
166 pub const fn is_empty(&self) -> bool {
167 self.len() == 0
168 }
169
170 #[must_use]
172 pub fn first(&self) -> Option<Node<'pr>> {
173 if self.is_empty() {
174 None
175 } else {
176 Some(unsafe { self.at(0) })
177 }
178 }
179
180 #[must_use]
182 pub fn last(&self) -> Option<Node<'pr>> {
183 if self.is_empty() {
184 None
185 } else {
186 Some(unsafe { self.at(self.len() - 1) })
187 }
188 }
189}
190
191impl<'pr> IntoIterator for &NodeList<'pr> {
192 type Item = Node<'pr>;
193 type IntoIter = NodeListIter<'pr>;
194 fn into_iter(self) -> Self::IntoIter {
195 self.iter()
196 }
197}
198
199impl std::fmt::Debug for NodeList<'_> {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 write!(f, "{:?}", self.iter().collect::<Vec<_>>())
202 }
203}
204
205pub struct ConstantId<'pr> {
207 parser: NonNull<pm_parser_t>,
208 id: pm_constant_id_t,
209 marker: PhantomData<&'pr mut pm_constant_id_t>,
210}
211
212impl<'pr> ConstantId<'pr> {
213 const fn new(parser: NonNull<pm_parser_t>, id: pm_constant_id_t) -> Self {
214 ConstantId { parser, id, marker: PhantomData }
215 }
216
217 #[must_use]
223 pub fn as_slice(&self) -> &'pr [u8] {
224 unsafe {
225 let pool = &(*self.parser.as_ptr()).constant_pool;
226 let constant = &(*pool.constants.add((self.id - 1).try_into().unwrap()));
227 std::slice::from_raw_parts(constant.start, constant.length)
228 }
229 }
230}
231
232impl std::fmt::Debug for ConstantId<'_> {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 write!(f, "{:?}", self.id)
235 }
236}
237
238pub struct ConstantListIter<'pr> {
240 parser: NonNull<pm_parser_t>,
241 pointer: NonNull<pm_constant_id_list_t>,
242 index: usize,
243 marker: PhantomData<&'pr mut pm_constant_id_list_t>,
244}
245
246impl<'pr> Iterator for ConstantListIter<'pr> {
247 type Item = ConstantId<'pr>;
248
249 fn next(&mut self) -> Option<Self::Item> {
250 if self.index >= unsafe { self.pointer.as_ref().size } {
251 None
252 } else {
253 let constant_id: pm_constant_id_t = unsafe { *(self.pointer.as_ref().ids.add(self.index)) };
254 self.index += 1;
255 Some(ConstantId::new(self.parser, constant_id))
256 }
257 }
258}
259
260pub struct ConstantList<'pr> {
262 parser: NonNull<pm_parser_t>,
264
265 pointer: NonNull<pm_constant_id_list_t>,
267
268 marker: PhantomData<&'pr mut pm_constant_id_list_t>,
270}
271
272impl<'pr> ConstantList<'pr> {
273 const unsafe fn at(&self, index: usize) -> ConstantId<'pr> {
274 let constant_id: pm_constant_id_t = *(self.pointer.as_ref().ids.add(index));
275 ConstantId::new(self.parser, constant_id)
276 }
277
278 #[must_use]
280 pub const fn iter(&self) -> ConstantListIter<'pr> {
281 ConstantListIter {
282 parser: self.parser,
283 pointer: self.pointer,
284 index: 0,
285 marker: PhantomData,
286 }
287 }
288
289 #[must_use]
291 pub const fn len(&self) -> usize {
292 unsafe { self.pointer.as_ref().size }
293 }
294
295 #[must_use]
297 pub const fn is_empty(&self) -> bool {
298 self.len() == 0
299 }
300
301 #[must_use]
303 pub const fn first(&self) -> Option<ConstantId<'pr>> {
304 if self.is_empty() {
305 None
306 } else {
307 Some(unsafe { self.at(0) })
308 }
309 }
310
311 #[must_use]
313 pub const fn last(&self) -> Option<ConstantId<'pr>> {
314 if self.is_empty() {
315 None
316 } else {
317 Some(unsafe { self.at(self.len() - 1) })
318 }
319 }
320}
321
322impl<'pr> IntoIterator for &ConstantList<'pr> {
323 type Item = ConstantId<'pr>;
324 type IntoIter = ConstantListIter<'pr>;
325 fn into_iter(self) -> Self::IntoIter {
326 self.iter()
327 }
328}
329
330impl std::fmt::Debug for ConstantList<'_> {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 write!(f, "{:?}", self.iter().collect::<Vec<_>>())
333 }
334}
335
336pub struct Integer<'pr> {
338 pointer: *const pm_integer_t,
340
341 marker: PhantomData<&'pr mut pm_constant_id_t>,
343}
344
345impl Integer<'_> {
346 const fn new(pointer: *const pm_integer_t) -> Self {
347 Integer { pointer, marker: PhantomData }
348 }
349
350 #[must_use]
353 pub const fn to_u32_digits(&self) -> (bool, &[u32]) {
354 let negative = unsafe { (*self.pointer).negative };
355 let length = unsafe { (*self.pointer).length };
356 let values = unsafe { (*self.pointer).values };
357
358 if values.is_null() {
359 let value_ptr = unsafe { std::ptr::addr_of!((*self.pointer).value) };
360 let slice = unsafe { std::slice::from_raw_parts(value_ptr, 1) };
361 (negative, slice)
362 } else {
363 let slice = unsafe { std::slice::from_raw_parts(values, length) };
364 (negative, slice)
365 }
366 }
367}
368
369impl std::fmt::Debug for Integer<'_> {
370 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371 write!(f, "{:?}", self.pointer)
372 }
373}
374
375impl TryInto<i32> for Integer<'_> {
376 type Error = ();
377
378 fn try_into(self) -> Result<i32, Self::Error> {
379 let negative = unsafe { (*self.pointer).negative };
380 let length = unsafe { (*self.pointer).length };
381
382 if length == 0 {
383 i32::try_from(unsafe { (*self.pointer).value }).map_or(Err(()), |value| if negative { Ok(-value) } else { Ok(value) })
384 } else {
385 Err(())
386 }
387 }
388}
389
390#[derive(Debug)]
392pub struct Diagnostic<'pr> {
393 diag: NonNull<pm_diagnostic_t>,
394 parser: NonNull<pm_parser_t>,
395 marker: PhantomData<&'pr pm_diagnostic_t>,
396}
397
398impl<'pr> Diagnostic<'pr> {
399 #[must_use]
406 pub fn message(&self) -> &str {
407 unsafe {
408 let message: *mut c_char = self.diag.as_ref().message.cast_mut();
409 CStr::from_ptr(message).to_str().expect("prism allows only UTF-8 for diagnostics.")
410 }
411 }
412
413 #[must_use]
415 pub const fn location(&self) -> Location<'pr> {
416 Location::new(self.parser, unsafe { &self.diag.as_ref().location })
417 }
418}
419
420#[derive(Debug)]
422pub struct Comment<'pr> {
423 content: NonNull<pm_comment_t>,
424 parser: NonNull<pm_parser_t>,
425 marker: PhantomData<&'pr pm_comment_t>,
426}
427
428#[derive(Debug, Clone, Copy, PartialEq, Eq)]
430pub enum CommentType {
431 InlineComment,
433 EmbDocComment,
435}
436
437impl<'pr> Comment<'pr> {
438 #[must_use]
443 pub fn text(&self) -> &[u8] {
444 self.location().as_slice()
445 }
446
447 #[must_use]
449 pub fn type_(&self) -> CommentType {
450 let type_ = unsafe { self.content.as_ref().type_ };
451 if type_ == pm_comment_type_t::PM_COMMENT_EMBDOC {
452 CommentType::EmbDocComment
453 } else {
454 CommentType::InlineComment
455 }
456 }
457
458 #[must_use]
460 pub const fn location(&self) -> Location<'pr> {
461 Location::new(self.parser, unsafe { &self.content.as_ref().location })
462 }
463}
464
465#[derive(Debug)]
467pub struct MagicComment<'pr> {
468 comment: NonNull<pm_magic_comment_t>,
469 marker: PhantomData<&'pr pm_magic_comment_t>,
470}
471
472impl MagicComment<'_> {
473 #[must_use]
475 pub const fn key(&self) -> &[u8] {
476 unsafe {
477 let start = self.comment.as_ref().key_start;
478 let len = self.comment.as_ref().key_length as usize;
479 std::slice::from_raw_parts(start, len)
480 }
481 }
482
483 #[must_use]
485 pub const fn value(&self) -> &[u8] {
486 unsafe {
487 let start = self.comment.as_ref().value_start;
488 let len = self.comment.as_ref().value_length as usize;
489 std::slice::from_raw_parts(start, len)
490 }
491 }
492}
493
494pub struct Diagnostics<'pr> {
497 diagnostic: *mut pm_diagnostic_t,
498 parser: NonNull<pm_parser_t>,
499 marker: PhantomData<&'pr pm_diagnostic_t>,
500}
501
502impl<'pr> Iterator for Diagnostics<'pr> {
503 type Item = Diagnostic<'pr>;
504
505 fn next(&mut self) -> Option<Self::Item> {
506 if let Some(diagnostic) = NonNull::new(self.diagnostic) {
507 let current = Diagnostic {
508 diag: diagnostic,
509 parser: self.parser,
510 marker: PhantomData,
511 };
512 self.diagnostic = unsafe { diagnostic.as_ref().node.next.cast::<pm_diagnostic_t>() };
513 Some(current)
514 } else {
515 None
516 }
517 }
518}
519
520pub struct Comments<'pr> {
523 comment: *mut pm_comment_t,
524 parser: NonNull<pm_parser_t>,
525 marker: PhantomData<&'pr pm_comment_t>,
526}
527
528impl<'pr> Iterator for Comments<'pr> {
529 type Item = Comment<'pr>;
530
531 fn next(&mut self) -> Option<Self::Item> {
532 if let Some(comment) = NonNull::new(self.comment) {
533 let current = Comment {
534 content: comment,
535 parser: self.parser,
536 marker: PhantomData,
537 };
538 self.comment = unsafe { comment.as_ref().node.next.cast::<pm_comment_t>() };
539 Some(current)
540 } else {
541 None
542 }
543 }
544}
545
546pub struct MagicComments<'pr> {
549 comment: *mut pm_magic_comment_t,
550 marker: PhantomData<&'pr pm_magic_comment_t>,
551}
552
553impl<'pr> Iterator for MagicComments<'pr> {
554 type Item = MagicComment<'pr>;
555
556 fn next(&mut self) -> Option<Self::Item> {
557 if let Some(comment) = NonNull::new(self.comment) {
558 let current = MagicComment { comment, marker: PhantomData };
559 self.comment = unsafe { comment.as_ref().node.next.cast::<pm_magic_comment_t>() };
560 Some(current)
561 } else {
562 None
563 }
564 }
565}
566
567#[derive(Debug)]
569pub struct ParseResult<'pr> {
570 source: &'pr [u8],
571 parser: NonNull<pm_parser_t>,
572 node: NonNull<pm_node_t>,
573}
574
575impl<'pr> ParseResult<'pr> {
576 #[must_use]
578 pub const fn source(&self) -> &'pr [u8] {
579 self.source
580 }
581
582 #[must_use]
584 pub fn frozen_string_literals(&self) -> bool {
585 unsafe { (*self.parser.as_ptr()).frozen_string_literal == 1 }
586 }
587
588 #[must_use]
594 pub fn as_slice(&self, location: &Location<'pr>) -> &'pr [u8] {
595 let root = self.source.as_ptr();
596
597 let start = usize::try_from(unsafe { location.start.offset_from(root) }).expect("start should point to memory after root");
598 let end = usize::try_from(unsafe { location.end.offset_from(root) }).expect("end should point to memory after root");
599
600 &self.source[start..end]
601 }
602
603 #[must_use]
606 pub fn errors(&self) -> Diagnostics<'_> {
607 unsafe {
608 let list = &mut (*self.parser.as_ptr()).error_list;
609 Diagnostics {
610 diagnostic: list.head.cast::<pm_diagnostic_t>(),
611 parser: self.parser,
612 marker: PhantomData,
613 }
614 }
615 }
616
617 #[must_use]
620 pub fn warnings(&self) -> Diagnostics<'_> {
621 unsafe {
622 let list = &mut (*self.parser.as_ptr()).warning_list;
623 Diagnostics {
624 diagnostic: list.head.cast::<pm_diagnostic_t>(),
625 parser: self.parser,
626 marker: PhantomData,
627 }
628 }
629 }
630
631 #[must_use]
634 pub fn comments(&self) -> Comments<'_> {
635 unsafe {
636 let list = &mut (*self.parser.as_ptr()).comment_list;
637 Comments {
638 comment: list.head.cast::<pm_comment_t>(),
639 parser: self.parser,
640 marker: PhantomData,
641 }
642 }
643 }
644
645 #[must_use]
648 pub fn magic_comments(&self) -> MagicComments<'_> {
649 unsafe {
650 let list = &mut (*self.parser.as_ptr()).magic_comment_list;
651 MagicComments {
652 comment: list.head.cast::<pm_magic_comment_t>(),
653 marker: PhantomData,
654 }
655 }
656 }
657
658 #[must_use]
660 pub fn data_loc(&self) -> Option<Location<'_>> {
661 let location = unsafe { &(*self.parser.as_ptr()).data_loc };
662 if location.start.is_null() {
663 None
664 } else {
665 Some(Location::new(self.parser, location))
666 }
667 }
668
669 #[must_use]
671 pub fn node(&self) -> Node<'_> {
672 Node::new(self.parser, self.node.as_ptr())
673 }
674}
675
676impl Drop for ParseResult<'_> {
677 fn drop(&mut self) {
678 unsafe {
679 pm_node_destroy(self.parser.as_ptr(), self.node.as_ptr());
680 pm_parser_free(self.parser.as_ptr());
681 drop(Box::from_raw(self.parser.as_ptr()));
682 }
683 }
684}
685
686#[must_use]
693pub fn parse(source: &[u8]) -> ParseResult<'_> {
694 unsafe {
695 let uninit = Box::new(MaybeUninit::<pm_parser_t>::uninit());
696 let uninit = Box::into_raw(uninit);
697
698 pm_parser_init((*uninit).as_mut_ptr(), source.as_ptr(), source.len(), std::ptr::null());
699
700 let parser = (*uninit).assume_init_mut();
701 let parser = NonNull::new_unchecked(parser);
702
703 let node = pm_parse(parser.as_ptr());
704 let node = NonNull::new_unchecked(node);
705
706 ParseResult { source, parser, node }
707 }
708}
709
710#[cfg(test)]
711mod tests {
712 use super::parse;
713
714 #[test]
715 fn comments_test() {
716 let source = "# comment 1\n# comment 2\n# comment 3\n";
717 let result = parse(source.as_ref());
718
719 for comment in result.comments() {
720 assert_eq!(super::CommentType::InlineComment, comment.type_());
721 let text = std::str::from_utf8(comment.text()).unwrap();
722 assert!(text.starts_with("# comment"));
723 }
724 }
725
726 #[test]
727 fn magic_comments_test() {
728 use crate::MagicComment;
729
730 let source = "# typed: ignore\n# typed:true\n#typed: strict\n";
731 let result = parse(source.as_ref());
732
733 let comments: Vec<MagicComment<'_>> = result.magic_comments().collect();
734 assert_eq!(3, comments.len());
735
736 assert_eq!(b"typed", comments[0].key());
737 assert_eq!(b"ignore", comments[0].value());
738
739 assert_eq!(b"typed", comments[1].key());
740 assert_eq!(b"true", comments[1].value());
741
742 assert_eq!(b"typed", comments[2].key());
743 assert_eq!(b"strict", comments[2].value());
744 }
745
746 #[test]
747 fn data_loc_test() {
748 let source = "1";
749 let result = parse(source.as_ref());
750 let data_loc = result.data_loc();
751 assert!(data_loc.is_none());
752
753 let source = "__END__\nabc\n";
754 let result = parse(source.as_ref());
755 let data_loc = result.data_loc().unwrap();
756 let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
757 assert_eq!(slice, "__END__\nabc\n");
758
759 let source = "1\n2\n3\n__END__\nabc\ndef\n";
760 let result = parse(source.as_ref());
761 let data_loc = result.data_loc().unwrap();
762 let slice = std::str::from_utf8(result.as_slice(&data_loc)).unwrap();
763 assert_eq!(slice, "__END__\nabc\ndef\n");
764 }
765
766 #[test]
767 fn location_test() {
768 let source = "111 + 222 + 333";
769 let result = parse(source.as_ref());
770
771 let node = result.node();
772 let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
773 let node = node.as_call_node().unwrap().receiver().unwrap();
774 let plus = node.as_call_node().unwrap();
775 let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
776
777 let location = node.as_integer_node().unwrap().location();
778 let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
779
780 assert_eq!(slice, "222");
781 assert_eq!(6, location.start_offset());
782 assert_eq!(9, location.end_offset());
783
784 let recv_loc = plus.receiver().unwrap().location();
785 assert_eq!(recv_loc.as_slice(), b"111");
786 assert_eq!(0, recv_loc.start_offset());
787 assert_eq!(3, recv_loc.end_offset());
788
789 let joined = recv_loc.join(&location).unwrap();
790 assert_eq!(joined.as_slice(), b"111 + 222");
791
792 let not_joined = location.join(&recv_loc);
793 assert!(not_joined.is_none());
794
795 {
796 let result = parse(source.as_ref());
797 let node = result.node();
798 let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
799 let node = node.as_call_node().unwrap().receiver().unwrap();
800 let plus = node.as_call_node().unwrap();
801 let node = plus.arguments().unwrap().arguments().iter().next().unwrap();
802
803 let location = node.as_integer_node().unwrap().location();
804 let not_joined = recv_loc.join(&location);
805 assert!(not_joined.is_none());
806
807 let not_joined = location.join(&recv_loc);
808 assert!(not_joined.is_none());
809 }
810
811 let location = node.location();
812 let slice = std::str::from_utf8(result.as_slice(&location)).unwrap();
813
814 assert_eq!(slice, "222");
815
816 let slice = std::str::from_utf8(location.as_slice()).unwrap();
817
818 assert_eq!(slice, "222");
819 }
820
821 #[test]
822 fn visitor_test() {
823 use super::{visit_interpolated_regular_expression_node, visit_regular_expression_node, InterpolatedRegularExpressionNode, RegularExpressionNode, Visit};
824
825 struct RegularExpressionVisitor {
826 count: usize,
827 }
828
829 impl Visit<'_> for RegularExpressionVisitor {
830 fn visit_interpolated_regular_expression_node(&mut self, node: &InterpolatedRegularExpressionNode<'_>) {
831 self.count += 1;
832 visit_interpolated_regular_expression_node(self, node);
833 }
834
835 fn visit_regular_expression_node(&mut self, node: &RegularExpressionNode<'_>) {
836 self.count += 1;
837 visit_regular_expression_node(self, node);
838 }
839 }
840
841 let source = "# comment 1\n# comment 2\nmodule Foo; class Bar; /abc #{/def/}/; end; end";
842 let result = parse(source.as_ref());
843
844 let mut visitor = RegularExpressionVisitor { count: 0 };
845 visitor.visit(&result.node());
846
847 assert_eq!(visitor.count, 2);
848 }
849
850 #[test]
851 fn node_upcast_test() {
852 use super::Node;
853
854 let source = "module Foo; end";
855 let result = parse(source.as_ref());
856
857 let node = result.node();
858 let upcast_node = node.as_program_node().unwrap().as_node();
859 assert!(matches!(upcast_node, Node::ProgramNode { .. }));
860
861 let node = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
862 let upcast_node = node.as_module_node().unwrap().as_node();
863 assert!(matches!(upcast_node, Node::ModuleNode { .. }));
864 }
865
866 #[test]
867 fn constant_id_test() {
868 let source = "module Foo; x = 1; y = 2; end";
869 let result = parse(source.as_ref());
870
871 let node = result.node();
872 assert_eq!(node.as_program_node().unwrap().statements().body().len(), 1);
873 assert!(!node.as_program_node().unwrap().statements().body().is_empty());
874 let module = node.as_program_node().and_then(|pn| pn.statements().body().first()).unwrap();
875 let module = module.as_module_node().unwrap();
876
877 assert_eq!(module.locals().len(), 2);
878 assert!(!module.locals().is_empty());
879
880 assert_eq!(module.locals().first().unwrap().as_slice(), b"x");
881 assert_eq!(module.locals().last().unwrap().as_slice(), b"y");
882
883 let source = "module Foo; end";
884 let result = parse(source.as_ref());
885
886 let node = result.node();
887 assert_eq!(node.as_program_node().unwrap().statements().body().len(), 1);
888 assert!(!node.as_program_node().unwrap().statements().body().is_empty());
889 let module = node.as_program_node().and_then(|pn| pn.statements().body().first()).unwrap();
890 let module = module.as_module_node().unwrap();
891
892 assert_eq!(module.locals().len(), 0);
893 assert!(module.locals().is_empty());
894 }
895
896 #[test]
897 fn optional_loc_test() {
898 let source = r"
899module Example
900 x = call_func(3, 4)
901 y = x.call_func 5, 6
902end
903";
904 let result = parse(source.as_ref());
905
906 let node = result.node();
907 let module = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
908 let module = module.as_module_node().unwrap();
909 let body = module.body();
910 let writes = body.iter().next().unwrap().as_statements_node().unwrap().body().iter().collect::<Vec<_>>();
911 assert_eq!(writes.len(), 2);
912
913 let asgn = &writes[0];
914 let call = asgn.as_local_variable_write_node().unwrap().value();
915 let call = call.as_call_node().unwrap();
916
917 let call_operator_loc = call.call_operator_loc();
918 assert!(call_operator_loc.is_none());
919 let closing_loc = call.closing_loc();
920 assert!(closing_loc.is_some());
921
922 let asgn = &writes[1];
923 let call = asgn.as_local_variable_write_node().unwrap().value();
924 let call = call.as_call_node().unwrap();
925
926 let call_operator_loc = call.call_operator_loc();
927 assert!(call_operator_loc.is_some());
928 let closing_loc = call.closing_loc();
929 assert!(closing_loc.is_none());
930 }
931
932 #[test]
933 fn frozen_strings_test() {
934 let source = r#"
935# frozen_string_literal: true
936"foo"
937"#;
938 let result = parse(source.as_ref());
939 assert!(result.frozen_string_literals());
940
941 let source = "3";
942 let result = parse(source.as_ref());
943 assert!(!result.frozen_string_literals());
944 }
945
946 #[test]
947 fn string_flags_test() {
948 let source = r#"
949# frozen_string_literal: true
950"foo"
951"#;
952 let result = parse(source.as_ref());
953
954 let node = result.node();
955 let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
956 let string = string.as_string_node().unwrap();
957 assert!(string.is_frozen());
958
959 let source = r#"
960"foo"
961"#;
962 let result = parse(source.as_ref());
963
964 let node = result.node();
965 let string = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
966 let string = string.as_string_node().unwrap();
967 assert!(!string.is_frozen());
968 }
969
970 #[test]
971 fn call_flags_test() {
972 let source = r"
973x
974";
975 let result = parse(source.as_ref());
976
977 let node = result.node();
978 let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
979 let call = call.as_call_node().unwrap();
980 assert!(call.is_variable_call());
981
982 let source = r"
983x&.foo
984";
985 let result = parse(source.as_ref());
986
987 let node = result.node();
988 let call = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
989 let call = call.as_call_node().unwrap();
990 assert!(call.is_safe_navigation());
991 }
992
993 #[test]
994 fn integer_flags_test() {
995 let source = r"
9960b1
997";
998 let result = parse(source.as_ref());
999
1000 let node = result.node();
1001 let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1002 let i = i.as_integer_node().unwrap();
1003 assert!(i.is_binary());
1004 assert!(!i.is_decimal());
1005 assert!(!i.is_octal());
1006 assert!(!i.is_hexadecimal());
1007
1008 let source = r"
10091
1010";
1011 let result = parse(source.as_ref());
1012
1013 let node = result.node();
1014 let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1015 let i = i.as_integer_node().unwrap();
1016 assert!(!i.is_binary());
1017 assert!(i.is_decimal());
1018 assert!(!i.is_octal());
1019 assert!(!i.is_hexadecimal());
1020
1021 let source = r"
10220o1
1023";
1024 let result = parse(source.as_ref());
1025
1026 let node = result.node();
1027 let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1028 let i = i.as_integer_node().unwrap();
1029 assert!(!i.is_binary());
1030 assert!(!i.is_decimal());
1031 assert!(i.is_octal());
1032 assert!(!i.is_hexadecimal());
1033
1034 let source = r"
10350x1
1036";
1037 let result = parse(source.as_ref());
1038
1039 let node = result.node();
1040 let i = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1041 let i = i.as_integer_node().unwrap();
1042 assert!(!i.is_binary());
1043 assert!(!i.is_decimal());
1044 assert!(!i.is_octal());
1045 assert!(i.is_hexadecimal());
1046 }
1047
1048 #[test]
1049 fn range_flags_test() {
1050 let source = r"
10510..1
1052";
1053 let result = parse(source.as_ref());
1054
1055 let node = result.node();
1056 let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1057 let range = range.as_range_node().unwrap();
1058 assert!(!range.is_exclude_end());
1059
1060 let source = r"
10610...1
1062";
1063 let result = parse(source.as_ref());
1064
1065 let node = result.node();
1066 let range = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1067 let range = range.as_range_node().unwrap();
1068 assert!(range.is_exclude_end());
1069 }
1070
1071 #[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
1072 #[test]
1073 fn regex_flags_test() {
1074 let source = r"
1075/a/i
1076";
1077 let result = parse(source.as_ref());
1078
1079 let node = result.node();
1080 let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1081 let regex = regex.as_regular_expression_node().unwrap();
1082 assert!(regex.is_ignore_case());
1083 assert!(!regex.is_extended());
1084 assert!(!regex.is_multi_line());
1085 assert!(!regex.is_euc_jp());
1086 assert!(!regex.is_ascii_8bit());
1087 assert!(!regex.is_windows_31j());
1088 assert!(!regex.is_utf_8());
1089 assert!(!regex.is_once());
1090
1091 let source = r"
1092/a/x
1093";
1094 let result = parse(source.as_ref());
1095
1096 let node = result.node();
1097 let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1098 let regex = regex.as_regular_expression_node().unwrap();
1099 assert!(!regex.is_ignore_case());
1100 assert!(regex.is_extended());
1101 assert!(!regex.is_multi_line());
1102 assert!(!regex.is_euc_jp());
1103 assert!(!regex.is_ascii_8bit());
1104 assert!(!regex.is_windows_31j());
1105 assert!(!regex.is_utf_8());
1106 assert!(!regex.is_once());
1107
1108 let source = r"
1109/a/m
1110";
1111 let result = parse(source.as_ref());
1112
1113 let node = result.node();
1114 let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1115 let regex = regex.as_regular_expression_node().unwrap();
1116 assert!(!regex.is_ignore_case());
1117 assert!(!regex.is_extended());
1118 assert!(regex.is_multi_line());
1119 assert!(!regex.is_euc_jp());
1120 assert!(!regex.is_ascii_8bit());
1121 assert!(!regex.is_windows_31j());
1122 assert!(!regex.is_utf_8());
1123 assert!(!regex.is_once());
1124
1125 let source = r"
1126/a/e
1127";
1128 let result = parse(source.as_ref());
1129
1130 let node = result.node();
1131 let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1132 let regex = regex.as_regular_expression_node().unwrap();
1133 assert!(!regex.is_ignore_case());
1134 assert!(!regex.is_extended());
1135 assert!(!regex.is_multi_line());
1136 assert!(regex.is_euc_jp());
1137 assert!(!regex.is_ascii_8bit());
1138 assert!(!regex.is_windows_31j());
1139 assert!(!regex.is_utf_8());
1140 assert!(!regex.is_once());
1141
1142 let source = r"
1143/a/n
1144";
1145 let result = parse(source.as_ref());
1146
1147 let node = result.node();
1148 let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1149 let regex = regex.as_regular_expression_node().unwrap();
1150 assert!(!regex.is_ignore_case());
1151 assert!(!regex.is_extended());
1152 assert!(!regex.is_multi_line());
1153 assert!(!regex.is_euc_jp());
1154 assert!(regex.is_ascii_8bit());
1155 assert!(!regex.is_windows_31j());
1156 assert!(!regex.is_utf_8());
1157 assert!(!regex.is_once());
1158
1159 let source = r"
1160/a/s
1161";
1162 let result = parse(source.as_ref());
1163
1164 let node = result.node();
1165 let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1166 let regex = regex.as_regular_expression_node().unwrap();
1167 assert!(!regex.is_ignore_case());
1168 assert!(!regex.is_extended());
1169 assert!(!regex.is_multi_line());
1170 assert!(!regex.is_euc_jp());
1171 assert!(!regex.is_ascii_8bit());
1172 assert!(regex.is_windows_31j());
1173 assert!(!regex.is_utf_8());
1174 assert!(!regex.is_once());
1175
1176 let source = r"
1177/a/u
1178";
1179 let result = parse(source.as_ref());
1180
1181 let node = result.node();
1182 let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1183 let regex = regex.as_regular_expression_node().unwrap();
1184 assert!(!regex.is_ignore_case());
1185 assert!(!regex.is_extended());
1186 assert!(!regex.is_multi_line());
1187 assert!(!regex.is_euc_jp());
1188 assert!(!regex.is_ascii_8bit());
1189 assert!(!regex.is_windows_31j());
1190 assert!(regex.is_utf_8());
1191 assert!(!regex.is_once());
1192
1193 let source = r"
1194/a/o
1195";
1196 let result = parse(source.as_ref());
1197
1198 let node = result.node();
1199 let regex = node.as_program_node().unwrap().statements().body().iter().next().unwrap();
1200 let regex = regex.as_regular_expression_node().unwrap();
1201 assert!(!regex.is_ignore_case());
1202 assert!(!regex.is_extended());
1203 assert!(!regex.is_multi_line());
1204 assert!(!regex.is_euc_jp());
1205 assert!(!regex.is_ascii_8bit());
1206 assert!(!regex.is_windows_31j());
1207 assert!(!regex.is_utf_8());
1208 assert!(regex.is_once());
1209 }
1210
1211 #[test]
1212 fn visitor_traversal_test() {
1213 use crate::{Node, Visit};
1214
1215 #[derive(Default)]
1216 struct NodeCounts {
1217 pre_parent: usize,
1218 post_parent: usize,
1219 pre_leaf: usize,
1220 post_leaf: usize,
1221 }
1222
1223 #[derive(Default)]
1224 struct CountingVisitor {
1225 counts: NodeCounts,
1226 }
1227
1228 impl Visit<'_> for CountingVisitor {
1229 fn visit_branch_node_enter(&mut self, _node: Node<'_>) {
1230 self.counts.pre_parent += 1;
1231 }
1232
1233 fn visit_branch_node_leave(&mut self) {
1234 self.counts.post_parent += 1;
1235 }
1236
1237 fn visit_leaf_node_enter(&mut self, _node: Node<'_>) {
1238 self.counts.pre_leaf += 1;
1239 }
1240
1241 fn visit_leaf_node_leave(&mut self) {
1242 self.counts.post_leaf += 1;
1243 }
1244 }
1245
1246 let source = r"
1247module Example
1248 x = call_func(3, 4)
1249 y = x.call_func 5, 6
1250end
1251";
1252 let result = parse(source.as_ref());
1253 let node = result.node();
1254 let mut visitor = CountingVisitor::default();
1255 visitor.visit(&node);
1256
1257 assert_eq!(7, visitor.counts.pre_parent);
1258 assert_eq!(7, visitor.counts.post_parent);
1259 assert_eq!(6, visitor.counts.pre_leaf);
1260 assert_eq!(6, visitor.counts.post_leaf);
1261 }
1262
1263 #[test]
1264 fn visitor_lifetime_test() {
1265 use crate::{Node, Visit};
1266
1267 #[derive(Default)]
1268 struct StackingNodeVisitor<'a> {
1269 stack: Vec<Node<'a>>,
1270 max_depth: usize,
1271 }
1272
1273 impl<'pr> Visit<'pr> for StackingNodeVisitor<'pr> {
1274 fn visit_branch_node_enter(&mut self, node: Node<'pr>) {
1275 self.stack.push(node);
1276 }
1277
1278 fn visit_branch_node_leave(&mut self) {
1279 self.stack.pop();
1280 }
1281
1282 fn visit_leaf_node_leave(&mut self) {
1283 self.max_depth = self.max_depth.max(self.stack.len());
1284 }
1285 }
1286
1287 let source = r"
1288module Example
1289 x = call_func(3, 4)
1290 y = x.call_func 5, 6
1291end
1292";
1293 let result = parse(source.as_ref());
1294 let node = result.node();
1295 let mut visitor = StackingNodeVisitor::default();
1296 visitor.visit(&node);
1297
1298 assert_eq!(0, visitor.stack.len());
1299 assert_eq!(5, visitor.max_depth);
1300 }
1301
1302 #[test]
1303 fn integer_value_test() {
1304 let result = parse("0xA".as_ref());
1305 let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1306 let value: i32 = integer.try_into().unwrap();
1307
1308 assert_eq!(value, 10);
1309 }
1310
1311 #[test]
1312 fn integer_small_value_to_u32_digits_test() {
1313 let result = parse("0xA".as_ref());
1314 let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1315 let (negative, digits) = integer.to_u32_digits();
1316
1317 assert!(!negative);
1318 assert_eq!(digits, &[10]);
1319 }
1320
1321 #[test]
1322 fn integer_large_value_to_u32_digits_test() {
1323 let result = parse("0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".as_ref());
1324 let integer = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_integer_node().unwrap().value();
1325 let (negative, digits) = integer.to_u32_digits();
1326
1327 assert!(!negative);
1328 assert_eq!(digits, &[4_294_967_295, 4_294_967_295, 4_294_967_295, 2_147_483_647]);
1329 }
1330
1331 #[test]
1332 fn float_value_test() {
1333 let result = parse("1.0".as_ref());
1334 let value: f64 = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_float_node().unwrap().value();
1335
1336 assert!((value - 1.0).abs() < f64::EPSILON);
1337 }
1338
1339 #[test]
1340 fn regex_value_test() {
1341 let result = parse(b"//");
1342 let node = result.node().as_program_node().unwrap().statements().body().iter().next().unwrap().as_regular_expression_node().unwrap();
1343 assert_eq!(node.unescaped(), b"");
1344 }
1345
1346 #[test]
1347 fn node_field_lifetime_test() {
1348 #![allow(clippy::needless_pass_by_value)]
1351
1352 use crate::Node;
1353
1354 #[derive(Default)]
1355 struct Extract<'pr> {
1356 scopes: Vec<crate::ConstantId<'pr>>,
1357 }
1358
1359 impl<'pr> Extract<'pr> {
1360 fn push_scope(&mut self, path: Node<'pr>) {
1361 if let Some(cread) = path.as_constant_read_node() {
1362 self.scopes.push(cread.name());
1363 } else if let Some(cpath) = path.as_constant_path_node() {
1364 if let Some(parent) = cpath.parent() {
1365 self.push_scope(parent);
1366 }
1367 self.scopes.push(cpath.name().unwrap());
1368 } else {
1369 panic!("Wrong node kind!");
1370 }
1371 }
1372 }
1373
1374 let source = "Some::Random::Constant";
1375 let result = parse(source.as_ref());
1376 let node = result.node();
1377 let mut extractor = Extract::default();
1378 extractor.push_scope(node.as_program_node().unwrap().statements().body().iter().next().unwrap());
1379 assert_eq!(3, extractor.scopes.len());
1380 }
1381
1382 #[test]
1383 fn malformed_shebang() {
1384 let source = "#!\x00";
1385 let result = parse(source.as_ref());
1386 assert!(result.errors().next().is_none());
1387 assert!(result.warnings().next().is_none());
1388 }
1389}