菜鸟笔记
提升您的技术认知

c 利用linux函数makecontext等实现简单协程-ag真人官方网

我是一个编程新手,最近了解到协程这个概念,协程可以理解为用户级线程,在用户空间实现调度,在处理异步io时,可以在子程序中让出cpu交给其他协程,等事件完成再切换到子程序中。当然回调也可以实现,但是使用协程会使程序“看起来”是顺序执行的。

我利用linux系统函数getcontext,makecontext,swapcontext来实现协程之间的切换。

getcontext(ucontext_t*)初始化上下文,makecontext(ucontext_t*,void(*)(void),int args)绑定切换到该上下文时的执行函数和函数的参数;swapcontext(ucontext_t* ouc,ucontext_t* uc);在ouc保存当前上下文,并切换到uc

    typedef struct ucontext {
               struct ucontext *uc_link;          //指向当当前上下文执行完毕后要切换的上下文
               sigset_t         uc_sigmask;
               stack_t          uc_stack;         //当前上下文使用的堆栈
               mcontext_t       uc_mcontext;
               …
           } ucontext_t;

我利用这些实现了一个简单的协程,在一个线程中实现多个协程之间的切换,使用epoll等待io事件或定时器事件发生后对协程唤醒,当没有事件发生时进行轮询顺序切换到多个协程。下面是线程工作函数:

void	myuthread::thread_work(void* pvoid) {
	myuthread* mut = (myuthread*)pvoid;
	struct epoll_event events[1024];
	while (!mut->shut_down_) {
		int ret = epoll_wait(mut->epoll_fd_, events, 1024, 4);
		if (ret > 0) {//有事件发生
			for (int i = 0; i < ret; i  ) {
				ucontextrevent* ctxex = (ucontextrevent*)events[i].data.ptr;
				auto ctx = ctxex->ctx_;
				if (ctx->get_timer() && ctxex->fd_ == ctx->get_timer()->get_timerfd()) {
					ctxex->revents_ = uthreadtimerout;
				}else {
					ctxex->revents_ = events[i].events;
				}
				ctx->resume();
				if (ctx->status_ == ucontext::finished) {
					mut->remove_uthread(ctx->ctx_index_);
				}
			}
		}else if (ret == -1) {//异常
			if (errno == eintr) {
				continue;
			}
			else {
				perror("epoll_wait:");
				abort();
			}
		}else {//epoll_wait超时
			  mut->curr_running_index_;
			if (mut->curr_running_index_ >= mut->curr_max_count_) {
				mut->curr_running_index_ = -1;
				continue;
			}
			auto ctx = mut->ctx_items_[mut->curr_running_index_];
			if (ctx == nullptr || !ctx->register_fd_set.empty()) {
				//如果该ctx协程正在等待某个异步事件,则不对他唤醒
				continue;
			}
			ctx->resume();
			if (ctx->status_ == ucontext::finished) {
				mut->remove_uthread(ctx->ctx_index_);
			}
		}
	}
}

使用:

#include 
#include 
#include 
#include "myuthread.h"
#include "mytimer.h"
using namespace std;
using namespace placeholders;
void	test1(ucontext*& ctx, int num) {
	cout << "enter test1:" << num << endl;
	mytimer timer;
	timer.set_once(2, 0);//设置定时器2s
	ucontextrevent revents(ctx,timer.get_timerfd());
	ctx->register_event(timer.get_timerfd(),epollin,&revents);//注册定时器事件
	ctx->attach_timer(&timer);
	cout << "start yield" << endl;
	ctx->yield();            //让出cpu,等待定时器事件
	if (revents.revents_ == uthreadtimerout) {
		cout << "uthreadtimerout" << endl;
	}
	ctx->remove_event(timer.get_timerfd());
	cout << "finished test1" << endl;
}
void	test2(ucontext*& ctx) {
	cout << "test2 yield" << endl;
	ctx->yield();
	cout << "finished test2" << endl;
}
void	finished() {
	cout << "test1 callback" << endl;
}
int main()
{
	myuthread mut(2, 8192);//参数:协程最大数量,协程栈大小
	mut.add_task(bind(test1, _1, 99999),finished);//添加任务,参数:工作函数,回调函数
	for (int i = 0; i < 10; i  ) {
		mut.add_task(test2);
	}
	mut.join();
	return 0;
}

下面贴出源代码:我还是个小白,把代码贴出来,如果有什么错误希望大家评论告诉我,谢谢(ㅎ-ㅎ;)

ucontext.h

#include 
#include 
#include 
#include 
using namespace std;
/*
epollin:1
epollout:4
epollrdhup:8192
epollpri:2
epollerr:8
epollhup:16
定时器超时标志
*/
#define uthreadtimerout 3
class ucontext;
typedef function      ucontextfunc;
typedef function                callback;
class mytimer;
struct ucontextrevent {
	ucontext*    ctx_;
	int          fd_;
	int          revents_;
	ucontextrevent(ucontext* ctx, int fd,int revents = 0) :
		ctx_(ctx), fd_(fd), revents_(revents){}
};
class ucontext {
public:
	ucontext(int index, ucontext_t* main_ctx, int stack_size, int epoll_fd);
	~ucontext();
	void	set_func(ucontextfunc func, callback callback = 0);
	void	make();
	void	resume();
	void	yield();
	void	attach_timer(mytimer*);
	void	register_event(int fd, int events, ucontextrevent* revents);
	void	remove_event(int fd);
	mytimer* get_timer();
private:
	static void	work_func(uint32_t low32, uint32_t high32);
public:
        int             ctx_index_;		//当前对象在myuthread对象的数组中保存的索引
	ucontext_t*     main_ctx_;		//线程主逻辑上下文
	ucontext_t*     ctx_;			//当前上下文
	char*           raw_stack_;		//上下文使用栈空间
	char*           stack_;			//栈空间(安全保护)
	ucontextfunc    func_;			//用户任务
	callback        callback_;		//任务回调
	int             stack_size_;	//栈大小
	enum { ready = 0, running, suspend, finished };
	int             status_;		//当前协程状态
	int             epoll_fd_;		//epoll fd
	mytimer*        mytimer_;		
	unordered_set    register_fd_set;
};

myuthread.h

#include 
#include 
#include 
#include 
#include 
#include 
#include "sys/epoll.h"
using namespace std;
#include "ucontext.h"
class myuthread {
public:
	myuthread(int max_uthread_count, int stack_size);
	~myuthread();
	void	add_task(ucontextfunc func, callback callback = 0);
	void	join();
	void	destory();
private:
	void	remove_uthread(int index);
	int		get_stack_size(int stack_size);
	static void thread_work(void* pvoid);
private:
	int					max_uthread_count_;		//最大协程数量
	int					stack_size_;			//栈大小
	volatile int		curr_running_index_;	//当前执行协程索引
	volatile int		curr_max_count_;		//当前最大协程数量
	volatile int		idle_count_;			//可用协程数量
	bool				shut_down_;				//是否退出
	vector	ctx_items_;				//调度的协程列表
	queue	ctx_ready_queue_;		//就绪队列,等待其它协程退出后,会被添加到ctx_items_中
	thread*				thread_;				//线程
	ucontext_t			main_ctx_;				//主上下文
	mutex				mutex_;					//锁queue
	mutex				mutex_join_;			
	condition_variable	cv_;					//使用锁和条件变量,阻塞等待协程全部执行完毕
	int					epoll_fd_;				//epoll fd
};

mytimer.h

#pragma once
#include 
#include 
using namespace std;
class ucontext;
class mytimer
{
public:
	mytimer();
	~mytimer();
	void    set_once(int seconds, int millseconds);
	void	set_cycle(int seconds, int millseconds, int intervalseconds, int intervalmillseconds);
	int		get_timerfd();
	void	stop();
	//获取距离时间到期还有多少纳秒
	int 	get_time();
private:
	void	set_time(int seconds, int millseconds, int intervalseconds, int intervalmillseconds);
private:
	int					timerfd_;
	struct itimerspec*	timespec_;
};

ucontext.cpp

#include "ucontext.h"
#include 
#include 
#include 
#include 
#include "mytimer.h"
ucontext::ucontext(int index, ucontext_t* main_ctx, int stack_size, int epoll_fd)
	:ctx_index_(index), main_ctx_(main_ctx), ctx_(nullptr),
	raw_stack_(nullptr), stack_(nullptr), func_(nullptr), callback_(nullptr),
	stack_size_(stack_size),
	status_(ready), epoll_fd_(epoll_fd),
	mytimer_(nullptr){
	//创建协程私有栈
	auto page_size = getpagesize();
	raw_stack_ = (char*)mmap(nullptr, stack_size_   page_size * 2,
		prot_read | prot_write, map_anonymous | map_private, -1, 0);
	assert(raw_stack_ != nullptr);
	stack_ = raw_stack_   page_size;
	assert(mprotect(raw_stack_, page_size, prot_none) == 0);
	assert(mprotect(raw_stack_   stack_size_   page_size, page_size, prot_none) == 0);
	ctx_ = new ucontext_t;
	ctx_->uc_flags = 0;
	ctx_->uc_link = main_ctx;
	ctx_->uc_stack.ss_sp = stack_;
	ctx_->uc_stack.ss_size = stack_size_;
	getcontext(ctx_);
}
ucontext::~ucontext() {
	delete ctx_;
	munmap(raw_stack_, stack_size_   getpagesize() * 2);
}
void	ucontext::set_func(ucontextfunc func, callback callback) {
	func_ = func;
	callback_ = callback;
}
void	ucontext::make() {
	auto ptr = (uintptr_t)this;
	makecontext(ctx_, (void(*)(void))work_func, 2, (uint32_t)ptr, (uint32_t)(ptr >> 32));
}
void	ucontext::resume() {
	status_ = running;
	swapcontext(main_ctx_, ctx_);
}
void	ucontext::yield() {
	status_ = suspend;
	swapcontext(ctx_, main_ctx_);
}
void	ucontext::attach_timer(mytimer* timer) {
	mytimer_ = timer;
}
mytimer* ucontext::get_timer() {
	return mytimer_;
}
void	ucontext::register_event(int fd, int events, ucontextrevent* revents) {
	epoll_event ev;
	ev.data.fd = fd;
	ev.events = events;
	ev.data.ptr = revents;
	epoll_ctl(epoll_fd_, epoll_ctl_add, fd, &ev);
	register_fd_set.insert(fd);
}
void	ucontext::remove_event(int fd) {
	epoll_ctl(epoll_fd_, epoll_ctl_del, fd, nullptr);
	if (mytimer_ && mytimer_->get_timerfd() == fd) {
		this->attach_timer(nullptr);
	}
	register_fd_set.erase(fd);
}
void	ucontext::work_func(uint32_t low32, uint32_t high32) {
	uintptr_t ptr = (uintptr_t)low32 | ((uintptr_t)high32 << 32);
	ucontext * uc = (ucontext*)ptr;
	if (uc->func_) {
		uc->func_(uc);
		if (uc->callback_) {
			uc->callback_();
		}
	}
	uc->status_ = finished;
}

myuthread.cpp

#include "myuthread.h"
#include "mytimer.h"
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
myuthread::myuthread(int max_uthread_count, int stack_size) :
	max_uthread_count_(max_uthread_count), 
	stack_size_(get_stack_size(stack_size)),
	curr_running_index_(-1), 
	curr_max_count_(0), 
	idle_count_(max_uthread_count),
	shut_down_(false), thread_(0), epoll_fd_(0)
{
	memset(&main_ctx_, 0, sizeof(main_ctx_));
	ctx_items_.resize(max_uthread_count_, 0);
	epoll_fd_ = epoll_create(100);
	thread_ = new thread(bind(thread_work, (void*)this));
}
myuthread::~myuthread() {
	ctx_items_.clear();
	while (!ctx_ready_queue_.empty()) {
		ctx_ready_queue_.pop();
	}
	delete thread_;
}
void	myuthread::add_task(ucontextfunc func, callback callback) {
	auto ctx = new ucontext(-1, &main_ctx_, stack_size_,epoll_fd_);
	ctx->set_func(func, callback);
	if (curr_max_count_ < max_uthread_count_) {
		ctx->ctx_index_ = curr_max_count_;
		ctx->make();
		ctx_items_[curr_max_count_] = ctx;
		  curr_max_count_;
		--idle_count_;
	}
	else {
		if (idle_count_ > 0) {
			auto index = -1;
			for (int i = 0, j = curr_max_count_ - 1; i <= j; i  , j--) {
				if (ctx_items_[i] == nullptr) {
					index = i;
					break;
				}
				if (ctx_items_[j] == nullptr) {
					index = j;
					break;
				}
			}
			ctx->ctx_index_ = index;
			ctx->make();
			ctx_items_[index] = ctx;
			--idle_count_;
		}
		else {
			lock_guard lock(mutex_);
			ctx_ready_queue_.push(ctx);
		}
	}
}
void	myuthread::join() {
	while (idle_count_ < max_uthread_count_ || !ctx_ready_queue_.empty()) {
		unique_lock lock(mutex_join_);
		cv_.wait(lock);
	}
	this->shut_down_ = true;
	this->thread_->join();
}
void	myuthread::destory() {
	this->shut_down_ = true;
	this->thread_->join();
}
void	myuthread::remove_uthread(int index) {
	if (index >= 0 && index < curr_max_count_) {
		delete ctx_items_[index];
		ctx_items_[index] = nullptr;
		  idle_count_;
		cv_.notify_all();
	}
	ucontext* ctx = nullptr;
	{
		lock_guard lock(mutex_);
		if (!ctx_ready_queue_.empty()) {
			ctx = ctx_ready_queue_.front();
			ctx_ready_queue_.pop();
		}
	}
	if (ctx) {
		ctx->ctx_index_ = index;
		ctx->make();
		ctx_items_[index] = ctx;
		--idle_count_;
	}
}
int		myuthread::get_stack_size(int stack_size) {
	auto page_size = getpagesize();
	if (stack_size < page_size) {
		return page_size;
	}
	int page_count = stack_size / page_size;
	if ((stack_size % page_size) > 0) {
		  page_count;
	}
	return page_count * page_size;
}
void	myuthread::thread_work(void* pvoid) {
	myuthread* mut = (myuthread*)pvoid;
	struct epoll_event events[1024];
	while (!mut->shut_down_) {
		int ret = epoll_wait(mut->epoll_fd_, events, 1024, 4);
		if (ret > 0) {//有事件发生
			for (int i = 0; i < ret; i  ) {
				ucontextrevent* ctxex = (ucontextrevent*)events[i].data.ptr;
				auto ctx = ctxex->ctx_;
				if (ctx->get_timer() && ctxex->fd_ == ctx->get_timer()->get_timerfd()) {
					ctxex->revents_ = uthreadtimerout;
				}else {
					ctxex->revents_ = events[i].events;
				}
				ctx->resume();
				if (ctx->status_ == ucontext::finished) {
					mut->remove_uthread(ctx->ctx_index_);
				}
			}
		}else if (ret == -1) {//异常
			if (errno == eintr) {
				continue;
			}
			else {
				perror("epoll_wait:");
				abort();
			}
		}else {//epoll_wait超时
			  mut->curr_running_index_;
			if (mut->curr_running_index_ >= mut->curr_max_count_) {
				mut->curr_running_index_ = -1;
				continue;
			}
			auto ctx = mut->ctx_items_[mut->curr_running_index_];
			if (ctx == nullptr || !ctx->register_fd_set.empty()) {
				//如果该ctx协程正在等待某个异步事件,则不对他唤醒
				continue;
			}
			ctx->resume();
			if (ctx->status_ == ucontext::finished) {
				mut->remove_uthread(ctx->ctx_index_);
			}
		}
	}
}

mytimer.cpp

#include "mytimer.h"
#include "ucontext.h"
#include 
#include 
#include 
#include 
#include 
mytimer::mytimer():timerfd_(-1), timespec_(nullptr)
{
	timerfd_ = timerfd_create(clock_monotonic, tfd_nonblock);
	if (timerfd_ == -1) {
		abort();
	}
	auto flags = fcntl(timerfd_, f_getfl, 0);
	flags |= o_nonblock;
	fcntl(timerfd_, f_setfl, flags);
	timespec_ = new struct itimerspec;
	memset(timespec_, 0, sizeof(struct itimerspec));
}
mytimer::~mytimer()
{
	if (timespec_) {
		delete timespec_;
	}
	close(timerfd_);
}
void    mytimer::set_once(int seconds, int millseconds) {
	this->set_time(seconds, millseconds, 0, 0);
}
void	mytimer::set_cycle(int seconds, int millseconds, int intervalseconds, int intervalmillseconds) {
	this->set_time(seconds, millseconds, intervalseconds, intervalmillseconds);
}
int		mytimer::get_timerfd() {
	return timerfd_;
}
void	mytimer::set_time(int seconds, int millseconds, int intervalseconds, int intervalmillseconds) {
	if (timerfd_ == -1 || timespec_ == nullptr)
		return;
	timespec_->it_value.tv_sec = seconds;
	timespec_->it_value.tv_nsec = millseconds * 1000;
	timespec_->it_interval.tv_sec = intervalseconds;
	timespec_->it_interval.tv_nsec = intervalmillseconds * 1000;
	if (-1 == timerfd_settime(timerfd_, 0, timespec_, nullptr)) {
		abort();
	}
}
void	mytimer::stop() {
	this->set_time(0, 0, 0, 0);
}
int 	mytimer::get_time() {
	struct itimerspec t;
	timerfd_gettime(timerfd_, &t);
	return t.it_value.tv_sec * 10001000   t.it_value.tv_nsec;
}

这就是我简单编写的全部内容,欢迎指正~

网站地图