832 lines
19 KiB
Go

package template
import (
"fmt"
"maps"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/valyala/fasthttp"
)
// Cache is the global singleton instance
var Cache *TemplateCache
type TemplateCache struct {
mu sync.RWMutex
templates map[string]*Template
basePath string
}
type Template struct {
name string
content string
modTime time.Time
filePath string
cache *TemplateCache
}
func NewCache(basePath string) *TemplateCache {
if basePath == "" {
exe, err := os.Executable()
if err != nil {
basePath = "."
} else {
basePath = filepath.Dir(exe)
}
}
return &TemplateCache{
templates: make(map[string]*Template),
basePath: basePath,
}
}
// InitializeCache initializes the global Cache singleton
func InitializeCache(basePath string) {
Cache = NewCache(basePath)
}
func (c *TemplateCache) Load(name string) (*Template, error) {
c.mu.RLock()
tmpl, exists := c.templates[name]
c.mu.RUnlock()
if exists {
if err := c.checkAndReload(tmpl); err != nil {
return nil, err
}
return tmpl, nil
}
return c.loadFromFile(name)
}
func (c *TemplateCache) loadFromFile(name string) (*Template, error) {
filePath := filepath.Join(c.basePath, "templates", name)
info, err := os.Stat(filePath)
if err != nil {
return nil, fmt.Errorf("template file not found: %s", name)
}
content, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read template: %w", err)
}
tmpl := &Template{
name: name,
content: string(content),
modTime: info.ModTime(),
filePath: filePath,
cache: c,
}
c.mu.Lock()
c.templates[name] = tmpl
c.mu.Unlock()
return tmpl, nil
}
func (c *TemplateCache) checkAndReload(tmpl *Template) error {
info, err := os.Stat(tmpl.filePath)
if err != nil {
return err
}
if info.ModTime().After(tmpl.modTime) {
content, err := os.ReadFile(tmpl.filePath)
if err != nil {
return err
}
c.mu.Lock()
tmpl.content = string(content)
tmpl.modTime = info.ModTime()
c.mu.Unlock()
}
return nil
}
func (t *Template) RenderPositional(args ...any) string {
result := t.content
for i, arg := range args {
placeholder := fmt.Sprintf("{%d}", i)
result = strings.ReplaceAll(result, placeholder, fmt.Sprintf("%v", arg))
}
result = t.processIncludes(result, nil)
return result
}
func (t *Template) RenderNamed(data map[string]any) string {
result := t.content
// Process blocks first to extract them
blocks := make(map[string]string)
result = t.processBlocks(result, blocks)
// Process includes
result = t.processIncludes(result, data)
// Process loops and conditionals
result = t.processLoops(result, data)
result = t.processConditionals(result, data)
// Process yield with conditionals in blocks
result = t.processYield(result, blocks, data)
// Apply data substitutions
for key, value := range data {
placeholder := fmt.Sprintf("{%s}", key)
result = strings.ReplaceAll(result, placeholder, fmt.Sprintf("%v", value))
}
result = t.replaceDotNotation(result, data)
return result
}
func (t *Template) replaceDotNotation(content string, data map[string]any) string {
result := content
start := 0
for {
startIdx := strings.Index(result[start:], "{")
if startIdx == -1 {
break
}
startIdx += start
endIdx := strings.Index(result[startIdx:], "}")
if endIdx == -1 {
break
}
endIdx += startIdx
placeholder := result[startIdx+1 : endIdx]
if strings.Contains(placeholder, ".") {
value := t.getNestedValue(data, placeholder)
if value != nil {
result = result[:startIdx] + fmt.Sprintf("%v", value) + result[endIdx+1:]
start = startIdx + len(fmt.Sprintf("%v", value))
continue
}
}
start = endIdx + 1
}
return result
}
func (t *Template) getNestedValue(data map[string]any, path string) any {
keys := strings.Split(path, ".")
var current any = data
for i, key := range keys {
if i == len(keys)-1 {
// Final key - handle both maps and structs
switch v := current.(type) {
case map[string]any:
return v[key]
default:
return t.getStructField(current, key)
}
}
// Intermediate key - get the next value
var next any
switch v := current.(type) {
case map[string]any:
var ok bool
next, ok = v[key]
if !ok {
return nil
}
default:
next = t.getStructField(current, key)
if next == nil {
return nil
}
}
// Prepare for next iteration
switch v := next.(type) {
case map[string]any:
current = v
case map[any]any:
newMap := make(map[string]any)
for k, val := range v {
newMap[fmt.Sprintf("%v", k)] = val
}
current = newMap
default:
rv := reflect.ValueOf(next)
if rv.Kind() == reflect.Map {
newMap := make(map[string]any)
for _, k := range rv.MapKeys() {
newMap[fmt.Sprintf("%v", k.Interface())] = rv.MapIndex(k).Interface()
}
current = newMap
} else {
// For structs, keep the struct value for the next iteration
current = next
}
}
}
return nil
}
// getStructField gets a field value from a struct using reflection
func (t *Template) getStructField(obj any, fieldName string) any {
if obj == nil {
return nil
}
rv := reflect.ValueOf(obj)
if rv.Kind() == reflect.Ptr {
if rv.IsNil() {
return nil
}
rv = rv.Elem()
}
if rv.Kind() != reflect.Struct {
return nil
}
field := rv.FieldByName(fieldName)
if !field.IsValid() {
return nil
}
return field.Interface()
}
func (t *Template) getLength(value any) int {
if value == nil {
return 0
}
rv := reflect.ValueOf(value)
switch rv.Kind() {
case reflect.Slice, reflect.Array, reflect.Map, reflect.String:
return rv.Len()
default:
return 0
}
}
func (t *Template) WriteTo(ctx *fasthttp.RequestCtx, data any) {
var result string
switch v := data.(type) {
case map[string]any:
result = t.RenderNamed(v)
case []any:
result = t.RenderPositional(v...)
default:
rv := reflect.ValueOf(data)
if rv.Kind() == reflect.Slice {
args := make([]any, rv.Len())
for i := 0; i < rv.Len(); i++ {
args[i] = rv.Index(i).Interface()
}
result = t.RenderPositional(args...)
} else {
result = t.RenderPositional(data)
}
}
ctx.SetContentType("text/html; charset=utf-8")
ctx.WriteString(result)
}
// processIncludes handles {include "template.html"} directives
func (t *Template) processIncludes(content string, data map[string]any) string {
result := content
for {
start := strings.Index(result, "{include ")
if start == -1 {
break
}
end := strings.Index(result[start:], "}")
if end == -1 {
break
}
end += start
directive := result[start+9 : end] // Skip "{include "
templateName := strings.Trim(directive, "\" ")
if includedTemplate, err := t.cache.Load(templateName); err == nil {
var includedContent string
if data != nil {
includedContent = includedTemplate.RenderNamed(data)
} else {
includedContent = includedTemplate.content
}
result = result[:start] + includedContent + result[end+1:]
} else {
// Remove the include directive if template not found
result = result[:start] + result[end+1:]
}
}
return result
}
// processBlocks extracts {block "name"}...{/block} sections
func (t *Template) processBlocks(content string, blocks map[string]string) string {
result := content
for {
start := strings.Index(result, "{block ")
if start == -1 {
break
}
nameEnd := strings.Index(result[start:], "}")
if nameEnd == -1 {
break
}
nameEnd += start
blockName := strings.Trim(result[start+7:nameEnd], "\" ")
contentStart := nameEnd + 1
endTag := "{/block}"
contentEnd := strings.Index(result[contentStart:], endTag)
if contentEnd == -1 {
break
}
contentEnd += contentStart
blockContent := result[contentStart:contentEnd]
blocks[blockName] = blockContent
// Remove the block definition from the template
result = result[:start] + result[contentEnd+len(endTag):]
}
return result
}
// processYield handles {yield} directives for template inheritance
func (t *Template) processYield(content string, blocks map[string]string, data map[string]any) string {
result := content
for blockName, blockContent := range blocks {
// Process conditionals and loops in block content before yielding
processedBlock := t.processLoops(blockContent, data)
processedBlock = t.processConditionals(processedBlock, data)
yieldPlaceholder := fmt.Sprintf("{yield \"%s\"}", blockName)
result = strings.ReplaceAll(result, yieldPlaceholder, processedBlock)
}
// Replace any remaining {yield} with empty string
result = strings.ReplaceAll(result, "{yield}", "")
return result
}
// processLoops handles {for item in items}...{/for} and {for key,value in map}...{/for}
func (t *Template) processLoops(content string, data map[string]any) string {
result := content
for {
start := strings.Index(result, "{for ")
if start == -1 {
break
}
headerEnd := strings.Index(result[start:], "}")
if headerEnd == -1 {
break
}
headerEnd += start
header := result[start+5 : headerEnd] // Skip "{for "
contentStart := headerEnd + 1
endTag := "{/for}"
contentEnd := strings.Index(result[contentStart:], endTag)
if contentEnd == -1 {
break
}
contentEnd += contentStart
loopContent := result[contentStart:contentEnd]
expanded := t.expandLoop(header, loopContent, data)
result = result[:start] + expanded + result[contentEnd+len(endTag):]
}
return result
}
// expandLoop processes a single loop construct
func (t *Template) expandLoop(header, content string, data map[string]any) string {
parts := strings.Split(strings.TrimSpace(header), " in ")
if len(parts) != 2 {
return ""
}
varPart := strings.TrimSpace(parts[0])
sourcePart := strings.TrimSpace(parts[1])
source := t.getNestedValue(data, sourcePart)
if source == nil {
return ""
}
var result strings.Builder
// Handle key,value pairs
if strings.Contains(varPart, ",") {
keyVar, valueVar := strings.TrimSpace(varPart[:strings.Index(varPart, ",")]), strings.TrimSpace(varPart[strings.Index(varPart, ",")+1:])
rv := reflect.ValueOf(source)
switch rv.Kind() {
case reflect.Map:
for _, key := range rv.MapKeys() {
iterData := make(map[string]any)
maps.Copy(iterData, data)
iterData[keyVar] = key.Interface()
iterData[valueVar] = rv.MapIndex(key).Interface()
iterResult := content
iterResult = t.processLoops(iterResult, iterData)
iterResult = t.processConditionals(iterResult, iterData)
for k, v := range iterData {
placeholder := fmt.Sprintf("{%s}", k)
iterResult = strings.ReplaceAll(iterResult, placeholder, fmt.Sprintf("%v", v))
}
iterResult = t.replaceDotNotation(iterResult, iterData)
result.WriteString(iterResult)
}
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
iterData := make(map[string]any)
maps.Copy(iterData, data)
iterData[keyVar] = i
iterData[valueVar] = rv.Index(i).Interface()
iterResult := content
iterResult = t.processLoops(iterResult, iterData)
iterResult = t.processConditionals(iterResult, iterData)
for k, v := range iterData {
placeholder := fmt.Sprintf("{%s}", k)
iterResult = strings.ReplaceAll(iterResult, placeholder, fmt.Sprintf("%v", v))
}
iterResult = t.replaceDotNotation(iterResult, iterData)
result.WriteString(iterResult)
}
}
} else {
// Single variable iteration
rv := reflect.ValueOf(source)
switch rv.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
iterData := make(map[string]any)
maps.Copy(iterData, data)
iterData[varPart] = rv.Index(i).Interface()
iterResult := content
iterResult = t.processLoops(iterResult, iterData)
iterResult = t.processConditionals(iterResult, iterData)
for k, v := range iterData {
placeholder := fmt.Sprintf("{%s}", k)
iterResult = strings.ReplaceAll(iterResult, placeholder, fmt.Sprintf("%v", v))
}
iterResult = t.replaceDotNotation(iterResult, iterData)
result.WriteString(iterResult)
}
case reflect.Map:
for _, key := range rv.MapKeys() {
iterData := make(map[string]any)
maps.Copy(iterData, data)
iterData[varPart] = rv.MapIndex(key).Interface()
iterResult := content
iterResult = t.processLoops(iterResult, iterData)
iterResult = t.processConditionals(iterResult, iterData)
for k, v := range iterData {
placeholder := fmt.Sprintf("{%s}", k)
iterResult = strings.ReplaceAll(iterResult, placeholder, fmt.Sprintf("%v", v))
}
iterResult = t.replaceDotNotation(iterResult, iterData)
result.WriteString(iterResult)
}
}
}
return result.String()
}
// processConditionals handles {if condition}...{/if} and {if condition}...{else}...{/if}
func (t *Template) processConditionals(content string, data map[string]any) string {
result := content
for {
start := strings.Index(result, "{if ")
if start == -1 {
break
}
headerEnd := strings.Index(result[start:], "}")
if headerEnd == -1 {
break
}
headerEnd += start
condition := strings.TrimSpace(result[start+4 : headerEnd]) // Skip "{if "
contentStart := headerEnd + 1
// Find matching {/if} by tracking nesting level
nestLevel := 1
pos := contentStart
contentEnd := -1
for pos < len(result) && nestLevel > 0 {
ifStart := strings.Index(result[pos:], "{if ")
endStart := strings.Index(result[pos:], "{/if}")
if ifStart != -1 && (endStart == -1 || ifStart < endStart) {
// Found nested {if}
nestLevel++
pos += ifStart + 4
} else if endStart != -1 {
// Found {/if}
nestLevel--
if nestLevel == 0 {
contentEnd = pos + endStart
break
}
pos += endStart + 5
} else {
break
}
}
if contentEnd == -1 {
break
}
ifContent := result[contentStart:contentEnd]
// Check for else clause at the same nesting level
elseStart := t.findElseAtLevel(ifContent)
var trueContent, falseContent string
if elseStart != -1 {
trueContent = ifContent[:elseStart]
falseContent = ifContent[elseStart+6:] // Skip "{else}"
} else {
trueContent = ifContent
falseContent = ""
}
var selectedContent string
if t.evaluateCondition(condition, data) {
selectedContent = trueContent
} else {
selectedContent = falseContent
}
// Recursively process the selected content
selectedContent = t.processLoops(selectedContent, data)
selectedContent = t.processConditionals(selectedContent, data)
result = result[:start] + selectedContent + result[contentEnd+5:] // +5 for "{/if}"
}
return result
}
// findElseAtLevel finds {else} at the top level (not nested)
func (t *Template) findElseAtLevel(content string) int {
nestLevel := 0
pos := 0
for pos < len(content) {
ifStart := strings.Index(content[pos:], "{if ")
elseStart := strings.Index(content[pos:], "{else}")
endStart := strings.Index(content[pos:], "{/if}")
// Find the earliest occurrence
earliest := -1
var tag string
if ifStart != -1 && (earliest == -1 || ifStart < earliest-pos) {
earliest = pos + ifStart
tag = "if"
}
if elseStart != -1 && (earliest == -1 || elseStart < earliest-pos) {
earliest = pos + elseStart
tag = "else"
}
if endStart != -1 && (earliest == -1 || endStart < earliest-pos) {
earliest = pos + endStart
tag = "end"
}
if earliest == -1 {
break
}
switch tag {
case "if":
nestLevel++
pos = earliest + 4
case "else":
if nestLevel == 0 {
return earliest
}
pos = earliest + 6
case "end":
nestLevel--
pos = earliest + 5
}
}
return -1
}
// evaluateCondition evaluates simple conditions like "user.name", "count > 0", "items"
func (t *Template) evaluateCondition(condition string, data map[string]any) bool {
condition = strings.TrimSpace(condition)
// Handle 'and' operator
if strings.Contains(condition, " and ") {
parts := strings.SplitSeq(condition, " and ")
for part := range parts {
if !t.evaluateCondition(strings.TrimSpace(part), data) {
return false
}
}
return true
}
// Handle comparison operators
for _, op := range []string{">=", "<=", "!=", "==", ">", "<"} {
if strings.Contains(condition, op) {
parts := strings.Split(condition, op)
if len(parts) == 2 {
left := strings.TrimSpace(parts[0])
right := strings.TrimSpace(parts[1])
return t.compareValues(t.getConditionValue(left, data), t.getConditionValue(right, data), op)
}
}
}
// Simple existence check
value := t.getConditionValue(condition, data)
return t.isTruthy(value)
}
// getConditionValue gets a value for condition evaluation
func (t *Template) getConditionValue(expr string, data map[string]any) any {
expr = strings.TrimSpace(expr)
// Handle length operator
if strings.HasPrefix(expr, "#") {
varName := expr[1:] // Remove the #
value := t.getNestedValue(data, varName)
return t.getLength(value)
}
// Try to parse as number
if num, err := strconv.ParseFloat(expr, 64); err == nil {
return num
}
// Try to parse as string literal
if strings.HasPrefix(expr, "\"") && strings.HasSuffix(expr, "\"") {
return expr[1 : len(expr)-1]
}
// Try as variable reference
if strings.Contains(expr, ".") {
return t.getNestedValue(data, expr)
}
if value, ok := data[expr]; ok {
return value
}
return expr
}
// compareValues compares two values with the given operator
func (t *Template) compareValues(left, right any, op string) bool {
switch op {
case "==":
return fmt.Sprintf("%v", left) == fmt.Sprintf("%v", right)
case "!=":
return fmt.Sprintf("%v", left) != fmt.Sprintf("%v", right)
case ">", ">=", "<", "<=":
leftNum, leftOk := t.toFloat(left)
rightNum, rightOk := t.toFloat(right)
if !leftOk || !rightOk {
return false
}
switch op {
case ">":
return leftNum > rightNum
case ">=":
return leftNum >= rightNum
case "<":
return leftNum < rightNum
case "<=":
return leftNum <= rightNum
}
}
return false
}
// toFloat converts a value to float64 if possible
func (t *Template) toFloat(value any) (float64, bool) {
switch v := value.(type) {
case int:
return float64(v), true
case int64:
return float64(v), true
case float32:
return float64(v), true
case float64:
return v, true
case string:
if f, err := strconv.ParseFloat(v, 64); err == nil {
return f, true
}
}
return 0, false
}
// isTruthy determines if a value is truthy
func (t *Template) isTruthy(value any) bool {
if value == nil {
return false
}
switch v := value.(type) {
case bool:
return v
case int:
return v != 0
case float64:
return v != 0
case string:
return v != ""
default:
rv := reflect.ValueOf(value)
switch rv.Kind() {
case reflect.Slice, reflect.Array, reflect.Map:
return rv.Len() > 0
case reflect.Ptr:
return !rv.IsNil()
}
return true
}
}
// RenderToContext is a simplified helper that renders a template and writes it to the request context
// with error handling. Returns true if successful, false if an error occurred (error is written to response).
func RenderToContext(ctx *fasthttp.RequestCtx, templateName string, data map[string]any) bool {
tmpl, err := Cache.Load(templateName)
if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
fmt.Fprintf(ctx, "Template error: %v", err)
return false
}
tmpl.WriteTo(ctx, data)
return true
}
// RenderNamed is a simplified helper that loads and renders a template with the given data,
// returning the rendered content or an error.
func RenderNamed(templateName string, data map[string]any) (string, error) {
tmpl, err := Cache.Load(templateName)
if err != nil {
return "", fmt.Errorf("failed to load template %s: %w", templateName, err)
}
return tmpl.RenderNamed(data), nil
}