hyperon_macros/
lib.rs

1use proc_macro::*;
2
3#[cfg(test)]
4#[proc_macro]
5pub fn print_stream(input: TokenStream) -> TokenStream {
6    for i in input.into_iter() {
7        println!("{:?}", i);
8    }
9    TokenStream::new()
10}
11
12/// Constructs new Atom using MeTTa S-expressions syntax for expressions. Recognizes grounded
13/// strings, numbers and booleans. For other grounded symbols use braces.
14/// Macros has a performance penalty because it creates and uses an additional
15/// wrapper for grounded atoms.
16#[proc_macro]
17pub fn metta(input: TokenStream) -> TokenStream {
18    MettaConverter::new(input, PrinterMut::default()).run()
19}
20
21/// Similar to [metta!] but constructs a constant Atom. Main goal is to be able
22/// writing `const SOME_SYMBOL: Atom = metta_const!(SomeSymbol)`. This macros
23/// uses constant constructors for symbols, variables and expressions internally.
24/// Grounded values are not supported as they cannot be instantiated without
25/// allocating memory.
26#[proc_macro]
27pub fn metta_const(input: TokenStream) -> TokenStream {
28    MettaConverter::new(input, PrinterConst::default()).run()
29}
30
31#[derive(Debug)]
32enum InternalToken {
33    ExprStart((usize, usize), (usize, usize)),
34    ExprEnd((usize, usize), (usize, usize)),
35    TokenTree(TokenTree),
36    Space,
37}
38
39impl InternalToken {
40    fn range(span: Span) -> (usize, usize) {
41        (span.line(), span.column())
42    }
43
44    fn start_expr(tt: &TokenTree) -> Self {
45        let (l, c) = Self::range(tt.span().start());
46        Self::ExprStart((l, c), (l, c+1))
47    }
48
49    fn end_expr(tt: &TokenTree) -> Self {
50        let (l, c) = Self::range(tt.span().start());
51        Self::ExprEnd((l, c), (l, c+1))
52    }
53
54    fn start(&self) -> (usize, usize) {
55        match self {
56            Self::ExprStart(s, _) => *s,
57            Self::ExprEnd(s, _) => *s,
58            Self::TokenTree(tt) => Self::range(tt.span().start()),
59            Self::Space => unreachable!(),
60        }
61    }
62
63    fn end(&self) -> (usize, usize) {
64        match self {
65            Self::ExprStart(_, e) => *e,
66            Self::ExprEnd(_, e) => *e,
67            Self::TokenTree(tt) => Self::range(tt.span().end()),
68            Self::Space => unreachable!(),
69        }
70    }
71}
72
73#[derive(Debug)]
74enum TokenizerState {
75    Start,
76    Symbol(String),
77    Variable(String),
78    Sign(String),
79    Gnd(String, GndType),
80    Token(InternalToken),
81}
82
83#[derive(Debug)]
84enum GndType {
85    Int,
86    Float,
87    Str,
88    Bool,
89}
90
91#[derive(Debug)]
92enum Token {
93    ExprStart,
94    ExprEnd,
95    Int(String),
96    Float(String),
97    Str(String),
98    Bool(String),
99    Variable(String),
100    Symbol(String),
101    Gnd(Group),
102    End,
103}
104
105impl ToString for Token {
106    fn to_string(&self) -> String {
107        match self {
108            Self::Int(s) => s.clone(),
109            Self::Float(s) => s.clone(),
110            Self::Str(s) => s.clone(),
111            _ => todo!(),
112        }
113    }
114}
115
116struct Tokenizer {
117    state: TokenizerState,
118    input: Box<dyn Iterator<Item=InternalToken>>,
119}
120
121impl Tokenizer {
122    fn new(input: TokenStream) -> Self {
123        let mut prev_end = (0, 0);
124        let input = input.into_iter().flat_map(Self::unroll_group).flat_map(
125            move |it| -> Box<dyn Iterator<Item=InternalToken>> {
126                let is_space = prev_end != it.start();
127                prev_end = it.end();
128                if is_space {
129                    Box::new(std::iter::once(InternalToken::Space).chain(std::iter::once(it)))
130                } else {
131                    Box::new(std::iter::once(it))
132                }
133            });
134        Self {
135            state: TokenizerState::Start,
136            input: Box::new(input),
137        }
138    }
139
140    fn unroll_group(tt: TokenTree) -> Box<dyn Iterator<Item=InternalToken>> {
141        match &tt {
142            TokenTree::Group(g) if g.delimiter() == Delimiter::Parenthesis => {
143                let open = std::iter::once(InternalToken::start_expr(&tt));
144                let close = std::iter::once(InternalToken::end_expr(&tt));
145                Box::new(open.chain(g.stream().into_iter().flat_map(Self::unroll_group)).chain(close))
146            },
147            _ => {
148                Box::new(std::iter::once(InternalToken::TokenTree(tt)))
149            }
150        }
151    }
152
153    fn next(&mut self) -> Token {
154        type TS = TokenizerState;
155        type IT = InternalToken;
156        type T = Token;
157        loop {
158            let (state, it) = match std::mem::replace(&mut self.state, TS::Start) {
159                TS::Token(t) => (TS::Start, Some(t)),
160                state => (state, self.input.next()),
161            };
162            let (token, state) = match (state, it) {
163                (TS::Start, None) => return T::End,
164                (TS::Start, Some(IT::ExprStart(_, _))) => (Some(T::ExprStart), TS::Start),
165                (TS::Start, Some(IT::ExprEnd(_, _))) => (Some(T::ExprEnd), TS::Start),
166                (TS::Start, Some(IT::TokenTree(tt))) => {
167                    match tt {
168                        TokenTree::Literal(l) => {
169                            let s = l.to_string();
170                            let lit = litrs::Literal::parse(s.clone()).expect("Failed to parse literal");
171                            match lit {
172                                litrs::Literal::Integer(_) => (None, TS::Gnd(s, GndType::Int)),
173                                litrs::Literal::Float(_) => (None, TS::Gnd(s, GndType::Float)),
174                                litrs::Literal::String(_) => (None, TS::Gnd(s, GndType::Str)),
175                                _ => (None, TS::Symbol(s)), 
176                            }
177                        },
178                        TokenTree::Ident(i) => {
179                            let s = i.to_string();
180                            if s == "True" || s == "False" {
181                                (None, TS::Gnd(s, GndType::Bool))
182                            } else {
183                                (None, TS::Symbol(s))
184                            }
185                        },
186                        TokenTree::Punct(p) if p.as_char() == '$' => {
187                            (None, TS::Variable(String::new()))
188                        },
189                        TokenTree::Punct(p)
190                            if p.as_char() == '+' || p.as_char() == '-' =>
191                                (None, TS::Sign(p.to_string())),
192                        TokenTree::Group(g)
193                            if g.delimiter() == Delimiter::Brace =>
194                                (Some(T::Gnd(g)), TS::Start),
195                        tt => (None, TS::Symbol(tt.to_string())), 
196                    }
197                }
198
199                (TS::Sign(s), Some(IT::TokenTree(tt))) => {
200                    match &tt {
201                        TokenTree::Literal(l) => {
202                            let l = l.to_string();
203                            let lit = litrs::Literal::parse(l.clone()).expect("Failed to parse literal");
204                            let s = s + l.as_str();
205                            match lit {
206                                litrs::Literal::Integer(_) => (None, TS::Gnd(s, GndType::Int)),
207                                litrs::Literal::Float(_) => (None, TS::Gnd(s, GndType::Float)),
208                                _ => (None, TS::Symbol(s)), 
209                            }
210                        },
211                        _ => (None, TS::Symbol(s + tt.to_string().as_str())), 
212                    }
213                }
214                (TS::Sign(s), Some(t)) => (Some(T::Symbol(s)), TS::Token(t)),
215                (TS::Sign(s), None) => (Some(T::Symbol(s)), TS::Start),
216
217                (TS::Gnd(s, _), Some(IT::TokenTree(tt))) => (None, TS::Symbol(s + tt.to_string().as_str())),
218                (TS::Gnd(s, typ), Some(token)) => (Some(Self::gnd_to_token(s, typ)), TS::Token(token)),
219                (TS::Gnd(s, typ), None) => (Some(Self::gnd_to_token(s, typ)), TS::Start),
220
221                (TS::Symbol(s), Some(IT::TokenTree(tt))) => (None, TS::Symbol(s + tt.to_string().as_str())),
222                (TS::Symbol(s), Some(t)) => (Some(T::Symbol(s)), TS::Token(t)),
223                (TS::Symbol(s), None) => (Some(T::Symbol(s)), TS::Start),
224
225                (TS::Variable(s), Some(IT::TokenTree(tt))) => (None, TS::Variable(s + tt.to_string().as_str())),
226                (TS::Variable(s), Some(t)) => (Some(T::Variable(s)), TS::Token(t)),
227                (TS::Variable(s), None) => (Some(T::Variable(s)), TS::Start),
228
229                (TS::Start, Some(IT::Space)) => (None, TS::Start),
230
231                (TS::Token(_), _) => unreachable!(),
232            };
233
234            self.state = state;
235            if let Some(token) = token {
236                return token
237            }
238        }
239    }
240
241    fn gnd_to_token(s: String, t: GndType) -> Token {
242        match t {
243            GndType::Int => Token::Int(s),
244            GndType::Float => Token::Float(s),
245            GndType::Str => Token::Str(s),
246            GndType::Bool => Token::Bool(s.to_lowercase()),
247        }
248    }
249}
250
251struct PrinterBase {
252    output: Vec<(Delimiter, TokenStream)>,
253}
254
255impl Default for PrinterBase {
256    fn default() -> Self {
257        Self {
258            output: vec![(Delimiter::None, TokenStream::new())],
259        }
260    }
261}
262
263impl PrinterBase {
264    fn get_token_stream(&mut self) -> TokenStream {
265        assert!(self.output.len() == 1, "Unbalanced group");
266        self.output.pop().unwrap().1
267    }
268
269    fn push(&mut self, tt: TokenTree) -> &mut Self {
270        let (_, last) = self.output.last_mut().unwrap();
271        last.extend([tt].into_iter());
272        self
273    }
274
275    fn ident(&mut self, name: &str) -> &mut Self {
276        self.push(TokenTree::Ident(Ident::new(name, Span::call_site())))
277    }
278
279    fn punct(&mut self, chars: &str) -> &mut Self {
280        assert!(!chars.is_empty(), "Empty punct");
281        let mut chars = chars.chars().peekable();
282        let mut c = chars.next().unwrap();
283        while chars.peek().is_some()  {
284            let _ = self.push(TokenTree::Punct(Punct::new(c, Spacing::Joint)));
285            c = chars.next().unwrap();
286        }
287        self.push(TokenTree::Punct(Punct::new(c, Spacing::Alone)))
288    }
289
290    fn group(&mut self, d: char) -> &mut Self {
291        let (open, delimiter) = match d {
292            '(' => (true, Delimiter::Parenthesis),
293            '{' => (true, Delimiter::Brace),
294            '[' => (true, Delimiter::Bracket),
295            ')' => (false, Delimiter::Parenthesis),
296            '}' => (false, Delimiter::Brace),
297            ']' => (false, Delimiter::Bracket),
298            _ => panic!("Unexpected delimiter: {}", d),
299        };
300        if open {
301            self.output.push((delimiter, TokenStream::new()));
302            self
303        } else {
304            assert!(self.output.len() > 1, "Unbalanced group");
305            let (d, stream) = self.output.pop().unwrap();
306            assert!(d == delimiter, "Closing delimiter {:?} is not equal to opening one {:?}", delimiter, d);
307            self.push(TokenTree::Group(Group::new(delimiter, stream)))
308        }
309    }
310
311    fn literal(&mut self, lit: Literal) -> &mut Self {
312        self.push(TokenTree::Literal(lit))
313    }
314
315    fn string(&mut self, text: &str) -> &mut Self {
316        self.push(TokenTree::Literal(Literal::string(text)))
317    }
318
319    fn bool(&mut self, b: &str) {
320        self.ident("hyperon_atom").punct("::").ident("Atom").punct("::").ident("gnd").group('(')
321            .ident("hyperon_atom").punct("::").ident("gnd").punct("::").ident("bool").punct("::").ident("Bool").group('(')
322            .ident(b)
323            .group(')').group(')');
324    }
325
326    fn integer(&mut self, n: i64) {
327        self.ident("hyperon_atom").punct("::").ident("Atom").punct("::").ident("gnd").group('(')
328            .ident("hyperon_atom").punct("::").ident("gnd").punct("::").ident("number").punct("::").ident("Number").punct("::").ident("Integer").group('(')
329            .literal(Literal::i64_suffixed(n))
330            .group(')').group(')');
331    }
332
333    fn float(&mut self, f: f64) {
334        self.ident("hyperon_atom").punct("::").ident("Atom").punct("::").ident("gnd").group('(')
335            .ident("hyperon_atom").punct("::").ident("gnd").punct("::").ident("number").punct("::").ident("Number").punct("::").ident("Float").group('(')
336            .literal(Literal::f64_suffixed(f))
337            .group(')').group(')');
338    }
339
340    fn str(&mut self, s: &str) {
341        self.ident("hyperon_atom").punct("::").ident("Atom").punct("::").ident("gnd").group('(')
342            .ident("hyperon_atom").punct("::").ident("gnd").punct("::").ident("str").punct("::").ident("Str").punct("::").ident("from_str").group('(')
343            .literal(Literal::string(s))
344            .group(')').group(')');
345    }
346
347    fn gnd(&mut self, g: Group) {
348        self.group('(').punct("&&").ident("hyperon_atom").punct("::").ident("Wrap").group('(')
349            .push(TokenTree::Group(Group::new(Delimiter::Parenthesis, g.stream())))
350            .group(')').group(')')
351            .punct(".").ident("to_atom").group('(').group(')');
352    }
353
354    fn expr_delimiter(&mut self) {
355        self.punct(",");
356    }
357}
358
359trait Printer {
360    fn symbol(&mut self, name: &str);
361    fn variable(&mut self, name: &str);
362    fn bool(&mut self, b: &str);
363    fn integer(&mut self, n: i64);
364    fn float(&mut self, f: f64);
365    fn str(&mut self, s: &str);
366    fn gnd(&mut self, g: Group);
367    fn expr_start(&mut self);
368    fn expr_delimiter(&mut self);
369    fn expr_end(&mut self);
370    fn get_token_stream(&mut self) -> TokenStream;
371}
372
373#[repr(transparent)]
374#[derive(Default)]
375struct PrinterMut {
376    base: PrinterBase,
377}
378
379impl Printer for PrinterMut {
380    fn symbol(&mut self, name: &str) {
381        self.base.ident("hyperon_atom").punct("::").ident("Atom").punct("::").ident("Symbol").group('(')
382            .ident("hyperon_atom").punct("::").ident("SymbolAtom").punct("::").ident("new").group('(')
383            .ident("hyperon_common").punct("::").ident("unique_string").punct("::").ident("UniqueString").punct("::").ident("from").group('(')
384            .string(name)
385            .group(')').group(')').group(')');
386    }
387
388    fn variable(&mut self, name: &str) {
389        self.base.ident("hyperon_atom").punct("::").ident("Atom").punct("::").ident("Variable").group('(')
390            .ident("hyperon_atom").punct("::").ident("VariableAtom").punct("::").ident("new").group('(')
391            .ident("hyperon_common").punct("::").ident("unique_string").punct("::").ident("UniqueString").punct("::").ident("from").group('(')
392            .string(name)
393            .group(')').group(')').group(')');
394    }
395
396    fn expr_start(&mut self) { 
397        self.base.ident("hyperon_atom").punct("::").ident("Atom").punct("::").ident("Expression").group('(')
398            .ident("hyperon_atom").punct("::").ident("ExpressionAtom").punct("::").ident("new").group('(')
399            .ident("hyperon_common").punct("::").ident("collections").punct("::").ident("CowArray").punct("::").ident("from").group('(')
400            .group('[');
401    }
402
403    fn expr_end(&mut self) {
404        self.base.group(']')
405            .group(')').group(')').group(')');
406    }
407
408    fn bool(&mut self, b: &str) { self.base.bool(b) }
409    fn integer(&mut self, n: i64) { self.base.integer(n) }
410    fn float(&mut self, f: f64) { self.base.float(f) }
411    fn str(&mut self, s: &str) { self.base.str(s) }
412    fn gnd(&mut self, g: Group) { self.base.gnd(g) }
413    fn expr_delimiter(&mut self) { self.base.expr_delimiter() }
414    fn get_token_stream(&mut self) -> TokenStream { self.base.get_token_stream() }
415}
416
417#[repr(transparent)]
418#[derive(Default)]
419struct PrinterConst {
420    base: PrinterBase,
421}
422
423impl Printer for PrinterConst {
424    fn symbol(&mut self, name: &str) {
425        self.base.ident("hyperon_atom").punct("::").ident("Atom").punct("::").ident("Symbol").group('(')
426            .ident("hyperon_atom").punct("::").ident("SymbolAtom").punct("::").ident("new").group('(')
427            .ident("hyperon_common").punct("::").ident("unique_string").punct("::").ident("UniqueString").punct("::").ident("Const").group('(')
428            .string(name)
429            .group(')').group(')').group(')');
430    }
431
432    fn variable(&mut self, name: &str) {
433        self.base.ident("hyperon_atom").punct("::").ident("Atom").punct("::").ident("Variable").group('(')
434            .ident("hyperon_atom").punct("::").ident("VariableAtom").punct("::").ident("new_const").group('(')
435            .ident("hyperon_common").punct("::").ident("unique_string").punct("::").ident("UniqueString").punct("::").ident("Const").group('(')
436            .string(name)
437            .group(')').group(')').group(')');
438    }
439
440    fn expr_start(&mut self) {
441        self.base.ident("hyperon_atom").punct("::").ident("Atom").punct("::").ident("Expression").group('(')
442            .ident("hyperon_atom").punct("::").ident("ExpressionAtom").punct("::").ident("new").group('(')
443            .ident("hyperon_common").punct("::").ident("collections").punct("::").ident("CowArray").punct("::").ident("Literal").group('(')
444            .ident("const").group('{').punct("&").group('[');
445    }
446
447    fn expr_end(&mut self) {
448        self.base.group(']').group('}')
449            .group(')').group(')').group(')');
450    }
451
452    fn bool(&mut self, _b: &str) { panic!("Grounded atoms cannot be instantiated as const") }
453    fn integer(&mut self, _n: i64) { panic!("Grounded atoms cannot be instantiated as const") }
454    fn float(&mut self, _f: f64) { panic!("Grounded atoms cannot be instantiated as const") }
455    fn str(&mut self, _s: &str) { panic!("Grounded atoms cannot be instantiated as const") }
456    fn gnd(&mut self, _g: Group) { panic!("Grounded atoms cannot be instantiated as const") }
457    fn expr_delimiter(&mut self) { self.base.expr_delimiter() }
458    fn get_token_stream(&mut self) -> TokenStream { self.base.get_token_stream() }
459}
460
461#[derive(Debug, PartialEq, Clone, Copy)]
462enum State {
463    Start,
464    ExprStart(usize),
465    Expression(usize),
466    Final,
467}
468
469struct MettaConverter<P: Printer> {
470    state: State,
471    input: Tokenizer,
472    output: P,
473}
474
475impl<P: Printer> MettaConverter<P> {
476    fn new(input: TokenStream, output: P) -> Self {
477        Self{
478            state: State::Start,
479            input: Tokenizer::new(input),
480            output,
481        }
482    }
483
484    fn run(&mut self) -> TokenStream {
485        loop {
486            if self.state == State::Final {
487                break
488            }
489            self.next_state();
490        }
491        self.output.get_token_stream()
492    }
493
494    fn next_state(&mut self) {
495        let token = self.input.next();
496
497        if matches!(token, Token::End) {
498            if !matches!(self.state, State::Start) {
499                panic!("Unexpected expression end");
500            }
501            self.state = State::Final;
502            return;
503        }
504
505        if matches!(self.state, State::Expression(_))
506            && !matches!(token, Token::ExprEnd) {
507                self.output.expr_delimiter();
508        }
509
510        let mut next_state = self.state;
511
512        match token {
513            Token::Symbol(s) => self.output.symbol(&s),
514            Token::Variable(v) => self.output.variable(&v),
515            Token::Int(s) => self.output.integer(s.parse::<i64>().unwrap()),
516            Token::Float(s) => self.output.float(s.parse::<f64>().unwrap()),
517            Token::Str(s) => self.output.str(&s[1..s.len() - 1]),
518            Token::Bool(s) => self.output.bool(&s),
519            Token::Gnd(g) => self.output.gnd(g),
520
521            Token::ExprStart => {
522                self.output.expr_start();
523                next_state = match self.state {
524                    State::Start => State::ExprStart(1),
525                    State::ExprStart(n) => State::ExprStart(n + 1),
526                    State::Expression(n) => State::ExprStart(n + 1),
527                    State::Final => unreachable!(),
528                };
529            },
530            Token::ExprEnd => {
531                next_state = match self.state {
532                    State::Start => panic!("Unexpected end of expression"),
533                    State::ExprStart(1) => State::Start,
534                    State::ExprStart(n) => State::Expression(n - 1),
535                    State::Expression(1) => State::Start,
536                    State::Expression(n) => State::Expression(n - 1),
537                    State::Final => unreachable!(),
538                };
539                self.output.expr_end();
540            },
541
542            Token::End => unreachable!(),
543        }
544
545        if let State::ExprStart(n) = self.state {
546            if next_state == self.state {
547                next_state = State::Expression(n);
548            }
549        }
550        self.state = next_state;
551    }
552}