#include "pch.h"
#include <windows.h>
#include <ntsecapi.h>
#include <string>
#include <vector>
#include <regex>
#include <shared_mutex>
#include <fstream>
#include <filesystem>
#include <ntstatus.h>
#include <cwctype>
#include <algorithm>


static const wchar_t* CONFIG_PATH = L"C:\\ProgramData\\regex.txt";
static const wchar_t* CONFIG_BLACKLIST_PATH = L"C:\\ProgramData\\blacklist.txt";
static const wchar_t* CONFIG_ENFORCE_PATH = L"C:\\ProgramData\\passwordfilter.ini";

static const size_t MAX_PWD_LEN = 512;

static std::vector<std::wregex> g_patterns;
static std::filesystem::file_time_type g_patterns_last_write_time{};
static std::shared_mutex g_patterns_lock;

static std::vector<std::wstring> g_blacklist;
static std::filesystem::file_time_type g_blacklist_last_write_time{};
static std::shared_mutex g_blacklist_lock;

static int g_enforce;
static std::filesystem::file_time_type g_enforce_last_write_time{};
static std::shared_mutex g_enforce_lock;

static void ToLower(std::wstring& s) {
    std::transform(s.begin(), s.end(), s.begin(),
                   [](wchar_t c){ return towlower(c); });
}

static bool Utf8ToUtf16(const std::string& utf8, std::wstring& out) {
    int wlen = MultiByteToWideChar(CP_UTF8, 0, utf8.data(), (int)utf8.size(), nullptr, 0);
    if (wlen <= 0) return false;
    out.resize(wlen);
    MultiByteToWideChar(CP_UTF8, 0, utf8.data(), (int)utf8.size(), &out[0], wlen);
    return true;
}

static void TrimWhitespace(std::wstring& s) {
    auto begin = std::find_if_not(
        s.begin(), s.end(), [](wchar_t c) {
            return iswspace(c);
        });
    if (begin == s.end()) {
        s.clear();
        return;
    }
    auto rbegin = std::find_if_not(
        s.rbegin(), s.rend(),
        [](wchar_t c) {
            return iswspace(c);
        });
    auto end = rbegin.base();
    s.assign(begin, end);
}

static bool ReadUtf8Lines(const wchar_t* path, std::vector<std::wstring>& out) {
    std::ifstream f(path, std::ios::binary);
    if (!f) return false;
    std::string line;
    while (std::getline(f, line)) {
        std::wstring wline;
        if (!Utf8ToUtf16(line, wline)) continue;
        while (!wline.empty() && (wline.back() == L'\r' || wline.back() == L'\n')) wline.pop_back();
        TrimWhitespace(wline);
        if (!wline.empty()) {
            out.push_back(std::move(wline));
        }
    }
    return true;
}

static bool UniStrToWstring(const UNICODE_STRING* ustr, std::wstring &out) {
    if (!ustr) return false;
    size_t len = ustr->Length / sizeof(WCHAR);
    if (len == 0) {
        out.clear();
        return true;
    }
    if (len > MAX_PWD_LEN) {
        return false;
    }
    out.assign(ustr->Buffer, ustr->Buffer + len);
    return true;
}

static void LoadPatternsIfNeeded() {
    try {
        std::error_code ec;
        auto ft = std::filesystem::last_write_time(CONFIG_PATH, ec);
        if (ec) return;
        {
            std::shared_lock readLock(g_patterns_lock);
            if (ft == g_patterns_last_write_time) return;
        }
        std::unique_lock writeLock(g_patterns_lock);
        if (ft == g_patterns_last_write_time) return;
        std::vector<std::wstring> lines;
        if (!ReadUtf8Lines(CONFIG_PATH, lines)) {
            return;
        }
        std::vector<std::wregex> patterns;
        patterns.reserve(lines.size());
        for (const auto& line : lines) {
            try {
                if (!line.empty()) {
                    patterns.emplace_back(line, std::regex::ECMAScript | std::regex::optimize | std::regex::icase);
                }
            } catch (...) {
                continue;
            }
        }
        g_patterns.swap(patterns);
        g_patterns_last_write_time = ft;
    } catch (...) {
    }
}

static void LoadEnforceIfNeeded() {
    try {
        std::error_code ec;
        auto ft = std::filesystem::last_write_time(CONFIG_ENFORCE_PATH, ec);
        if (ec) return;
        {
            std::shared_lock readLock(g_enforce_lock);
            if (ft == g_enforce_last_write_time) return;
        }
        std::unique_lock writeLock(g_enforce_lock);
        if (ft == g_enforce_last_write_time) return;
        std::vector<std::wstring> lines;
        if (!ReadUtf8Lines(CONFIG_ENFORCE_PATH, lines)) {
            return;
        }
        int enforce = g_enforce;
        for (const auto& line : lines) {
            if (line.empty()) continue;
            if (line[0] == L'#') continue;
            if (line == L"enforce=0") enforce = 0;
        }
        g_enforce = enforce;
        const std::wstring key = L"enforce=";
        for(auto line: lines){
            TrimWhitespace(line);
            if (line.empty()) continue;
            ToLower(line);
            auto pos = line.find(key);
            if (pos == std::wstring::npos) continue;
            auto val = line.substr(pos + key.size());
            TrimWhitespace(val);
            if (val.empty()) continue;
            enforce = std::stoi(val);
            break;
        }
        g_enforce = enforce;
        g_enforce_last_write_time = ft;
    } catch (...) {
    }
}

static void LoadBlacklistIfNeeded() {
    try {
        std::error_code ec;
        auto ft = std::filesystem::last_write_time(CONFIG_BLACKLIST_PATH, ec);
        if (ec) return;
        {
            std::shared_lock readLock(g_blacklist_lock);
            if (ft == g_blacklist_last_write_time) return;
        }
        std::unique_lock writeLock(g_blacklist_lock);
        if (ft == g_blacklist_last_write_time) return;
        std::vector<std::wstring> lines;
        if (!ReadUtf8Lines(CONFIG_BLACKLIST_PATH, lines)) {
            return;
        }

        for (auto& s : lines) ToLower(s);

        g_blacklist.swap(lines);
        g_blacklist_last_write_time = ft;
    } catch (...) {
    }
}

extern "C" BOOLEAN __stdcall InitializeChangeNotify(void) {
    LoadEnforceIfNeeded();
    LoadBlacklistIfNeeded();
    LoadPatternsIfNeeded();
    return TRUE;
}

extern "C" BOOLEAN __stdcall PasswordFilter(
    PUNICODE_STRING AccountName,
    PUNICODE_STRING FullName,
    PUNICODE_STRING Password,
    BOOLEAN SetOperation
) {
    std::wstring pwd;
    try {
        LoadEnforceIfNeeded();
        {
            std::shared_lock enforceLock(g_enforce_lock);
            if (g_enforce == 0) return TRUE;
        }
        if (!UniStrToWstring(Password, pwd)) return TRUE;
        if (pwd.empty()) return TRUE;
        ToLower(pwd);
        LoadBlacklistIfNeeded();
        LoadPatternsIfNeeded();
        {
            std::shared_lock readLock(g_blacklist_lock);
            for (const auto& bl : g_blacklist) {
                if (bl.empty()) continue;
                if (pwd.find(bl) != std::wstring::npos) {
                    SecureZeroMemory(pwd.data(), pwd.size() * sizeof(wchar_t));
                    return FALSE;
                }
            }
        }
        {
            std::shared_lock readLock(g_patterns_lock);
            for (const auto& pattern : g_patterns) {
                if (std::regex_search(pwd, pattern)) {
                    SecureZeroMemory(pwd.data(), pwd.size() * sizeof(wchar_t));
                    return FALSE;
                }
            }
        }
        SecureZeroMemory(pwd.data(), pwd.size() * sizeof(wchar_t));
        return TRUE;
    } catch (...) {
        SecureZeroMemory(pwd.data(), pwd.size() * sizeof(wchar_t));
        return TRUE;
    }

}

extern "C" NTSTATUS __stdcall PasswordChangeNotify(
    PUNICODE_STRING UserName,
    ULONG RelativeId,
    PUNICODE_STRING NewPassword
) {
    return STATUS_SUCCESS;
}