我是一个编程新手,最近了解到协程这个概念,协程可以理解为用户级线程,在用户空间实现调度,在处理异步io时,可以在子程序中让出cpu交给其他协程,等事件完成再切换到子程序中。当然回调也可以实现,但是使用协程会使程序“看起来”是顺序执行的。
我利用linux系统函数getcontext,makecontext,swapcontext来实现协程之间的切换。
getcontext(ucontext_t*)初始化上下文,makecontext(ucontext_t*,void(*)(void),int args)绑定切换到该上下文时的执行函数和函数的参数;swapcontext(ucontext_t* ouc,ucontext_t* uc);在ouc保存当前上下文,并切换到uc
typedef struct ucontext {
struct ucontext *uc_link; //指向当当前上下文执行完毕后要切换的上下文
sigset_t uc_sigmask;
stack_t uc_stack; //当前上下文使用的堆栈
mcontext_t uc_mcontext;
…
} ucontext_t;
我利用这些实现了一个简单的协程,在一个线程中实现多个协程之间的切换,使用epoll等待io事件或定时器事件发生后对协程唤醒,当没有事件发生时进行轮询顺序切换到多个协程。下面是线程工作函数:
void myuthread::thread_work(void* pvoid) {
myuthread* mut = (myuthread*)pvoid;
struct epoll_event events[1024];
while (!mut->shut_down_) {
int ret = epoll_wait(mut->epoll_fd_, events, 1024, 4);
if (ret > 0) {//有事件发生
for (int i = 0; i < ret; i ) {
ucontextrevent* ctxex = (ucontextrevent*)events[i].data.ptr;
auto ctx = ctxex->ctx_;
if (ctx->get_timer() && ctxex->fd_ == ctx->get_timer()->get_timerfd()) {
ctxex->revents_ = uthreadtimerout;
}else {
ctxex->revents_ = events[i].events;
}
ctx->resume();
if (ctx->status_ == ucontext::finished) {
mut->remove_uthread(ctx->ctx_index_);
}
}
}else if (ret == -1) {//异常
if (errno == eintr) {
continue;
}
else {
perror("epoll_wait:");
abort();
}
}else {//epoll_wait超时
mut->curr_running_index_;
if (mut->curr_running_index_ >= mut->curr_max_count_) {
mut->curr_running_index_ = -1;
continue;
}
auto ctx = mut->ctx_items_[mut->curr_running_index_];
if (ctx == nullptr || !ctx->register_fd_set.empty()) {
//如果该ctx协程正在等待某个异步事件,则不对他唤醒
continue;
}
ctx->resume();
if (ctx->status_ == ucontext::finished) {
mut->remove_uthread(ctx->ctx_index_);
}
}
}
}
使用:
#include
#include
#include
#include "myuthread.h"
#include "mytimer.h"
using namespace std;
using namespace placeholders;
void test1(ucontext*& ctx, int num) {
cout << "enter test1:" << num << endl;
mytimer timer;
timer.set_once(2, 0);//设置定时器2s
ucontextrevent revents(ctx,timer.get_timerfd());
ctx->register_event(timer.get_timerfd(),epollin,&revents);//注册定时器事件
ctx->attach_timer(&timer);
cout << "start yield" << endl;
ctx->yield(); //让出cpu,等待定时器事件
if (revents.revents_ == uthreadtimerout) {
cout << "uthreadtimerout" << endl;
}
ctx->remove_event(timer.get_timerfd());
cout << "finished test1" << endl;
}
void test2(ucontext*& ctx) {
cout << "test2 yield" << endl;
ctx->yield();
cout << "finished test2" << endl;
}
void finished() {
cout << "test1 callback" << endl;
}
int main()
{
myuthread mut(2, 8192);//参数:协程最大数量,协程栈大小
mut.add_task(bind(test1, _1, 99999),finished);//添加任务,参数:工作函数,回调函数
for (int i = 0; i < 10; i ) {
mut.add_task(test2);
}
mut.join();
return 0;
}
下面贴出源代码:我还是个小白,把代码贴出来,如果有什么错误希望大家评论告诉我,谢谢(ㅎ-ㅎ;)
ucontext.h
#include
#include
#include
#include
using namespace std;
/*
epollin:1
epollout:4
epollrdhup:8192
epollpri:2
epollerr:8
epollhup:16
定时器超时标志
*/
#define uthreadtimerout 3
class ucontext;
typedef function ucontextfunc;
typedef function callback;
class mytimer;
struct ucontextrevent {
ucontext* ctx_;
int fd_;
int revents_;
ucontextrevent(ucontext* ctx, int fd,int revents = 0) :
ctx_(ctx), fd_(fd), revents_(revents){}
};
class ucontext {
public:
ucontext(int index, ucontext_t* main_ctx, int stack_size, int epoll_fd);
~ucontext();
void set_func(ucontextfunc func, callback callback = 0);
void make();
void resume();
void yield();
void attach_timer(mytimer*);
void register_event(int fd, int events, ucontextrevent* revents);
void remove_event(int fd);
mytimer* get_timer();
private:
static void work_func(uint32_t low32, uint32_t high32);
public:
int ctx_index_; //当前对象在myuthread对象的数组中保存的索引
ucontext_t* main_ctx_; //线程主逻辑上下文
ucontext_t* ctx_; //当前上下文
char* raw_stack_; //上下文使用栈空间
char* stack_; //栈空间(安全保护)
ucontextfunc func_; //用户任务
callback callback_; //任务回调
int stack_size_; //栈大小
enum { ready = 0, running, suspend, finished };
int status_; //当前协程状态
int epoll_fd_; //epoll fd
mytimer* mytimer_;
unordered_set register_fd_set;
};
myuthread.h
#include
#include
#include
#include
#include
#include
#include "sys/epoll.h"
using namespace std;
#include "ucontext.h"
class myuthread {
public:
myuthread(int max_uthread_count, int stack_size);
~myuthread();
void add_task(ucontextfunc func, callback callback = 0);
void join();
void destory();
private:
void remove_uthread(int index);
int get_stack_size(int stack_size);
static void thread_work(void* pvoid);
private:
int max_uthread_count_; //最大协程数量
int stack_size_; //栈大小
volatile int curr_running_index_; //当前执行协程索引
volatile int curr_max_count_; //当前最大协程数量
volatile int idle_count_; //可用协程数量
bool shut_down_; //是否退出
vector ctx_items_; //调度的协程列表
queue ctx_ready_queue_; //就绪队列,等待其它协程退出后,会被添加到ctx_items_中
thread* thread_; //线程
ucontext_t main_ctx_; //主上下文
mutex mutex_; //锁queue
mutex mutex_join_;
condition_variable cv_; //使用锁和条件变量,阻塞等待协程全部执行完毕
int epoll_fd_; //epoll fd
};
mytimer.h
#pragma once
#include
#include
using namespace std;
class ucontext;
class mytimer
{
public:
mytimer();
~mytimer();
void set_once(int seconds, int millseconds);
void set_cycle(int seconds, int millseconds, int intervalseconds, int intervalmillseconds);
int get_timerfd();
void stop();
//获取距离时间到期还有多少纳秒
int get_time();
private:
void set_time(int seconds, int millseconds, int intervalseconds, int intervalmillseconds);
private:
int timerfd_;
struct itimerspec* timespec_;
};
ucontext.cpp
#include "ucontext.h"
#include
#include
#include
#include
#include "mytimer.h"
ucontext::ucontext(int index, ucontext_t* main_ctx, int stack_size, int epoll_fd)
:ctx_index_(index), main_ctx_(main_ctx), ctx_(nullptr),
raw_stack_(nullptr), stack_(nullptr), func_(nullptr), callback_(nullptr),
stack_size_(stack_size),
status_(ready), epoll_fd_(epoll_fd),
mytimer_(nullptr){
//创建协程私有栈
auto page_size = getpagesize();
raw_stack_ = (char*)mmap(nullptr, stack_size_ page_size * 2,
prot_read | prot_write, map_anonymous | map_private, -1, 0);
assert(raw_stack_ != nullptr);
stack_ = raw_stack_ page_size;
assert(mprotect(raw_stack_, page_size, prot_none) == 0);
assert(mprotect(raw_stack_ stack_size_ page_size, page_size, prot_none) == 0);
ctx_ = new ucontext_t;
ctx_->uc_flags = 0;
ctx_->uc_link = main_ctx;
ctx_->uc_stack.ss_sp = stack_;
ctx_->uc_stack.ss_size = stack_size_;
getcontext(ctx_);
}
ucontext::~ucontext() {
delete ctx_;
munmap(raw_stack_, stack_size_ getpagesize() * 2);
}
void ucontext::set_func(ucontextfunc func, callback callback) {
func_ = func;
callback_ = callback;
}
void ucontext::make() {
auto ptr = (uintptr_t)this;
makecontext(ctx_, (void(*)(void))work_func, 2, (uint32_t)ptr, (uint32_t)(ptr >> 32));
}
void ucontext::resume() {
status_ = running;
swapcontext(main_ctx_, ctx_);
}
void ucontext::yield() {
status_ = suspend;
swapcontext(ctx_, main_ctx_);
}
void ucontext::attach_timer(mytimer* timer) {
mytimer_ = timer;
}
mytimer* ucontext::get_timer() {
return mytimer_;
}
void ucontext::register_event(int fd, int events, ucontextrevent* revents) {
epoll_event ev;
ev.data.fd = fd;
ev.events = events;
ev.data.ptr = revents;
epoll_ctl(epoll_fd_, epoll_ctl_add, fd, &ev);
register_fd_set.insert(fd);
}
void ucontext::remove_event(int fd) {
epoll_ctl(epoll_fd_, epoll_ctl_del, fd, nullptr);
if (mytimer_ && mytimer_->get_timerfd() == fd) {
this->attach_timer(nullptr);
}
register_fd_set.erase(fd);
}
void ucontext::work_func(uint32_t low32, uint32_t high32) {
uintptr_t ptr = (uintptr_t)low32 | ((uintptr_t)high32 << 32);
ucontext * uc = (ucontext*)ptr;
if (uc->func_) {
uc->func_(uc);
if (uc->callback_) {
uc->callback_();
}
}
uc->status_ = finished;
}
myuthread.cpp
#include "myuthread.h"
#include "mytimer.h"
#include
#include
#include
#include
#include
#include
using namespace std;
myuthread::myuthread(int max_uthread_count, int stack_size) :
max_uthread_count_(max_uthread_count),
stack_size_(get_stack_size(stack_size)),
curr_running_index_(-1),
curr_max_count_(0),
idle_count_(max_uthread_count),
shut_down_(false), thread_(0), epoll_fd_(0)
{
memset(&main_ctx_, 0, sizeof(main_ctx_));
ctx_items_.resize(max_uthread_count_, 0);
epoll_fd_ = epoll_create(100);
thread_ = new thread(bind(thread_work, (void*)this));
}
myuthread::~myuthread() {
ctx_items_.clear();
while (!ctx_ready_queue_.empty()) {
ctx_ready_queue_.pop();
}
delete thread_;
}
void myuthread::add_task(ucontextfunc func, callback callback) {
auto ctx = new ucontext(-1, &main_ctx_, stack_size_,epoll_fd_);
ctx->set_func(func, callback);
if (curr_max_count_ < max_uthread_count_) {
ctx->ctx_index_ = curr_max_count_;
ctx->make();
ctx_items_[curr_max_count_] = ctx;
curr_max_count_;
--idle_count_;
}
else {
if (idle_count_ > 0) {
auto index = -1;
for (int i = 0, j = curr_max_count_ - 1; i <= j; i , j--) {
if (ctx_items_[i] == nullptr) {
index = i;
break;
}
if (ctx_items_[j] == nullptr) {
index = j;
break;
}
}
ctx->ctx_index_ = index;
ctx->make();
ctx_items_[index] = ctx;
--idle_count_;
}
else {
lock_guard lock(mutex_);
ctx_ready_queue_.push(ctx);
}
}
}
void myuthread::join() {
while (idle_count_ < max_uthread_count_ || !ctx_ready_queue_.empty()) {
unique_lock lock(mutex_join_);
cv_.wait(lock);
}
this->shut_down_ = true;
this->thread_->join();
}
void myuthread::destory() {
this->shut_down_ = true;
this->thread_->join();
}
void myuthread::remove_uthread(int index) {
if (index >= 0 && index < curr_max_count_) {
delete ctx_items_[index];
ctx_items_[index] = nullptr;
idle_count_;
cv_.notify_all();
}
ucontext* ctx = nullptr;
{
lock_guard lock(mutex_);
if (!ctx_ready_queue_.empty()) {
ctx = ctx_ready_queue_.front();
ctx_ready_queue_.pop();
}
}
if (ctx) {
ctx->ctx_index_ = index;
ctx->make();
ctx_items_[index] = ctx;
--idle_count_;
}
}
int myuthread::get_stack_size(int stack_size) {
auto page_size = getpagesize();
if (stack_size < page_size) {
return page_size;
}
int page_count = stack_size / page_size;
if ((stack_size % page_size) > 0) {
page_count;
}
return page_count * page_size;
}
void myuthread::thread_work(void* pvoid) {
myuthread* mut = (myuthread*)pvoid;
struct epoll_event events[1024];
while (!mut->shut_down_) {
int ret = epoll_wait(mut->epoll_fd_, events, 1024, 4);
if (ret > 0) {//有事件发生
for (int i = 0; i < ret; i ) {
ucontextrevent* ctxex = (ucontextrevent*)events[i].data.ptr;
auto ctx = ctxex->ctx_;
if (ctx->get_timer() && ctxex->fd_ == ctx->get_timer()->get_timerfd()) {
ctxex->revents_ = uthreadtimerout;
}else {
ctxex->revents_ = events[i].events;
}
ctx->resume();
if (ctx->status_ == ucontext::finished) {
mut->remove_uthread(ctx->ctx_index_);
}
}
}else if (ret == -1) {//异常
if (errno == eintr) {
continue;
}
else {
perror("epoll_wait:");
abort();
}
}else {//epoll_wait超时
mut->curr_running_index_;
if (mut->curr_running_index_ >= mut->curr_max_count_) {
mut->curr_running_index_ = -1;
continue;
}
auto ctx = mut->ctx_items_[mut->curr_running_index_];
if (ctx == nullptr || !ctx->register_fd_set.empty()) {
//如果该ctx协程正在等待某个异步事件,则不对他唤醒
continue;
}
ctx->resume();
if (ctx->status_ == ucontext::finished) {
mut->remove_uthread(ctx->ctx_index_);
}
}
}
}
mytimer.cpp
#include "mytimer.h"
#include "ucontext.h"
#include
#include
#include
#include
#include
mytimer::mytimer():timerfd_(-1), timespec_(nullptr)
{
timerfd_ = timerfd_create(clock_monotonic, tfd_nonblock);
if (timerfd_ == -1) {
abort();
}
auto flags = fcntl(timerfd_, f_getfl, 0);
flags |= o_nonblock;
fcntl(timerfd_, f_setfl, flags);
timespec_ = new struct itimerspec;
memset(timespec_, 0, sizeof(struct itimerspec));
}
mytimer::~mytimer()
{
if (timespec_) {
delete timespec_;
}
close(timerfd_);
}
void mytimer::set_once(int seconds, int millseconds) {
this->set_time(seconds, millseconds, 0, 0);
}
void mytimer::set_cycle(int seconds, int millseconds, int intervalseconds, int intervalmillseconds) {
this->set_time(seconds, millseconds, intervalseconds, intervalmillseconds);
}
int mytimer::get_timerfd() {
return timerfd_;
}
void mytimer::set_time(int seconds, int millseconds, int intervalseconds, int intervalmillseconds) {
if (timerfd_ == -1 || timespec_ == nullptr)
return;
timespec_->it_value.tv_sec = seconds;
timespec_->it_value.tv_nsec = millseconds * 1000;
timespec_->it_interval.tv_sec = intervalseconds;
timespec_->it_interval.tv_nsec = intervalmillseconds * 1000;
if (-1 == timerfd_settime(timerfd_, 0, timespec_, nullptr)) {
abort();
}
}
void mytimer::stop() {
this->set_time(0, 0, 0, 0);
}
int mytimer::get_time() {
struct itimerspec t;
timerfd_gettime(timerfd_, &t);
return t.it_value.tv_sec * 10001000 t.it_value.tv_nsec;
}
这就是我简单编写的全部内容,欢迎指正~