222 lines
5.4 KiB
Go

package main
import (
"flag"
"fmt"
"log"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
"dk/internal/control"
"dk/internal/database"
"dk/internal/helpers/email"
"dk/internal/helpers/markdown"
"dk/internal/models/users"
"dk/internal/routes"
"dk/internal/template"
sushi "git.sharkk.net/Sharkk/Sushi"
"git.sharkk.net/Sharkk/Sushi/auth"
"git.sharkk.net/Sharkk/Sushi/csrf"
"git.sharkk.net/Sharkk/Sushi/session"
"git.sharkk.net/Sharkk/Sushi/timing"
sashimi "git.sharkk.net/Sharkk/Sashimi"
)
func main() {
var port string
flag.StringVar(&port, "p", "3000", "Port to run server on")
if len(os.Args) < 2 {
startServer(port)
return
}
switch os.Args[1] {
case "serve":
flag.CommandLine.Parse(os.Args[2:])
startServer(port)
case "migrate":
handleMigrationCommand()
default:
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", os.Args[1])
fmt.Fprintln(os.Stderr, "Available commands:")
fmt.Fprintln(os.Stderr, " serve - Start the server")
fmt.Fprintln(os.Stderr, " migrate - Run pending migrations")
fmt.Fprintln(os.Stderr, " migrate new - Create a new migration")
fmt.Fprintln(os.Stderr, " migrate status - Show migration status")
fmt.Fprintln(os.Stderr, " (no command) - Start the server")
os.Exit(1)
}
}
func getDBPath() (string, error) {
cwd, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("failed to get current working directory: %w", err)
}
return filepath.Join(cwd, "data", "dk.db"), nil
}
func getMigrationsDir() (string, error) {
cwd, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("failed to get current working directory: %w", err)
}
return filepath.Join(cwd, "sql"), nil
}
func initDatabase() error {
dbPath, err := getDBPath()
if err != nil {
return err
}
return database.Init(dbPath)
}
func handleMigrationCommand() {
if err := initDatabase(); err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
defer database.Close()
migrationsDir, err := getMigrationsDir()
if err != nil {
log.Fatalf("Failed to get migrations directory: %v", err)
}
migrator := sashimi.NewMigrator(database.DB(), migrationsDir)
if len(os.Args) < 3 {
if err := migrator.Run(); err != nil {
log.Fatalf("Migration failed: %v", err)
}
return
}
subcommand := os.Args[2]
switch subcommand {
case "new":
if len(os.Args) < 4 {
log.Fatal("Usage: migrate new <migration_name>")
}
migrationName := strings.Join(os.Args[3:], " ")
if err := migrator.CreateNew(migrationName); err != nil {
log.Fatalf("Failed to create migration: %v", err)
}
case "status":
if err := migrator.Status(); err != nil {
log.Fatalf("Failed to get migration status: %v", err)
}
default:
fmt.Fprintf(os.Stderr, "Unknown migration subcommand: %s\n", subcommand)
fmt.Fprintln(os.Stderr, "Available subcommands:")
fmt.Fprintln(os.Stderr, " (none) - Run pending migrations")
fmt.Fprintln(os.Stderr, " new - Create a new migration")
fmt.Fprintln(os.Stderr, " status - Show migration status")
os.Exit(1)
}
}
func startServer(port string) {
fmt.Println("Dragon Knight is starting!")
if err := start(port); err != nil {
log.Fatalf("Server failed: %v", err)
}
}
func start(port string) error {
cwd, err := os.Getwd()
if err != nil {
return fmt.Errorf("failed to get current working directory: %w", err)
}
if err := initDatabase(); err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
defer database.Close()
control.Init(filepath.Join(cwd, "data/control.json"))
defer control.Save()
controls := control.Get()
template.InitializeCache(cwd)
template.RegisterFunc("md", markdown.MarkdownTemplateFunc)
email.Init(controls.EmailMode, controls.EmailFilePath,
controls.SMTPHost, controls.SMTPPort, controls.SMTPUsername, controls.SMTPPassword)
authMW := auth.New(getUserByID)
app := sushi.New()
sushi.InitSessions(filepath.Join(cwd, "data/sessions.json"))
app.Use(session.Middleware())
app.Use(authMW.Middleware())
app.Use(csrf.Middleware())
app.Use(timing.Middleware())
app.Use(func(ctx sushi.Ctx, next func()) {
if ctx.IsAuthenticated() {
user := ctx.GetCurrentUser().(*users.User)
now := time.Now().Unix()
if (now - user.LastOnline) >= 540 { // 540 seconds = 9 minutes
database.Update("users", map[string]any{
"last_online": now,
}, "id", user.ID)
}
}
next()
})
app.Get("/", routes.Index)
protected := app.Group("")
protected.Use(auth.RequireAuth("/login"))
protected.Get("/explore", routes.Explore)
protected.Post("/move", routes.Move)
protected.Get("/teleport/:to", routes.Teleport)
protected.Get("/change-password", routes.ShowChangePassword)
protected.Post("/change-password", routes.ChangePassword)
routes.RegisterAuthRoutes(app)
routes.RegisterTownRoutes(app)
routes.RegisterFightRoutes(app)
routes.RegisterForumRoutes(app)
routes.RegisterHelpRoutes(app)
routes.RegisterAdminRoutes(app)
app.Get("/assets/*path", sushi.Static(cwd))
addr := ":" + port
log.Printf("Server starting on %s", addr)
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
app.Listen(addr)
}()
<-c
log.Println("\nShutting down! Beginning cleanup...")
log.Println("Saving sessions...")
sushi.SaveSessions()
log.Println("Server stopped")
return nil
}
func getUserByID(userID int) any {
user, err := users.Find(userID)
if err != nil {
panic(fmt.Sprintf("Error finding user ID %d", userID))
}
return user
}