commit 6e7e30eeaa4b47c5d602e8bb2745c2080cc81205 Author: ovizro Date: Tue Dec 24 19:51:32 2024 +0800 init repo diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..598aab4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,31 @@ +################### CMake config ################### + +build + +################ Executable program ################ + +bin/ +dist/ + +#################### Library files ################# + +lib/ + +################### VSCode config ################## + +.vscode +.VSCodeCounter + +####################### Tool chains ###################### + +toolchains/* + +#################### Test config ################### + +test.* +disabled.* +*.disabled + +###################### Misc ####################### + +!.gitkeep diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..9ac9339 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,48 @@ +cmake_minimum_required(VERSION 3.15) + +project(transport CXX) + +set(CMAKE_CXX_STANDARD 11) +if (NOT CMAKE_CROSSCOMPILING) + set(EXECUTABLE_OUTPUT_PATH ${CMAKE_SOURCE_DIR}/bin) + set(LIBRARY_OUTPUT_PATH ${CMAKE_SOURCE_DIR}/lib) +endif() + +if (UNIX) + set(CMAKE_CXX_FLAGS "-std=c++11 -Wall ${CMAKE_CXX_FLAGS}") + set(CMAKE_CXX_FLAGS_DEBUG "-g ${CMAKE_CXX_FLAGS}") + set(CMAKE_CXX_FLAGS_RELEASE "-g -O2 ${CMAKE_CXX_FLAGS}") +elseif (WIN32) + # windows platform + #add_definitions(-D_CRT_SECURE_NO_WARNINGS) + #set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MDd") + #set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MD") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MTd /EHsc") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT /EHsc") +endif() + +if(MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /utf-8") +endif() + +if (CMAKE_CROSSCOMPILING) + message(STATUS "Cross compiling ...") + message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") + message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") +endif() + +include_directories(./include) + +aux_source_directory(${CMAKE_SOURCE_DIR}/src SRC_LIST) + +link_libraries(pthread util) + +add_library(transport_static STATIC ${SRC_LIST}) +add_library(transport SHARED ${SRC_LIST}) +set_target_properties(transport_static PROPERTIES OUTPUT_NAME transport) + +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + target_compile_options(transport PRIVATE -Wno-nonnull) +endif() + +add_subdirectory(tests) diff --git a/README.md b/README.md new file mode 100644 index 0000000..596e595 --- /dev/null +++ b/README.md @@ -0,0 +1,4 @@ +# Transport + +A simple C++ language data transfer module. + diff --git a/include/dataqueue.hpp b/include/dataqueue.hpp new file mode 100644 index 0000000..f31bc53 --- /dev/null +++ b/include/dataqueue.hpp @@ -0,0 +1,143 @@ +#ifndef _INCLUDED_QUEUE_ +#define _INCLUDED_QUEUE_ + +#include +#include +#include +#include +#include + +template +class DataQueue; + +class QueueException : public std::exception +{ +public: + QueueException(void* queue) : queue(queue) {} + + template + DataQueue* GetQueue() const noexcept + { + return static_cast*>(queue); + } + +private: + void* queue; +}; + +class QueueCleared : public QueueException +{ +public: + QueueCleared(void* queue) : QueueException(queue) {} + + const char* what() const noexcept override + { + return "queue cleared"; + } +}; + +class QueueTimeout : public QueueException +{ +public: + QueueTimeout(void* queue) : QueueException(queue) {} + + const char* what() const noexcept override + { + return "queue timeout"; + } +}; + +typedef uint8_t queue_epoch_t; + +template +class DataQueue { +public: + DataQueue() : m_CurrEpoch(0) {} + DataQueue(const DataQueue&) = delete; + + ~DataQueue() { Clear(); } + + void Push(T data) + { + std::lock_guard lock(m_Mutex); + m_Queue.push_back(std::move(data)); + m_Cond.notify_one(); + } + + T Pop() + { + std::unique_lock lock(m_Mutex); + auto epoch = m_CurrEpoch; + while (m_Queue.empty()) + { + m_Cond.wait(lock); + if (epoch != m_CurrEpoch) { + throw QueueCleared(this); + } + } + T data = std::move(m_Queue.front()); + m_Queue.pop_front(); + return data; + } + + template + T Pop(const std::chrono::duration timeout) + { + std::unique_lock lock(m_Mutex); + auto epoch = m_CurrEpoch; + while (m_Queue.empty()) + { + if (m_Cond.wait_for(lock, timeout) == std::cv_status::timeout) { + throw QueueTimeout(this); + } + if (epoch != m_CurrEpoch) { + throw QueueCleared(this); + } + } + T data = std::move(m_Queue.front()); + m_Queue.pop_front(); + return data; + } + + bool Empty() noexcept + { + std::lock_guard lock(m_Mutex); + return m_Queue.empty(); + } + + size_t Size() noexcept + { + std::lock_guard lock(m_Mutex); + return m_Queue.size(); + } + + queue_epoch_t GetEpoch() noexcept + { + std::lock_guard lock(m_Mutex); + return m_CurrEpoch; + } + + bool CheckEpoch(queue_epoch_t epoch) noexcept + { + std::lock_guard lock(m_Mutex); + return m_CurrEpoch == epoch; + } + + void Clear() noexcept + { + std::lock_guard lock(m_Mutex); + m_Queue.clear(); + m_CurrEpoch++; + m_Cond.notify_all(); + } + +protected: + std::deque m_Queue; + std::mutex m_Mutex; + std::condition_variable m_Cond; + +private: + queue_epoch_t m_CurrEpoch; +}; + +#endif \ No newline at end of file diff --git a/include/logging/interface.h b/include/logging/interface.h new file mode 100644 index 0000000..e0b3444 --- /dev/null +++ b/include/logging/interface.h @@ -0,0 +1,52 @@ +#ifndef _INCLUDE_LOGGING_INTERFACE_ +#define _INCLUDE_LOGGING_INTERFACE_ + +#include +#include +#include "level.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void log_init(const char* level); +int log_level(void); +void log_set_level(int level); +void log_log(int level, const char *fmt, ...); +void log_vlog(int level, const char *fmt, va_list ap); + +#define log_debug(...) log_log(LOG_LEVEL_DEBUG, __VA_ARGS__) +#define log_info(...) log_log(LOG_LEVEL_INFO, __VA_ARGS__) +#define log_warn(...) log_log(LOG_LEVEL_WARN, __VA_ARGS__) +#define log_error(...) log_log(LOG_LEVEL_ERROR, __VA_ARGS__) +#define log_fatal(...) log_log(LOG_LEVEL_FATAL, __VA_ARGS__) + +#define vlog_debug(...) log_vlog(LOG_LEVEL_DEBUG, __VA_ARGS__) +#define vlog_info(...) log_vlog(LOG_LEVEL_INFO, __VA_ARGS__) +#define vlog_warn(...) log_vlog(LOG_LEVEL_WARN, __VA_ARGS__) +#define vlog_error(...) log_vlog(LOG_LEVEL_ERROR, __VA_ARGS__) +#define vlog_fatal(...) log_vlog(LOG_LEVEL_FATAL, __VA_ARGS__) + +#define log_with_source(level, msg) log_log(level,\ + "Trackback (most recent call last):\n File \"%s\", line %d, in \"%s\"\n%s",\ + __FILE__, __LINE__, __ASSERT_FUNCTION, msg) +#define log_debug_with_source(msg) log_with_source(LOG_LEVEL_DEBUG, msg) +#define log_info_with_source(msg) log_with_source(LOG_LEVEL_INFO, msg) +#define log_warn_with_source(msg) log_with_source(LOG_LEVEL_WARN, msg) +#define log_error_with_source(msg) log_with_source(LOG_LEVEL_ERROR, msg) +#define log_fatal_with_source(msg) log_with_source(LOG_LEVEL_FATAL, msg) + +#define log_from_errno(level, msg) do {\ + if (errno) log_log(level,"%s: %s", msg, strerror(errno));\ +} while (0) +#define log_debug_from_errno(msg) log_from_errno(LOG_LEVEL_DEBUG, msg) +#define log_info_from_errno(msg) log_from_errno(LOG_LEVEL_INFO, msg) +#define log_warn_from_errno(msg) log_from_errno(LOG_LEVEL_WARN, msg) +#define log_error_from_errno(msg) log_from_errno(LOG_LEVEL_ERROR, msg) +#define log_fatal_from_errno(msg) log_from_errno(LOG_LEVEL_FATAL, msg) + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/include/logging/level.h b/include/logging/level.h new file mode 100644 index 0000000..e1ba713 --- /dev/null +++ b/include/logging/level.h @@ -0,0 +1,10 @@ +#ifndef _INCLUDE_LOGGING_LEVEL_ +#define _INCLUDE_LOGGING_LEVEL_ + +#define LOG_LEVEL_DEBUG 1 +#define LOG_LEVEL_INFO 2 +#define LOG_LEVEL_WARN 3 +#define LOG_LEVEL_ERROR 4 +#define LOG_LEVEL_FATAL 5 + +#endif \ No newline at end of file diff --git a/include/logging/logger.hpp b/include/logging/logger.hpp new file mode 100644 index 0000000..1af66d0 --- /dev/null +++ b/include/logging/logger.hpp @@ -0,0 +1,260 @@ +#ifndef _INCLUDE_LOGGER_ +#define _INCLUDE_LOGGER_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "level.h" + +namespace logging { + +struct Record; +class LoggerOStream; +class LoggerStreamBuf; + +class Logger +{ +public: + enum class Level: uint8_t + { + UNKNOWN = 0, + DEBUG = LOG_LEVEL_DEBUG, + INFO = LOG_LEVEL_INFO, + WARN = LOG_LEVEL_WARN, + ERROR = LOG_LEVEL_ERROR, + FATAL = LOG_LEVEL_FATAL + }; + + constexpr static const char* namesep = "::"; + + explicit Logger(Level level = Level::INFO) : _parent(nullptr), _level(level) {} + Logger(const Logger&) = delete; + virtual ~Logger() = default; + + inline Level level() const + { + if (_level == Logger::Level::UNKNOWN) { + if (_parent) { + return _parent->level(); + } + return Logger::Level::INFO; + } + return _level; + } + + inline const std::string& name() const + { + return _name; + } + + inline Logger* parent() const + { + return _parent; + } + + inline size_t children_count() const + { + return logger_cache.size(); + } + + inline void set_level(Level level) + { + _level = level; + } + + template + inline void log(const char* fmt, ...) { + va_list args; + va_start(args, fmt); + vlog(level, fmt, args); + va_end(args); + } + inline void log(Level level, const char* fmt, ...) + { + va_list args; + va_start(args, fmt); + vlog(level, fmt, args); + va_end(args); + } + inline void debug(const char* fmt, ...) + { + va_list args; + va_start(args, fmt); + vlog(Logger::Level::DEBUG, fmt, args); + va_end(args); + } + void info(const char* fmt, ...) + { + va_list args; + va_start(args, fmt); + vlog(Logger::Level::INFO, fmt, args); + va_end(args); + } + void warn(const char* fmt, ...) + { + va_list args; + va_start(args, fmt); + vlog(Logger::Level::WARN, fmt, args); + va_end(args); + } + void error(const char* fmt, ...) + { + va_list args; + va_start(args, fmt); + vlog(Logger::Level::ERROR, fmt, args); + va_end(args); + } + void fatal(const char* fmt, ...) + { + va_list args; + va_start(args, fmt); + vlog(Logger::Level::FATAL, fmt, args); + va_end(args); + } + template + inline void vlog(const char* fmt, va_list args) { + vlog(level, fmt, args); + } + template + inline void raise_from_errno(const char* msg) { + error("%s: %s", msg, strerror(errno)); + throw E(msg); + } + + void vlog(Level level, const char* fmt, va_list args); + virtual void log_message(Level level, const std::string& msg); + + LoggerOStream operator[](Level level); + LoggerOStream operator[](int level); + + void move_children_to(Logger& other); + + template + Logger* get_child(const std::string& name, Level level) + { + Logger* logger; + auto sep_pos = name.find(namesep); + do { + if (name.empty() || sep_pos == 0) { + logger = this; + break; + } + std::string base_name; + if (sep_pos == std::string::npos) { + base_name = name; + } else { + base_name = name.substr(0, sep_pos); + } + auto iter = logger_cache.find(base_name); + + if (iter != logger_cache.end()) { + logger = iter->second.get(); + break; + } + logger = new T_Logger(base_name, level, this); + logger_cache[base_name] = std::unique_ptr(logger); + } while (0); + auto seqlen = strlen(namesep); + if (sep_pos == std::string::npos || sep_pos + seqlen >= name.size()) + return logger; + return logger->get_child(name.substr(sep_pos + seqlen), level); + } + + inline void add_stream(std::ostream& stream) + { + _streams.push_back(std::unique_ptr(new std::ostream(stream.rdbuf()))); + } + + template + inline void add_stream(T&& stream) + { + _streams.push_back(std::unique_ptr(new typename std::remove_reference::type(std::move(stream)))); + } + + inline std::vector streams() const + { + std::vector streams; + for (auto& stream : _streams) + streams.push_back(stream.get()); + return streams; + } + +protected: + explicit Logger(const std::string& name, Level level = Level::INFO, Logger* parent = nullptr); + + virtual void log_record(const Record& record); + virtual void write_record(std::ostream& os, const Record& record); + + static constexpr std::ostream& default_stream = std::cerr; + + std::vector> _streams; + Logger* _parent; + +private: + std::unordered_map> logger_cache; + std::string _name; + Level _level; +}; + +class LoggerStreamBuf : public std::streambuf { +public: + LoggerStreamBuf(Logger& logger, Logger::Level level); + LoggerStreamBuf(const LoggerStreamBuf&) = delete; + LoggerStreamBuf(LoggerStreamBuf&& loggerStream) = default; + +protected: + int overflow(int c) override; + std::streamsize xsputn(const char* s, std::streamsize n) override; + int sync() override; + +private: + void flush_line(); // 将当前行写入日志 + + Logger& _logger; + Logger::Level _level; + std::string _lineBuffer; // 行缓存 +}; + +class LoggerOStream : public std::ostream { +public: + LoggerOStream(Logger& logger, Logger::Level level); + LoggerOStream(const LoggerOStream&) = delete; + LoggerOStream(LoggerOStream&& loggerStream); + +private: + LoggerStreamBuf _streamBuf; +}; + +struct Record +{ + std::string name; + std::chrono::system_clock::time_point time; + Logger::Level level; + std::string msg; + + Record(const std::string& name, Logger::Level level, const std::string& msg); +}; + +using LogLevel = Logger::Level; + +std::ostream& operator<<(std::ostream& os, const Logger::Level level); +Logger::Level str2level(const char* level); +std::unique_ptr& _get_global_logger(); +Logger& get_global_logger(); +void set_global_logger(std::unique_ptr&& logger); + +template +static inline Logger* get_logger(const std::string& name, LogLevel level = LogLevel::UNKNOWN) +{ + return _get_global_logger()->get_child(name, level); +} + +} + +#endif \ No newline at end of file diff --git a/include/transport/base.hpp b/include/transport/base.hpp new file mode 100644 index 0000000..e5ac534 --- /dev/null +++ b/include/transport/base.hpp @@ -0,0 +1,164 @@ +#ifndef _INCLUDE_TRANSPORT_BASE_ +#define _INCLUDE_TRANSPORT_BASE_ + +#include +#include +#include +#include +#include +#include "dataqueue.hpp" +#include "logging/logger.hpp" +#include "protocol.hpp" + +#define TRANSPORT_MAX_RETRY 5 +#define TRANSPORT_TIMEOUT 1000 + +namespace transport +{ + +template +class BaseTransport; + +class _transport_base { +public: + _transport_base() : is_open(false), is_closed(false) {} + _transport_base(const _transport_base&) = delete; + virtual ~_transport_base() { + close(); + } + + virtual void open() { + if (is_open) + return; + is_open = true; + std::thread([this] { + try { + send_backend(); + } catch (const QueueCleared&) {} + }).detach(); + std::thread([this] { + try { + receive_backend(); + } catch (const QueueCleared&) {} + }).detach(); + } + + virtual void close() { + is_closed = true; + } + + bool closed() const { + return is_closed; + } + +protected: + void ensure_open() { + auto& logger = *logging::get_logger("transport"); + if (is_closed) + { + logger.fatal("transport closed"); + throw std::runtime_error("transport closed"); + } + if (!is_open) + { + open(); + } + } + virtual void send_backend() = 0; + virtual void receive_backend() = 0; + + bool is_open; + bool is_closed; +}; + +class TransportToken +{ +public: + explicit TransportToken(_transport_base *transport) : transport_(transport) {} + + template + BaseTransport

*transport() const { + return dynamic_cast*>(transport_); + } + virtual bool operator==(const TransportToken &other) const { + return transport_ == other.transport_; + } + +protected: + _transport_base *transport_; + friend class _transport_base; + friend std::hash; +}; + +template +class BaseTransport : public _transport_base +{ +public: + typedef P Protocol; + typedef typename P::FrameType FrameType; + + ~BaseTransport() override + { + close(); + } + + template + inline void send(typename P::FrameType frame, std::shared_ptr token = nullptr) + { + ensure_open(); + send_que.Push(std::make_pair(frame, token)); + } + template + inline std::pair> receive(std::chrono::duration dur = std::chrono::milliseconds(0)) + { + ensure_open(); + DataPair frame_pair; + if (!dur.count()) + frame_pair = recv_que.Pop(); + else + frame_pair = recv_que.Pop(dur); + return frame_pair; + } + template + inline typename P::FrameType request(typename P::FrameType frame, int max_retry = TRANSPORT_MAX_RETRY, std::chrono::duration dur = std::chrono::milliseconds(TRANSPORT_TIMEOUT)) + { + ensure_open(); + while (max_retry--) + { + send_que.Push(std::make_pair(frame, nullptr)); + DataPair frame_pair = recv_que.Pop(dur); + if (frame_pair.first) { + return frame_pair.first; + } + auto& logger = *logging::get_logger("transport"); + logger.warn("request timeout, retrying..."); + } + return nullptr; + } + + void close() override { + is_closed = true; + recv_que.Clear(); + send_que.Clear(); + } + + typedef std::pair> DataPair; + +protected: + DataQueue send_que; + DataQueue recv_que; +}; + +} + +namespace std { + template <> + struct hash + { + size_t operator()(const transport::TransportToken &token) const + { + return std::hash()(reinterpret_cast(token.transport_)); + } + }; +} +#endif \ No newline at end of file diff --git a/include/transport/protocol.hpp b/include/transport/protocol.hpp new file mode 100644 index 0000000..ae2a8a5 --- /dev/null +++ b/include/transport/protocol.hpp @@ -0,0 +1,35 @@ +#ifndef _INCLUDE_TRANSPORT_PROTOCOL_ +#define _INCLUDE_TRANSPORT_PROTOCOL_ + +#include +#include + +namespace transport +{ + +class Protocol { +public: + typedef std::vector FrameType; + + static ssize_t pred_size(void* buf, size_t size) { + if (buf == nullptr) return 1; + return size; + } + + static FrameType make_frame(void* buf, size_t size) { + if (buf == nullptr) return FrameType(); + return FrameType((uint8_t*)buf, (uint8_t*)buf + size); + } + + static size_t frame_size(FrameType frame) { + return frame.size(); + } + + static void* frame_data(FrameType frame) { + return frame.data(); + } +}; + +} + +#endif \ No newline at end of file diff --git a/include/transport/serial_port.hpp b/include/transport/serial_port.hpp new file mode 100644 index 0000000..c4d2d39 --- /dev/null +++ b/include/transport/serial_port.hpp @@ -0,0 +1,306 @@ +#ifndef _INCLUDE_TRANSPORT_TTY_ +#define _INCLUDE_TRANSPORT_TTY_ + +#include +#include +#include +#include +#include +#include "base.hpp" + +#define TRANSPORT_SERIAL_PORT_BUFFER_SIZE 1024 * 1024 +// #define TRANSPORT_SERIAL_PORT_DEBUG +#ifdef COM_FRAME_DEBUG +#define TRANSPORT_SERIAL_PORT_DEBUG +#endif + +namespace transport +{ + +template +class SerialPortTransport : public BaseTransport

+{ +public: + SerialPortTransport(const std::string &path, int baudrate = 115200, size_t buffer_size = TRANSPORT_SERIAL_PORT_BUFFER_SIZE) + : path(path), tty_id(-1), baudrate(baudrate), buffer_size(buffer_size) {} + + SerialPortTransport(int tty_id, int baudrate = 115200, size_t buffer_size = 1024) + : tty_id(tty_id), baudrate(baudrate), buffer_size(buffer_size) {} + + ~SerialPortTransport() override + { + close(); + } + + void open() override + { + auto &logger = *logging::get_logger("transport"); + if (this->is_open) + { + return; + } + else if (this->is_closed) + { + logger.info("reopen serial port transport"); + this->is_open = false; + this->is_closed = false; + } + + if (tty_id < 0) + { + logger.info("open serial port %s", path.c_str()); + tty_id = ::open(path.c_str(), O_RDWR | O_NOCTTY | O_NONBLOCK); + if (tty_id < 0) + { + logger.raise_from_errno("open serial port failed"); + } + } + else + { + logger.info("use serial port %d", tty_id); + } + + struct termios options, old; + // deploy usart par + memset(&options, 0, sizeof(options)); + int ret = tcgetattr(tty_id, &old); + if (ret != 0) + { + logger.error("tcgetattr failed: %s", strerror(errno)); + goto end; + } + + tcflush(tty_id, TCIOFLUSH); + + switch (baudrate) + { + case 9600: + cfsetispeed(&options, B9600); + cfsetospeed(&options, B9600); + break; + case 19200: + cfsetispeed(&options, B19200); + cfsetospeed(&options, B19200); + break; + case 38400: + cfsetispeed(&options, B38400); + cfsetospeed(&options, B38400); + break; + case 57600: + cfsetispeed(&options, B57600); + cfsetospeed(&options, B57600); + break; + case 115200: + cfsetispeed(&options, B115200); + cfsetospeed(&options, B115200); + break; + case 576000: + cfsetispeed(&options, B576000); + cfsetospeed(&options, B576000); + break; + case 921600: + cfsetispeed(&options, B921600); + cfsetospeed(&options, B921600); + break; + case 2000000: + cfsetispeed(&options, B2000000); + cfsetospeed(&options, B2000000); + break; + case 3000000: + cfsetispeed(&options, B3000000); + cfsetospeed(&options, B3000000); + break; + default: + logger.error("bad baud rate %u", baudrate); + break; + } + switch (1) + { + case 0: + options.c_cflag &= ~PARENB; + options.c_cflag &= ~INPCK; + break; + case 1: + options.c_cflag |= (PARODD // 使用奇校验代替偶校验 + | PARENB); // 校验位有效 + options.c_iflag |= INPCK; // 校验有效 + break; + case 2: + options.c_cflag |= PARENB; + options.c_cflag &= ~PARODD; + options.c_iflag |= INPCK; + break; + case 3: + options.c_cflag &= ~PARENB; + options.c_cflag &= ~CSTOPB; + break; + default: + options.c_cflag &= ~PARENB; + break; + } + + options.c_cflag |= (CLOCAL | CREAD); + options.c_cflag &= ~CSIZE; + options.c_cflag &= ~CRTSCTS; + options.c_cflag |= CS8; + options.c_cflag &= ~CSTOPB; + options.c_oflag = 0; + options.c_lflag = 0; + options.c_cc[VTIME] = 0; + options.c_cc[VMIN] = 0; + // 启用输出的XON/XOFF控制字符 + // Enable software flow control (XON/XOFF) for both input and output + options.c_iflag |= (IXON | IXOFF); // Enable input and output XON/XOFF control characters + options.c_oflag |= (IXON | IXOFF); // Enable input and output XON/XOFF control characters + tcflush(tty_id, TCIFLUSH); + + if ((tcsetattr(tty_id, TCSANOW, &options)) != 0) + { + logger.error("tcsetattr failed: %s", strerror(errno)); + } + + end: + super::open(); + } + + void close() override + { + if (this->is_open && !this->closed()) + { + auto &logger = *logging::get_logger("transport"); + logger.info("close serial port %s", path.c_str()); + ::close(tty_id); + } + super::close(); + } + +protected: + void send_backend() override + { + // this->ensure_open(); + auto &logger = *logging::get_logger("transport"); + logger.debug("start serial port send backend"); + while (!this->is_closed) + { + auto frame_pair = this->send_que.Pop(); + auto frame = frame_pair.first; + if (frame_pair.second && frame_pair.second->template transport

() != this) + { + logger.error("invalid token received"); + continue; + } + size_t remaining_size = P::frame_size(frame); + if (remaining_size == 0) { + continue; + } + logger.debug("send data %zu", remaining_size); + size_t offset = 0; + while (remaining_size > 0) + { + auto written_size = write(tty_id, static_cast(P::frame_data(frame)) + offset, remaining_size); + if (written_size < 0) + { + logger.error("write serial port failed: %s", strerror(errno)); + } + else + { + remaining_size -= written_size; + offset += written_size; + } + } + } + } + + void receive_backend() override + { + // this->ensure_open(); + auto &logger = *logging::get_logger("transport"); + logger.debug("start serial port receive backend"); + bool find_head = false; + size_t min_size = P::pred_size(nullptr, 0); + if (!min_size) min_size = 1; + ssize_t recv_size; + size_t pred_size = min_size * 2; // min buffer size to keep, + size_t offset = 0; // scanned data size + size_t cached_size = 0; // read data size + assert(buffer_size >= pred_size); + + uint8_t *buffer = new uint8_t[buffer_size]; + + while (!this->is_closed) + { + recv_size = read(tty_id, ((uint8_t *)buffer) + cached_size, buffer_size - cached_size); + if (recv_size <= 0) + { + if (cached_size == offset) { + usleep(10); + continue; + } + recv_size = 0; + } +#ifdef TRANSPORT_SERIAL_PORT_DEBUG + printf("receive com data (received=%zd,cached=%zu)\nbuffer: ", + recv_size, cached_size); + for (size_t i = 0; i < cached_size + recv_size; ++i) + { + printf("%02x", ((uint8_t *)buffer)[i]); + } + putchar('\n'); +#endif + cached_size += recv_size; + + if (!find_head) + { + // update offset to scan the header + for (; offset + min_size < cached_size; ++offset) + { + ssize_t pred = P::pred_size(((uint8_t *)buffer) + offset, cached_size - offset); + find_head = pred > 0; + if (find_head) + { + pred_size = pred; + logger.debug("find valid data (length=%zu)", pred_size); + if (pred_size > buffer_size) + { + logger.error("data size is too large (%zu)\n", pred_size); + find_head = false; + pred_size = min_size * 2; + continue; + } + break; + } + } + } + + if (find_head && cached_size >= pred_size + offset) + { + // all data received + auto frame = P::make_frame((uint8_t *)buffer + offset, pred_size); + logger.debug("receive data %zu", pred_size); + this->recv_que.Push(std::make_pair(std::move(frame), std::make_shared(this))); + offset += pred_size; // update offset for next run + } + + // clear the cache when the remaining length of the cache is + if (offset && buffer_size - cached_size < pred_size) + { + cached_size -= offset; + memmove(buffer, ((uint8_t *)buffer) + offset, cached_size); + offset = 0; + } + } + delete[] buffer; + } + +private: + typedef BaseTransport

super; + + std::string path; + int tty_id; + int baudrate; + size_t buffer_size; +}; + +} + +#endif \ No newline at end of file diff --git a/include/transport/udp.hpp b/include/transport/udp.hpp new file mode 100644 index 0000000..3ebf763 --- /dev/null +++ b/include/transport/udp.hpp @@ -0,0 +1,247 @@ +#ifndef _INCLUDE_TRANSPORT_UDP_ +#define _INCLUDE_TRANSPORT_UDP_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "base.hpp" + +#define TRANSPORT_UDP_BUFFER_SIZE 1024 * 64 + +namespace transport { + +class DatagramTransportToken : public TransportToken { +public: + explicit DatagramTransportToken(_transport_base* transport, const struct sockaddr_in& addr, socklen_t addr_len) + : TransportToken(transport), addr(addr), addr_len(addr_len) {} + + bool operator==(const TransportToken& other) const override + { + auto other_token = dynamic_cast(&other); + if (!other_token) + { + return false; + } + if (transport_ != other_token->transport_ || addr_len != other_token->addr_len) + { + return false; + } + return memcmp(&addr, &other_token->addr, addr_len) == 0; + } + +protected: + struct sockaddr_in addr; + socklen_t addr_len; + + friend std::hash; + template + friend class DatagramTransport; +}; + +template +class DatagramTransport : public BaseTransport

{ +public: + explicit DatagramTransport(size_t buffer_size = TRANSPORT_UDP_BUFFER_SIZE) + : sockfd(-1), buffer_size(buffer_size) + { + memset(&bind_addr, 0, sizeof(bind_addr)); + memset(&connect_addr, 0, sizeof(connect_addr)); + bind_addr.sin_family = AF_INET; + connect_addr.sin_family = AF_INET; + } + + DatagramTransport(std::pair local_addr, std::pair remote_addr, size_t buffer_size = TRANSPORT_UDP_BUFFER_SIZE) + : DatagramTransport(buffer_size) + { + bind_addr.sin_port = htons(local_addr.second); + resolve_hostname(local_addr.first, bind_addr); + + connect_addr.sin_port = htons(remote_addr.second); + resolve_hostname(remote_addr.first, connect_addr); + } + + ~DatagramTransport() + { + close(); + } + + + void open() override + { + auto &logger = *logging::get_logger("transport"); + if (this->is_open) + { + return; + } else if (this->is_closed) + { + logger.info("reopen datagram transport"); + this->is_open = false; + this->is_closed = false; + } + + sockfd = socket(AF_INET, SOCK_DGRAM, 0); + if (sockfd < 0) { + logger.raise_from_errno("failed to create socket"); + } else { + logger.info("open socket fd %d", sockfd); + } + + super::open(); + + if (bind_addr.sin_port) + { + if (::bind(sockfd, (struct sockaddr *)&bind_addr, sizeof(bind_addr)) < 0) + { + logger.raise_from_errno("failed to bind socket"); + } + logger.info("listening on %s:%d", inet_ntoa(bind_addr.sin_addr), ntohs(bind_addr.sin_port)); + } + } + + void close() override + { + if (this->is_open && !this->closed()) + { + auto &logger = *logging::get_logger("transport"); + logger.info("close socket fd %d", sockfd); + ::close(sockfd); + } + super::close(); + } + + void bind(const std::string& address, int port) + { + this->ensure_open(); + auto &logger = *logging::get_logger("transport"); + + resolve_hostname(address, bind_addr); + bind_addr.sin_port = htons(port); + + if (::bind(sockfd, (struct sockaddr *)&bind_addr, sizeof(bind_addr)) < 0) + { + logger.raise_from_errno("failed to bind socket"); + } + logger[logging::LogLevel::INFO] << "listening on " << address << ":" << port << std::endl; + } + + void connect(const std::string& address, int port) + { + auto &logger = *logging::get_logger("transport"); + + resolve_hostname(address, connect_addr); + connect_addr.sin_port = htons(port); + logger[logging::LogLevel::INFO] << "connecting to " << address << ":" << port << std::endl; + } + + constexpr static std::pair nulladdr = {"", 0}; + +protected: + void send_backend() override + { + // this->ensure_open(); + auto &logger = *logging::get_logger("transport"); + logger.debug("start datagram send backend"); + while (!this->is_closed) + { + auto frame_pair = this->send_que.Pop(); + auto frame = frame_pair.first; + if (!P::frame_size(frame)) + continue; + auto token = dynamic_cast((frame_pair.second.get())); + struct sockaddr* addr = (struct sockaddr *)((token) ? &token->addr : &connect_addr); + socklen_t addr_len = (token) ? token->addr_len : sizeof(connect_addr); + + ssize_t sent_size = sendto(sockfd, P::frame_data(frame), P::frame_size(frame), 0, + addr, addr_len); + logger.debug("send data %zd", sent_size); + if (sent_size < 0) + { + logger.error("udp send failed: %s", strerror(errno)); + } + else if ((size_t)sent_size < P::frame_size(frame)) + { + logger.warn("sendto failed, only %zd bytes sent", sent_size); + } + } + } + + void receive_backend() override + { + // this->ensure_open(); + auto &logger = *logging::get_logger("transport"); + logger.debug("start datagram receive backend"); + uint8_t *buffer = new uint8_t[buffer_size]; + while (!this->is_closed) + { + struct sockaddr_in addr; + socklen_t addr_len = sizeof(addr); + ssize_t recv_size = recvfrom(sockfd, buffer, buffer_size, 0, + (struct sockaddr *)&addr, &addr_len); + if (recv_size < 0) + { + logger.error("udp recv failed: %s", strerror(errno)); + continue; + } + logger.debug("receive data %zd", recv_size); + ssize_t pred_size = P::pred_size(buffer, recv_size); + if (pred_size < 0) + { + logger.error("invalid frame received"); + continue; + } + auto frame = P::make_frame(buffer, recv_size); + this->recv_que.Push(std::make_pair(frame, std::make_shared(this, addr, addr_len))); + } + + delete[] buffer; + } + +private: + static void resolve_hostname(const std::string& hostname, struct sockaddr_in& result) + { + if (hostname.empty()) + { + result.sin_addr.s_addr = INADDR_ANY; + return; + } + std::string hostname_str(hostname); + struct hostent *he = gethostbyname(hostname_str.c_str()); + if (he == nullptr) + { + auto &logger = *logging::get_logger("transport"); + logger.error("failed to resolve hostname %s: %s", hostname_str.c_str(), hstrerror(h_errno)); + throw std::runtime_error("Failed to resolve hostname"); + } + memcpy(&result.sin_addr, he->h_addr_list[0], he->h_length); + } + + typedef BaseTransport

super; + + int sockfd; + struct sockaddr_in bind_addr; + struct sockaddr_in connect_addr; + size_t buffer_size; +}; + +} + +namespace std { + template<> + struct hash { + size_t operator()(const transport::DatagramTransportToken &token) const + { + std::size_t hash1 = std::hash()(token); + std::size_t hash2 = std::hash()(token.addr.sin_addr.s_addr); + std::size_t hash3 = std::hash()(token.addr.sin_port); + hash1 ^= (hash2 + 0x9e3779b9 + (hash1 << 6) + (hash1 >> 2)); + hash1 ^= (hash3 + 0x9e3779b9 + (hash1 << 6) + (hash1 >> 2)); + return hash1; + } + }; +} +#endif \ No newline at end of file diff --git a/include/transport/unix_udp.hpp b/include/transport/unix_udp.hpp new file mode 100644 index 0000000..acff38b --- /dev/null +++ b/include/transport/unix_udp.hpp @@ -0,0 +1,230 @@ +#ifndef _INCLUDE_TRANSPORT_UNIX_UDP_ +#define _INCLUDE_TRANSPORT_UNIX_UDP_ + +#include +#include +#include +#include +#include +#include +#include +#include "base.hpp" + +#define TRANSPORT_UDP_BUFFER_SIZE 1024 + +namespace transport { + +class UnixDatagramTransportToken : public TransportToken { +public: + explicit UnixDatagramTransportToken(_transport_base* transport, const struct sockaddr_un& addr, socklen_t addr_len) + : TransportToken(transport), addr(addr), addr_len(addr_len) {} + + bool operator==(const TransportToken& other) const override + { + auto other_token = dynamic_cast(&other); + if (!other_token) + { + return false; + } + if (transport_ != other_token->transport_ || addr_len != other_token->addr_len) + { + return false; + } + return memcmp(&addr, &other_token->addr, addr_len) == 0; + } + +protected: + struct sockaddr_un addr; + socklen_t addr_len; + + friend std::hash; + template + friend class UnixDatagramTransport; +}; + +template +class UnixDatagramTransport : public BaseTransport

{ +public: + explicit UnixDatagramTransport(size_t buffer_size = TRANSPORT_UDP_BUFFER_SIZE) + : sockfd(-1), buffer_size(buffer_size) + { + memset(&bind_addr, 0, sizeof(bind_addr)); + memset(&connect_addr, 0, sizeof(connect_addr)); + bind_addr.sun_family = AF_UNIX; + connect_addr.sun_family = AF_UNIX; + } + UnixDatagramTransport(const std::string& local_addr, const std::string& remote_addr, size_t buffer_size = TRANSPORT_UDP_BUFFER_SIZE) + : UnixDatagramTransport(buffer_size) + { + set_sock_path(local_addr, bind_addr); + set_sock_path(remote_addr, connect_addr); + } + + ~UnixDatagramTransport() + { + close(); + } + + void open() override + { + auto& logger = *logging::get_logger("transport"); + if (this->is_open) + { + return; + } + else if (this->is_closed) + { + logger.info("reopen datagram transport"); + this->is_open = false; + this->is_closed = false; + } + + sockfd = socket(AF_UNIX, SOCK_DGRAM, 0); + if (sockfd < 0) { + logger.raise_from_errno("failed to create socket"); + } else { + logger.info("open socket fd %d", sockfd); + } + + super::open(); + + if (*bind_addr.sun_path) + { + _bind(); + } + } + + void close() override + { + if (this->is_open && !this->closed()) + { + auto &logger = *logging::get_logger("transport"); + logger.info("close socket fd %d", sockfd); + ::close(sockfd); + if (*bind_addr.sun_path) + unlink(bind_addr.sun_path); + } + super::close(); + } + + void bind(const std::string& address) + { + this->ensure_open(); + set_sock_path(address, bind_addr); + _bind(); + } + + void connect(const std::string& address) + { + set_sock_path(address, connect_addr); + auto& logger = *logging::get_logger("transport"); + logger[logging::LogLevel::INFO] << "connecting to " << address << std::endl; + } + +protected: + void send_backend() override + { + // this->ensure_open(); + auto& logger = *logging::get_logger("transport"); + logger.debug("start datagram send backend"); + while (!this->is_closed) + { + auto frame_pair = this->send_que.Pop(); + auto frame = frame_pair.first; + if (!P::frame_size(frame)) + continue; + auto token = dynamic_cast((frame_pair.second.get())); + struct sockaddr* addr = (struct sockaddr *)((token) ? &token->addr : &connect_addr); + socklen_t addr_len = (token) ? token->addr_len : sizeof(connect_addr); + + ssize_t sent_size = sendto(sockfd, P::frame_data(frame), P::frame_size(frame), 0, + addr, addr_len); + logger.debug("send data %zd", sent_size); + if (sent_size < 0) + { + logger.error("unix udp send failed: %s", strerror(errno)); + } + else if ((size_t)sent_size < P::frame_size(frame)) + { + logger.warn("sendto failed, only %zd bytes sent", sent_size); + } + } + } + + void receive_backend() override + { + // this->ensure_open(); + auto& logger = *logging::get_logger("transport"); + logger.debug("start datagram receive backend"); + uint8_t *buffer = new uint8_t[buffer_size]; + while (!this->is_closed) + { + struct sockaddr_un addr; + socklen_t addr_len = sizeof(addr); + ssize_t recv_size = recvfrom(sockfd, buffer, buffer_size, 0, + (struct sockaddr *)&addr, &addr_len); + if (recv_size < 0) + { + logger.error("unix udp recv failed: %s", strerror(errno)); + continue; + } + logger.debug("receive data %zd", recv_size); + ssize_t pred_size = P::pred_size(buffer, recv_size); + if (pred_size < 0) + { + logger.error("invalid frame received"); + continue; + } + auto frame = P::make_frame(buffer, recv_size); + this->recv_que.Push(std::make_pair(frame, std::make_shared(this, addr, addr_len))); + } + delete[] buffer; + } + +private: + static void set_sock_path(const std::string& path, struct sockaddr_un& result) + { + if (path.size() + 1 >= sizeof(result.sun_path)) + { + auto& logger = *logging::get_logger("transport"); + logger.fatal("socket path too long"); + throw std::runtime_error("socket path too long"); + } + strncpy(result.sun_path, path.data(), path.size()); + result.sun_path[path.size()] = '\0'; + } + + void _bind() + { + auto& logger = *logging::get_logger("transport"); + unlink(bind_addr.sun_path); + if (::bind(sockfd, (struct sockaddr *)&bind_addr, sizeof(bind_addr)) < 0) + { + logger.raise_from_errno("failed to bind socket"); + } + logger.info("listening on %s", bind_addr.sun_path); + } + + typedef BaseTransport

super; + + int sockfd; + struct sockaddr_un bind_addr; + struct sockaddr_un connect_addr; + size_t buffer_size; +}; + +} + +namespace std { + template<> + struct hash { + size_t operator()(const transport::UnixDatagramTransportToken &token) const + { + std::size_t hash1 = std::hash()(token); + std::size_t hash2 = std::hash()(token.addr.sun_path); + hash1 ^= (hash2 + 0x9e3779b9 + (hash1 << 6) + (hash1 >> 2)); + return hash1; + } + }; +} +#endif \ No newline at end of file diff --git a/scripts/unittest.py b/scripts/unittest.py new file mode 100755 index 0000000..853baf1 --- /dev/null +++ b/scripts/unittest.py @@ -0,0 +1,415 @@ +#!/usr/bin/python3 + +import os +import sys +from io import StringIO +from argparse import ArgumentParser +from enum import IntEnum +from glob import glob +from subprocess import run +from time import time +from traceback import print_exception +from typing import List, NoReturn, Optional + + +__version__ = "0.1.2" + + +if sys.stdout.isatty(): + TTY_COLOR_RED = "\033[31m" + TTY_COLOR_RED_BOLD = "\033[1;31m" + TTY_COLOR_GREEN = "\033[32m" + TTY_COLOR_GREEN_BOLD = "\033[1;32m" + TTY_COLOR_YELLOW = "\033[33m" + TTY_COLOR_YELLOW_BOLD = "\033[1;33m" + TTY_COLOR_CYAN_BOLD = "\033[1;36m" + TTY_COLOR_CLEAR = "\033[0m" + + TTY_COLUMNS_SIZE = os.get_terminal_size().columns +else: + TTY_COLOR_RED = "" + TTY_COLOR_RED_BOLD = "" + TTY_COLOR_GREEN = "" + TTY_COLOR_GREEN_BOLD = "" + TTY_COLOR_YELLOW = "" + TTY_COLOR_YELLOW_BOLD = "" + TTY_COLOR_CYAN_BOLD = "" + TTY_COLOR_CLEAR = "" + + TTY_COLUMNS_SIZE = 80 + + +def print_separator_ex(lc: str, title: str, color: str) -> None: + len_str = len(title) + + print(f"{TTY_COLOR_CLEAR}{color}", end='') # 设置颜色样式 + if len_str + 2 > TTY_COLUMNS_SIZE: + print(title) + else: + print(f" {title} ".center(TTY_COLUMNS_SIZE, lc)) + print(TTY_COLOR_CLEAR, end='') # 重置颜色为默认 + + +class CTestCaseStatus(IntEnum): + NOT_RUN = -1 + PASSED = 0 + ERROR = 1 + FAILED = 16 + SKIPPED = 32 + SETUP_FAILED = 17 + TEARDOWN_FAILED = 18 + + +class CTestCaseCounter: + __slots__ = ["total_count", "passed", "error", "failed", "skipped"] + + def __init__(self) -> None: + self.total_count = 0 + self.passed: 'set[CTestCase]' = set() + self.error: 'set[CTestCase]' = set() + self.failed: 'set[CTestCase]' = set() + self.skipped: 'set[CTestCase]' = set() + + def update(self, test_case: "CTestCase") -> None: + self.total_count += 1 + if test_case.status == CTestCaseStatus.PASSED: + self.passed.add(test_case) + elif test_case.status == CTestCaseStatus.ERROR: + self.error.add(test_case) + elif test_case.status == CTestCaseStatus.SKIPPED: + self.skipped.add(test_case) + elif test_case.status in [CTestCaseStatus.FAILED, CTestCaseStatus.SETUP_FAILED, CTestCaseStatus.TEARDOWN_FAILED]: + self.failed.add(test_case) + else: + raise ValueError(f"{test_case.status} is not a valid status for counter") + + def clone(self) -> "CTestCaseCounter": + counter = CTestCaseCounter() + counter.total_count = self.total_count + counter.passed = self.passed + counter.error = self.error + counter.failed = self.failed + counter.skipped = self.skipped + return counter + + def __add__(self, other: "CTestCaseCounter") -> "CTestCaseCounter": + counter = self.clone() + counter += other + return counter + + def __iadd__(self, other: "CTestCaseCounter") -> "CTestCaseCounter": + self.total_count += other.total_count + self.passed.update(other.passed) + self.error.update(other.error) + self.failed.update(other.failed) + self.skipped.update(other.skipped) + return self + + @property + def status(self) -> CTestCaseStatus: + if self.error: + return CTestCaseStatus.ERROR + elif self.failed: + return CTestCaseStatus.FAILED + elif self.skipped: + return CTestCaseStatus.SKIPPED + elif self.passed: + return CTestCaseStatus.PASSED + else: + return CTestCaseStatus.NOT_RUN + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + ss = StringIO() + ss.write(f"total: {self.total_count}") + if self.passed: + ss.write(f", passed: {len(self.passed)}") + elif self.skipped: + ss.write(f", skipped: {len(self.skipped)}") + elif self.failed: + ss.write(f", failed: {len(self.failed)}") + elif self.error: + ss.write(f", error: {len(self.error)}") + return ss.getvalue() + + +class CTestCase: + __slots__ = ["id", "name", "status", "file", "result", "error_info"] + + def __init__(self, id: int, name: str, file: "CTestCaseFile") -> None: + self.id = id + self.name = name + self.file = file + self.status = CTestCaseStatus.NOT_RUN + self.result = None + self.error_info = None + + def run(self, *, verbose: bool = False, capture: bool = True) -> CTestCaseStatus: + try: + sys.stdout.flush() + sys.stderr.flush() + self.result = run([self.file.path, "--unittest", str(self.id)], capture_output=capture) + except Exception: + self.status = CTestCaseStatus.ERROR + self.error_info = sys.exc_info() + if not capture: + print(TTY_COLOR_RED) + print_exception(*self.error_info) + print(TTY_COLOR_CLEAR) + except KeyboardInterrupt: + self.status = CTestCaseStatus.ERROR + self.error_info = sys.exc_info() + else: + code = self.result.returncode + if code in CTestCaseStatus.__members__.values(): + self.status = CTestCaseStatus(code) + else: + self.status = CTestCaseStatus.ERROR + if verbose: + self.print_status_verbose() + else: + self.print_status() + return self.status + + def print_status(self) -> None: + if self.status == CTestCaseStatus.PASSED: + print(f"{TTY_COLOR_GREEN}.{TTY_COLOR_CLEAR}", end='') + elif self.status in [CTestCaseStatus.FAILED, CTestCaseStatus.SETUP_FAILED, CTestCaseStatus.TEARDOWN_FAILED]: + print(f"{TTY_COLOR_RED}F{TTY_COLOR_CLEAR}", end='') + elif self.status == CTestCaseStatus.ERROR: + print(f"{TTY_COLOR_RED}E{TTY_COLOR_CLEAR}", end='') + elif self.status == CTestCaseStatus.SKIPPED: + print(f"{TTY_COLOR_YELLOW}s{TTY_COLOR_CLEAR}", end='') + else: + raise ValueError(f"invalid test case status: {self.status}") + + def print_status_verbose(self) -> None: + if self.status == CTestCaseStatus.PASSED: + print(f"{self.name} {TTY_COLOR_GREEN}PASSED{TTY_COLOR_CLEAR}") + elif self.status == CTestCaseStatus.FAILED: + print(f"{self.name} {TTY_COLOR_RED}FAILED{TTY_COLOR_CLEAR}") + elif self.status == CTestCaseStatus.ERROR: + print(f"{self.name} {TTY_COLOR_RED}ERROR{TTY_COLOR_CLEAR}") + elif self.status == CTestCaseStatus.SETUP_FAILED: + print(f"{self.name} {TTY_COLOR_RED}SETUP FAILED{TTY_COLOR_CLEAR}") + elif self.status == CTestCaseStatus.TEARDOWN_FAILED: + print(f"{self.name} {TTY_COLOR_RED}TEARDOWN FAILED{TTY_COLOR_CLEAR}") + elif self.status == CTestCaseStatus.SKIPPED: + print(f"{self.name} {TTY_COLOR_YELLOW}SKIPPED{TTY_COLOR_CLEAR}") + else: + raise ValueError(f"invalid test case status: {self.status}") + + def report(self) -> Optional[str]: + if self.status == CTestCaseStatus.PASSED: + return + elif self.status == CTestCaseStatus.SKIPPED: + return f"{TTY_COLOR_YELLOW}SKIPPED{TTY_COLOR_CLEAR} {self.name}" + elif self.status == CTestCaseStatus.ERROR: + if self.error_info: + print_separator_ex('_', self.name, TTY_COLOR_RED) + print(TTY_COLOR_RED, end='') + print_exception(*self.error_info) + print(TTY_COLOR_CLEAR, end='') + assert self.error_info[0] + if str(self.error_info[1]): + error = f"{self.error_info[0].__name__}: {self.error_info[1]}" + else: + error = f"{self.error_info[0].__name__}" + return f"{TTY_COLOR_RED}ERROR{TTY_COLOR_CLEAR} {self.name} - {error}" + else: + assert self.result + if self.result.stderr: + print_separator_ex('_', self.name, TTY_COLOR_RED) + print(TTY_COLOR_RED, end='') + print(self.result.stderr.decode("utf-8"), end='') + print(TTY_COLOR_CLEAR, end='') + if self.result.stdout: + print_separator_ex('-', "Captured stdout", '') + print(self.result.stdout.decode("utf-8"), end='') + return f"{TTY_COLOR_RED}ERROR{TTY_COLOR_CLEAR} {self.name} - RuntimeError ({self.result.returncode})" + elif self.status in [CTestCaseStatus.FAILED, CTestCaseStatus.SETUP_FAILED, CTestCaseStatus.TEARDOWN_FAILED]: + assert self.result + if self.result.stderr: + print_separator_ex('_', self.name, TTY_COLOR_RED) + print(TTY_COLOR_RED, end='') + print(self.result.stderr.decode("utf-8"), end='') + print(TTY_COLOR_CLEAR, end='') + if self.result.stdout: + if self.result.stderr: + print_separator_ex('-', "Captured stdout", '') + else: + print_separator_ex('_', self.name, '') + print(self.result.stdout.decode("utf-8", "replace"), end='') + if self.status == CTestCaseStatus.FAILED: + return f"{TTY_COLOR_RED}FAILED{TTY_COLOR_CLEAR} {self.name}" + elif self.status == CTestCaseStatus.SETUP_FAILED: + return f"{TTY_COLOR_RED}FAILED{TTY_COLOR_CLEAR} {self.name} - SetupError" + else: + return f"{TTY_COLOR_RED}FAILED{TTY_COLOR_CLEAR} {self.name} - TeardownError" + else: + raise ValueError(f"invalid test case status: {self.status}") + + +class CTestCaseFile: + __slots__ = ["path", "test_cases", "collect_result", "collect_error_info", "counter"] + + def __init__(self, path: str) -> None: + self.path = path + self.test_cases: List[CTestCase] = [] + self.collect_result = None + self.collect_error_info = None + + def collect(self) -> int: + try: + result = run([self.path, "--collect"], capture_output=True) + except Exception: + self.collect_error_info = sys.exc_info() + return 0 + + if result.returncode != 0: + self.collect_result = result + return 0 + for id, name in enumerate(result.stdout.decode("ascii").split()): + self.test_cases.append(CTestCase(id, name, self)) + return len(self.test_cases) + + def run(self, verbose: bool = False, capture: bool = True) -> CTestCaseCounter: + counter = CTestCaseCounter() + if not verbose: + print(self.path, end=' ') + for i in self.test_cases: + if verbose: + print(self.path, end='::') + i.run(verbose=verbose, capture=capture) + counter.update(i) + if not verbose: + print() + return counter + + def report(self) -> List[str]: + if self.collect_result is not None: + print_separator_ex('_', f"ERROR collecting {self.path}", TTY_COLOR_RED_BOLD) + print(TTY_COLOR_RED, end='') + print(self.collect_result.stderr.decode()) + print(TTY_COLOR_CLEAR, end='') + if self.collect_result.stdout: + print_separator_ex('-', "Captured stdout", '') + print(self.collect_result.stdout.decode()) + return [f"ERROR {self.path} - CollectError ({self.collect_result.returncode})"] + elif self.collect_error_info is not None: + assert self.collect_error_info[0] + print_separator_ex('_', f"ERROR collecting {self.path}", TTY_COLOR_RED_BOLD) + print(TTY_COLOR_RED, end='') + print_exception(*self.collect_error_info) + print(TTY_COLOR_CLEAR, end='') + return [f"ERROR{TTY_COLOR_CLEAR} {self.path} - {self.collect_error_info[0].__name__}: {self.collect_error_info[1]}"] + return list(filter(None, (i.report() for i in self.test_cases))) + + @property + def error(self) -> bool: + return self.collect_result is not None or self.collect_error_info is not None + + +def report_collect_error(start_time: float, *error_files: CTestCaseFile) -> NoReturn: + print_separator_ex('=', "ERRORS", '') + summary = [] + for i in error_files: + summary.extend(i.report()) + + print_separator_ex('=', "short test summary info", TTY_COLOR_CYAN_BOLD) + for i in summary: + print(i) + print_separator_ex('!', f"Interrupted: {len(error_files)} error during collection", '') + cur_time = time() + print_separator_ex('=', f"{len(summary)} error in {cur_time - start_time:.2f}s", TTY_COLOR_RED_BOLD) + sys.exit(1) + + +def report_no_ran(start_time: float) -> NoReturn: + cur_time = time() + print_separator_ex('=', f"no tests ran in {cur_time - start_time:.2f}s", TTY_COLOR_YELLOW) + sys.exit() + +def report(start_time: float, counter: CTestCaseCounter, *, show_capture: bool = True) -> NoReturn: + cur_time = time() + summary = [] + if counter.error: + if show_capture: + print_separator_ex('=', "ERRORS", '') + for i in counter.error: + summary.append(i.report()) + if counter.failed: + if show_capture: + print_separator_ex('=', "FAILURES", TTY_COLOR_RED) + for i in counter.failed: + summary.append(i.report()) + if counter.skipped: + for i in counter.skipped: + summary.append(i.report()) + + if summary: + print_separator_ex('=', "short test summary info", TTY_COLOR_CYAN_BOLD) + for i in summary: + print(i) + + if counter.status in [CTestCaseStatus.FAILED, CTestCaseStatus.ERROR]: + color = TTY_COLOR_RED_BOLD + elif counter.status == CTestCaseStatus.SKIPPED: + color = TTY_COLOR_YELLOW_BOLD + else: + color = TTY_COLOR_GREEN_BOLD + + print_separator_ex('=', f"{counter} in {cur_time - start_time:.2f}s", color) + if counter.status in [CTestCaseStatus.FAILED, CTestCaseStatus.ERROR]: + sys.exit(1) + sys.exit() + + +if __name__ == "__main__": + parser = ArgumentParser("unittest", description="Run unit tests") + parser.add_argument("path", nargs='+', help="path to the test directory or file", default="./test_*") + parser.add_argument("-V", "--version", action="version", version=__version__) + parser.add_argument("-v", "--verbose", action="store_true", help="verbose output") + parser.add_argument("-s", "--no-capture", action="store_false", help="capture stdout and stderr") + + namespace = parser.parse_args() + + print_separator_ex("=", "test session starts", "") + print(f"platform: {sys.platform} -- Python {sys.version.split(' ')[0]}, c_unittest {__version__}") + print(f"rootdir: {os.getcwd()}") + + files: List[CTestCaseFile] = [] + total = 0 + error_files = [] + start_time = time() + + paths = [] + for p in namespace.path: + if '*' in p: + paths.extend(glob(p, recursive=True)) + elif os.path.isfile(p): + paths.append(p) + for p in paths: + f = CTestCaseFile(p) + total += f.collect() + if f.error: + error_files.append(f) + else: + files.append(f) + + if error_files: + print(f"collected {total} items / {len(error_files)} error\n") + report_collect_error(start_time, *error_files) + else: + print(f"collected {total} items\n") + + if total == 0: + report_no_ran(start_time) + + counter = CTestCaseCounter() + for f in files: + counter += f.run(verbose=namespace.verbose, capture=namespace.no_capture) + + report(start_time, counter, show_capture=namespace.no_capture) diff --git a/scripts/unittest.sh b/scripts/unittest.sh new file mode 100755 index 0000000..f02157c --- /dev/null +++ b/scripts/unittest.sh @@ -0,0 +1,16 @@ +width=$(stty size | awk '{print $2}') +# run add test_* in bin/* +for i in `ls bin/test_*` +do + echo + printf '%*s\n' "$width" '' | tr ' ' '*' + + echo run test file: $i + $i $@ + if [ $? -ne 0 ]; then + echo "\033[0m\033[1;31mtest $i failed\033[0m" + fi + echo +done + +# printf '%*s\n' "$width" '' | tr ' ' '*' diff --git a/src/logging.cpp b/src/logging.cpp new file mode 100644 index 0000000..e917d60 --- /dev/null +++ b/src/logging.cpp @@ -0,0 +1,248 @@ +#include +#include +#include +#include +#include +#include +#include +#include "logging/interface.h" +#include "logging/logger.hpp" + +using namespace logging; + +std::unique_ptr& logging::_get_global_logger() +{ + static std::unique_ptr global_logger(new Logger(LogLevel::INFO)); + return global_logger; +} + +Logger& logging::get_global_logger() +{ + return *_get_global_logger(); +} + +void logging::set_global_logger(std::unique_ptr&& logger) +{ + auto& global_logger = _get_global_logger(); + global_logger->move_children_to(*logger); + global_logger = std::move(logger); +} + +Logger::Logger(const std::string& name, Level level, Logger* parent) + : _parent(parent), _level(level) +{ + if (parent && !parent->_name.empty()) { + _name = parent->_name + namesep + name; + } else { + _name = name; + } +} + +void Logger::vlog(Level level, const char* fmt, va_list args) +{ + if (level < this->level()) + return; + + va_list args2; + va_copy(args2, args); + int size = vsnprintf(nullptr, 0, fmt, args2); + va_end(args2); + + char* buffer = new char[size + 1]; + vsnprintf(buffer, size + 1, fmt, args); + try { + log_message(level, buffer); + } catch (...) { + delete[] buffer; + throw; + } + delete[] buffer; +} + +void Logger::log_message(Level level, const std::string& msg) +{ + if (level < this->level()) + return; + + auto record = Record(name(), level, msg); + log_record(record); +} + +void Logger::log_record(const Record& record) +{ + if (record.level < level()) { + return; + } + if (!_streams.empty()) { + for (auto& stream : _streams) { + write_record(*stream, record); + } + } else if (_parent) { + _parent->log_record(record); + } else { + write_record(default_stream, record); + } +} + +void Logger::write_record(std::ostream& os, const Record& record) +{ + auto now = record.time; + + // 转换为时间_t 类型 + std::time_t now_c = std::chrono::system_clock::to_time_t(now); + + // 获取毫秒部分 + auto milliseconds = std::chrono::duration_cast(now.time_since_epoch()) % 1000; + + // 格式化时间 + std::stringstream ss; + std::tm *tm_now = std::localtime(&now_c); + ss << std::put_time(tm_now, "%Y-%m-%d %H:%M:%S") + << ',' << std::setfill('0') << std::setw(3) << milliseconds.count(); + if (record.name.empty()) { + os << ss.str() << " [" << record.level << "] " << record.msg << std::endl; + } else { + os << ss.str() << " [" << record.name << "] [" << record.level << "] " << record.msg << std::endl; + } +} + +LoggerOStream Logger::operator[](Logger::Level level) +{ + return LoggerOStream(*this, level); +} + +LoggerOStream Logger::operator[](int level) +{ + return (*this)[static_cast(level)]; +} + +void Logger::move_children_to(Logger& target_logger) +{ + if (logger_cache.empty()) return; + for (auto it = logger_cache.begin(); it != logger_cache.end(); ) { + // Transfer ownership of the child logger to the target_logger + it->second->_parent = &target_logger; + target_logger.logger_cache.insert(std::move(*it)); + it = logger_cache.erase(it); // Remove from current logger's cache + } +} + +LoggerOStream::LoggerOStream(Logger& logger, Logger::Level level) + : std::ostream(&_streamBuf), _streamBuf(logger, level) {} + +LoggerOStream::LoggerOStream(LoggerOStream&& loggerStream) + : std::ostream(std::move(loggerStream)), _streamBuf(std::move(loggerStream._streamBuf)) +{} + +LoggerStreamBuf::LoggerStreamBuf(Logger& logger, Logger::Level level) + : _logger(logger), _level(level) {} + +int LoggerStreamBuf::overflow(int c) { + if (c != EOF) { + // 处理换行符 + if (c == '\n') { + flush_line(); // 刷新当前行 + } else { + _lineBuffer += static_cast(c); // 将字符添加到行缓存 + } + } + return c; +} + +std::streamsize LoggerStreamBuf::xsputn(const char* s, std::streamsize n) { + std::streamsize lineStart = 0; // 记录行的起始位置 + + for (std::streamsize i = 0; i < n; ++i) { + if (s[i] == '\n') { + // 将当前行缓存中的内容添加到行缓存 + _lineBuffer.append(s + lineStart, i - lineStart + 1); + flush_line(); // 刷新当前行 + lineStart = i + 1; // 更新行的起始位置 + } + } + + // 处理剩余的字符 + if (lineStart < n) { + _lineBuffer.append(s + lineStart, n - lineStart); + } + return n; // 返回写入的字符总数 +} + +int LoggerStreamBuf::sync() { + return 0; // 不需要同步 +} + +void LoggerStreamBuf::flush_line() { + if (!_lineBuffer.empty()) { + // 调用 Logger 的 log 方法 + _logger.log_message(_level, _lineBuffer); + _lineBuffer.clear(); // 清空行缓存 + } +} + +Record::Record(const std::string& name, Logger::Level level, const std::string& msg) + : name(name), time(std::chrono::system_clock::now()), level(level), msg(msg) +{} + +LogLevel logging::str2level(const char* level) +{ + if (!level) return Logger::Level::INFO; + else if (strcasecmp(level, "debug") == 0) return Logger::Level::DEBUG; + else if (strcasecmp(level, "info") == 0) return Logger::Level::INFO; + else if (strcasecmp(level, "warn") == 0) return Logger::Level::WARN; + else if (strcasecmp(level, "error") == 0) return Logger::Level::ERROR; + else if (strcasecmp(level, "fatal") == 0) return Logger::Level::FATAL; + log_error("Unknown log level: %s", level); + return Logger::Level::UNKNOWN; +} + +std::ostream& logging::operator<<(std::ostream& stream, const LogLevel level) +{ + switch (level) { + case Logger::Level::DEBUG: + return stream << "DEBUG"; + case Logger::Level::INFO: + return stream << "INFO"; + case Logger::Level::WARN: + return stream << "WARN"; + case Logger::Level::ERROR: + return stream << "ERROR"; + case Logger::Level::FATAL: + return stream << "FATAL"; + default: + return stream << "UNKNOWN"; + } +} + +void log_init(const char* level) +{ + if (level == NULL || *level == '\0') { + level = "info"; + } else if (strcasecmp(level, "env") == 0 || strcasecmp(level, "auto") == 0) { + level = getenv("LOG_LEVEL"); + } + set_global_logger(std::unique_ptr(new Logger(str2level(level)))); +} + +int log_level() +{ + return static_cast(_get_global_logger()->level()); +} + +void log_set_level(int level) +{ + _get_global_logger()->set_level(static_cast(level)); +} + +void log_log(int level, const char* fmt, ...) +{ + va_list args; + va_start(args, fmt); + log_vlog(level, fmt, args); + va_end(args); +} + +void log_vlog(int level, const char* fmt, va_list args) +{ + _get_global_logger()->vlog(static_cast(level), fmt, args); +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 0000000..e0fe279 --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,17 @@ +set(TESTS_DIR .) +set(TEST_COMMON_SOURCE "${TESTS_DIR}/c_testcase.cpp") +file(GLOB_RECURSE TEST_FILES "${TESTS_DIR}/test_*.cpp") + +foreach(TEST_FILE ${TEST_FILES}) + get_filename_component(TEST_NAME ${TEST_FILE} NAME_WE) + add_executable(${TEST_NAME} ${TEST_FILE} ${TEST_COMMON_SOURCE} ${SRC_LIST}) + list(APPEND TEST_EXECUTABLES "${EXECUTABLE_OUTPUT_PATH}/${TEST_NAME}") +endforeach() + +# message(STATUS "Test files: ${TEST_FILES}") +# message(STATUS "Test executables: ${TEST_EXECUTABLES}") + +add_custom_target(test + COMMAND ${CMAKE_SOURCE_DIR}/scripts/unittest.py ${TEST_EXECUTABLES} + DEPENDS ${TEST_EXECUTABLES} +) diff --git a/tests/c_testcase.cpp b/tests/c_testcase.cpp new file mode 100644 index 0000000..a1d6e35 --- /dev/null +++ b/tests/c_testcase.cpp @@ -0,0 +1,289 @@ +#include +#include +#include +#include +#include +#include +#include +#include "c_testcase.h" + +#ifdef _WIN32 +#include +#else +#include +#include +#endif + +using namespace std; + +#define TEST_CASE_STATUS_PASSED 0 +#define TEST_CASE_STATUS_SKIPPED 32 +#define TEST_CASE_STATUS_FAILED 16 +#define TEST_CASE_STATUS_SETUP_FAILED 17 +#define TEST_CASE_STATUS_TEARDOWN_FAILED 18 + +#define MAX_TESTCASE 64 + +#define DEFAULT_TTY_COL_SIZE 80 + +struct TestCase { + const char* name; + test_case func; +}; + +static TestCase test_cases[MAX_TESTCASE]; +static int test_case_total = 0; + +static interactive_func interactive = NULL; +static context_func setup = NULL; +static context_func teardown = NULL; + +static atomic_bool testcase_running; +static jmp_buf testcase_env; +static std::thread::id testcase_thread_id; +static int testcase_exit_code = 0; + +int _add_test_case(const char* name, test_case func) { + if (test_case_total == MAX_TESTCASE) { + fprintf(stderr, "too many test case\n"); + exit(1); + } + TestCase* tc = &(test_cases[test_case_total++]); + tc->name = name; + tc->func = func; + return 0; +} + +int _set_interactive(interactive_func func) { + interactive = func; + return 0; +} + +int _set_setup(context_func func) { + setup = func; + return 0; +} + +int _set_teardown(context_func func) { + teardown = func; + return 0; +} + +[[noreturn]] void test_case_abort(int exit_code) { + auto tid = std::this_thread::get_id(); + if (!testcase_running.load(std::memory_order_acquire) || tid != testcase_thread_id) { + exit(exit_code); + } + testcase_exit_code = exit_code; + longjmp(testcase_env, 1); +} + +static inline int get_tty_col(int fd) { +#ifdef _WIN32 + // Windows + HANDLE hConsole = (HANDLE)_get_osfhandle(fd); + if (hConsole == INVALID_HANDLE_VALUE) { + return DEFAULT_TTY_COL_SIZE; // 错误处理 + } + + CONSOLE_SCREEN_BUFFER_INFO csbi; + if (!GetConsoleScreenBufferInfo(hConsole, &csbi)) { + return DEFAULT_TTY_COL_SIZE; // 错误处理 + } + return csbi.srWindow.Right - csbi.srWindow.Left + 1; // 计算列数 +#else + // POSIX (Linux, macOS, etc.) + struct winsize size; + if (ioctl(fd, TIOCGWINSZ, &size) == -1) { + return DEFAULT_TTY_COL_SIZE; // 错误处理 + } + return size.ws_col; // 返回列数 +#endif +} + +static __inline void print_separator(char lc) { + int size = get_tty_col(STDOUT_FILENO); + for (int i = 0; i < size; ++i) { + putchar(lc); + } +} + +static __inline void print_separator_ex(char lc, const char* str, const char* color) { + int size = get_tty_col(STDOUT_FILENO); + int len = strlen(str); + printf("\033[0m%s", color); // 设置颜色 + if(len > size) { + printf("%s\n", str); + } else { + int pad = (size - len - 2) / 2; + for(int i = 0; i < pad; i++) { + putchar(lc); + } + printf(" %s ", str); + for(int i = 0; i < pad; i++) { + putchar(lc); + } + if((size - len) % 2) putchar(lc); + putchar('\n'); + } + printf("\033[0m"); // 重置颜色 +} + +static int collect_testcase() { + for (int i = 0; i < test_case_total; ++i) { + puts(test_cases[i].name); + } + return 0; +} + +static TestCase* get_test_case(const char* name) { + TestCase* tc = NULL; + if (*name >= '0' && *name <= '9') { + int id = atoi(name); + if (id >= 0 && id < test_case_total) { + tc = &(test_cases[id]); + } + } else { + for (int i = 0; i < test_case_total; ++i) { + if (strcmp(test_cases[i].name, name) == 0) { + tc = &(test_cases[i]); + break; + } + } + } + return tc; +} + +static int run_test_case_func(test_case func) { + bool running = false; + if (!testcase_running.compare_exchange_strong(running, true, std::memory_order_acq_rel)) { + cerr << "test case is running" << endl; + return 1; + } + if (setjmp(testcase_env)) { + return testcase_exit_code; + } + testcase_thread_id = std::this_thread::get_id(); + int ret = func(); + running = true; + if (!testcase_running.compare_exchange_strong(running, false, std::memory_order_acq_rel)) { + cerr << "test case is not running" << endl; + return 1; + } + return ret; +} + +static int unittest_testcase(TestCase* tc) { + assert(tc != NULL); + if (setup && setup(tc->name)) { + return TEST_CASE_STATUS_SETUP_FAILED; + } + int ret = run_test_case_func(tc->func); + if (teardown && teardown(tc->name)) { + return TEST_CASE_STATUS_TEARDOWN_FAILED; + } + + if (ret == SKIP_RET_NUMBER) { + return TEST_CASE_STATUS_SKIPPED; + } else if (ret != 0) { + return TEST_CASE_STATUS_FAILED; + } else { + return 0; + } +} + +int main(int argc, const char** argv) { + for (int i = 1; i < argc; ++i) { + const char* arg = argv[i]; + if (strcmp(arg, "-h") == 0 || strcmp(arg, "--help") == 0) { + printf("usage: %s [-c] [-u NAME] [-h]\n", argv[0]); + printf("\nOptions:\n"); + printf(" -i, --interactive: run in interactive mode.\n"); + printf(" -c, --collect: list all test cases.\n"); + printf(" -u, --unittest: run a single test case.\n"); + printf(" -h, --help: show the help text.\n"); + return 0; + } + if (strcmp(arg, "-i") == 0 || strcmp(arg, "--interactive") == 0) { + if (interactive) + return interactive(argc, argv); + else { + cout << "interactive mode is not supported" << endl; + return 1; + } + } + else if (strcmp(arg, "-c") == 0 || strcmp(arg, "--collect") == 0) { + return collect_testcase(); + } + else if (strcmp(arg, "-u") == 0 || strcmp(arg, "--unittest") == 0) { + const char* name = NULL; + if (i + 1 < argc) { + name = argv[++i]; + } else { + cout << "--unittest require an argument" << endl; + return 2; + } + TestCase* tc = get_test_case(name); + if (tc == NULL) { + cout << "test case " << name << " not found" << endl; + return 1; + } + return unittest_testcase(tc); + } else { + if (interactive) + return interactive(argc, argv); + else { + cout << "unknown argument '" << arg << "'" << endl; + return 1; + } + } + } + + int total = test_case_total; + int passed = 0; + int failed = 0; + int skipped = 0; + + for (int i = 0; i < total; ++i) { + print_separator('-'); + TestCase* tc = &test_cases[i]; + cout << "running " << tc->name << endl; + switch (unittest_testcase(tc)) { + case 0: + cout << "\033[0m\033[1;32mtest case \"" << tc->name << "\" passed\033[0m" << endl; + passed++; + break; + case TEST_CASE_STATUS_SKIPPED: + cout << "\033[0m\033[1;33mtest case \"" << tc->name << "\" skipped\033[0m" << endl; + skipped++; + break; + case TEST_CASE_STATUS_SETUP_FAILED: + cout << "\033[0m\033[1;31msetup \"" << tc->name << "\" failed\033[0m" << endl; + failed++; + break; + case TEST_CASE_STATUS_TEARDOWN_FAILED: + cout << "\033[0m\033[1;31mteardown \"" << tc->name << "\" failed\033[0m" << endl; + failed++; + break; + default: + cout << "\033[0m\033[1;31mtest case \"" << tc->name << "\" failed\033[0m" << endl; + failed++; + break; + } + } + + stringstream ss; + ss << "total: " << total << ", passed: " << passed << ", failed: " << failed << ", skipped: " << skipped; + string sum = ss.str(); + + const char* color; + if (failed) + color = "\033[1;31m"; + else if (skipped) + color = "\033[1;33m"; + else + color = "\033[1;32m"; + + print_separator_ex('=', sum.c_str(), color); + return 0; +} \ No newline at end of file diff --git a/tests/c_testcase.h b/tests/c_testcase.h new file mode 100644 index 0000000..295d42f --- /dev/null +++ b/tests/c_testcase.h @@ -0,0 +1,203 @@ +#ifndef _INCLUDE_C_TESTCASE_ +#define _INCLUDE_C_TESTCASE_ + +#include +#include +#include + +#ifdef __cplusplus +#include + +extern "C"{ +#endif + +#define SKIP_RET_NUMBER (*((const int*)"SKIP")) +#define SKIP_TEST do { return SKIP_RET_NUMBER; } while (0) +#define END_TEST do { return 0; } while (0) + +#define TEST_CASE(name) \ + int name (); \ + int _tc_tmp_ ## name = _add_test_case(# name, name); \ + int name () +#define INTERACTIVE \ + int interactive_impl(int argc, const char** argv); \ + int _tc_s_tmp_interactive = _set_interactive(interactive_impl); \ + int interactive_impl(int argc, const char** argv) +#define SETUP \ + int setup_impl(const char* test_case_name); \ + int _tc_s_tmp_setup = _set_setup(setup_impl); \ + int setup_impl(const char* test_case_name) +#define TEARDOWN \ + int teardown_impl(const char* test_case_name); \ + int _tc_s_tmp_teardown = _set_teardown(teardown_impl); \ + int teardown_impl(const char* test_case_name) + +#undef assert +#ifdef __cplusplus +#define assert(expr) do { \ + if (!(static_cast (expr))) { \ + ::std::cout << "assert failed: " #expr "\n" << ::std::endl; \ + ::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\ + test_case_abort(1); \ + } \ +} while (0) +#define assert_eq(expr1, expr2) do { \ + if ((expr1) != (expr2)) { \ + ::std::cout << "assert failed: " #expr1 " == " #expr2 "\n" << ::std::endl; \ + ::std::cout << "\t#0: " << (expr1) << ::std::endl; \ + ::std::cout << "\t#1: " << (expr2) << ::std::endl; \ + ::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\ + test_case_abort(1); \ + } \ +} while (0) +#define assert_ne(expr1, expr2) do { \ + if ((expr1) == (expr2)) { \ + ::std::cout << "assert failed: " #expr1 " != " #expr2 "\n" << ::std::endl; \ + ::std::cout << "\t#0: " << (expr1) << ::std::endl; \ + ::std::cout << "\t#1: " << (expr2) << ::std::endl; \ + ::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\ + test_case_abort(1); \ + } \ +} while (0) +#define assert_gt(expr1, expr2) do { \ + if ((expr1) <= (expr2)) { \ + ::std::cout << "assert failed: " #expr1 " > " #expr2 "\n" << ::std::endl; \ + ::std::cout << "\t#0: " << (expr1) << ::std::endl; \ + ::std::cout << "\t#1: " << (expr2) << ::std::endl; \ + ::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\ + test_case_abort(1); \ + } \ +} while (0) +#define assert_ls(expr1, expr2) do { \ + if ((expr1) >= (expr2)) { \ + ::std::cout << "assert failed: " #expr1 " < " #expr2 "\n" << ::std::endl; \ + ::std::cout << "\t#0: " << (expr1) << ::std::endl; \ + ::std::cout << "\t#1: " << (expr2) << ::std::endl; \ + ::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\ + test_case_abort(1); \ + } \ +} while (0) +#define assert_ge(expr1, expr2) do { \ + if ((expr1) < (expr2)) { \ + ::std::cout << "assert failed: " #expr1 " >= " #expr2 "\n" << ::std::endl; \ + ::std::cout << "\t#0: " << (expr1) << ::std::endl; \ + ::std::cout << "\t#1: " << (expr2) << ::std::endl; \ + ::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\ + test_case_abort(1); \ + } \ +} while (0) +#define assert_le(expr1, expr2) do { \ + if ((expr1) > (expr2)) { \ + ::std::cout << "assert failed: " #expr1 " <= " #expr2 "\n" << ::std::endl; \ + ::std::cout << "\t#0: " << (expr1) << ::std::endl; \ + ::std::cout << "\t#1: " << (expr2) << ::std::endl; \ + ::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\ + test_case_abort(1); \ + } \ +} while (0) +#else +#define assert(expr) do { \ + if (!((bool)(expr))) { \ + printf("assert failed: " #expr "\n"); \ + printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\ + test_case_abort(1); \ + } \ +} while (0) +#endif +#define assert_i32_eq(expr1, expr2) do { \ + if ((int32_t)(expr1) != (uint32_t)(expr2)) { \ + printf("assert failed: " #expr1 " == " #expr2 "\n"); \ + printf("\t\t%d\t!=\t%d\n", (int32_t)(expr1), (int32_t)(expr2)); \ + printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\ + test_case_abort(1); \ + } \ +} while (0) +#define assert_i32_ne(expr1, expr2) do { \ + if ((int32_t)(expr1) == (int32_t)(expr2)) { \ + printf("assert failed: " #expr1 " != " #expr2 "\n"); \ + printf("\t\t%d\t==\t%d\n", (int32_t)(expr1), (int32_t)(expr2)); \ + printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\ + test_case_abort(1); \ + } \ +} while (0) +#define assert_i64_eq(expr1, expr2) do { \ + if ((expr1) != (expr2)) { \ + printf("assert failed: " #expr1 " == " #expr2 "\n"); \ + printf("\t\t%lld\t!=\t%lld\n", (int64_t)(expr1), (int64_t)(expr2)); \ + printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\ + test_case_abort(1); \ + } \ +} while (0) +#define assert_i64_ne(expr1, expr2) do { \ + if ((expr1) == (expr2)) { \ + printf("assert failed: " #expr1 " != " #expr2 "\n"); \ + printf("\t\t%lld\t==\t%lld\n", (int64_t)(expr1), (int64_t)(expr2)); \ + printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\ + test_case_abort(1); \ + } \ +} while (0) +#define assert_str_eq(expr1, expr2) do { \ + if (strcmp((expr1), (expr2)) != 0) { \ + printf("assert failed: " #expr1 " == " #expr2 "\n"); \ + printf("\t#0: %s\n", (expr1)); \ + printf("\t#1: %s\n", (expr2)); \ + printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\ + test_case_abort(1); \ + } \ +} while (0) +#define assert_str_ne(expr1, expr2) do { \ + if (strcmp((expr1), (expr2)) == 0) { \ + printf("assert failed: " #expr1 " != " #expr2 "\n"); \ + printf("\t#0: %s\n", (expr1)); \ + printf("\t#1: %s\n", (expr2)); \ + printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\ + test_case_abort(1); \ + } \ +} while (0) +#define assert_mem_eq(expr1, expr2, size) do { \ + if (memcmp((expr1), (expr2), (size)) != 0) { \ + printf("assertion failed: %s == %s\n", #expr1, #expr2); \ + printf("\t#0: "); \ + for (size_t i = 0; i < (size); i++) { \ + printf("%02X", ((uint8_t*)(expr1))[i]); \ + } \ + printf("\n\t#1: "); \ + for (size_t i = 0; i < (size); i++) { \ + printf("%02X", ((uint8_t*)(expr2))[i]); \ + } \ + printf("\nfile: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\ + test_case_abort(1); \ + } \ +} while (0) +#define assert_mem_ne(expr1, expr2, size) do { \ + if (memcmp((expr1), (expr2), (size)) == 0) { \ + printf("assertion failed: %s != %s\n", #expr1, #expr2); \ + printf("\t#0: "); \ + for (int i = 0; i < (size); i++) { \ + printf("%02X", ((uint8_t*)(expr1))[i]); \ + } \ + printf("\n\t#1: "); \ + for (int i = 0; i < (size); i++) { \ + printf("%02X", ((uint8_t*)(expr2))[i]); \ + } \ + printf("\nfile: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\ + test_case_abort(1); \ + } \ +} while (0) + + +typedef int (*test_case)(); +typedef int (*interactive_func)(int argc, const char** argv); +typedef int (*context_func)(const char* name); + +int _add_test_case(const char* name, test_case func); +int _set_interactive(interactive_func func); +int _set_setup(context_func func); +int _set_teardown(context_func func); +[[noreturn]] void test_case_abort(int exit_code); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/tests/test_base.cpp b/tests/test_base.cpp new file mode 100644 index 0000000..0eefc7a --- /dev/null +++ b/tests/test_base.cpp @@ -0,0 +1,40 @@ +#include "transport/base.hpp" +#include "transport/protocol.hpp" +#include "c_testcase.h" + +using namespace transport; + +class TestTransport: public BaseTransport { +protected: + void send_backend() override { + while (!is_closed) { + DataPair pair = send_que.Pop(); + recv_que.Push(std::move(pair)); + } + } + + void receive_backend() override {} +}; + +const int timeout = 3; + +TEST_CASE(test_init) { + TestTransport t; + END_TEST; +} + +TEST_CASE(test_transport) { + TestTransport t; + assert(!t.closed()); + + t.open(); + t.send(std::vector(10, 1)); + auto data_pair = t.receive(std::chrono::seconds(timeout)); + assert_eq(data_pair.first.size(), 10); + assert_eq(data_pair.first[1], 1); + assert(!data_pair.second); + + t.close(); + assert(t.closed()); + END_TEST; +} diff --git a/tests/test_datagram.cpp b/tests/test_datagram.cpp new file mode 100644 index 0000000..76e4b53 --- /dev/null +++ b/tests/test_datagram.cpp @@ -0,0 +1,40 @@ +#include "transport/udp.hpp" +#include "transport/protocol.hpp" +#include "c_testcase.h" + +using namespace transport; + +TEST_CASE(test_init) { + DatagramTransport transport; + transport.open(); + END_TEST; +} + +TEST_CASE(test_send_recv) { + DatagramTransport transport_server; + transport_server.open(); + transport_server.bind("127.0.0.1", 12345); + + DatagramTransport transport_client; + transport_client.open(); + transport_client.connect("127.0.0.1", 12345); + transport_client.send(std::vector{0x01, 0x02, 0x03, 0x04}); + + auto [frame, token] = transport_server.receive(std::chrono::seconds(3)); + assert_eq(frame.size(), 4); + assert(token); + assert_eq(frame[0], 0x01); + assert_eq(frame[1], 0x02); + assert_eq(frame[2], 0x03); + assert_eq(frame[3], 0x04); + + transport_server.send(std::vector{0x04, 0x03, 0x02, 0x01}, token); + auto [frame2, token2] = transport_client.receive(std::chrono::seconds(3)); + assert_eq(frame2.size(), 4); + assert(token2); + assert_eq(frame2[0], 0x04); + assert_eq(frame2[1], 0x03); + assert_eq(frame2[2], 0x02); + assert_eq(frame2[3], 0x01); + END_TEST; +} diff --git a/tests/test_serial_port.cpp b/tests/test_serial_port.cpp new file mode 100644 index 0000000..69db0ce --- /dev/null +++ b/tests/test_serial_port.cpp @@ -0,0 +1,66 @@ +#include +#include +#include +#include "c_testcase.h" +#include "transport/serial_port.hpp" +#include "transport/protocol.hpp" + +using namespace transport; + +int master_fd = -1; +int slave_fd = -1; + +const int timeout = 3; + + +SETUP { + if (openpty(&master_fd, &slave_fd, NULL, NULL, NULL) < 0) { + throw std::runtime_error("openpty failed"); + } + fcntl(master_fd, F_SETFL, fcntl(master_fd, F_GETFL) | O_NONBLOCK); + fcntl(slave_fd, F_SETFL, fcntl(slave_fd, F_GETFL) | O_NONBLOCK); + return 0; +} + +TEARDOWN { + close(master_fd); + close(slave_fd); + master_fd = -1; + slave_fd = -1; + return 0; +} + +TEST_CASE(test_pty) { + char buffer[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + ssize_t ret; + ret = write(slave_fd, buffer, 10); + assert_eq(ret, 10); + ret = read(master_fd, buffer, 10); + assert_eq(ret, 10); + for (int i = 0; i < 10; i++) { + assert_eq(buffer[i], i); + } + END_TEST; +} + +TEST_CASE(test_serial_port) { + SerialPortTransport t1(master_fd); + SerialPortTransport t2(slave_fd); + + assert(!t1.closed()); + assert(!t2.closed()); + std::pair> data_pair; + std::thread ([&] { + t2.open(); + t2.send(std::vector(10, 2)); + }).detach(); + t1.open(); + data_pair = t1.receive(std::chrono::seconds(timeout)); + assert_eq(data_pair.first.size(), 10); + assert_eq(data_pair.first[0], 2); + assert(data_pair.second); + assert_eq(data_pair.second->transport(), &t1); + t1.close(); + assert(t1.closed()); + END_TEST; +} diff --git a/tests/test_unix_datagram.cpp b/tests/test_unix_datagram.cpp new file mode 100644 index 0000000..ff84a89 --- /dev/null +++ b/tests/test_unix_datagram.cpp @@ -0,0 +1,38 @@ +#include "transport/unix_udp.hpp" +#include "transport/protocol.hpp" +#include "c_testcase.h" + +using namespace transport; + +TEST_CASE(test_init) { + UnixDatagramTransport transport; + transport.open(); + END_TEST; +} + +TEST_CASE(test_send_recv) { + UnixDatagramTransport transport_server("/tmp/vxup_test.sock", ""); + UnixDatagramTransport transport_client("/tmp/vxup_test1.sock", "/tmp/vxup_test.sock"); + transport_server.open(); + transport_client.open(); + transport_client.send(std::vector{0x01, 0x02, 0x03, 0x04}); + + auto [frame, token] = transport_server.receive(std::chrono::seconds(3)); + assert_eq(frame.size(), 4); + assert(token); + assert_eq(frame[0], 0x01); + assert_eq(frame[1], 0x02); + assert_eq(frame[2], 0x03); + assert_eq(frame[3], 0x04); + + transport_server.send(std::vector{0x04, 0x03, 0x02, 0x01}, token); + auto [frame2, token2] = transport_client.receive(std::chrono::seconds(3)); + assert_eq(frame2.size(), 4); + assert(token2); + assert_eq(frame2[0], 0x04); + assert_eq(frame2[1], 0x03); + assert_eq(frame2[2], 0x02); + assert_eq(frame2[3], 0x01); + END_TEST; + END_TEST; +}