diff --git a/dub.sdl b/dub.sdl index bfd34b7..5808c6a 100644 --- a/dub.sdl +++ b/dub.sdl @@ -8,6 +8,8 @@ dependency "taggedalgebraic" version="~>0.11.23" targetType "executable" buildRequirements "requireBoundsCheck" "requireContracts" +versions "LoxConcatNonStrings" "LoxExtraNativeFuncs" + configuration "jlox" { sourcePaths "src/jlox" } diff --git a/src/jlox/expr.d b/src/jlox/expr.d index ddf4df5..e942d4d 100644 --- a/src/jlox/expr.d +++ b/src/jlox/expr.d @@ -14,6 +14,7 @@ abstract class Expr{ interface Visitor(R){ R visit(Assign expr); R visit(Binary expr); + R visit(Call expr); R visit(Grouping expr); R visit(Literal expr); R visit(Logical expr); @@ -40,6 +41,12 @@ abstract class Expr{ Expr right; mixin defCtorAndAccept; } + static class Call : typeof(this){ + Expr callee; + Token paren; + Expr[] arguments; + mixin defCtorAndAccept; + } static class Grouping : typeof(this){ Expr expression; mixin defCtorAndAccept; diff --git a/src/jlox/interpreter.d b/src/jlox/interpreter.d index 1fde9e2..fd61716 100644 --- a/src/jlox/interpreter.d +++ b/src/jlox/interpreter.d @@ -2,11 +2,11 @@ module jlox.interpreter; import std.conv; import std.stdio; +import std.algorithm; +import std.array; import std.format : format; import std.functional : ctEval; -import taggedalgebraic; - import jlox.expr; import jlox.stmt; import jlox.token; @@ -14,6 +14,7 @@ import jlox.tokentype; import jlox.token : TValue; import jlox.main; import jlox.environment; +import jlox.loxfunction; class RuntimeError : Exception{ const Token token; @@ -22,9 +23,38 @@ class RuntimeError : Exception{ this.token = token; } } +class Return : Exception{ + const TValue value; + this(TValue value){ + super(null); + this.value = value; + } +} class Interpreter : Stmt.Visitor!void, Expr.Visitor!TValue { - private Environment environment = new Environment(); + Environment globals = new Environment(); + private Environment environment; + this(){ + environment = globals; + + import std.datetime.stopwatch; + auto sw = StopWatch(AutoStart.yes); + globals.define("clock", TValue.cal(new class LoxCallable{ + int arity() => 0; + TValue call(Interpreter interpreter, TValue[] arguments) => TValue.dbl(sw.peek.total!"usecs" / (1000.0 * 1000.0)); + })); + + version(LoxExtraNativeFuncs){ + globals.define("sleep", TValue.cal(new class LoxCallable{ + int arity() => 1; + import core.thread.osthread; + TValue call(Interpreter interpreter, TValue[] arguments){ + Thread.sleep(dur!"usecs"(cast(long)(arguments[0].dblValue * 1000 * 1000))); + return TValue.nil(tvalueNil); + } + })); + } + } void interpret(Stmt[] statements){ try { foreach(statement; statements) @@ -33,10 +63,10 @@ class Interpreter : Stmt.Visitor!void, Expr.Visitor!TValue { Lox.runtimeError(error); } } - private void execute(Stmt stmt){ + package void execute(Stmt stmt){ stmt.accept(this); } - private void executeBlock(Stmt[] statements, Environment environment){ + package void executeBlock(Stmt[] statements, Environment environment){ Environment previous = this.environment; try { this.environment = environment; @@ -75,6 +105,10 @@ class Interpreter : Stmt.Visitor!void, Expr.Visitor!TValue { void visit(Stmt.Expression stmt){ evaluate(stmt.expression); } + void visit(Stmt.Function stmt){ + LoxFunction func = new LoxFunction(stmt); + environment.define(stmt.name.lexeme, TValue.cal(func)); + } void visit(Stmt.If stmt){ if(isTruthy(evaluate(stmt.condition))) execute(stmt.thenBranch); @@ -85,8 +119,12 @@ class Interpreter : Stmt.Visitor!void, Expr.Visitor!TValue { TValue value = evaluate(stmt.expression); writeln(tvalueToString(value)); } + void visit(Stmt.Return stmt){ + TValue value = stmt.value !is null ? evaluate(stmt.value) : TValue.nil(tvalueNil); + throw new Return(value); + } void visit(Stmt.Var stmt){ - environment.define(stmt.name.lexeme, stmt.initialiser is null ? TValue.nil(0) : evaluate(stmt.initialiser)); + environment.define(stmt.name.lexeme, stmt.initialiser is null ? TValue.nil(tvalueNil) : evaluate(stmt.initialiser)); } void visit(Stmt.While stmt){ while(isTruthy(evaluate(stmt.condition))) @@ -126,28 +164,30 @@ class Interpreter : Stmt.Visitor!void, Expr.Visitor!TValue { TValue left = evaluate(expr.left); TValue right = evaluate(expr.right); static string m(TokenType t, string op, string v, string vv){ - return q{case %s: return TValue.%s( left.%s %s right.%s );}.format(t, v, vv, op, vv); - } - with(TokenType) switch(expr.operator.type){ - static foreach(t, op; [ MINUS: "-", SLASH: "/", STAR: "*" ]){ + return q{case %s: checkNumberOperand(expr.operator, left); checkNumberOperand(expr.operator, right); + return TValue.%s( left.%s %s right.%s ); + }.format(t, v, vv, op, vv); + } + with(TokenType) switch(expr.operator.type){ + static foreach(t, op; [ MINUS: "-", SLASH: "/", STAR: "*" ]) mixin(ctEval!(m(t, op, "dbl", "dblValue"))); - } + case PLUS: if(left.isDbl && right.isDbl) return TValue.dbl(left.dblValue + right.dblValue); else if(left.isStr && right.isStr) return TValue.str(left.strValue ~ right.strValue); + version(LoxConcatNonStrings){ + if(left.isStr || right.isStr) + return TValue.str(tvalueToString(left) ~ tvalueToString(right)); + } checkNumberOperand(expr.operator, left); checkNumberOperand(expr.operator, right); - assert(0); - static foreach(t, op; [ GREATER: ">", GREATER_EQUAL: ">=", LESS: "<", LESS_EQUAL: "<=" ]){ - checkNumberOperand(expr.operator, left); - checkNumberOperand(expr.operator, right); + static foreach(t, op; [ GREATER: ">", GREATER_EQUAL: ">=", LESS: "<", LESS_EQUAL: "<=" ]) mixin(ctEval!(m(t, op, "bln", "dblValue"))); - } case BANG_EQUAL: return TValue.bln(!isEqual(left, right)); @@ -157,6 +197,16 @@ class Interpreter : Stmt.Visitor!void, Expr.Visitor!TValue { assert(0); } } + TValue visit(Expr.Call expr){ + TValue callee = evaluate(expr.callee); + if(!callee.isCal) + throw new RuntimeError(expr.paren, "Can only call functions and classes."); + auto arguments = expr.arguments.map!(a => evaluate(a)); + LoxCallable func = callee.calValue; + if(arguments.length != func.arity()) + throw new RuntimeError(expr.paren, "Expected " ~ func.arity().to!string ~ " arguments but got " ~ arguments.length.to!string ~ "."); + return func.call(this, arguments.array); + } TValue visit(Expr.Variable expr){ return environment.get(expr.name); } diff --git a/src/jlox/loxfunction.d b/src/jlox/loxfunction.d new file mode 100644 index 0000000..39a91ff --- /dev/null +++ b/src/jlox/loxfunction.d @@ -0,0 +1,28 @@ +module jlox.loxfunction; + +import std.conv; + +import jlox.token; +import jlox.stmt; +import jlox.interpreter; +import jlox.environment; +import jlox.util; + +class LoxFunction : LoxCallable{ + private Stmt.Function declaration; + mixin defaultCtor; + + int arity() => declaration.params.length.to!int; + TValue call(Interpreter interpreter, TValue[] arguments){ + Environment environment = new Environment(interpreter.globals); + foreach(i; 0 .. declaration.params.length) + environment.define(declaration.params[i].lexeme, arguments[i]); + try{ + interpreter.executeBlock(declaration.body, environment); + } catch(Return returnValue){ + return returnValue.value; + } + return TValue.nil(tvalueNil); + } +} + diff --git a/src/jlox/parser.d b/src/jlox/parser.d index 5cc5972..4f88207 100644 --- a/src/jlox/parser.d +++ b/src/jlox/parser.d @@ -8,6 +8,7 @@ import jlox.util; import jlox.expr; import jlox.main; import jlox.stmt; +import jlox.loxfunction; class Parser{ private Token[] tokens; @@ -84,11 +85,35 @@ class Parser{ consume(TokenType.SEMICOLON, "Expect ';' after value."); return new Stmt.Print(value); } + private Stmt returnStatement(){ + Token keyword = previous(); + Expr value; + if(!check(TokenType.SEMICOLON)) + value = expression(); + consume(TokenType.SEMICOLON, "Expect ';' after return value."); + return new Stmt.Return(keyword, value); + } private Stmt expressionStatement(){ Expr expr = expression(); consume(TokenType.SEMICOLON, "Expect ';' after expression."); return new Stmt.Expression(expr); } + private Stmt.Function fun(string kind){ + Token name = consume(TokenType.IDENTIFIER, "Expect " ~ kind ~ " name."); + consume(TokenType.LEFT_PAREN, "Expect '(' after " ~ kind ~ " name."); + Token[] parameters; + if(!check(TokenType.RIGHT_PAREN)){ + do{ + if(parameters.length >= 255) + error(peek(), "Can't have more than 255 parameters."); + parameters ~= consume(TokenType.IDENTIFIER, "Expect parameter name."); + } while(match(TokenType.COMMA)); + } + consume(TokenType.RIGHT_PAREN, "Expect ')' after parameters."); + consume(TokenType.LEFT_BRACE, "Expect '{' before " ~ kind ~ " body."); + Stmt[] body = block(); + return new Stmt.Function(name, parameters, body); + } private Stmt ifStatement(){ consume(TokenType.LEFT_PAREN, "Expect '(' after 'if'."); Expr condition = expression(); @@ -142,6 +167,8 @@ class Parser{ return whileStatement(); if(match(TokenType.PRINT)) return printStatement(); + if(match(TokenType.RETURN)) + return returnStatement(); if(match(TokenType.LEFT_BRACE)) return new Stmt.Block(block()); return expressionStatement(); @@ -227,7 +254,29 @@ class Parser{ Expr right = unary(); return new Expr.Unary(operator, right); } - return primary(); + return call(); + } + private Expr call(){ + Expr finishCall(Expr callee){ + Expr[] arguments; + if(!check(TokenType.RIGHT_PAREN)){ + do { + if(arguments.length >= 255) + error(peek(), "Can't have more than 255 arguments."); + arguments ~= expression(); + } while(match(TokenType.COMMA)); + } + Token paren = consume(TokenType.RIGHT_PAREN, "Expect ')' after arguments."); + return new Expr.Call(callee, paren, arguments); + } + Expr expr = primary(); + while(true){ + if(match(TokenType.LEFT_PAREN)) + expr = finishCall(expr); + else + break; + } + return expr; } private Expr primary(){ if(match(TokenType.IDENTIFIER)) @@ -237,7 +286,7 @@ class Parser{ if(match(TokenType.TRUE)) return new Expr.Literal(TValue.bln(true)); if(match(TokenType.NIL)) - return new Expr.Literal(TValue.nil(0)); + return new Expr.Literal(tvalueNil); if(match(TokenType.NUMBER, TokenType.STRING)) return new Expr.Literal(previous().literal); if(match(TokenType.LEFT_PAREN)){ @@ -257,6 +306,8 @@ class Parser{ } private Stmt declaration(){ try { + if(match(TokenType.FUN)) + return fun("function"); if(match(TokenType.VAR)) return varDeclaration(); return statement(); diff --git a/src/jlox/scanner.d b/src/jlox/scanner.d index bcb7740..5d0a412 100644 --- a/src/jlox/scanner.d +++ b/src/jlox/scanner.d @@ -30,7 +30,7 @@ class Scanner { current++; return true; } - private void addToken(TokenType type, TValue literal = TValue.nil(0)){ + private void addToken(TokenType type, TValue literal = TValue.nil(tvalueNil)){ string text = source[start .. current]; tokens ~= new Token(type, text, literal, line); } diff --git a/src/jlox/stmt.d b/src/jlox/stmt.d index f088a41..8310bd6 100644 --- a/src/jlox/stmt.d +++ b/src/jlox/stmt.d @@ -13,8 +13,10 @@ abstract class Stmt{ interface Visitor(R){ R visit(Block expr); R visit(Expression expr); + R visit(Function expr); R visit(If expr); R visit(Print expr); + R visit(Return expr); R visit(Var expr); R visit(While expr); } @@ -35,6 +37,12 @@ abstract class Stmt{ Expr expression; mixin defCtorAndAccept; } + static class Function : typeof(this){ + Token name; + Token[] params; + Stmt[] body; + mixin defCtorAndAccept; + } static class If : typeof(this){ Expr condition; Stmt thenBranch; @@ -45,6 +53,11 @@ abstract class Stmt{ Expr expression; mixin defCtorAndAccept; } + static class Return : typeof(this){ + Token keyword; + Expr value; + mixin defCtorAndAccept; + } static class Var : typeof(this){ Token name; Expr initialiser; diff --git a/src/jlox/token.d b/src/jlox/token.d index 83418a8..fb2c117 100644 --- a/src/jlox/token.d +++ b/src/jlox/token.d @@ -5,14 +5,22 @@ import std.conv; import taggedalgebraic; import jlox.tokentype; +import jlox.interpreter; +interface LoxCallable{ + int arity(); + TValue call(Interpreter interpreter, TValue[] arguments); +} private struct Value{ string str; double dbl; bool bln; - bool nil = false; + LoxCallable cal; + struct Nil{} + Nil nil; } alias TValue = TaggedUnion!Value; +immutable tvalueNil = Value.Nil(); string tvalueToString(TValue val){ final switch(val.kind){ case TValue.Kind.str: @@ -21,6 +29,8 @@ string tvalueToString(TValue val){ return val.dblValue.to!string; case TValue.Kind.bln: return val.blnValue ? "true" : "false"; + case TValue.Kind.cal: + return ""; case TValue.Kind.nil: return "nil"; } diff --git a/test/all.d b/test/all.d index 14ca0b7..d22369c 100755 --- a/test/all.d +++ b/test/all.d @@ -4,17 +4,18 @@ import std.process; import std.conv; void main(){ - string fib21(){ + string fib(uint n){ string r = ""; double a = 0; double temp; - for(double b = 1; a < 10000; b = temp + b){ + for(double b = 1; a <= n; b = temp + b){ r ~= a.to!string ~ "\n"; temp = a; a = b; } return r; } - assert([ "./lox", "test/fib21.lox" ].execute.output == fib21()); + assert([ "./lox", "test/fib21.lox" ].execute.output == fib(6765)); + assert([ "./lox", "test/fib10.lox" ].execute.output == fib(34)); } diff --git a/test/fib10.lox b/test/fib10.lox new file mode 100644 index 0000000..a991ddd --- /dev/null +++ b/test/fib10.lox @@ -0,0 +1,11 @@ + +fun fib(n){ + if(n <= 1) + return n; + return fib(n - 2) + fib(n - 1); +} + +for(var i = 0; i < 10; i = i + 1){ + print fib(i); +} +