diff --git a/compiler/compiler.go b/compiler/compiler.go index eb7615c..d87a544 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -5,26 +5,64 @@ import ( "git.sharkk.net/Sharkk/Mako/vm" ) -// Compiler converts AST to bytecode +// Compile converts AST to bytecode func Compile(program *parser.Program) *vm.Bytecode { c := &compiler{ constants: []any{}, instructions: []vm.Instruction{}, + scopes: []scope{}, } + // Start in global scope + c.enterScope() + for _, stmt := range program.Statements { c.compileStatement(stmt) } + c.exitScope() + return &vm.Bytecode{ Constants: c.constants, Instructions: c.instructions, } } +type scope struct { + variables map[string]bool +} + type compiler struct { constants []any instructions []vm.Instruction + scopes []scope +} + +func (c *compiler) enterScope() { + c.scopes = append(c.scopes, scope{ + variables: make(map[string]bool), + }) + c.emit(vm.OpEnterScope, 0) +} + +func (c *compiler) exitScope() { + c.scopes = c.scopes[:len(c.scopes)-1] + c.emit(vm.OpExitScope, 0) +} + +func (c *compiler) declareVariable(name string) { + if len(c.scopes) > 0 { + c.scopes[len(c.scopes)-1].variables[name] = true + } +} + +func (c *compiler) isLocalVariable(name string) bool { + for i := len(c.scopes) - 1; i >= 0; i-- { + if _, ok := c.scopes[i].variables[name]; ok { + return true + } + } + return false } func (c *compiler) compileStatement(stmt parser.Statement) { @@ -32,10 +70,31 @@ func (c *compiler) compileStatement(stmt parser.Statement) { case *parser.VariableStatement: c.compileExpression(s.Value) nameIndex := c.addConstant(s.Name.Value) - c.emit(vm.OpSetGlobal, nameIndex) + + // Use SetGlobal for top-level variables to persist between REPL lines + if len(c.scopes) <= 1 { + c.emit(vm.OpSetGlobal, nameIndex) + } else { + c.declareVariable(s.Name.Value) + c.emit(vm.OpSetLocal, nameIndex) + } + + case *parser.IndexAssignmentStatement: + c.compileExpression(s.Left) + c.compileExpression(s.Index) + c.compileExpression(s.Value) + c.emit(vm.OpSetIndex, 0) + case *parser.EchoStatement: c.compileExpression(s.Value) c.emit(vm.OpEcho, 0) + + case *parser.BlockStatement: + c.enterScope() + for _, blockStmt := range s.Statements { + c.compileStatement(blockStmt) + } + c.exitScope() } } @@ -44,12 +103,47 @@ func (c *compiler) compileExpression(expr parser.Expression) { case *parser.StringLiteral: constIndex := c.addConstant(e.Value) c.emit(vm.OpConstant, constIndex) + case *parser.NumberLiteral: constIndex := c.addConstant(e.Value) c.emit(vm.OpConstant, constIndex) + case *parser.Identifier: nameIndex := c.addConstant(e.Value) - c.emit(vm.OpGetGlobal, nameIndex) + + // Check if it's a local variable first + if c.isLocalVariable(e.Value) { + c.emit(vm.OpGetLocal, nameIndex) + } else { + // Otherwise treat as global + c.emit(vm.OpGetGlobal, nameIndex) + } + + case *parser.TableLiteral: + c.emit(vm.OpNewTable, 0) + + for key, value := range e.Pairs { + c.emit(vm.OpDup, 0) + + // Special handling for identifier keys in tables + if ident, ok := key.(*parser.Identifier); ok { + // Treat identifiers as string literals in table keys + strIndex := c.addConstant(ident.Value) + c.emit(vm.OpConstant, strIndex) + } else { + // For other expressions, compile normally + c.compileExpression(key) + } + + c.compileExpression(value) + c.emit(vm.OpSetIndex, 0) + c.emit(vm.OpPop, 0) + } + + case *parser.IndexExpression: + c.compileExpression(e.Left) + c.compileExpression(e.Index) + c.emit(vm.OpGetIndex, 0) } } diff --git a/lexer/lexer.go b/lexer/lexer.go index 75382aa..c692242 100644 --- a/lexer/lexer.go +++ b/lexer/lexer.go @@ -10,6 +10,11 @@ const ( TokenEqual TokenEcho TokenSemicolon + TokenLeftBrace + TokenRightBrace + TokenLeftBracket + TokenRightBracket + TokenComma ) type Token struct { @@ -55,6 +60,16 @@ func (l *Lexer) NextToken() Token { return tok case 0: tok = Token{Type: TokenEOF, Value: ""} + case '{': + tok = Token{Type: TokenLeftBrace, Value: "{"} + case '}': + tok = Token{Type: TokenRightBrace, Value: "}"} + case '[': + tok = Token{Type: TokenLeftBracket, Value: "["} + case ']': + tok = Token{Type: TokenRightBracket, Value: "]"} + case ',': + tok = Token{Type: TokenComma, Value: ","} default: if isLetter(l.ch) { tok.Value = l.readIdentifier() diff --git a/main.go b/main.go index faad6c4..80fe43f 100644 --- a/main.go +++ b/main.go @@ -16,7 +16,7 @@ func main() { scanner := bufio.NewScanner(os.Stdin) virtualMachine := vm.New() - fmt.Println("LuaGo Interpreter (type 'exit' to quit)") + fmt.Println("Mako REPL (type 'exit' to quit)") for { fmt.Print(">> ") if !scanner.Scan() { diff --git a/parser/ast.go b/parser/ast.go index dd3780c..9b7c66d 100644 --- a/parser/ast.go +++ b/parser/ast.go @@ -67,3 +67,42 @@ type NumberLiteral struct { func (nl *NumberLiteral) expressionNode() {} func (nl *NumberLiteral) TokenLiteral() string { return nl.Token.Value } + +// TableLiteral represents a table: {key1 = val1, key2 = val2} +type TableLiteral struct { + Token lexer.Token + Pairs map[Expression]Expression +} + +func (tl *TableLiteral) expressionNode() {} +func (tl *TableLiteral) TokenLiteral() string { return tl.Token.Value } + +// IndexExpression represents table access: table[key] +type IndexExpression struct { + Token lexer.Token + Left Expression + Index Expression +} + +func (ie *IndexExpression) expressionNode() {} +func (ie *IndexExpression) TokenLiteral() string { return ie.Token.Value } + +// IndexAssignmentStatement represents: table[key] = value +type IndexAssignmentStatement struct { + Token lexer.Token + Left Expression + Index Expression + Value Expression +} + +func (ias *IndexAssignmentStatement) statementNode() {} +func (ias *IndexAssignmentStatement) TokenLiteral() string { return ias.Token.Value } + +// BlockStatement represents a block of code enclosed in braces +type BlockStatement struct { + Token lexer.Token + Statements []Statement +} + +func (bs *BlockStatement) statementNode() {} +func (bs *BlockStatement) TokenLiteral() string { return bs.Token.Value } diff --git a/parser/parser.go b/parser/parser.go index 459f9bd..7492f04 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -49,13 +49,34 @@ func (p *Parser) parseStatement() Statement { case lexer.TokenIdentifier: if p.peekToken.Type == lexer.TokenEqual { return p.parseVariableStatement() + } else if p.peekToken.Type == lexer.TokenLeftBracket { + return p.parseIndexAssignmentStatement() } case lexer.TokenEcho: return p.parseEchoStatement() + case lexer.TokenLeftBrace: + return p.parseBlockStatement() } return nil } +func (p *Parser) parseBlockStatement() *BlockStatement { + block := &BlockStatement{Token: p.curToken} + block.Statements = []Statement{} + + p.nextToken() // Skip '{' + + for p.curToken.Type != lexer.TokenRightBrace && p.curToken.Type != lexer.TokenEOF { + stmt := p.parseStatement() + if stmt != nil { + block.Statements = append(block.Statements, stmt) + } + p.nextToken() + } + + return block +} + func (p *Parser) parseVariableStatement() *VariableStatement { stmt := &VariableStatement{Token: p.curToken} @@ -64,18 +85,7 @@ func (p *Parser) parseVariableStatement() *VariableStatement { p.nextToken() // Skip identifier p.nextToken() // Skip = - switch p.curToken.Type { - case lexer.TokenString: - stmt.Value = &StringLiteral{Token: p.curToken, Value: p.curToken.Value} - case lexer.TokenNumber: - num, err := strconv.ParseFloat(p.curToken.Value, 64) - if err != nil { - p.errors = append(p.errors, fmt.Sprintf("could not parse %q as float", p.curToken.Value)) - } - stmt.Value = &NumberLiteral{Token: p.curToken, Value: num} - case lexer.TokenIdentifier: - stmt.Value = &Identifier{Token: p.curToken, Value: p.curToken.Value} - } + stmt.Value = p.parseExpression() if p.peekToken.Type == lexer.TokenSemicolon { p.nextToken() @@ -89,18 +99,7 @@ func (p *Parser) parseEchoStatement() *EchoStatement { p.nextToken() - switch p.curToken.Type { - case lexer.TokenString: - stmt.Value = &StringLiteral{Token: p.curToken, Value: p.curToken.Value} - case lexer.TokenNumber: - num, err := strconv.ParseFloat(p.curToken.Value, 64) - if err != nil { - p.errors = append(p.errors, fmt.Sprintf("could not parse %q as float", p.curToken.Value)) - } - stmt.Value = &NumberLiteral{Token: p.curToken, Value: num} - case lexer.TokenIdentifier: - stmt.Value = &Identifier{Token: p.curToken, Value: p.curToken.Value} - } + stmt.Value = p.parseExpression() if p.peekToken.Type == lexer.TokenSemicolon { p.nextToken() @@ -108,3 +107,141 @@ func (p *Parser) parseEchoStatement() *EchoStatement { return stmt } + +func (p *Parser) parseIndexAssignmentStatement() *IndexAssignmentStatement { + stmt := &IndexAssignmentStatement{ + Token: p.curToken, + Left: &Identifier{Token: p.curToken, Value: p.curToken.Value}, + } + + p.nextToken() // Skip identifier + p.nextToken() // Skip '[' + + stmt.Index = p.parseExpression() + + if p.peekToken.Type != lexer.TokenRightBracket { + p.errors = append(p.errors, "expected ] after index expression") + return stmt + } + + p.nextToken() // Skip index + p.nextToken() // Skip ']' + + // Fix: Check current token, not peek token + if p.curToken.Type != lexer.TokenEqual { + p.errors = append(p.errors, "expected = after index expression") + return stmt + } + + p.nextToken() // Skip = + + stmt.Value = p.parseExpression() + + if p.peekToken.Type == lexer.TokenSemicolon { + p.nextToken() + } + + return stmt +} + +func (p *Parser) parseExpression() Expression { + switch p.curToken.Type { + case lexer.TokenString: + return &StringLiteral{Token: p.curToken, Value: p.curToken.Value} + case lexer.TokenNumber: + num, err := strconv.ParseFloat(p.curToken.Value, 64) + if err != nil { + p.errors = append(p.errors, fmt.Sprintf("could not parse %q as float", p.curToken.Value)) + } + return &NumberLiteral{Token: p.curToken, Value: num} + case lexer.TokenIdentifier: + if p.peekToken.Type == lexer.TokenLeftBracket { + return p.parseIndexExpression() + } + return &Identifier{Token: p.curToken, Value: p.curToken.Value} + case lexer.TokenLeftBrace: + return p.parseTableLiteral() + } + return nil +} + +func (p *Parser) parseTableLiteral() *TableLiteral { + table := &TableLiteral{ + Token: p.curToken, + Pairs: make(map[Expression]Expression), + } + + p.nextToken() // Skip '{' + + if p.curToken.Type == lexer.TokenRightBrace { + return table // Empty table + } + + // Parse the first key-value pair + key := p.parseExpression() + + if p.peekToken.Type != lexer.TokenEqual { + p.errors = append(p.errors, "expected = after table key") + return table + } + + p.nextToken() // Skip key + p.nextToken() // Skip = + + value := p.parseExpression() + table.Pairs[key] = value + + p.nextToken() // Skip value + + // Parse remaining key-value pairs + for p.curToken.Type == lexer.TokenComma { + p.nextToken() // Skip comma + + if p.curToken.Type == lexer.TokenRightBrace { + break // Allow trailing comma + } + + key = p.parseExpression() + + if p.peekToken.Type != lexer.TokenEqual { + p.errors = append(p.errors, "expected = after table key") + return table + } + + p.nextToken() // Skip key + p.nextToken() // Skip = + + value = p.parseExpression() + table.Pairs[key] = value + + p.nextToken() // Skip value + } + + if p.curToken.Type != lexer.TokenRightBrace { + p.errors = append(p.errors, "expected } or , after table entry") + } + + return table +} + +func (p *Parser) parseIndexExpression() *IndexExpression { + exp := &IndexExpression{ + Token: p.curToken, + Left: &Identifier{Token: p.curToken, Value: p.curToken.Value}, + } + + p.nextToken() // Skip identifier + p.nextToken() // Skip '[' + + exp.Index = p.parseExpression() + + if p.peekToken.Type != lexer.TokenRightBracket { + p.errors = append(p.errors, "expected ] after index expression") + return exp + } + + p.nextToken() // Skip index + p.nextToken() // Skip ']' + + return exp +} diff --git a/types/types.go b/types/types.go index 812d015..78e6210 100644 --- a/types/types.go +++ b/types/types.go @@ -7,6 +7,7 @@ const ( TypeNumber TypeString TypeBoolean + TypeTable // New type for tables ) type Value struct { @@ -14,6 +15,10 @@ type Value struct { Data any } +func NewNull() Value { + return Value{Type: TypeNull, Data: nil} +} + func NewString(s string) Value { return Value{Type: TypeString, Data: s} } @@ -22,10 +27,86 @@ func NewNumber(n float64) Value { return Value{Type: TypeNumber, Data: n} } -func NewBoolean(b bool) Value { - return Value{Type: TypeBoolean, Data: b} +// TableEntry maintains insertion order +type TableEntry struct { + Key Value + Value Value } -func NewNull() Value { - return Value{Type: TypeNull, Data: nil} +// Table with ordered entries +type Table struct { + Entries []TableEntry // Preserves insertion order + HashMap map[string]int // Fast lookups for string keys + NumMap map[float64]int // Fast lookups for number keys + BoolMap map[bool]int // Fast lookups for boolean keys +} + +func NewTable() *Table { + return &Table{ + Entries: []TableEntry{}, + HashMap: make(map[string]int), + NumMap: make(map[float64]int), + BoolMap: make(map[bool]int), + } +} + +func NewTableValue() Value { + return Value{Type: TypeTable, Data: NewTable()} +} + +// TableSet preserves insertion order +func (t *Table) Set(key, value Value) { + idx := -1 + + switch key.Type { + case TypeString: + if i, ok := t.HashMap[key.Data.(string)]; ok { + idx = i + } + case TypeNumber: + if i, ok := t.NumMap[key.Data.(float64)]; ok { + idx = i + } + case TypeBoolean: + if i, ok := t.BoolMap[key.Data.(bool)]; ok { + idx = i + } + } + + if idx >= 0 { + // Update existing entry + t.Entries[idx].Value = value + } else { + // Add new entry + t.Entries = append(t.Entries, TableEntry{Key: key, Value: value}) + idx = len(t.Entries) - 1 + + // Update lookup maps + switch key.Type { + case TypeString: + t.HashMap[key.Data.(string)] = idx + case TypeNumber: + t.NumMap[key.Data.(float64)] = idx + case TypeBoolean: + t.BoolMap[key.Data.(bool)] = idx + } + } +} + +func (t *Table) Get(key Value) Value { + switch key.Type { + case TypeString: + if i, ok := t.HashMap[key.Data.(string)]; ok { + return t.Entries[i].Value + } + case TypeNumber: + if i, ok := t.NumMap[key.Data.(float64)]; ok { + return t.Entries[i].Value + } + case TypeBoolean: + if i, ok := t.BoolMap[key.Data.(bool)]; ok { + return t.Entries[i].Value + } + } + return NewNull() } diff --git a/vm/vm.go b/vm/vm.go index 2ffe12c..c7e3c5f 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -9,10 +9,19 @@ import ( type Opcode byte const ( - OpConstant Opcode = iota - OpSetGlobal - OpGetGlobal + OpConstant Opcode = iota + OpSetLocal // Set local variable + OpGetLocal // Get local variable + OpSetGlobal // Set global variable + OpGetGlobal // Get global variable OpEcho + OpNewTable + OpSetIndex + OpGetIndex + OpDup + OpPop + OpEnterScope // New opcode for entering a block + OpExitScope // New opcode for exiting a block ) type Instruction struct { @@ -25,9 +34,15 @@ type Bytecode struct { Instructions []Instruction } +// Scope represents a lexical scope +type Scope struct { + Variables map[string]types.Value +} + type VM struct { constants []any globals map[string]types.Value + scopes []Scope // Stack of local scopes stack []types.Value sp int // Stack pointer } @@ -35,7 +50,8 @@ type VM struct { func New() *VM { return &VM{ globals: make(map[string]types.Value), - stack: make([]types.Value, 1024), // Fixed stack size for now + scopes: []Scope{}, // Initially no scopes + stack: make([]types.Value, 1024), sp: 0, } } @@ -48,7 +64,6 @@ func (vm *VM) Run(bytecode *Bytecode) { switch instruction.Opcode { case OpConstant: - // Push constant to stack constIndex := instruction.Operand constant := vm.constants[constIndex] @@ -59,15 +74,49 @@ func (vm *VM) Run(bytecode *Bytecode) { vm.push(types.NewNumber(v)) } + case OpSetLocal: + constIndex := instruction.Operand + name := vm.constants[constIndex].(string) + value := vm.pop() + + // Set in current scope if it exists + if len(vm.scopes) > 0 { + vm.scopes[len(vm.scopes)-1].Variables[name] = value + } else { + // No scope, set as global + vm.globals[name] = value + } + + case OpGetLocal: + constIndex := instruction.Operand + name := vm.constants[constIndex].(string) + + // Check local scopes from innermost to outermost + found := false + for i := len(vm.scopes) - 1; i >= 0; i-- { + if val, ok := vm.scopes[i].Variables[name]; ok { + vm.push(val) + found = true + break + } + } + + // If not found in locals, check globals + if !found { + if val, ok := vm.globals[name]; ok { + vm.push(val) + } else { + vm.push(types.NewNull()) + } + } + case OpSetGlobal: - // Set global variable constIndex := instruction.Operand name := vm.constants[constIndex].(string) value := vm.pop() vm.globals[name] = value case OpGetGlobal: - // Get global variable constIndex := instruction.Operand name := vm.constants[constIndex].(string) if val, ok := vm.globals[name]; ok { @@ -76,8 +125,59 @@ func (vm *VM) Run(bytecode *Bytecode) { vm.push(types.NewNull()) } + case OpEnterScope: + // Push a new scope + vm.scopes = append(vm.scopes, Scope{ + Variables: make(map[string]types.Value), + }) + + case OpExitScope: + // Pop the current scope + if len(vm.scopes) > 0 { + vm.scopes = vm.scopes[:len(vm.scopes)-1] + } + + case OpNewTable: + vm.push(types.NewTableValue()) + + case OpSetIndex: + value := vm.pop() + key := vm.pop() + tableVal := vm.pop() + + if tableVal.Type != types.TypeTable { + fmt.Println("Error: attempt to index non-table value") + vm.push(types.NewNull()) + continue + } + + table := tableVal.Data.(*types.Table) + table.Set(key, value) + vm.push(tableVal) + + case OpGetIndex: + key := vm.pop() + tableVal := vm.pop() + + if tableVal.Type != types.TypeTable { + fmt.Println("Error: attempt to index non-table value") + vm.push(types.NewNull()) + continue + } + + table := tableVal.Data.(*types.Table) + value := table.Get(key) + vm.push(value) + + case OpDup: + if vm.sp > 0 { + vm.push(vm.stack[vm.sp-1]) + } + + case OpPop: + vm.pop() + case OpEcho: - // Print value value := vm.pop() switch value.Type { case types.TypeString: @@ -88,6 +188,8 @@ func (vm *VM) Run(bytecode *Bytecode) { fmt.Println(value.Data.(bool)) case types.TypeNull: fmt.Println("null") + case types.TypeTable: + fmt.Println(vm.formatTable(value.Data.(*types.Table))) } } } @@ -102,3 +204,32 @@ func (vm *VM) pop() types.Value { vm.sp-- return vm.stack[vm.sp] } + +func (vm *VM) formatTable(table *types.Table) string { + result := "{" + for i, entry := range table.Entries { + result += vm.formatValue(entry.Key) + " = " + vm.formatValue(entry.Value) + if i < len(table.Entries)-1 { + result += ", " + } + } + result += "}" + return result +} + +func (vm *VM) formatValue(value types.Value) string { + switch value.Type { + case types.TypeString: + return "\"" + value.Data.(string) + "\"" + case types.TypeNumber: + return fmt.Sprintf("%v", value.Data.(float64)) + case types.TypeBoolean: + return fmt.Sprintf("%v", value.Data.(bool)) + case types.TypeNull: + return "null" + case types.TypeTable: + return vm.formatTable(value.Data.(*types.Table)) + default: + return "unknown" + } +}