437 lines
9.9 KiB
Go

package template
import (
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"time"
"github.com/valyala/fasthttp"
)
// Cache is the global singleton instance
var Cache *TemplateCache
type TemplateCache struct {
mu sync.RWMutex
templates map[string]*Template
basePath string
}
type RenderOptions struct {
ResolveIncludes bool
Blocks map[string]string
}
type Template struct {
name string
content string
modTime time.Time
filePath string
cache *TemplateCache
}
func NewCache(basePath string) *TemplateCache {
if basePath == "" {
exe, err := os.Executable()
if err != nil {
basePath = "."
} else {
basePath = filepath.Dir(exe)
}
}
return &TemplateCache{
templates: make(map[string]*Template),
basePath: basePath,
}
}
// InitializeCache initializes the global Cache singleton
func InitializeCache(basePath string) {
Cache = NewCache(basePath)
}
func (c *TemplateCache) Load(name string) (*Template, error) {
c.mu.RLock()
tmpl, exists := c.templates[name]
c.mu.RUnlock()
if exists {
if err := c.checkAndReload(tmpl); err != nil {
return nil, err
}
return tmpl, nil
}
return c.loadFromFile(name)
}
func (c *TemplateCache) loadFromFile(name string) (*Template, error) {
filePath := filepath.Join(c.basePath, "templates", name)
info, err := os.Stat(filePath)
if err != nil {
return nil, fmt.Errorf("template file not found: %s", name)
}
content, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read template: %w", err)
}
tmpl := &Template{
name: name,
content: string(content),
modTime: info.ModTime(),
filePath: filePath,
cache: c,
}
c.mu.Lock()
c.templates[name] = tmpl
c.mu.Unlock()
return tmpl, nil
}
func (c *TemplateCache) checkAndReload(tmpl *Template) error {
info, err := os.Stat(tmpl.filePath)
if err != nil {
return err
}
if info.ModTime().After(tmpl.modTime) {
content, err := os.ReadFile(tmpl.filePath)
if err != nil {
return err
}
c.mu.Lock()
tmpl.content = string(content)
tmpl.modTime = info.ModTime()
c.mu.Unlock()
}
return nil
}
func (t *Template) RenderPositional(args ...any) string {
return t.RenderPositionalWithOptions(RenderOptions{ResolveIncludes: true}, args...)
}
func (t *Template) RenderPositionalWithOptions(opts RenderOptions, args ...any) string {
result := t.content
for i, arg := range args {
placeholder := fmt.Sprintf("{%d}", i)
result = strings.ReplaceAll(result, placeholder, fmt.Sprintf("%v", arg))
}
if opts.ResolveIncludes {
result = t.processIncludes(result, nil, opts)
}
return result
}
func (t *Template) RenderNamed(data map[string]any) string {
return t.RenderNamedWithOptions(RenderOptions{ResolveIncludes: true}, data)
}
func (t *Template) RenderNamedWithOptions(opts RenderOptions, data map[string]any) string {
result := t.content
// Process blocks first to extract them
result = t.processBlocks(result, &opts)
// Process includes next so they get the data substitutions
if opts.ResolveIncludes {
result = t.processIncludes(result, data, opts)
}
// Apply data substitutions after includes are processed
for key, value := range data {
placeholder := fmt.Sprintf("{%s}", key)
result = strings.ReplaceAll(result, placeholder, fmt.Sprintf("%v", value))
}
result = t.replaceDotNotation(result, data)
result = t.processYield(result, opts)
return result
}
func (t *Template) replaceDotNotation(content string, data map[string]any) string {
result := content
start := 0
for {
startIdx := strings.Index(result[start:], "{")
if startIdx == -1 {
break
}
startIdx += start
endIdx := strings.Index(result[startIdx:], "}")
if endIdx == -1 {
break
}
endIdx += startIdx
placeholder := result[startIdx+1 : endIdx]
if strings.Contains(placeholder, ".") {
value := t.getNestedValue(data, placeholder)
if value != nil {
result = result[:startIdx] + fmt.Sprintf("%v", value) + result[endIdx+1:]
start = startIdx + len(fmt.Sprintf("%v", value))
continue
}
}
start = endIdx + 1
}
return result
}
func (t *Template) getNestedValue(data map[string]any, path string) any {
keys := strings.Split(path, ".")
var current any = data
for i, key := range keys {
if i == len(keys)-1 {
// Final key - handle both maps and structs
switch v := current.(type) {
case map[string]any:
return v[key]
default:
return t.getStructField(current, key)
}
}
// Intermediate key - get the next value
var next any
switch v := current.(type) {
case map[string]any:
var ok bool
next, ok = v[key]
if !ok {
return nil
}
default:
next = t.getStructField(current, key)
if next == nil {
return nil
}
}
// Prepare for next iteration
switch v := next.(type) {
case map[string]any:
current = v
case map[any]any:
newMap := make(map[string]any)
for k, val := range v {
newMap[fmt.Sprintf("%v", k)] = val
}
current = newMap
default:
rv := reflect.ValueOf(next)
if rv.Kind() == reflect.Map {
newMap := make(map[string]any)
for _, k := range rv.MapKeys() {
newMap[fmt.Sprintf("%v", k.Interface())] = rv.MapIndex(k).Interface()
}
current = newMap
} else {
// For structs, keep the struct value for the next iteration
current = next
}
}
}
return nil
}
// getStructField gets a field value from a struct using reflection
func (t *Template) getStructField(obj any, fieldName string) any {
if obj == nil {
return nil
}
rv := reflect.ValueOf(obj)
if rv.Kind() == reflect.Ptr {
if rv.IsNil() {
return nil
}
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct {
return nil
}
field := rv.FieldByName(fieldName)
if !field.IsValid() {
return nil
}
return field.Interface()
}
func (t *Template) WriteTo(ctx *fasthttp.RequestCtx, data any) {
t.WriteToWithOptions(ctx, data, RenderOptions{ResolveIncludes: true})
}
func (t *Template) WriteToWithOptions(ctx *fasthttp.RequestCtx, data any, opts RenderOptions) {
var result string
switch v := data.(type) {
case map[string]any:
result = t.RenderNamedWithOptions(opts, v)
case []any:
result = t.RenderPositionalWithOptions(opts, v...)
default:
rv := reflect.ValueOf(data)
if rv.Kind() == reflect.Slice {
args := make([]any, rv.Len())
for i := 0; i < rv.Len(); i++ {
args[i] = rv.Index(i).Interface()
}
result = t.RenderPositionalWithOptions(opts, args...)
} else {
result = t.RenderPositionalWithOptions(opts, data)
}
}
ctx.SetContentType("text/html; charset=utf-8")
ctx.WriteString(result)
}
// processIncludes handles {include "template.html"} directives
func (t *Template) processIncludes(content string, data map[string]any, opts RenderOptions) string {
result := content
for {
start := strings.Index(result, "{include ")
if start == -1 {
break
}
end := strings.Index(result[start:], "}")
if end == -1 {
break
}
end += start
directive := result[start+9 : end] // Skip "{include "
templateName := strings.Trim(directive, "\" ")
if includedTemplate, err := t.cache.Load(templateName); err == nil {
var includedContent string
if data != nil {
// Create new options to pass blocks to included template
includeOpts := RenderOptions{
ResolveIncludes: opts.ResolveIncludes,
Blocks: opts.Blocks,
}
includedContent = includedTemplate.RenderNamedWithOptions(includeOpts, data)
} else {
includedContent = includedTemplate.content
}
result = result[:start] + includedContent + result[end+1:]
} else {
// Remove the include directive if template not found
result = result[:start] + result[end+1:]
}
}
return result
}
// processYield handles {yield} directives for template inheritance
func (t *Template) processYield(content string, opts RenderOptions) string {
if opts.Blocks == nil {
return strings.ReplaceAll(content, "{yield}", "")
}
result := content
for blockName, blockContent := range opts.Blocks {
yieldPlaceholder := fmt.Sprintf("{yield %s}", blockName)
result = strings.ReplaceAll(result, yieldPlaceholder, blockContent)
}
// Replace any remaining {yield} with empty string
result = strings.ReplaceAll(result, "{yield}", "")
return result
}
// processBlocks extracts {block "name"}...{/block} sections
func (t *Template) processBlocks(content string, opts *RenderOptions) string {
if opts.Blocks == nil {
opts.Blocks = make(map[string]string)
}
result := content
for {
start := strings.Index(result, "{block ")
if start == -1 {
break
}
nameEnd := strings.Index(result[start:], "}")
if nameEnd == -1 {
break
}
nameEnd += start
blockName := strings.Trim(result[start+7:nameEnd], "\" ")
contentStart := nameEnd + 1
endTag := "{/block}"
contentEnd := strings.Index(result[contentStart:], endTag)
if contentEnd == -1 {
break
}
contentEnd += contentStart
blockContent := result[contentStart:contentEnd]
opts.Blocks[blockName] = blockContent
// Remove the block definition from the template
result = result[:start] + result[contentEnd+len(endTag):]
}
return result
}
// RenderToContext is a simplified helper that renders a template and writes it to the request context
// with error handling. Returns true if successful, false if an error occurred (error is written to response).
func RenderToContext(ctx *fasthttp.RequestCtx, templateName string, data map[string]any) bool {
tmpl, err := Cache.Load(templateName)
if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err)
return false
}
tmpl.WriteTo(ctx, data)
return true
}
// RenderNamed is a simplified helper that loads and renders a template with the given data,
// returning the rendered content or an error.
func RenderNamed(templateName string, data map[string]any) (string, error) {
tmpl, err := Cache.Load(templateName)
if err != nil {
return "", fmt.Errorf("failed to load template %s: %w", templateName, err)
}
return tmpl.RenderNamed(data), nil
}