359 lines
7.8 KiB
Go

package template
import (
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"time"
"github.com/valyala/fasthttp"
)
type Cache struct {
mu sync.RWMutex
templates map[string]*Template
basePath string
}
type RenderOptions struct {
ResolveIncludes bool
Blocks map[string]string
}
type Template struct {
name string
content string
modTime time.Time
filePath string
cache *Cache
}
func NewCache(basePath string) *Cache {
if basePath == "" {
exe, err := os.Executable()
if err != nil {
basePath = "."
} else {
basePath = filepath.Dir(exe)
}
}
return &Cache{
templates: make(map[string]*Template),
basePath: basePath,
}
}
func (c *Cache) Load(name string) (*Template, error) {
c.mu.RLock()
tmpl, exists := c.templates[name]
c.mu.RUnlock()
if exists {
if err := c.checkAndReload(tmpl); err != nil {
return nil, err
}
return tmpl, nil
}
return c.loadFromFile(name)
}
func (c *Cache) loadFromFile(name string) (*Template, error) {
filePath := filepath.Join(c.basePath, "templates", name)
info, err := os.Stat(filePath)
if err != nil {
return nil, fmt.Errorf("template file not found: %s", name)
}
content, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read template: %w", err)
}
tmpl := &Template{
name: name,
content: string(content),
modTime: info.ModTime(),
filePath: filePath,
cache: c,
}
c.mu.Lock()
c.templates[name] = tmpl
c.mu.Unlock()
return tmpl, nil
}
func (c *Cache) checkAndReload(tmpl *Template) error {
info, err := os.Stat(tmpl.filePath)
if err != nil {
return err
}
if info.ModTime().After(tmpl.modTime) {
content, err := os.ReadFile(tmpl.filePath)
if err != nil {
return err
}
c.mu.Lock()
tmpl.content = string(content)
tmpl.modTime = info.ModTime()
c.mu.Unlock()
}
return nil
}
func (t *Template) RenderPositional(args ...any) string {
return t.RenderPositionalWithOptions(RenderOptions{ResolveIncludes: true}, args...)
}
func (t *Template) RenderPositionalWithOptions(opts RenderOptions, args ...any) string {
result := t.content
for i, arg := range args {
placeholder := fmt.Sprintf("{%d}", i)
result = strings.ReplaceAll(result, placeholder, fmt.Sprintf("%v", arg))
}
if opts.ResolveIncludes {
result = t.processIncludes(result, nil, opts)
}
return result
}
func (t *Template) RenderNamed(data map[string]any) string {
return t.RenderNamedWithOptions(RenderOptions{ResolveIncludes: true}, data)
}
func (t *Template) RenderNamedWithOptions(opts RenderOptions, data map[string]any) string {
result := t.content
// Process blocks first to extract them
result = t.processBlocks(result, &opts)
// Process includes next so they get the data substitutions
if opts.ResolveIncludes {
result = t.processIncludes(result, data, opts)
}
// Apply data substitutions after includes are processed
for key, value := range data {
placeholder := fmt.Sprintf("{%s}", key)
result = strings.ReplaceAll(result, placeholder, fmt.Sprintf("%v", value))
}
result = t.replaceDotNotation(result, data)
result = t.processYield(result, opts)
return result
}
func (t *Template) replaceDotNotation(content string, data map[string]any) string {
result := content
start := 0
for {
startIdx := strings.Index(result[start:], "{")
if startIdx == -1 {
break
}
startIdx += start
endIdx := strings.Index(result[startIdx:], "}")
if endIdx == -1 {
break
}
endIdx += startIdx
placeholder := result[startIdx+1 : endIdx]
if strings.Contains(placeholder, ".") {
value := t.getNestedValue(data, placeholder)
if value != nil {
result = result[:startIdx] + fmt.Sprintf("%v", value) + result[endIdx+1:]
start = startIdx + len(fmt.Sprintf("%v", value))
continue
}
}
start = endIdx + 1
}
return result
}
func (t *Template) getNestedValue(data map[string]any, path string) any {
keys := strings.Split(path, ".")
current := data
for i, key := range keys {
if i == len(keys)-1 {
return current[key]
}
next, ok := current[key]
if !ok {
return nil
}
switch v := next.(type) {
case map[string]any:
current = v
case map[any]any:
newMap := make(map[string]any)
for k, val := range v {
newMap[fmt.Sprintf("%v", k)] = val
}
current = newMap
default:
rv := reflect.ValueOf(next)
if rv.Kind() == reflect.Map {
newMap := make(map[string]any)
for _, k := range rv.MapKeys() {
newMap[fmt.Sprintf("%v", k.Interface())] = rv.MapIndex(k).Interface()
}
current = newMap
} else {
return nil
}
}
}
return nil
}
func (t *Template) WriteTo(ctx *fasthttp.RequestCtx, data any) {
t.WriteToWithOptions(ctx, data, RenderOptions{ResolveIncludes: true})
}
func (t *Template) WriteToWithOptions(ctx *fasthttp.RequestCtx, data any, opts RenderOptions) {
var result string
switch v := data.(type) {
case map[string]any:
result = t.RenderNamedWithOptions(opts, v)
case []any:
result = t.RenderPositionalWithOptions(opts, v...)
default:
rv := reflect.ValueOf(data)
if rv.Kind() == reflect.Slice {
args := make([]any, rv.Len())
for i := 0; i < rv.Len(); i++ {
args[i] = rv.Index(i).Interface()
}
result = t.RenderPositionalWithOptions(opts, args...)
} else {
result = t.RenderPositionalWithOptions(opts, data)
}
}
ctx.SetContentType("text/html; charset=utf-8")
ctx.WriteString(result)
}
// processIncludes handles {include "template.html"} directives
func (t *Template) processIncludes(content string, data map[string]any, opts RenderOptions) string {
result := content
for {
start := strings.Index(result, "{include ")
if start == -1 {
break
}
end := strings.Index(result[start:], "}")
if end == -1 {
break
}
end += start
directive := result[start+9:end] // Skip "{include "
templateName := strings.Trim(directive, "\" ")
if includedTemplate, err := t.cache.Load(templateName); err == nil {
var includedContent string
if data != nil {
// Create new options to pass blocks to included template
includeOpts := RenderOptions{
ResolveIncludes: opts.ResolveIncludes,
Blocks: opts.Blocks,
}
includedContent = includedTemplate.RenderNamedWithOptions(includeOpts, data)
} else {
includedContent = includedTemplate.content
}
result = result[:start] + includedContent + result[end+1:]
} else {
// Remove the include directive if template not found
result = result[:start] + result[end+1:]
}
}
return result
}
// processYield handles {yield} directives for template inheritance
func (t *Template) processYield(content string, opts RenderOptions) string {
if opts.Blocks == nil {
return strings.ReplaceAll(content, "{yield}", "")
}
result := content
for blockName, blockContent := range opts.Blocks {
yieldPlaceholder := fmt.Sprintf("{yield %s}", blockName)
result = strings.ReplaceAll(result, yieldPlaceholder, blockContent)
}
// Replace any remaining {yield} with empty string
result = strings.ReplaceAll(result, "{yield}", "")
return result
}
// processBlocks extracts {block "name"}...{/block} sections
func (t *Template) processBlocks(content string, opts *RenderOptions) string {
if opts.Blocks == nil {
opts.Blocks = make(map[string]string)
}
result := content
for {
start := strings.Index(result, "{block ")
if start == -1 {
break
}
nameEnd := strings.Index(result[start:], "}")
if nameEnd == -1 {
break
}
nameEnd += start
blockName := strings.Trim(result[start+7:nameEnd], "\" ")
contentStart := nameEnd + 1
endTag := "{/block}"
contentEnd := strings.Index(result[contentStart:], endTag)
if contentEnd == -1 {
break
}
contentEnd += contentStart
blockContent := result[contentStart:contentEnd]
opts.Blocks[blockName] = blockContent
// Remove the block definition from the template
result = result[:start] + result[contentEnd+len(endTag):]
}
return result
}