init repo

This commit is contained in:
ovizro 2024-12-24 19:51:32 +08:00
commit 6e7e30eeaa
22 changed files with 2902 additions and 0 deletions

31
.gitignore vendored Normal file
View File

@ -0,0 +1,31 @@
################### CMake config ###################
build
################ Executable program ################
bin/
dist/
#################### Library files #################
lib/
################### VSCode config ##################
.vscode
.VSCodeCounter
####################### Tool chains ######################
toolchains/*
#################### Test config ###################
test.*
disabled.*
*.disabled
###################### Misc #######################
!.gitkeep

48
CMakeLists.txt Normal file
View File

@ -0,0 +1,48 @@
cmake_minimum_required(VERSION 3.15)
project(transport CXX)
set(CMAKE_CXX_STANDARD 11)
if (NOT CMAKE_CROSSCOMPILING)
set(EXECUTABLE_OUTPUT_PATH ${CMAKE_SOURCE_DIR}/bin)
set(LIBRARY_OUTPUT_PATH ${CMAKE_SOURCE_DIR}/lib)
endif()
if (UNIX)
set(CMAKE_CXX_FLAGS "-std=c++11 -Wall ${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS_DEBUG "-g ${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS_RELEASE "-g -O2 ${CMAKE_CXX_FLAGS}")
elseif (WIN32)
# windows platform
#add_definitions(-D_CRT_SECURE_NO_WARNINGS)
#set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MDd")
#set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MD")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MTd /EHsc")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT /EHsc")
endif()
if(MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /utf-8")
endif()
if (CMAKE_CROSSCOMPILING)
message(STATUS "Cross compiling ...")
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
endif()
include_directories(./include)
aux_source_directory(${CMAKE_SOURCE_DIR}/src SRC_LIST)
link_libraries(pthread util)
add_library(transport_static STATIC ${SRC_LIST})
add_library(transport SHARED ${SRC_LIST})
set_target_properties(transport_static PROPERTIES OUTPUT_NAME transport)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
target_compile_options(transport PRIVATE -Wno-nonnull)
endif()
add_subdirectory(tests)

4
README.md Normal file
View File

@ -0,0 +1,4 @@
# Transport
A simple C++ language data transfer module.

143
include/dataqueue.hpp Normal file
View File

@ -0,0 +1,143 @@
#ifndef _INCLUDED_QUEUE_
#define _INCLUDED_QUEUE_
#include <deque>
#include <mutex>
#include <chrono>
#include <condition_variable>
#include <stdexcept>
template <typename T>
class DataQueue;
class QueueException : public std::exception
{
public:
QueueException(void* queue) : queue(queue) {}
template <typename T>
DataQueue<T>* GetQueue() const noexcept
{
return static_cast<DataQueue<T>*>(queue);
}
private:
void* queue;
};
class QueueCleared : public QueueException
{
public:
QueueCleared(void* queue) : QueueException(queue) {}
const char* what() const noexcept override
{
return "queue cleared";
}
};
class QueueTimeout : public QueueException
{
public:
QueueTimeout(void* queue) : QueueException(queue) {}
const char* what() const noexcept override
{
return "queue timeout";
}
};
typedef uint8_t queue_epoch_t;
template <typename T>
class DataQueue {
public:
DataQueue() : m_CurrEpoch(0) {}
DataQueue(const DataQueue&) = delete;
~DataQueue() { Clear(); }
void Push(T data)
{
std::lock_guard<std::mutex> lock(m_Mutex);
m_Queue.push_back(std::move(data));
m_Cond.notify_one();
}
T Pop()
{
std::unique_lock<std::mutex> lock(m_Mutex);
auto epoch = m_CurrEpoch;
while (m_Queue.empty())
{
m_Cond.wait(lock);
if (epoch != m_CurrEpoch) {
throw QueueCleared(this);
}
}
T data = std::move(m_Queue.front());
m_Queue.pop_front();
return data;
}
template <typename Rep = uint64_t, typename Period = std::milli>
T Pop(const std::chrono::duration<Rep, Period> timeout)
{
std::unique_lock<std::mutex> lock(m_Mutex);
auto epoch = m_CurrEpoch;
while (m_Queue.empty())
{
if (m_Cond.wait_for(lock, timeout) == std::cv_status::timeout) {
throw QueueTimeout(this);
}
if (epoch != m_CurrEpoch) {
throw QueueCleared(this);
}
}
T data = std::move(m_Queue.front());
m_Queue.pop_front();
return data;
}
bool Empty() noexcept
{
std::lock_guard<std::mutex> lock(m_Mutex);
return m_Queue.empty();
}
size_t Size() noexcept
{
std::lock_guard<std::mutex> lock(m_Mutex);
return m_Queue.size();
}
queue_epoch_t GetEpoch() noexcept
{
std::lock_guard<std::mutex> lock(m_Mutex);
return m_CurrEpoch;
}
bool CheckEpoch(queue_epoch_t epoch) noexcept
{
std::lock_guard<std::mutex> lock(m_Mutex);
return m_CurrEpoch == epoch;
}
void Clear() noexcept
{
std::lock_guard<std::mutex> lock(m_Mutex);
m_Queue.clear();
m_CurrEpoch++;
m_Cond.notify_all();
}
protected:
std::deque<T> m_Queue;
std::mutex m_Mutex;
std::condition_variable m_Cond;
private:
queue_epoch_t m_CurrEpoch;
};
#endif

View File

@ -0,0 +1,52 @@
#ifndef _INCLUDE_LOGGING_INTERFACE_
#define _INCLUDE_LOGGING_INTERFACE_
#include <assert.h>
#include <stdarg.h>
#include "level.h"
#ifdef __cplusplus
extern "C" {
#endif
void log_init(const char* level);
int log_level(void);
void log_set_level(int level);
void log_log(int level, const char *fmt, ...);
void log_vlog(int level, const char *fmt, va_list ap);
#define log_debug(...) log_log(LOG_LEVEL_DEBUG, __VA_ARGS__)
#define log_info(...) log_log(LOG_LEVEL_INFO, __VA_ARGS__)
#define log_warn(...) log_log(LOG_LEVEL_WARN, __VA_ARGS__)
#define log_error(...) log_log(LOG_LEVEL_ERROR, __VA_ARGS__)
#define log_fatal(...) log_log(LOG_LEVEL_FATAL, __VA_ARGS__)
#define vlog_debug(...) log_vlog(LOG_LEVEL_DEBUG, __VA_ARGS__)
#define vlog_info(...) log_vlog(LOG_LEVEL_INFO, __VA_ARGS__)
#define vlog_warn(...) log_vlog(LOG_LEVEL_WARN, __VA_ARGS__)
#define vlog_error(...) log_vlog(LOG_LEVEL_ERROR, __VA_ARGS__)
#define vlog_fatal(...) log_vlog(LOG_LEVEL_FATAL, __VA_ARGS__)
#define log_with_source(level, msg) log_log(level,\
"Trackback (most recent call last):\n File \"%s\", line %d, in \"%s\"\n%s",\
__FILE__, __LINE__, __ASSERT_FUNCTION, msg)
#define log_debug_with_source(msg) log_with_source(LOG_LEVEL_DEBUG, msg)
#define log_info_with_source(msg) log_with_source(LOG_LEVEL_INFO, msg)
#define log_warn_with_source(msg) log_with_source(LOG_LEVEL_WARN, msg)
#define log_error_with_source(msg) log_with_source(LOG_LEVEL_ERROR, msg)
#define log_fatal_with_source(msg) log_with_source(LOG_LEVEL_FATAL, msg)
#define log_from_errno(level, msg) do {\
if (errno) log_log(level,"%s: %s", msg, strerror(errno));\
} while (0)
#define log_debug_from_errno(msg) log_from_errno(LOG_LEVEL_DEBUG, msg)
#define log_info_from_errno(msg) log_from_errno(LOG_LEVEL_INFO, msg)
#define log_warn_from_errno(msg) log_from_errno(LOG_LEVEL_WARN, msg)
#define log_error_from_errno(msg) log_from_errno(LOG_LEVEL_ERROR, msg)
#define log_fatal_from_errno(msg) log_from_errno(LOG_LEVEL_FATAL, msg)
#ifdef __cplusplus
}
#endif
#endif

10
include/logging/level.h Normal file
View File

@ -0,0 +1,10 @@
#ifndef _INCLUDE_LOGGING_LEVEL_
#define _INCLUDE_LOGGING_LEVEL_
#define LOG_LEVEL_DEBUG 1
#define LOG_LEVEL_INFO 2
#define LOG_LEVEL_WARN 3
#define LOG_LEVEL_ERROR 4
#define LOG_LEVEL_FATAL 5
#endif

260
include/logging/logger.hpp Normal file
View File

@ -0,0 +1,260 @@
#ifndef _INCLUDE_LOGGER_
#define _INCLUDE_LOGGER_
#include <stdint.h>
#include <stdarg.h>
#include <errno.h>
#include <string.h>
#include <vector>
#include <memory>
#include <iostream>
#include <chrono>
#include <unordered_map>
#include "level.h"
namespace logging {
struct Record;
class LoggerOStream;
class LoggerStreamBuf;
class Logger
{
public:
enum class Level: uint8_t
{
UNKNOWN = 0,
DEBUG = LOG_LEVEL_DEBUG,
INFO = LOG_LEVEL_INFO,
WARN = LOG_LEVEL_WARN,
ERROR = LOG_LEVEL_ERROR,
FATAL = LOG_LEVEL_FATAL
};
constexpr static const char* namesep = "::";
explicit Logger(Level level = Level::INFO) : _parent(nullptr), _level(level) {}
Logger(const Logger&) = delete;
virtual ~Logger() = default;
inline Level level() const
{
if (_level == Logger::Level::UNKNOWN) {
if (_parent) {
return _parent->level();
}
return Logger::Level::INFO;
}
return _level;
}
inline const std::string& name() const
{
return _name;
}
inline Logger* parent() const
{
return _parent;
}
inline size_t children_count() const
{
return logger_cache.size();
}
inline void set_level(Level level)
{
_level = level;
}
template<Level level>
inline void log(const char* fmt, ...) {
va_list args;
va_start(args, fmt);
vlog(level, fmt, args);
va_end(args);
}
inline void log(Level level, const char* fmt, ...)
{
va_list args;
va_start(args, fmt);
vlog(level, fmt, args);
va_end(args);
}
inline void debug(const char* fmt, ...)
{
va_list args;
va_start(args, fmt);
vlog(Logger::Level::DEBUG, fmt, args);
va_end(args);
}
void info(const char* fmt, ...)
{
va_list args;
va_start(args, fmt);
vlog(Logger::Level::INFO, fmt, args);
va_end(args);
}
void warn(const char* fmt, ...)
{
va_list args;
va_start(args, fmt);
vlog(Logger::Level::WARN, fmt, args);
va_end(args);
}
void error(const char* fmt, ...)
{
va_list args;
va_start(args, fmt);
vlog(Logger::Level::ERROR, fmt, args);
va_end(args);
}
void fatal(const char* fmt, ...)
{
va_list args;
va_start(args, fmt);
vlog(Logger::Level::FATAL, fmt, args);
va_end(args);
}
template<Level level>
inline void vlog(const char* fmt, va_list args) {
vlog(level, fmt, args);
}
template<typename E = std::runtime_error>
inline void raise_from_errno(const char* msg) {
error("%s: %s", msg, strerror(errno));
throw E(msg);
}
void vlog(Level level, const char* fmt, va_list args);
virtual void log_message(Level level, const std::string& msg);
LoggerOStream operator[](Level level);
LoggerOStream operator[](int level);
void move_children_to(Logger& other);
template<class T_Logger = Logger>
Logger* get_child(const std::string& name, Level level)
{
Logger* logger;
auto sep_pos = name.find(namesep);
do {
if (name.empty() || sep_pos == 0) {
logger = this;
break;
}
std::string base_name;
if (sep_pos == std::string::npos) {
base_name = name;
} else {
base_name = name.substr(0, sep_pos);
}
auto iter = logger_cache.find(base_name);
if (iter != logger_cache.end()) {
logger = iter->second.get();
break;
}
logger = new T_Logger(base_name, level, this);
logger_cache[base_name] = std::unique_ptr<Logger>(logger);
} while (0);
auto seqlen = strlen(namesep);
if (sep_pos == std::string::npos || sep_pos + seqlen >= name.size())
return logger;
return logger->get_child(name.substr(sep_pos + seqlen), level);
}
inline void add_stream(std::ostream& stream)
{
_streams.push_back(std::unique_ptr<std::ostream>(new std::ostream(stream.rdbuf())));
}
template <typename T>
inline void add_stream(T&& stream)
{
_streams.push_back(std::unique_ptr<std::ostream>(new typename std::remove_reference<T>::type(std::move(stream))));
}
inline std::vector<std::ostream*> streams() const
{
std::vector<std::ostream*> streams;
for (auto& stream : _streams)
streams.push_back(stream.get());
return streams;
}
protected:
explicit Logger(const std::string& name, Level level = Level::INFO, Logger* parent = nullptr);
virtual void log_record(const Record& record);
virtual void write_record(std::ostream& os, const Record& record);
static constexpr std::ostream& default_stream = std::cerr;
std::vector<std::unique_ptr<std::ostream>> _streams;
Logger* _parent;
private:
std::unordered_map<std::string, std::unique_ptr<Logger>> logger_cache;
std::string _name;
Level _level;
};
class LoggerStreamBuf : public std::streambuf {
public:
LoggerStreamBuf(Logger& logger, Logger::Level level);
LoggerStreamBuf(const LoggerStreamBuf&) = delete;
LoggerStreamBuf(LoggerStreamBuf&& loggerStream) = default;
protected:
int overflow(int c) override;
std::streamsize xsputn(const char* s, std::streamsize n) override;
int sync() override;
private:
void flush_line(); // 将当前行写入日志
Logger& _logger;
Logger::Level _level;
std::string _lineBuffer; // 行缓存
};
class LoggerOStream : public std::ostream {
public:
LoggerOStream(Logger& logger, Logger::Level level);
LoggerOStream(const LoggerOStream&) = delete;
LoggerOStream(LoggerOStream&& loggerStream);
private:
LoggerStreamBuf _streamBuf;
};
struct Record
{
std::string name;
std::chrono::system_clock::time_point time;
Logger::Level level;
std::string msg;
Record(const std::string& name, Logger::Level level, const std::string& msg);
};
using LogLevel = Logger::Level;
std::ostream& operator<<(std::ostream& os, const Logger::Level level);
Logger::Level str2level(const char* level);
std::unique_ptr<Logger>& _get_global_logger();
Logger& get_global_logger();
void set_global_logger(std::unique_ptr<Logger>&& logger);
template<class T_Logger = Logger>
static inline Logger* get_logger(const std::string& name, LogLevel level = LogLevel::UNKNOWN)
{
return _get_global_logger()->get_child<T_Logger>(name, level);
}
}
#endif

164
include/transport/base.hpp Normal file
View File

@ -0,0 +1,164 @@
#ifndef _INCLUDE_TRANSPORT_BASE_
#define _INCLUDE_TRANSPORT_BASE_
#include <stdint.h>
#include <chrono>
#include <memory>
#include <thread>
#include <functional>
#include "dataqueue.hpp"
#include "logging/logger.hpp"
#include "protocol.hpp"
#define TRANSPORT_MAX_RETRY 5
#define TRANSPORT_TIMEOUT 1000
namespace transport
{
template <typename P>
class BaseTransport;
class _transport_base {
public:
_transport_base() : is_open(false), is_closed(false) {}
_transport_base(const _transport_base&) = delete;
virtual ~_transport_base() {
close();
}
virtual void open() {
if (is_open)
return;
is_open = true;
std::thread([this] {
try {
send_backend();
} catch (const QueueCleared&) {}
}).detach();
std::thread([this] {
try {
receive_backend();
} catch (const QueueCleared&) {}
}).detach();
}
virtual void close() {
is_closed = true;
}
bool closed() const {
return is_closed;
}
protected:
void ensure_open() {
auto& logger = *logging::get_logger("transport");
if (is_closed)
{
logger.fatal("transport closed");
throw std::runtime_error("transport closed");
}
if (!is_open)
{
open();
}
}
virtual void send_backend() = 0;
virtual void receive_backend() = 0;
bool is_open;
bool is_closed;
};
class TransportToken
{
public:
explicit TransportToken(_transport_base *transport) : transport_(transport) {}
template <typename P>
BaseTransport<P> *transport() const {
return dynamic_cast<BaseTransport<P>*>(transport_);
}
virtual bool operator==(const TransportToken &other) const {
return transport_ == other.transport_;
}
protected:
_transport_base *transport_;
friend class _transport_base;
friend std::hash<TransportToken>;
};
template <typename P>
class BaseTransport : public _transport_base
{
public:
typedef P Protocol;
typedef typename P::FrameType FrameType;
~BaseTransport() override
{
close();
}
template <typename Rep = uint64_t, typename Period = std::milli>
inline void send(typename P::FrameType frame, std::shared_ptr<TransportToken> token = nullptr)
{
ensure_open();
send_que.Push(std::make_pair(frame, token));
}
template <typename Rep = uint64_t, typename Period = std::milli>
inline std::pair<typename P::FrameType, std::shared_ptr<TransportToken>> receive(std::chrono::duration<Rep, Period> dur = std::chrono::milliseconds(0))
{
ensure_open();
DataPair frame_pair;
if (!dur.count())
frame_pair = recv_que.Pop();
else
frame_pair = recv_que.Pop(dur);
return frame_pair;
}
template <typename Rep = uint64_t, typename Period = std::milli>
inline typename P::FrameType request(typename P::FrameType frame, int max_retry = TRANSPORT_MAX_RETRY, std::chrono::duration<Rep, Period> dur = std::chrono::milliseconds(TRANSPORT_TIMEOUT))
{
ensure_open();
while (max_retry--)
{
send_que.Push(std::make_pair(frame, nullptr));
DataPair frame_pair = recv_que.Pop(dur);
if (frame_pair.first) {
return frame_pair.first;
}
auto& logger = *logging::get_logger("transport");
logger.warn("request timeout, retrying...");
}
return nullptr;
}
void close() override {
is_closed = true;
recv_que.Clear();
send_que.Clear();
}
typedef std::pair<typename P::FrameType, std::shared_ptr<TransportToken>> DataPair;
protected:
DataQueue<DataPair> send_que;
DataQueue<DataPair> recv_que;
};
}
namespace std {
template <>
struct hash<transport::TransportToken>
{
size_t operator()(const transport::TransportToken &token) const
{
return std::hash<uintptr_t>()(reinterpret_cast<uintptr_t>(token.transport_));
}
};
}
#endif

View File

@ -0,0 +1,35 @@
#ifndef _INCLUDE_TRANSPORT_PROTOCOL_
#define _INCLUDE_TRANSPORT_PROTOCOL_
#include <stdint.h>
#include <vector>
namespace transport
{
class Protocol {
public:
typedef std::vector<uint8_t> FrameType;
static ssize_t pred_size(void* buf, size_t size) {
if (buf == nullptr) return 1;
return size;
}
static FrameType make_frame(void* buf, size_t size) {
if (buf == nullptr) return FrameType();
return FrameType((uint8_t*)buf, (uint8_t*)buf + size);
}
static size_t frame_size(FrameType frame) {
return frame.size();
}
static void* frame_data(FrameType frame) {
return frame.data();
}
};
}
#endif

View File

@ -0,0 +1,306 @@
#ifndef _INCLUDE_TRANSPORT_TTY_
#define _INCLUDE_TRANSPORT_TTY_
#include <fcntl.h>
#include <termios.h>
#include <unistd.h>
#include <sys/ioctl.h>
#include <assert.h>
#include "base.hpp"
#define TRANSPORT_SERIAL_PORT_BUFFER_SIZE 1024 * 1024
// #define TRANSPORT_SERIAL_PORT_DEBUG
#ifdef COM_FRAME_DEBUG
#define TRANSPORT_SERIAL_PORT_DEBUG
#endif
namespace transport
{
template <typename P>
class SerialPortTransport : public BaseTransport<P>
{
public:
SerialPortTransport(const std::string &path, int baudrate = 115200, size_t buffer_size = TRANSPORT_SERIAL_PORT_BUFFER_SIZE)
: path(path), tty_id(-1), baudrate(baudrate), buffer_size(buffer_size) {}
SerialPortTransport(int tty_id, int baudrate = 115200, size_t buffer_size = 1024)
: tty_id(tty_id), baudrate(baudrate), buffer_size(buffer_size) {}
~SerialPortTransport() override
{
close();
}
void open() override
{
auto &logger = *logging::get_logger("transport");
if (this->is_open)
{
return;
}
else if (this->is_closed)
{
logger.info("reopen serial port transport");
this->is_open = false;
this->is_closed = false;
}
if (tty_id < 0)
{
logger.info("open serial port %s", path.c_str());
tty_id = ::open(path.c_str(), O_RDWR | O_NOCTTY | O_NONBLOCK);
if (tty_id < 0)
{
logger.raise_from_errno("open serial port failed");
}
}
else
{
logger.info("use serial port %d", tty_id);
}
struct termios options, old;
// deploy usart par
memset(&options, 0, sizeof(options));
int ret = tcgetattr(tty_id, &old);
if (ret != 0)
{
logger.error("tcgetattr failed: %s", strerror(errno));
goto end;
}
tcflush(tty_id, TCIOFLUSH);
switch (baudrate)
{
case 9600:
cfsetispeed(&options, B9600);
cfsetospeed(&options, B9600);
break;
case 19200:
cfsetispeed(&options, B19200);
cfsetospeed(&options, B19200);
break;
case 38400:
cfsetispeed(&options, B38400);
cfsetospeed(&options, B38400);
break;
case 57600:
cfsetispeed(&options, B57600);
cfsetospeed(&options, B57600);
break;
case 115200:
cfsetispeed(&options, B115200);
cfsetospeed(&options, B115200);
break;
case 576000:
cfsetispeed(&options, B576000);
cfsetospeed(&options, B576000);
break;
case 921600:
cfsetispeed(&options, B921600);
cfsetospeed(&options, B921600);
break;
case 2000000:
cfsetispeed(&options, B2000000);
cfsetospeed(&options, B2000000);
break;
case 3000000:
cfsetispeed(&options, B3000000);
cfsetospeed(&options, B3000000);
break;
default:
logger.error("bad baud rate %u", baudrate);
break;
}
switch (1)
{
case 0:
options.c_cflag &= ~PARENB;
options.c_cflag &= ~INPCK;
break;
case 1:
options.c_cflag |= (PARODD // 使用奇校验代替偶校验
| PARENB); // 校验位有效
options.c_iflag |= INPCK; // 校验有效
break;
case 2:
options.c_cflag |= PARENB;
options.c_cflag &= ~PARODD;
options.c_iflag |= INPCK;
break;
case 3:
options.c_cflag &= ~PARENB;
options.c_cflag &= ~CSTOPB;
break;
default:
options.c_cflag &= ~PARENB;
break;
}
options.c_cflag |= (CLOCAL | CREAD);
options.c_cflag &= ~CSIZE;
options.c_cflag &= ~CRTSCTS;
options.c_cflag |= CS8;
options.c_cflag &= ~CSTOPB;
options.c_oflag = 0;
options.c_lflag = 0;
options.c_cc[VTIME] = 0;
options.c_cc[VMIN] = 0;
// 启用输出的XON/XOFF控制字符
// Enable software flow control (XON/XOFF) for both input and output
options.c_iflag |= (IXON | IXOFF); // Enable input and output XON/XOFF control characters
options.c_oflag |= (IXON | IXOFF); // Enable input and output XON/XOFF control characters
tcflush(tty_id, TCIFLUSH);
if ((tcsetattr(tty_id, TCSANOW, &options)) != 0)
{
logger.error("tcsetattr failed: %s", strerror(errno));
}
end:
super::open();
}
void close() override
{
if (this->is_open && !this->closed())
{
auto &logger = *logging::get_logger("transport");
logger.info("close serial port %s", path.c_str());
::close(tty_id);
}
super::close();
}
protected:
void send_backend() override
{
// this->ensure_open();
auto &logger = *logging::get_logger("transport");
logger.debug("start serial port send backend");
while (!this->is_closed)
{
auto frame_pair = this->send_que.Pop();
auto frame = frame_pair.first;
if (frame_pair.second && frame_pair.second->template transport<P>() != this)
{
logger.error("invalid token received");
continue;
}
size_t remaining_size = P::frame_size(frame);
if (remaining_size == 0) {
continue;
}
logger.debug("send data %zu", remaining_size);
size_t offset = 0;
while (remaining_size > 0)
{
auto written_size = write(tty_id, static_cast<uint8_t*>(P::frame_data(frame)) + offset, remaining_size);
if (written_size < 0)
{
logger.error("write serial port failed: %s", strerror(errno));
}
else
{
remaining_size -= written_size;
offset += written_size;
}
}
}
}
void receive_backend() override
{
// this->ensure_open();
auto &logger = *logging::get_logger("transport");
logger.debug("start serial port receive backend");
bool find_head = false;
size_t min_size = P::pred_size(nullptr, 0);
if (!min_size) min_size = 1;
ssize_t recv_size;
size_t pred_size = min_size * 2; // min buffer size to keep,
size_t offset = 0; // scanned data size
size_t cached_size = 0; // read data size
assert(buffer_size >= pred_size);
uint8_t *buffer = new uint8_t[buffer_size];
while (!this->is_closed)
{
recv_size = read(tty_id, ((uint8_t *)buffer) + cached_size, buffer_size - cached_size);
if (recv_size <= 0)
{
if (cached_size == offset) {
usleep(10);
continue;
}
recv_size = 0;
}
#ifdef TRANSPORT_SERIAL_PORT_DEBUG
printf("receive com data (received=%zd,cached=%zu)\nbuffer: ",
recv_size, cached_size);
for (size_t i = 0; i < cached_size + recv_size; ++i)
{
printf("%02x", ((uint8_t *)buffer)[i]);
}
putchar('\n');
#endif
cached_size += recv_size;
if (!find_head)
{
// update offset to scan the header
for (; offset + min_size < cached_size; ++offset)
{
ssize_t pred = P::pred_size(((uint8_t *)buffer) + offset, cached_size - offset);
find_head = pred > 0;
if (find_head)
{
pred_size = pred;
logger.debug("find valid data (length=%zu)", pred_size);
if (pred_size > buffer_size)
{
logger.error("data size is too large (%zu)\n", pred_size);
find_head = false;
pred_size = min_size * 2;
continue;
}
break;
}
}
}
if (find_head && cached_size >= pred_size + offset)
{
// all data received
auto frame = P::make_frame((uint8_t *)buffer + offset, pred_size);
logger.debug("receive data %zu", pred_size);
this->recv_que.Push(std::make_pair(std::move(frame), std::make_shared<TransportToken>(this)));
offset += pred_size; // update offset for next run
}
// clear the cache when the remaining length of the cache is
if (offset && buffer_size - cached_size < pred_size)
{
cached_size -= offset;
memmove(buffer, ((uint8_t *)buffer) + offset, cached_size);
offset = 0;
}
}
delete[] buffer;
}
private:
typedef BaseTransport<P> super;
std::string path;
int tty_id;
int baudrate;
size_t buffer_size;
};
}
#endif

247
include/transport/udp.hpp Normal file
View File

@ -0,0 +1,247 @@
#ifndef _INCLUDE_TRANSPORT_UDP_
#define _INCLUDE_TRANSPORT_UDP_
#include <memory>
#include <string>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <sys/socket.h>
#include "base.hpp"
#define TRANSPORT_UDP_BUFFER_SIZE 1024 * 64
namespace transport {
class DatagramTransportToken : public TransportToken {
public:
explicit DatagramTransportToken(_transport_base* transport, const struct sockaddr_in& addr, socklen_t addr_len)
: TransportToken(transport), addr(addr), addr_len(addr_len) {}
bool operator==(const TransportToken& other) const override
{
auto other_token = dynamic_cast<const DatagramTransportToken*>(&other);
if (!other_token)
{
return false;
}
if (transport_ != other_token->transport_ || addr_len != other_token->addr_len)
{
return false;
}
return memcmp(&addr, &other_token->addr, addr_len) == 0;
}
protected:
struct sockaddr_in addr;
socklen_t addr_len;
friend std::hash<DatagramTransportToken>;
template<typename P>
friend class DatagramTransport;
};
template<typename P>
class DatagramTransport : public BaseTransport<P> {
public:
explicit DatagramTransport(size_t buffer_size = TRANSPORT_UDP_BUFFER_SIZE)
: sockfd(-1), buffer_size(buffer_size)
{
memset(&bind_addr, 0, sizeof(bind_addr));
memset(&connect_addr, 0, sizeof(connect_addr));
bind_addr.sin_family = AF_INET;
connect_addr.sin_family = AF_INET;
}
DatagramTransport(std::pair<const char*, int> local_addr, std::pair<const char*, int> remote_addr, size_t buffer_size = TRANSPORT_UDP_BUFFER_SIZE)
: DatagramTransport(buffer_size)
{
bind_addr.sin_port = htons(local_addr.second);
resolve_hostname(local_addr.first, bind_addr);
connect_addr.sin_port = htons(remote_addr.second);
resolve_hostname(remote_addr.first, connect_addr);
}
~DatagramTransport()
{
close();
}
void open() override
{
auto &logger = *logging::get_logger("transport");
if (this->is_open)
{
return;
} else if (this->is_closed)
{
logger.info("reopen datagram transport");
this->is_open = false;
this->is_closed = false;
}
sockfd = socket(AF_INET, SOCK_DGRAM, 0);
if (sockfd < 0) {
logger.raise_from_errno("failed to create socket");
} else {
logger.info("open socket fd %d", sockfd);
}
super::open();
if (bind_addr.sin_port)
{
if (::bind(sockfd, (struct sockaddr *)&bind_addr, sizeof(bind_addr)) < 0)
{
logger.raise_from_errno("failed to bind socket");
}
logger.info("listening on %s:%d", inet_ntoa(bind_addr.sin_addr), ntohs(bind_addr.sin_port));
}
}
void close() override
{
if (this->is_open && !this->closed())
{
auto &logger = *logging::get_logger("transport");
logger.info("close socket fd %d", sockfd);
::close(sockfd);
}
super::close();
}
void bind(const std::string& address, int port)
{
this->ensure_open();
auto &logger = *logging::get_logger("transport");
resolve_hostname(address, bind_addr);
bind_addr.sin_port = htons(port);
if (::bind(sockfd, (struct sockaddr *)&bind_addr, sizeof(bind_addr)) < 0)
{
logger.raise_from_errno("failed to bind socket");
}
logger[logging::LogLevel::INFO] << "listening on " << address << ":" << port << std::endl;
}
void connect(const std::string& address, int port)
{
auto &logger = *logging::get_logger("transport");
resolve_hostname(address, connect_addr);
connect_addr.sin_port = htons(port);
logger[logging::LogLevel::INFO] << "connecting to " << address << ":" << port << std::endl;
}
constexpr static std::pair<const char*, int> nulladdr = {"", 0};
protected:
void send_backend() override
{
// this->ensure_open();
auto &logger = *logging::get_logger("transport");
logger.debug("start datagram send backend");
while (!this->is_closed)
{
auto frame_pair = this->send_que.Pop();
auto frame = frame_pair.first;
if (!P::frame_size(frame))
continue;
auto token = dynamic_cast<DatagramTransportToken *>((frame_pair.second.get()));
struct sockaddr* addr = (struct sockaddr *)((token) ? &token->addr : &connect_addr);
socklen_t addr_len = (token) ? token->addr_len : sizeof(connect_addr);
ssize_t sent_size = sendto(sockfd, P::frame_data(frame), P::frame_size(frame), 0,
addr, addr_len);
logger.debug("send data %zd", sent_size);
if (sent_size < 0)
{
logger.error("udp send failed: %s", strerror(errno));
}
else if ((size_t)sent_size < P::frame_size(frame))
{
logger.warn("sendto failed, only %zd bytes sent", sent_size);
}
}
}
void receive_backend() override
{
// this->ensure_open();
auto &logger = *logging::get_logger("transport");
logger.debug("start datagram receive backend");
uint8_t *buffer = new uint8_t[buffer_size];
while (!this->is_closed)
{
struct sockaddr_in addr;
socklen_t addr_len = sizeof(addr);
ssize_t recv_size = recvfrom(sockfd, buffer, buffer_size, 0,
(struct sockaddr *)&addr, &addr_len);
if (recv_size < 0)
{
logger.error("udp recv failed: %s", strerror(errno));
continue;
}
logger.debug("receive data %zd", recv_size);
ssize_t pred_size = P::pred_size(buffer, recv_size);
if (pred_size < 0)
{
logger.error("invalid frame received");
continue;
}
auto frame = P::make_frame(buffer, recv_size);
this->recv_que.Push(std::make_pair(frame, std::make_shared<DatagramTransportToken>(this, addr, addr_len)));
}
delete[] buffer;
}
private:
static void resolve_hostname(const std::string& hostname, struct sockaddr_in& result)
{
if (hostname.empty())
{
result.sin_addr.s_addr = INADDR_ANY;
return;
}
std::string hostname_str(hostname);
struct hostent *he = gethostbyname(hostname_str.c_str());
if (he == nullptr)
{
auto &logger = *logging::get_logger("transport");
logger.error("failed to resolve hostname %s: %s", hostname_str.c_str(), hstrerror(h_errno));
throw std::runtime_error("Failed to resolve hostname");
}
memcpy(&result.sin_addr, he->h_addr_list[0], he->h_length);
}
typedef BaseTransport<P> super;
int sockfd;
struct sockaddr_in bind_addr;
struct sockaddr_in connect_addr;
size_t buffer_size;
};
}
namespace std {
template<>
struct hash<transport::DatagramTransportToken> {
size_t operator()(const transport::DatagramTransportToken &token) const
{
std::size_t hash1 = std::hash<transport::TransportToken>()(token);
std::size_t hash2 = std::hash<in_addr_t>()(token.addr.sin_addr.s_addr);
std::size_t hash3 = std::hash<in_port_t>()(token.addr.sin_port);
hash1 ^= (hash2 + 0x9e3779b9 + (hash1 << 6) + (hash1 >> 2));
hash1 ^= (hash3 + 0x9e3779b9 + (hash1 << 6) + (hash1 >> 2));
return hash1;
}
};
}
#endif

View File

@ -0,0 +1,230 @@
#ifndef _INCLUDE_TRANSPORT_UNIX_UDP_
#define _INCLUDE_TRANSPORT_UNIX_UDP_
#include <memory>
#include <string>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <netdb.h>
#include <sys/socket.h>
#include "base.hpp"
#define TRANSPORT_UDP_BUFFER_SIZE 1024
namespace transport {
class UnixDatagramTransportToken : public TransportToken {
public:
explicit UnixDatagramTransportToken(_transport_base* transport, const struct sockaddr_un& addr, socklen_t addr_len)
: TransportToken(transport), addr(addr), addr_len(addr_len) {}
bool operator==(const TransportToken& other) const override
{
auto other_token = dynamic_cast<const UnixDatagramTransportToken*>(&other);
if (!other_token)
{
return false;
}
if (transport_ != other_token->transport_ || addr_len != other_token->addr_len)
{
return false;
}
return memcmp(&addr, &other_token->addr, addr_len) == 0;
}
protected:
struct sockaddr_un addr;
socklen_t addr_len;
friend std::hash<UnixDatagramTransportToken>;
template <typename P>
friend class UnixDatagramTransport;
};
template <typename P>
class UnixDatagramTransport : public BaseTransport<P> {
public:
explicit UnixDatagramTransport(size_t buffer_size = TRANSPORT_UDP_BUFFER_SIZE)
: sockfd(-1), buffer_size(buffer_size)
{
memset(&bind_addr, 0, sizeof(bind_addr));
memset(&connect_addr, 0, sizeof(connect_addr));
bind_addr.sun_family = AF_UNIX;
connect_addr.sun_family = AF_UNIX;
}
UnixDatagramTransport(const std::string& local_addr, const std::string& remote_addr, size_t buffer_size = TRANSPORT_UDP_BUFFER_SIZE)
: UnixDatagramTransport(buffer_size)
{
set_sock_path(local_addr, bind_addr);
set_sock_path(remote_addr, connect_addr);
}
~UnixDatagramTransport()
{
close();
}
void open() override
{
auto& logger = *logging::get_logger("transport");
if (this->is_open)
{
return;
}
else if (this->is_closed)
{
logger.info("reopen datagram transport");
this->is_open = false;
this->is_closed = false;
}
sockfd = socket(AF_UNIX, SOCK_DGRAM, 0);
if (sockfd < 0) {
logger.raise_from_errno("failed to create socket");
} else {
logger.info("open socket fd %d", sockfd);
}
super::open();
if (*bind_addr.sun_path)
{
_bind();
}
}
void close() override
{
if (this->is_open && !this->closed())
{
auto &logger = *logging::get_logger("transport");
logger.info("close socket fd %d", sockfd);
::close(sockfd);
if (*bind_addr.sun_path)
unlink(bind_addr.sun_path);
}
super::close();
}
void bind(const std::string& address)
{
this->ensure_open();
set_sock_path(address, bind_addr);
_bind();
}
void connect(const std::string& address)
{
set_sock_path(address, connect_addr);
auto& logger = *logging::get_logger("transport");
logger[logging::LogLevel::INFO] << "connecting to " << address << std::endl;
}
protected:
void send_backend() override
{
// this->ensure_open();
auto& logger = *logging::get_logger("transport");
logger.debug("start datagram send backend");
while (!this->is_closed)
{
auto frame_pair = this->send_que.Pop();
auto frame = frame_pair.first;
if (!P::frame_size(frame))
continue;
auto token = dynamic_cast<UnixDatagramTransportToken *>((frame_pair.second.get()));
struct sockaddr* addr = (struct sockaddr *)((token) ? &token->addr : &connect_addr);
socklen_t addr_len = (token) ? token->addr_len : sizeof(connect_addr);
ssize_t sent_size = sendto(sockfd, P::frame_data(frame), P::frame_size(frame), 0,
addr, addr_len);
logger.debug("send data %zd", sent_size);
if (sent_size < 0)
{
logger.error("unix udp send failed: %s", strerror(errno));
}
else if ((size_t)sent_size < P::frame_size(frame))
{
logger.warn("sendto failed, only %zd bytes sent", sent_size);
}
}
}
void receive_backend() override
{
// this->ensure_open();
auto& logger = *logging::get_logger("transport");
logger.debug("start datagram receive backend");
uint8_t *buffer = new uint8_t[buffer_size];
while (!this->is_closed)
{
struct sockaddr_un addr;
socklen_t addr_len = sizeof(addr);
ssize_t recv_size = recvfrom(sockfd, buffer, buffer_size, 0,
(struct sockaddr *)&addr, &addr_len);
if (recv_size < 0)
{
logger.error("unix udp recv failed: %s", strerror(errno));
continue;
}
logger.debug("receive data %zd", recv_size);
ssize_t pred_size = P::pred_size(buffer, recv_size);
if (pred_size < 0)
{
logger.error("invalid frame received");
continue;
}
auto frame = P::make_frame(buffer, recv_size);
this->recv_que.Push(std::make_pair(frame, std::make_shared<UnixDatagramTransportToken>(this, addr, addr_len)));
}
delete[] buffer;
}
private:
static void set_sock_path(const std::string& path, struct sockaddr_un& result)
{
if (path.size() + 1 >= sizeof(result.sun_path))
{
auto& logger = *logging::get_logger("transport");
logger.fatal("socket path too long");
throw std::runtime_error("socket path too long");
}
strncpy(result.sun_path, path.data(), path.size());
result.sun_path[path.size()] = '\0';
}
void _bind()
{
auto& logger = *logging::get_logger("transport");
unlink(bind_addr.sun_path);
if (::bind(sockfd, (struct sockaddr *)&bind_addr, sizeof(bind_addr)) < 0)
{
logger.raise_from_errno("failed to bind socket");
}
logger.info("listening on %s", bind_addr.sun_path);
}
typedef BaseTransport<P> super;
int sockfd;
struct sockaddr_un bind_addr;
struct sockaddr_un connect_addr;
size_t buffer_size;
};
}
namespace std {
template<>
struct hash<transport::UnixDatagramTransportToken> {
size_t operator()(const transport::UnixDatagramTransportToken &token) const
{
std::size_t hash1 = std::hash<transport::TransportToken>()(token);
std::size_t hash2 = std::hash<std::string>()(token.addr.sun_path);
hash1 ^= (hash2 + 0x9e3779b9 + (hash1 << 6) + (hash1 >> 2));
return hash1;
}
};
}
#endif

415
scripts/unittest.py Executable file
View File

@ -0,0 +1,415 @@
#!/usr/bin/python3
import os
import sys
from io import StringIO
from argparse import ArgumentParser
from enum import IntEnum
from glob import glob
from subprocess import run
from time import time
from traceback import print_exception
from typing import List, NoReturn, Optional
__version__ = "0.1.2"
if sys.stdout.isatty():
TTY_COLOR_RED = "\033[31m"
TTY_COLOR_RED_BOLD = "\033[1;31m"
TTY_COLOR_GREEN = "\033[32m"
TTY_COLOR_GREEN_BOLD = "\033[1;32m"
TTY_COLOR_YELLOW = "\033[33m"
TTY_COLOR_YELLOW_BOLD = "\033[1;33m"
TTY_COLOR_CYAN_BOLD = "\033[1;36m"
TTY_COLOR_CLEAR = "\033[0m"
TTY_COLUMNS_SIZE = os.get_terminal_size().columns
else:
TTY_COLOR_RED = ""
TTY_COLOR_RED_BOLD = ""
TTY_COLOR_GREEN = ""
TTY_COLOR_GREEN_BOLD = ""
TTY_COLOR_YELLOW = ""
TTY_COLOR_YELLOW_BOLD = ""
TTY_COLOR_CYAN_BOLD = ""
TTY_COLOR_CLEAR = ""
TTY_COLUMNS_SIZE = 80
def print_separator_ex(lc: str, title: str, color: str) -> None:
len_str = len(title)
print(f"{TTY_COLOR_CLEAR}{color}", end='') # 设置颜色样式
if len_str + 2 > TTY_COLUMNS_SIZE:
print(title)
else:
print(f" {title} ".center(TTY_COLUMNS_SIZE, lc))
print(TTY_COLOR_CLEAR, end='') # 重置颜色为默认
class CTestCaseStatus(IntEnum):
NOT_RUN = -1
PASSED = 0
ERROR = 1
FAILED = 16
SKIPPED = 32
SETUP_FAILED = 17
TEARDOWN_FAILED = 18
class CTestCaseCounter:
__slots__ = ["total_count", "passed", "error", "failed", "skipped"]
def __init__(self) -> None:
self.total_count = 0
self.passed: 'set[CTestCase]' = set()
self.error: 'set[CTestCase]' = set()
self.failed: 'set[CTestCase]' = set()
self.skipped: 'set[CTestCase]' = set()
def update(self, test_case: "CTestCase") -> None:
self.total_count += 1
if test_case.status == CTestCaseStatus.PASSED:
self.passed.add(test_case)
elif test_case.status == CTestCaseStatus.ERROR:
self.error.add(test_case)
elif test_case.status == CTestCaseStatus.SKIPPED:
self.skipped.add(test_case)
elif test_case.status in [CTestCaseStatus.FAILED, CTestCaseStatus.SETUP_FAILED, CTestCaseStatus.TEARDOWN_FAILED]:
self.failed.add(test_case)
else:
raise ValueError(f"{test_case.status} is not a valid status for counter")
def clone(self) -> "CTestCaseCounter":
counter = CTestCaseCounter()
counter.total_count = self.total_count
counter.passed = self.passed
counter.error = self.error
counter.failed = self.failed
counter.skipped = self.skipped
return counter
def __add__(self, other: "CTestCaseCounter") -> "CTestCaseCounter":
counter = self.clone()
counter += other
return counter
def __iadd__(self, other: "CTestCaseCounter") -> "CTestCaseCounter":
self.total_count += other.total_count
self.passed.update(other.passed)
self.error.update(other.error)
self.failed.update(other.failed)
self.skipped.update(other.skipped)
return self
@property
def status(self) -> CTestCaseStatus:
if self.error:
return CTestCaseStatus.ERROR
elif self.failed:
return CTestCaseStatus.FAILED
elif self.skipped:
return CTestCaseStatus.SKIPPED
elif self.passed:
return CTestCaseStatus.PASSED
else:
return CTestCaseStatus.NOT_RUN
def __repr__(self) -> str:
return f"<counter total: {self.total_count}, passed: {len(self.passed)}, error: {len(self.error)}, failed: {len(self.failed)}, skipped: {len(self.skipped)}>"
def __str__(self) -> str:
ss = StringIO()
ss.write(f"total: {self.total_count}")
if self.passed:
ss.write(f", passed: {len(self.passed)}")
elif self.skipped:
ss.write(f", skipped: {len(self.skipped)}")
elif self.failed:
ss.write(f", failed: {len(self.failed)}")
elif self.error:
ss.write(f", error: {len(self.error)}")
return ss.getvalue()
class CTestCase:
__slots__ = ["id", "name", "status", "file", "result", "error_info"]
def __init__(self, id: int, name: str, file: "CTestCaseFile") -> None:
self.id = id
self.name = name
self.file = file
self.status = CTestCaseStatus.NOT_RUN
self.result = None
self.error_info = None
def run(self, *, verbose: bool = False, capture: bool = True) -> CTestCaseStatus:
try:
sys.stdout.flush()
sys.stderr.flush()
self.result = run([self.file.path, "--unittest", str(self.id)], capture_output=capture)
except Exception:
self.status = CTestCaseStatus.ERROR
self.error_info = sys.exc_info()
if not capture:
print(TTY_COLOR_RED)
print_exception(*self.error_info)
print(TTY_COLOR_CLEAR)
except KeyboardInterrupt:
self.status = CTestCaseStatus.ERROR
self.error_info = sys.exc_info()
else:
code = self.result.returncode
if code in CTestCaseStatus.__members__.values():
self.status = CTestCaseStatus(code)
else:
self.status = CTestCaseStatus.ERROR
if verbose:
self.print_status_verbose()
else:
self.print_status()
return self.status
def print_status(self) -> None:
if self.status == CTestCaseStatus.PASSED:
print(f"{TTY_COLOR_GREEN}.{TTY_COLOR_CLEAR}", end='')
elif self.status in [CTestCaseStatus.FAILED, CTestCaseStatus.SETUP_FAILED, CTestCaseStatus.TEARDOWN_FAILED]:
print(f"{TTY_COLOR_RED}F{TTY_COLOR_CLEAR}", end='')
elif self.status == CTestCaseStatus.ERROR:
print(f"{TTY_COLOR_RED}E{TTY_COLOR_CLEAR}", end='')
elif self.status == CTestCaseStatus.SKIPPED:
print(f"{TTY_COLOR_YELLOW}s{TTY_COLOR_CLEAR}", end='')
else:
raise ValueError(f"invalid test case status: {self.status}")
def print_status_verbose(self) -> None:
if self.status == CTestCaseStatus.PASSED:
print(f"{self.name} {TTY_COLOR_GREEN}PASSED{TTY_COLOR_CLEAR}")
elif self.status == CTestCaseStatus.FAILED:
print(f"{self.name} {TTY_COLOR_RED}FAILED{TTY_COLOR_CLEAR}")
elif self.status == CTestCaseStatus.ERROR:
print(f"{self.name} {TTY_COLOR_RED}ERROR{TTY_COLOR_CLEAR}")
elif self.status == CTestCaseStatus.SETUP_FAILED:
print(f"{self.name} {TTY_COLOR_RED}SETUP FAILED{TTY_COLOR_CLEAR}")
elif self.status == CTestCaseStatus.TEARDOWN_FAILED:
print(f"{self.name} {TTY_COLOR_RED}TEARDOWN FAILED{TTY_COLOR_CLEAR}")
elif self.status == CTestCaseStatus.SKIPPED:
print(f"{self.name} {TTY_COLOR_YELLOW}SKIPPED{TTY_COLOR_CLEAR}")
else:
raise ValueError(f"invalid test case status: {self.status}")
def report(self) -> Optional[str]:
if self.status == CTestCaseStatus.PASSED:
return
elif self.status == CTestCaseStatus.SKIPPED:
return f"{TTY_COLOR_YELLOW}SKIPPED{TTY_COLOR_CLEAR} {self.name}"
elif self.status == CTestCaseStatus.ERROR:
if self.error_info:
print_separator_ex('_', self.name, TTY_COLOR_RED)
print(TTY_COLOR_RED, end='')
print_exception(*self.error_info)
print(TTY_COLOR_CLEAR, end='')
assert self.error_info[0]
if str(self.error_info[1]):
error = f"{self.error_info[0].__name__}: {self.error_info[1]}"
else:
error = f"{self.error_info[0].__name__}"
return f"{TTY_COLOR_RED}ERROR{TTY_COLOR_CLEAR} {self.name} - {error}"
else:
assert self.result
if self.result.stderr:
print_separator_ex('_', self.name, TTY_COLOR_RED)
print(TTY_COLOR_RED, end='')
print(self.result.stderr.decode("utf-8"), end='')
print(TTY_COLOR_CLEAR, end='')
if self.result.stdout:
print_separator_ex('-', "Captured stdout", '')
print(self.result.stdout.decode("utf-8"), end='')
return f"{TTY_COLOR_RED}ERROR{TTY_COLOR_CLEAR} {self.name} - RuntimeError ({self.result.returncode})"
elif self.status in [CTestCaseStatus.FAILED, CTestCaseStatus.SETUP_FAILED, CTestCaseStatus.TEARDOWN_FAILED]:
assert self.result
if self.result.stderr:
print_separator_ex('_', self.name, TTY_COLOR_RED)
print(TTY_COLOR_RED, end='')
print(self.result.stderr.decode("utf-8"), end='')
print(TTY_COLOR_CLEAR, end='')
if self.result.stdout:
if self.result.stderr:
print_separator_ex('-', "Captured stdout", '')
else:
print_separator_ex('_', self.name, '')
print(self.result.stdout.decode("utf-8", "replace"), end='')
if self.status == CTestCaseStatus.FAILED:
return f"{TTY_COLOR_RED}FAILED{TTY_COLOR_CLEAR} {self.name}"
elif self.status == CTestCaseStatus.SETUP_FAILED:
return f"{TTY_COLOR_RED}FAILED{TTY_COLOR_CLEAR} {self.name} - SetupError"
else:
return f"{TTY_COLOR_RED}FAILED{TTY_COLOR_CLEAR} {self.name} - TeardownError"
else:
raise ValueError(f"invalid test case status: {self.status}")
class CTestCaseFile:
__slots__ = ["path", "test_cases", "collect_result", "collect_error_info", "counter"]
def __init__(self, path: str) -> None:
self.path = path
self.test_cases: List[CTestCase] = []
self.collect_result = None
self.collect_error_info = None
def collect(self) -> int:
try:
result = run([self.path, "--collect"], capture_output=True)
except Exception:
self.collect_error_info = sys.exc_info()
return 0
if result.returncode != 0:
self.collect_result = result
return 0
for id, name in enumerate(result.stdout.decode("ascii").split()):
self.test_cases.append(CTestCase(id, name, self))
return len(self.test_cases)
def run(self, verbose: bool = False, capture: bool = True) -> CTestCaseCounter:
counter = CTestCaseCounter()
if not verbose:
print(self.path, end=' ')
for i in self.test_cases:
if verbose:
print(self.path, end='::')
i.run(verbose=verbose, capture=capture)
counter.update(i)
if not verbose:
print()
return counter
def report(self) -> List[str]:
if self.collect_result is not None:
print_separator_ex('_', f"ERROR collecting {self.path}", TTY_COLOR_RED_BOLD)
print(TTY_COLOR_RED, end='')
print(self.collect_result.stderr.decode())
print(TTY_COLOR_CLEAR, end='')
if self.collect_result.stdout:
print_separator_ex('-', "Captured stdout", '')
print(self.collect_result.stdout.decode())
return [f"ERROR {self.path} - CollectError ({self.collect_result.returncode})"]
elif self.collect_error_info is not None:
assert self.collect_error_info[0]
print_separator_ex('_', f"ERROR collecting {self.path}", TTY_COLOR_RED_BOLD)
print(TTY_COLOR_RED, end='')
print_exception(*self.collect_error_info)
print(TTY_COLOR_CLEAR, end='')
return [f"ERROR{TTY_COLOR_CLEAR} {self.path} - {self.collect_error_info[0].__name__}: {self.collect_error_info[1]}"]
return list(filter(None, (i.report() for i in self.test_cases)))
@property
def error(self) -> bool:
return self.collect_result is not None or self.collect_error_info is not None
def report_collect_error(start_time: float, *error_files: CTestCaseFile) -> NoReturn:
print_separator_ex('=', "ERRORS", '')
summary = []
for i in error_files:
summary.extend(i.report())
print_separator_ex('=', "short test summary info", TTY_COLOR_CYAN_BOLD)
for i in summary:
print(i)
print_separator_ex('!', f"Interrupted: {len(error_files)} error during collection", '')
cur_time = time()
print_separator_ex('=', f"{len(summary)} error in {cur_time - start_time:.2f}s", TTY_COLOR_RED_BOLD)
sys.exit(1)
def report_no_ran(start_time: float) -> NoReturn:
cur_time = time()
print_separator_ex('=', f"no tests ran in {cur_time - start_time:.2f}s", TTY_COLOR_YELLOW)
sys.exit()
def report(start_time: float, counter: CTestCaseCounter, *, show_capture: bool = True) -> NoReturn:
cur_time = time()
summary = []
if counter.error:
if show_capture:
print_separator_ex('=', "ERRORS", '')
for i in counter.error:
summary.append(i.report())
if counter.failed:
if show_capture:
print_separator_ex('=', "FAILURES", TTY_COLOR_RED)
for i in counter.failed:
summary.append(i.report())
if counter.skipped:
for i in counter.skipped:
summary.append(i.report())
if summary:
print_separator_ex('=', "short test summary info", TTY_COLOR_CYAN_BOLD)
for i in summary:
print(i)
if counter.status in [CTestCaseStatus.FAILED, CTestCaseStatus.ERROR]:
color = TTY_COLOR_RED_BOLD
elif counter.status == CTestCaseStatus.SKIPPED:
color = TTY_COLOR_YELLOW_BOLD
else:
color = TTY_COLOR_GREEN_BOLD
print_separator_ex('=', f"{counter} in {cur_time - start_time:.2f}s", color)
if counter.status in [CTestCaseStatus.FAILED, CTestCaseStatus.ERROR]:
sys.exit(1)
sys.exit()
if __name__ == "__main__":
parser = ArgumentParser("unittest", description="Run unit tests")
parser.add_argument("path", nargs='+', help="path to the test directory or file", default="./test_*")
parser.add_argument("-V", "--version", action="version", version=__version__)
parser.add_argument("-v", "--verbose", action="store_true", help="verbose output")
parser.add_argument("-s", "--no-capture", action="store_false", help="capture stdout and stderr")
namespace = parser.parse_args()
print_separator_ex("=", "test session starts", "")
print(f"platform: {sys.platform} -- Python {sys.version.split(' ')[0]}, c_unittest {__version__}")
print(f"rootdir: {os.getcwd()}")
files: List[CTestCaseFile] = []
total = 0
error_files = []
start_time = time()
paths = []
for p in namespace.path:
if '*' in p:
paths.extend(glob(p, recursive=True))
elif os.path.isfile(p):
paths.append(p)
for p in paths:
f = CTestCaseFile(p)
total += f.collect()
if f.error:
error_files.append(f)
else:
files.append(f)
if error_files:
print(f"collected {total} items / {len(error_files)} error\n")
report_collect_error(start_time, *error_files)
else:
print(f"collected {total} items\n")
if total == 0:
report_no_ran(start_time)
counter = CTestCaseCounter()
for f in files:
counter += f.run(verbose=namespace.verbose, capture=namespace.no_capture)
report(start_time, counter, show_capture=namespace.no_capture)

16
scripts/unittest.sh Executable file
View File

@ -0,0 +1,16 @@
width=$(stty size | awk '{print $2}')
# run add test_* in bin/*
for i in `ls bin/test_*`
do
echo
printf '%*s\n' "$width" '' | tr ' ' '*'
echo run test file: $i
$i $@
if [ $? -ne 0 ]; then
echo "\033[0m\033[1;31mtest $i failed\033[0m"
fi
echo
done
# printf '%*s\n' "$width" '' | tr ' ' '*'

248
src/logging.cpp Normal file
View File

@ -0,0 +1,248 @@
#include <time.h>
#include <string.h>
#include <chrono>
#include <iomanip>
#include <ctime>
#include <memory>
#include <sstream>
#include "logging/interface.h"
#include "logging/logger.hpp"
using namespace logging;
std::unique_ptr<Logger>& logging::_get_global_logger()
{
static std::unique_ptr<Logger> global_logger(new Logger(LogLevel::INFO));
return global_logger;
}
Logger& logging::get_global_logger()
{
return *_get_global_logger();
}
void logging::set_global_logger(std::unique_ptr<Logger>&& logger)
{
auto& global_logger = _get_global_logger();
global_logger->move_children_to(*logger);
global_logger = std::move(logger);
}
Logger::Logger(const std::string& name, Level level, Logger* parent)
: _parent(parent), _level(level)
{
if (parent && !parent->_name.empty()) {
_name = parent->_name + namesep + name;
} else {
_name = name;
}
}
void Logger::vlog(Level level, const char* fmt, va_list args)
{
if (level < this->level())
return;
va_list args2;
va_copy(args2, args);
int size = vsnprintf(nullptr, 0, fmt, args2);
va_end(args2);
char* buffer = new char[size + 1];
vsnprintf(buffer, size + 1, fmt, args);
try {
log_message(level, buffer);
} catch (...) {
delete[] buffer;
throw;
}
delete[] buffer;
}
void Logger::log_message(Level level, const std::string& msg)
{
if (level < this->level())
return;
auto record = Record(name(), level, msg);
log_record(record);
}
void Logger::log_record(const Record& record)
{
if (record.level < level()) {
return;
}
if (!_streams.empty()) {
for (auto& stream : _streams) {
write_record(*stream, record);
}
} else if (_parent) {
_parent->log_record(record);
} else {
write_record(default_stream, record);
}
}
void Logger::write_record(std::ostream& os, const Record& record)
{
auto now = record.time;
// 转换为时间_t 类型
std::time_t now_c = std::chrono::system_clock::to_time_t(now);
// 获取毫秒部分
auto milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(now.time_since_epoch()) % 1000;
// 格式化时间
std::stringstream ss;
std::tm *tm_now = std::localtime(&now_c);
ss << std::put_time(tm_now, "%Y-%m-%d %H:%M:%S")
<< ',' << std::setfill('0') << std::setw(3) << milliseconds.count();
if (record.name.empty()) {
os << ss.str() << " [" << record.level << "] " << record.msg << std::endl;
} else {
os << ss.str() << " [" << record.name << "] [" << record.level << "] " << record.msg << std::endl;
}
}
LoggerOStream Logger::operator[](Logger::Level level)
{
return LoggerOStream(*this, level);
}
LoggerOStream Logger::operator[](int level)
{
return (*this)[static_cast<Logger::Level>(level)];
}
void Logger::move_children_to(Logger& target_logger)
{
if (logger_cache.empty()) return;
for (auto it = logger_cache.begin(); it != logger_cache.end(); ) {
// Transfer ownership of the child logger to the target_logger
it->second->_parent = &target_logger;
target_logger.logger_cache.insert(std::move(*it));
it = logger_cache.erase(it); // Remove from current logger's cache
}
}
LoggerOStream::LoggerOStream(Logger& logger, Logger::Level level)
: std::ostream(&_streamBuf), _streamBuf(logger, level) {}
LoggerOStream::LoggerOStream(LoggerOStream&& loggerStream)
: std::ostream(std::move(loggerStream)), _streamBuf(std::move(loggerStream._streamBuf))
{}
LoggerStreamBuf::LoggerStreamBuf(Logger& logger, Logger::Level level)
: _logger(logger), _level(level) {}
int LoggerStreamBuf::overflow(int c) {
if (c != EOF) {
// 处理换行符
if (c == '\n') {
flush_line(); // 刷新当前行
} else {
_lineBuffer += static_cast<char>(c); // 将字符添加到行缓存
}
}
return c;
}
std::streamsize LoggerStreamBuf::xsputn(const char* s, std::streamsize n) {
std::streamsize lineStart = 0; // 记录行的起始位置
for (std::streamsize i = 0; i < n; ++i) {
if (s[i] == '\n') {
// 将当前行缓存中的内容添加到行缓存
_lineBuffer.append(s + lineStart, i - lineStart + 1);
flush_line(); // 刷新当前行
lineStart = i + 1; // 更新行的起始位置
}
}
// 处理剩余的字符
if (lineStart < n) {
_lineBuffer.append(s + lineStart, n - lineStart);
}
return n; // 返回写入的字符总数
}
int LoggerStreamBuf::sync() {
return 0; // 不需要同步
}
void LoggerStreamBuf::flush_line() {
if (!_lineBuffer.empty()) {
// 调用 Logger 的 log 方法
_logger.log_message(_level, _lineBuffer);
_lineBuffer.clear(); // 清空行缓存
}
}
Record::Record(const std::string& name, Logger::Level level, const std::string& msg)
: name(name), time(std::chrono::system_clock::now()), level(level), msg(msg)
{}
LogLevel logging::str2level(const char* level)
{
if (!level) return Logger::Level::INFO;
else if (strcasecmp(level, "debug") == 0) return Logger::Level::DEBUG;
else if (strcasecmp(level, "info") == 0) return Logger::Level::INFO;
else if (strcasecmp(level, "warn") == 0) return Logger::Level::WARN;
else if (strcasecmp(level, "error") == 0) return Logger::Level::ERROR;
else if (strcasecmp(level, "fatal") == 0) return Logger::Level::FATAL;
log_error("Unknown log level: %s", level);
return Logger::Level::UNKNOWN;
}
std::ostream& logging::operator<<(std::ostream& stream, const LogLevel level)
{
switch (level) {
case Logger::Level::DEBUG:
return stream << "DEBUG";
case Logger::Level::INFO:
return stream << "INFO";
case Logger::Level::WARN:
return stream << "WARN";
case Logger::Level::ERROR:
return stream << "ERROR";
case Logger::Level::FATAL:
return stream << "FATAL";
default:
return stream << "UNKNOWN";
}
}
void log_init(const char* level)
{
if (level == NULL || *level == '\0') {
level = "info";
} else if (strcasecmp(level, "env") == 0 || strcasecmp(level, "auto") == 0) {
level = getenv("LOG_LEVEL");
}
set_global_logger(std::unique_ptr<Logger>(new Logger(str2level(level))));
}
int log_level()
{
return static_cast<int>(_get_global_logger()->level());
}
void log_set_level(int level)
{
_get_global_logger()->set_level(static_cast<Logger::Level>(level));
}
void log_log(int level, const char* fmt, ...)
{
va_list args;
va_start(args, fmt);
log_vlog(level, fmt, args);
va_end(args);
}
void log_vlog(int level, const char* fmt, va_list args)
{
_get_global_logger()->vlog(static_cast<Logger::Level>(level), fmt, args);
}

17
tests/CMakeLists.txt Normal file
View File

@ -0,0 +1,17 @@
set(TESTS_DIR .)
set(TEST_COMMON_SOURCE "${TESTS_DIR}/c_testcase.cpp")
file(GLOB_RECURSE TEST_FILES "${TESTS_DIR}/test_*.cpp")
foreach(TEST_FILE ${TEST_FILES})
get_filename_component(TEST_NAME ${TEST_FILE} NAME_WE)
add_executable(${TEST_NAME} ${TEST_FILE} ${TEST_COMMON_SOURCE} ${SRC_LIST})
list(APPEND TEST_EXECUTABLES "${EXECUTABLE_OUTPUT_PATH}/${TEST_NAME}")
endforeach()
# message(STATUS "Test files: ${TEST_FILES}")
# message(STATUS "Test executables: ${TEST_EXECUTABLES}")
add_custom_target(test
COMMAND ${CMAKE_SOURCE_DIR}/scripts/unittest.py ${TEST_EXECUTABLES}
DEPENDS ${TEST_EXECUTABLES}
)

289
tests/c_testcase.cpp Normal file
View File

@ -0,0 +1,289 @@
#include <signal.h>
#include <setjmp.h>
#include <iostream>
#include <string>
#include <sstream>
#include <atomic>
#include <thread>
#include "c_testcase.h"
#ifdef _WIN32
#include <windows.h>
#else
#include <sys/ioctl.h>
#include <unistd.h>
#endif
using namespace std;
#define TEST_CASE_STATUS_PASSED 0
#define TEST_CASE_STATUS_SKIPPED 32
#define TEST_CASE_STATUS_FAILED 16
#define TEST_CASE_STATUS_SETUP_FAILED 17
#define TEST_CASE_STATUS_TEARDOWN_FAILED 18
#define MAX_TESTCASE 64
#define DEFAULT_TTY_COL_SIZE 80
struct TestCase {
const char* name;
test_case func;
};
static TestCase test_cases[MAX_TESTCASE];
static int test_case_total = 0;
static interactive_func interactive = NULL;
static context_func setup = NULL;
static context_func teardown = NULL;
static atomic_bool testcase_running;
static jmp_buf testcase_env;
static std::thread::id testcase_thread_id;
static int testcase_exit_code = 0;
int _add_test_case(const char* name, test_case func) {
if (test_case_total == MAX_TESTCASE) {
fprintf(stderr, "too many test case\n");
exit(1);
}
TestCase* tc = &(test_cases[test_case_total++]);
tc->name = name;
tc->func = func;
return 0;
}
int _set_interactive(interactive_func func) {
interactive = func;
return 0;
}
int _set_setup(context_func func) {
setup = func;
return 0;
}
int _set_teardown(context_func func) {
teardown = func;
return 0;
}
[[noreturn]] void test_case_abort(int exit_code) {
auto tid = std::this_thread::get_id();
if (!testcase_running.load(std::memory_order_acquire) || tid != testcase_thread_id) {
exit(exit_code);
}
testcase_exit_code = exit_code;
longjmp(testcase_env, 1);
}
static inline int get_tty_col(int fd) {
#ifdef _WIN32
// Windows
HANDLE hConsole = (HANDLE)_get_osfhandle(fd);
if (hConsole == INVALID_HANDLE_VALUE) {
return DEFAULT_TTY_COL_SIZE; // 错误处理
}
CONSOLE_SCREEN_BUFFER_INFO csbi;
if (!GetConsoleScreenBufferInfo(hConsole, &csbi)) {
return DEFAULT_TTY_COL_SIZE; // 错误处理
}
return csbi.srWindow.Right - csbi.srWindow.Left + 1; // 计算列数
#else
// POSIX (Linux, macOS, etc.)
struct winsize size;
if (ioctl(fd, TIOCGWINSZ, &size) == -1) {
return DEFAULT_TTY_COL_SIZE; // 错误处理
}
return size.ws_col; // 返回列数
#endif
}
static __inline void print_separator(char lc) {
int size = get_tty_col(STDOUT_FILENO);
for (int i = 0; i < size; ++i) {
putchar(lc);
}
}
static __inline void print_separator_ex(char lc, const char* str, const char* color) {
int size = get_tty_col(STDOUT_FILENO);
int len = strlen(str);
printf("\033[0m%s", color); // 设置颜色
if(len > size) {
printf("%s\n", str);
} else {
int pad = (size - len - 2) / 2;
for(int i = 0; i < pad; i++) {
putchar(lc);
}
printf(" %s ", str);
for(int i = 0; i < pad; i++) {
putchar(lc);
}
if((size - len) % 2) putchar(lc);
putchar('\n');
}
printf("\033[0m"); // 重置颜色
}
static int collect_testcase() {
for (int i = 0; i < test_case_total; ++i) {
puts(test_cases[i].name);
}
return 0;
}
static TestCase* get_test_case(const char* name) {
TestCase* tc = NULL;
if (*name >= '0' && *name <= '9') {
int id = atoi(name);
if (id >= 0 && id < test_case_total) {
tc = &(test_cases[id]);
}
} else {
for (int i = 0; i < test_case_total; ++i) {
if (strcmp(test_cases[i].name, name) == 0) {
tc = &(test_cases[i]);
break;
}
}
}
return tc;
}
static int run_test_case_func(test_case func) {
bool running = false;
if (!testcase_running.compare_exchange_strong(running, true, std::memory_order_acq_rel)) {
cerr << "test case is running" << endl;
return 1;
}
if (setjmp(testcase_env)) {
return testcase_exit_code;
}
testcase_thread_id = std::this_thread::get_id();
int ret = func();
running = true;
if (!testcase_running.compare_exchange_strong(running, false, std::memory_order_acq_rel)) {
cerr << "test case is not running" << endl;
return 1;
}
return ret;
}
static int unittest_testcase(TestCase* tc) {
assert(tc != NULL);
if (setup && setup(tc->name)) {
return TEST_CASE_STATUS_SETUP_FAILED;
}
int ret = run_test_case_func(tc->func);
if (teardown && teardown(tc->name)) {
return TEST_CASE_STATUS_TEARDOWN_FAILED;
}
if (ret == SKIP_RET_NUMBER) {
return TEST_CASE_STATUS_SKIPPED;
} else if (ret != 0) {
return TEST_CASE_STATUS_FAILED;
} else {
return 0;
}
}
int main(int argc, const char** argv) {
for (int i = 1; i < argc; ++i) {
const char* arg = argv[i];
if (strcmp(arg, "-h") == 0 || strcmp(arg, "--help") == 0) {
printf("usage: %s [-c] [-u NAME] [-h]\n", argv[0]);
printf("\nOptions:\n");
printf(" -i, --interactive: run in interactive mode.\n");
printf(" -c, --collect: list all test cases.\n");
printf(" -u, --unittest: run a single test case.\n");
printf(" -h, --help: show the help text.\n");
return 0;
}
if (strcmp(arg, "-i") == 0 || strcmp(arg, "--interactive") == 0) {
if (interactive)
return interactive(argc, argv);
else {
cout << "interactive mode is not supported" << endl;
return 1;
}
}
else if (strcmp(arg, "-c") == 0 || strcmp(arg, "--collect") == 0) {
return collect_testcase();
}
else if (strcmp(arg, "-u") == 0 || strcmp(arg, "--unittest") == 0) {
const char* name = NULL;
if (i + 1 < argc) {
name = argv[++i];
} else {
cout << "--unittest require an argument" << endl;
return 2;
}
TestCase* tc = get_test_case(name);
if (tc == NULL) {
cout << "test case " << name << " not found" << endl;
return 1;
}
return unittest_testcase(tc);
} else {
if (interactive)
return interactive(argc, argv);
else {
cout << "unknown argument '" << arg << "'" << endl;
return 1;
}
}
}
int total = test_case_total;
int passed = 0;
int failed = 0;
int skipped = 0;
for (int i = 0; i < total; ++i) {
print_separator('-');
TestCase* tc = &test_cases[i];
cout << "running " << tc->name << endl;
switch (unittest_testcase(tc)) {
case 0:
cout << "\033[0m\033[1;32mtest case \"" << tc->name << "\" passed\033[0m" << endl;
passed++;
break;
case TEST_CASE_STATUS_SKIPPED:
cout << "\033[0m\033[1;33mtest case \"" << tc->name << "\" skipped\033[0m" << endl;
skipped++;
break;
case TEST_CASE_STATUS_SETUP_FAILED:
cout << "\033[0m\033[1;31msetup \"" << tc->name << "\" failed\033[0m" << endl;
failed++;
break;
case TEST_CASE_STATUS_TEARDOWN_FAILED:
cout << "\033[0m\033[1;31mteardown \"" << tc->name << "\" failed\033[0m" << endl;
failed++;
break;
default:
cout << "\033[0m\033[1;31mtest case \"" << tc->name << "\" failed\033[0m" << endl;
failed++;
break;
}
}
stringstream ss;
ss << "total: " << total << ", passed: " << passed << ", failed: " << failed << ", skipped: " << skipped;
string sum = ss.str();
const char* color;
if (failed)
color = "\033[1;31m";
else if (skipped)
color = "\033[1;33m";
else
color = "\033[1;32m";
print_separator_ex('=', sum.c_str(), color);
return 0;
}

203
tests/c_testcase.h Normal file
View File

@ -0,0 +1,203 @@
#ifndef _INCLUDE_C_TESTCASE_
#define _INCLUDE_C_TESTCASE_
#include <assert.h>
#include <stdio.h>
#include <string.h>
#ifdef __cplusplus
#include <iostream>
extern "C"{
#endif
#define SKIP_RET_NUMBER (*((const int*)"SKIP"))
#define SKIP_TEST do { return SKIP_RET_NUMBER; } while (0)
#define END_TEST do { return 0; } while (0)
#define TEST_CASE(name) \
int name (); \
int _tc_tmp_ ## name = _add_test_case(# name, name); \
int name ()
#define INTERACTIVE \
int interactive_impl(int argc, const char** argv); \
int _tc_s_tmp_interactive = _set_interactive(interactive_impl); \
int interactive_impl(int argc, const char** argv)
#define SETUP \
int setup_impl(const char* test_case_name); \
int _tc_s_tmp_setup = _set_setup(setup_impl); \
int setup_impl(const char* test_case_name)
#define TEARDOWN \
int teardown_impl(const char* test_case_name); \
int _tc_s_tmp_teardown = _set_teardown(teardown_impl); \
int teardown_impl(const char* test_case_name)
#undef assert
#ifdef __cplusplus
#define assert(expr) do { \
if (!(static_cast <bool> (expr))) { \
::std::cout << "assert failed: " #expr "\n" << ::std::endl; \
::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\
test_case_abort(1); \
} \
} while (0)
#define assert_eq(expr1, expr2) do { \
if ((expr1) != (expr2)) { \
::std::cout << "assert failed: " #expr1 " == " #expr2 "\n" << ::std::endl; \
::std::cout << "\t#0: " << (expr1) << ::std::endl; \
::std::cout << "\t#1: " << (expr2) << ::std::endl; \
::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\
test_case_abort(1); \
} \
} while (0)
#define assert_ne(expr1, expr2) do { \
if ((expr1) == (expr2)) { \
::std::cout << "assert failed: " #expr1 " != " #expr2 "\n" << ::std::endl; \
::std::cout << "\t#0: " << (expr1) << ::std::endl; \
::std::cout << "\t#1: " << (expr2) << ::std::endl; \
::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\
test_case_abort(1); \
} \
} while (0)
#define assert_gt(expr1, expr2) do { \
if ((expr1) <= (expr2)) { \
::std::cout << "assert failed: " #expr1 " > " #expr2 "\n" << ::std::endl; \
::std::cout << "\t#0: " << (expr1) << ::std::endl; \
::std::cout << "\t#1: " << (expr2) << ::std::endl; \
::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\
test_case_abort(1); \
} \
} while (0)
#define assert_ls(expr1, expr2) do { \
if ((expr1) >= (expr2)) { \
::std::cout << "assert failed: " #expr1 " < " #expr2 "\n" << ::std::endl; \
::std::cout << "\t#0: " << (expr1) << ::std::endl; \
::std::cout << "\t#1: " << (expr2) << ::std::endl; \
::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\
test_case_abort(1); \
} \
} while (0)
#define assert_ge(expr1, expr2) do { \
if ((expr1) < (expr2)) { \
::std::cout << "assert failed: " #expr1 " >= " #expr2 "\n" << ::std::endl; \
::std::cout << "\t#0: " << (expr1) << ::std::endl; \
::std::cout << "\t#1: " << (expr2) << ::std::endl; \
::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\
test_case_abort(1); \
} \
} while (0)
#define assert_le(expr1, expr2) do { \
if ((expr1) > (expr2)) { \
::std::cout << "assert failed: " #expr1 " <= " #expr2 "\n" << ::std::endl; \
::std::cout << "\t#0: " << (expr1) << ::std::endl; \
::std::cout << "\t#1: " << (expr2) << ::std::endl; \
::std::cout << "file: \"" __FILE__ "\", line " << __LINE__ << ", in " << __ASSERT_FUNCTION << "\n" << ::std::endl;\
test_case_abort(1); \
} \
} while (0)
#else
#define assert(expr) do { \
if (!((bool)(expr))) { \
printf("assert failed: " #expr "\n"); \
printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\
test_case_abort(1); \
} \
} while (0)
#endif
#define assert_i32_eq(expr1, expr2) do { \
if ((int32_t)(expr1) != (uint32_t)(expr2)) { \
printf("assert failed: " #expr1 " == " #expr2 "\n"); \
printf("\t\t%d\t!=\t%d\n", (int32_t)(expr1), (int32_t)(expr2)); \
printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\
test_case_abort(1); \
} \
} while (0)
#define assert_i32_ne(expr1, expr2) do { \
if ((int32_t)(expr1) == (int32_t)(expr2)) { \
printf("assert failed: " #expr1 " != " #expr2 "\n"); \
printf("\t\t%d\t==\t%d\n", (int32_t)(expr1), (int32_t)(expr2)); \
printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\
test_case_abort(1); \
} \
} while (0)
#define assert_i64_eq(expr1, expr2) do { \
if ((expr1) != (expr2)) { \
printf("assert failed: " #expr1 " == " #expr2 "\n"); \
printf("\t\t%lld\t!=\t%lld\n", (int64_t)(expr1), (int64_t)(expr2)); \
printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\
test_case_abort(1); \
} \
} while (0)
#define assert_i64_ne(expr1, expr2) do { \
if ((expr1) == (expr2)) { \
printf("assert failed: " #expr1 " != " #expr2 "\n"); \
printf("\t\t%lld\t==\t%lld\n", (int64_t)(expr1), (int64_t)(expr2)); \
printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\
test_case_abort(1); \
} \
} while (0)
#define assert_str_eq(expr1, expr2) do { \
if (strcmp((expr1), (expr2)) != 0) { \
printf("assert failed: " #expr1 " == " #expr2 "\n"); \
printf("\t#0: %s\n", (expr1)); \
printf("\t#1: %s\n", (expr2)); \
printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\
test_case_abort(1); \
} \
} while (0)
#define assert_str_ne(expr1, expr2) do { \
if (strcmp((expr1), (expr2)) == 0) { \
printf("assert failed: " #expr1 " != " #expr2 "\n"); \
printf("\t#0: %s\n", (expr1)); \
printf("\t#1: %s\n", (expr2)); \
printf("file: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\
test_case_abort(1); \
} \
} while (0)
#define assert_mem_eq(expr1, expr2, size) do { \
if (memcmp((expr1), (expr2), (size)) != 0) { \
printf("assertion failed: %s == %s\n", #expr1, #expr2); \
printf("\t#0: "); \
for (size_t i = 0; i < (size); i++) { \
printf("%02X", ((uint8_t*)(expr1))[i]); \
} \
printf("\n\t#1: "); \
for (size_t i = 0; i < (size); i++) { \
printf("%02X", ((uint8_t*)(expr2))[i]); \
} \
printf("\nfile: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\
test_case_abort(1); \
} \
} while (0)
#define assert_mem_ne(expr1, expr2, size) do { \
if (memcmp((expr1), (expr2), (size)) == 0) { \
printf("assertion failed: %s != %s\n", #expr1, #expr2); \
printf("\t#0: "); \
for (int i = 0; i < (size); i++) { \
printf("%02X", ((uint8_t*)(expr1))[i]); \
} \
printf("\n\t#1: "); \
for (int i = 0; i < (size); i++) { \
printf("%02X", ((uint8_t*)(expr2))[i]); \
} \
printf("\nfile: \"%s\", line %d, in %s\n", __FILE__, __LINE__, __ASSERT_FUNCTION);\
test_case_abort(1); \
} \
} while (0)
typedef int (*test_case)();
typedef int (*interactive_func)(int argc, const char** argv);
typedef int (*context_func)(const char* name);
int _add_test_case(const char* name, test_case func);
int _set_interactive(interactive_func func);
int _set_setup(context_func func);
int _set_teardown(context_func func);
[[noreturn]] void test_case_abort(int exit_code);
#ifdef __cplusplus
}
#endif
#endif

40
tests/test_base.cpp Normal file
View File

@ -0,0 +1,40 @@
#include "transport/base.hpp"
#include "transport/protocol.hpp"
#include "c_testcase.h"
using namespace transport;
class TestTransport: public BaseTransport<Protocol> {
protected:
void send_backend() override {
while (!is_closed) {
DataPair pair = send_que.Pop();
recv_que.Push(std::move(pair));
}
}
void receive_backend() override {}
};
const int timeout = 3;
TEST_CASE(test_init) {
TestTransport t;
END_TEST;
}
TEST_CASE(test_transport) {
TestTransport t;
assert(!t.closed());
t.open();
t.send(std::vector<uint8_t>(10, 1));
auto data_pair = t.receive(std::chrono::seconds(timeout));
assert_eq(data_pair.first.size(), 10);
assert_eq(data_pair.first[1], 1);
assert(!data_pair.second);
t.close();
assert(t.closed());
END_TEST;
}

40
tests/test_datagram.cpp Normal file
View File

@ -0,0 +1,40 @@
#include "transport/udp.hpp"
#include "transport/protocol.hpp"
#include "c_testcase.h"
using namespace transport;
TEST_CASE(test_init) {
DatagramTransport<Protocol> transport;
transport.open();
END_TEST;
}
TEST_CASE(test_send_recv) {
DatagramTransport<Protocol> transport_server;
transport_server.open();
transport_server.bind("127.0.0.1", 12345);
DatagramTransport<Protocol> transport_client;
transport_client.open();
transport_client.connect("127.0.0.1", 12345);
transport_client.send(std::vector<uint8_t>{0x01, 0x02, 0x03, 0x04});
auto [frame, token] = transport_server.receive(std::chrono::seconds(3));
assert_eq(frame.size(), 4);
assert(token);
assert_eq(frame[0], 0x01);
assert_eq(frame[1], 0x02);
assert_eq(frame[2], 0x03);
assert_eq(frame[3], 0x04);
transport_server.send(std::vector<uint8_t>{0x04, 0x03, 0x02, 0x01}, token);
auto [frame2, token2] = transport_client.receive(std::chrono::seconds(3));
assert_eq(frame2.size(), 4);
assert(token2);
assert_eq(frame2[0], 0x04);
assert_eq(frame2[1], 0x03);
assert_eq(frame2[2], 0x02);
assert_eq(frame2[3], 0x01);
END_TEST;
}

View File

@ -0,0 +1,66 @@
#include <pty.h>
#include <fcntl.h>
#include <thread>
#include "c_testcase.h"
#include "transport/serial_port.hpp"
#include "transport/protocol.hpp"
using namespace transport;
int master_fd = -1;
int slave_fd = -1;
const int timeout = 3;
SETUP {
if (openpty(&master_fd, &slave_fd, NULL, NULL, NULL) < 0) {
throw std::runtime_error("openpty failed");
}
fcntl(master_fd, F_SETFL, fcntl(master_fd, F_GETFL) | O_NONBLOCK);
fcntl(slave_fd, F_SETFL, fcntl(slave_fd, F_GETFL) | O_NONBLOCK);
return 0;
}
TEARDOWN {
close(master_fd);
close(slave_fd);
master_fd = -1;
slave_fd = -1;
return 0;
}
TEST_CASE(test_pty) {
char buffer[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
ssize_t ret;
ret = write(slave_fd, buffer, 10);
assert_eq(ret, 10);
ret = read(master_fd, buffer, 10);
assert_eq(ret, 10);
for (int i = 0; i < 10; i++) {
assert_eq(buffer[i], i);
}
END_TEST;
}
TEST_CASE(test_serial_port) {
SerialPortTransport<Protocol> t1(master_fd);
SerialPortTransport<Protocol> t2(slave_fd);
assert(!t1.closed());
assert(!t2.closed());
std::pair<Protocol::FrameType, std::shared_ptr<TransportToken>> data_pair;
std::thread ([&] {
t2.open();
t2.send(std::vector<uint8_t>(10, 2));
}).detach();
t1.open();
data_pair = t1.receive(std::chrono::seconds(timeout));
assert_eq(data_pair.first.size(), 10);
assert_eq(data_pair.first[0], 2);
assert(data_pair.second);
assert_eq(data_pair.second->transport<Protocol>(), &t1);
t1.close();
assert(t1.closed());
END_TEST;
}

View File

@ -0,0 +1,38 @@
#include "transport/unix_udp.hpp"
#include "transport/protocol.hpp"
#include "c_testcase.h"
using namespace transport;
TEST_CASE(test_init) {
UnixDatagramTransport<Protocol> transport;
transport.open();
END_TEST;
}
TEST_CASE(test_send_recv) {
UnixDatagramTransport<Protocol> transport_server("/tmp/vxup_test.sock", "");
UnixDatagramTransport<Protocol> transport_client("/tmp/vxup_test1.sock", "/tmp/vxup_test.sock");
transport_server.open();
transport_client.open();
transport_client.send(std::vector<uint8_t>{0x01, 0x02, 0x03, 0x04});
auto [frame, token] = transport_server.receive(std::chrono::seconds(3));
assert_eq(frame.size(), 4);
assert(token);
assert_eq(frame[0], 0x01);
assert_eq(frame[1], 0x02);
assert_eq(frame[2], 0x03);
assert_eq(frame[3], 0x04);
transport_server.send(std::vector<uint8_t>{0x04, 0x03, 0x02, 0x01}, token);
auto [frame2, token2] = transport_client.receive(std::chrono::seconds(3));
assert_eq(frame2.size(), 4);
assert(token2);
assert_eq(frame2[0], 0x04);
assert_eq(frame2[1], 0x03);
assert_eq(frame2[2], 0x02);
assert_eq(frame2[3], 0x01);
END_TEST;
END_TEST;
}