228 lines
5.8 KiB
C++
228 lines
5.8 KiB
C++
#pragma once
|
|
|
|
#include "epoll_socket.hpp"
|
|
#include "router.hpp"
|
|
#include "parser.hpp"
|
|
#include "response.hpp"
|
|
#include "static_file_handler.hpp"
|
|
#include "kv_store.hpp"
|
|
#include "session_store.hpp"
|
|
#include <iostream>
|
|
#include <string.h>
|
|
#include <string_view>
|
|
#include <array>
|
|
#include <string>
|
|
#include <thread>
|
|
#include <vector>
|
|
#include <memory>
|
|
|
|
class Server {
|
|
public:
|
|
KeyValueStore store;
|
|
SessionStore sessions;
|
|
|
|
explicit Server(uint16_t port, Router& router) : port_(port), router_(router) {}
|
|
|
|
~Server() {
|
|
stop();
|
|
store.save();
|
|
sessions.save();
|
|
for (auto& worker : workers_) {
|
|
if (worker->thread.joinable()) {
|
|
worker->thread.join();
|
|
}
|
|
}
|
|
}
|
|
|
|
bool start() {
|
|
store.load();
|
|
sessions.load();
|
|
|
|
unsigned int num_cores = std::thread::hardware_concurrency();
|
|
if (num_cores == 0) num_cores = 1;
|
|
|
|
workers_.reserve(num_cores);
|
|
|
|
for (unsigned int i = 0; i < num_cores; ++i) {
|
|
auto worker = std::make_unique<Worker>(port_, router_, static_handler_, sessions);
|
|
if (!worker->socket.start()) return false;
|
|
workers_.push_back(std::move(worker));
|
|
}
|
|
|
|
for (auto& worker : workers_) {
|
|
worker->thread = std::thread([&worker]() { worker->socket.run(); });
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void run() {
|
|
for (auto& worker : workers_) {
|
|
if (worker->thread.joinable()) {
|
|
worker->thread.join();
|
|
}
|
|
}
|
|
}
|
|
|
|
void stop() {
|
|
for (auto& worker : workers_) {
|
|
worker->socket.stop();
|
|
}
|
|
}
|
|
|
|
static std::string get_path_param(string_view path, size_t segment_index = 0) {
|
|
size_t start = 0;
|
|
size_t current_segment = 0;
|
|
|
|
while (start < path.length()) {
|
|
if (path[start] == '/') {
|
|
start++;
|
|
continue;
|
|
}
|
|
|
|
size_t end = path.find('/', start);
|
|
if (end == std::string_view::npos) end = path.length();
|
|
|
|
if (current_segment == segment_index) {
|
|
return std::string(path.substr(start, end - start));
|
|
}
|
|
|
|
current_segment++;
|
|
start = end;
|
|
}
|
|
return "";
|
|
}
|
|
|
|
void serve_static(const std::string& static_dir, const std::string& url_prefix = "") {
|
|
static_handler_ = std::make_shared<StaticFileHandler>(static_dir, url_prefix);
|
|
}
|
|
|
|
private:
|
|
static constexpr int BUFFER_SIZE = 65536;
|
|
std::shared_ptr<StaticFileHandler> static_handler_;
|
|
|
|
struct Worker {
|
|
EpollSocket socket;
|
|
Router& router;
|
|
std::shared_ptr<StaticFileHandler>& static_handler;
|
|
SessionStore& sessions;
|
|
std::array<char, BUFFER_SIZE> buffer;
|
|
std::thread thread;
|
|
|
|
Worker(uint16_t port, Router& r, std::shared_ptr<StaticFileHandler>& sh, SessionStore& s)
|
|
: socket(port), router(r), static_handler(sh), sessions(s) {
|
|
socket.on_connection([this](int fd) { handle_connection(fd); });
|
|
socket.on_data([this](int fd) { handle_data(fd); });
|
|
socket.on_disconnect([this](int fd) { handle_disconnect(fd); });
|
|
}
|
|
|
|
void handle_connection(int client_fd) {
|
|
// Client connected
|
|
}
|
|
|
|
void handle_data(int client_fd) {
|
|
while (true) {
|
|
ssize_t bytes_read = read(client_fd, buffer.data(), buffer.size());
|
|
|
|
if (bytes_read == -1) {
|
|
if (errno == EAGAIN || errno == EWOULDBLOCK) break;
|
|
return;
|
|
}
|
|
|
|
if (bytes_read == 0) return;
|
|
|
|
std::string request_data(buffer.data(), bytes_read);
|
|
if (request_data.find("\r\n\r\n") != std::string::npos) {
|
|
process_request(client_fd, request_data);
|
|
}
|
|
}
|
|
}
|
|
|
|
void handle_disconnect(int client_fd) {
|
|
// Client disconnected
|
|
}
|
|
|
|
void process_request(int client_fd, std::string_view request_data) {
|
|
Request req = Parser::parse(request_data);
|
|
|
|
if (!req.valid) {
|
|
send_error_response(client_fd, "Bad Request", 400, req.version);
|
|
return;
|
|
}
|
|
|
|
// Handle session
|
|
std::string session_id = get_session_id(req);
|
|
if (session_id.empty()) {
|
|
session_id = sessions.create();
|
|
}
|
|
|
|
Response response;
|
|
|
|
// Try router first
|
|
if (router.handle(req, response)) {
|
|
set_session_cookie(response, session_id);
|
|
send_response(client_fd, response, req.version);
|
|
return;
|
|
}
|
|
|
|
// Then try static files
|
|
if (static_handler && static_handler->handle(req, response)) {
|
|
set_session_cookie(response, session_id);
|
|
send_response(client_fd, response, req.version);
|
|
} else {
|
|
response.status = 404;
|
|
response.set_text("Not Found");
|
|
set_session_cookie(response, session_id);
|
|
send_response(client_fd, response, req.version);
|
|
}
|
|
}
|
|
|
|
std::string get_session_id(const Request& req) {
|
|
auto it = req.headers.find("Cookie");
|
|
if (it == req.headers.end()) return "";
|
|
|
|
std::string_view cookies = it->second;
|
|
size_t pos = cookies.find("session_id=");
|
|
if (pos == std::string_view::npos) return "";
|
|
|
|
pos += 11; // length of "session_id="
|
|
size_t end = cookies.find(';', pos);
|
|
if (end == std::string_view::npos) end = cookies.length();
|
|
|
|
return std::string(cookies.substr(pos, end - pos));
|
|
}
|
|
|
|
void set_session_cookie(Response& response, const std::string& session_id) {
|
|
response.headers["Set-Cookie"] = "session_id=" + session_id + "; HttpOnly; Path=/; SameSite=Strict";
|
|
}
|
|
|
|
void send_response(int client_fd, const Response& response, string_view version) {
|
|
std::string http_response = ResponseBuilder::build_response(response, version);
|
|
send_raw_response(client_fd, http_response);
|
|
}
|
|
|
|
void send_raw_response(int client_fd, const std::string& response) {
|
|
ssize_t total_sent = 0;
|
|
ssize_t response_len = response.size();
|
|
|
|
while (total_sent < response_len) {
|
|
ssize_t sent = write(client_fd, response.data() + total_sent, response_len - total_sent);
|
|
if (sent == -1) {
|
|
if (errno == EAGAIN || errno == EWOULDBLOCK) continue;
|
|
break;
|
|
}
|
|
total_sent += sent;
|
|
}
|
|
}
|
|
|
|
void send_error_response(int client_fd, const std::string& message, int status, string_view version) {
|
|
std::string response = ResponseBuilder::build_error_response(status, message, version);
|
|
send_raw_response(client_fd, response);
|
|
}
|
|
};
|
|
|
|
uint16_t port_;
|
|
Router& router_;
|
|
std::vector<std::unique_ptr<Worker>> workers_;
|
|
};
|