跳转到内容
View in the app

A better way to browse. Learn more.

彼岸论坛

A full-screen app on your home screen with push notifications, badges and more.

To install this app on iOS and iPadOS
  1. Tap the Share icon in Safari
  2. Scroll the menu and tap Add to Home Screen.
  3. Tap Add in the top-right corner.
To install this app on Android
  1. Tap the 3-dot menu (⋮) in the top-right corner of the browser.
  2. Tap Add to Home screen or Install app.
  3. Confirm by tapping Install.
欢迎抵达彼岸 彼岸花开 此处谁在 -彼岸论坛

[C++] 请教大家一个 C++线程池的问题

发表于

最近在找一个简单的 C++11 线程池实现,发现网上有很多相关的代码,在 CSDN 网上看到一个比较简洁的。但是总感觉是不是实现错了。

  1. Any 类 noncopyable 的,仅仅支持移动语义,
  2. Result 类使用了 Any 实例作为成员变量,那么 Result 类应该也是 noncopyable 的,
  3. Result SubmitTask(std::shared_ptr<Task> taskPtr);直接使用了复制语义,应该是有问题吧,可是代码能够被 vs2022 正常编译。

threadpool.h

#pragma once
#include <vector>
#include <cstdint>
#include <queue>
#include <memory>
#include <atomic>
#include <mutex>
#include <thread>
#include <condition_variable>
#include <functional>
#include <sstream>
#include <unordered_map>


// Any 类型:可以接收任意数据的类型
// 任意其他类型 template
// 能让一个类型指向其他类型,基类指针可以指向子类
class Any
{
public:
	Any() = default;
	~Any() = default;
	Any(const Any&) = delete;
	Any& operator=(const Any&) = delete;
	Any(Any&&) = default;
	Any& operator=(Any&&) = default;

	template<typename T>
	Any(T data) : m_base(std::make_unique<Derive<T>>(data)) {}

	template<typename T>
	T cast_()
	{
		Derive<T>* pd = dynamic_cast<Derive<T>*>(m_base.get());

		if (pd == nullptr) {
			throw "type is unmath!!";
		}

		return pd->m_data;
	}

private:
	// 基类
	class Base
	{
	public:
		virtual ~Base() = default;
	};

	// 派生类
	template<typename T>
	class Derive : public Base
	{
	public:
		Derive(T data) : m_data(data) {}
	public:
		T m_data;
	};

private:
	std::unique_ptr<Base> m_base;
};


// 实现一个信号量类
class Semaphore
{
public:
	Semaphore(int limit = 0) : m_resLimit(limit)
	{}

	~Semaphore() = default;

	// 获取一个信号量资源
	void wait()
	{
		std::unique_lock<std::mutex> lock(m_mtx);
		// 如果没有资源,阻塞线程
		while (m_resLimit < 1) {
			m_cond.wait(lock);
		}

		m_resLimit--;
	}

	// 增加一个信号量资源
	void post()
	{
		std::unique_lock<std::mutex> lock(m_mtx);
		m_resLimit++;
		m_cond.notify_all();

	}
private:
	int m_resLimit;  // 资源量
	std::mutex m_mtx;
	std::condition_variable m_cond;
};


// Task 类型前置声明
class Task;

// 实现接收提交到线程池的 task 任务执行完成后的返回值类型
class Result
{
public:
	Result(std::shared_ptr<Task> task, bool isValid = true);
	~Result() = default;

	// setVal
	void setVal(Any result);

	// get 方法,用户调用这个方法获取 task 的返回值
	Any get();
private:
	Any m_any;
	Semaphore m_sem;
	std::shared_ptr<Task> m_task;
	std::atomic_bool m_isValid;
};


// 任务抽象基类
class Task
{
public:
	void exec();
	void setResult(Result* res);
	virtual Any run() = 0;

private:
	Result* m_result{ nullptr };  // 不要用智能指针,task 含有 Result  Result 含有 task ,可能导致问题
};

class MyTask : public Task
{
public:
	MyTask(int start, int end) : m_start(start), m_end(end) {}

	Any run()
	{
		std::ostringstream ostr;
		ostr << std::this_thread::get_id();
		printf("thead %s, task start \n", ostr.str().c_str());

		uint64_t sum = 0;

		for (int i = m_start; i <= m_end; i++) {
			sum += i;
		}

		printf("sum %llu\n", sum);
		std::this_thread::sleep_for(std::chrono::seconds(2));
		printf("thread %s, task finish \n", ostr.str().c_str());

		return sum;
	}

private:
	int m_start;
	int m_end;
};

enum ThreadPoolMode
{
	MODE_FIXED,  // 固定数量的线程
	MODE_CACHED,  // 线程数量可以动态增长
};

class Thread
{
public:
	using ThreadFunc = std::function<void(int)>;

	Thread(ThreadFunc func);
	~Thread();

	void Start();
	int GetId() { return m_threadId; }
private:
	ThreadFunc m_func;
	static int generateId;
	int m_threadId;
};


class ThreadPool
{
public:
	ThreadPool();
	~ThreadPool();

	// 设置线程池工作模式
	void SetMode(ThreadPoolMode mode);

	// 设置任务数量上限
	void SetTaskQueMaxThreshold(int value);

	// 给线程池提交任务
	Result SubmitTask(std::shared_ptr<Task> taskPtr);

	// 开启线程池
	void Start(int initThreadSize = std::thread::hardware_concurrency());

private:
	ThreadPool(const ThreadPool&) = delete;
	ThreadPool& operator=(const ThreadPool&) = delete;

	// 定义线程函数
	void ThreadFunc(int threadId);
	bool CheckRunningState() const;

private:
	std::unordered_map<int, std::unique_ptr<Thread>> m_threadMap;  // 线程列表
	int m_initThreadSize;  // 初始的线程数量
	std::atomic_int m_curThreadSize;  // 当前线程数量

	std::queue<std::shared_ptr<Task>> m_taskQue;  // 任务队列
	std::atomic_int m_taskSize;  // 任务的数量
	int m_taskQueMaxThreshold;  // 任务队列的数量上限

	std::mutex m_taskQueMtx;  // 保证任务队列的线程安全
	std::condition_variable m_taskQueNotFullCv;  // 表示任务队列不满
	std::condition_variable m_taskQueNotEmptyCv;  // 表示任务队列不空
	std::condition_variable m_exitCv;  // 退出线程池

	ThreadPoolMode m_poolMode;  // 当前线程池的工作模式
	std::atomic_bool m_isPoolRuning;  // 当前线程工作状态
};

threadpool.cpp

#include "threadpool.h"
#include <functional>
#include <iostream>

constexpr int TASK_MAX_THRESHOLD = 1024;

ThreadPool::ThreadPool() : m_initThreadSize(4), m_taskSize(0),
m_taskQueMaxThreshold(TASK_MAX_THRESHOLD),
m_poolMode(ThreadPoolMode::MODE_FIXED)
{
}

ThreadPool::~ThreadPool()
{
	m_isPoolRuning = false;
	std::unique_lock<std::mutex> lock(m_taskQueMtx);

	// 线程 要么在阻塞中 要么在工作中
	while (m_threadMap.size() > 0) {
		m_taskQueNotEmptyCv.notify_all();  // 唤醒等待的工作线程
		m_exitCv.wait(lock);
	}
}

void ThreadPool::SetMode(ThreadPoolMode mode)
{
	if (m_isPoolRuning) { return; }  // 线程池启动后,不允许设置线程池一些参数

	m_poolMode = mode;
}

void ThreadPool::SetTaskQueMaxThreshold(int value)
{
	if (m_isPoolRuning) { return; }

	m_taskQueMaxThreshold = value;
}

Result ThreadPool::SubmitTask(std::shared_ptr<Task> taskPtr)
{
	// 获取锁
	std::unique_lock<std::mutex> lock(m_taskQueMtx);

	// 线程通信,检查任务队列是否有空余
	while (m_taskQue.size() >= m_taskQueMaxThreshold) {

		// 用于提交任务,不能阻塞太长时间,如果超过 1s ,给用户返回提交失败
		if (m_taskQueNotFullCv.wait_for(lock, std::chrono::seconds(1)) == std::cv_status::timeout) {
			return Result(taskPtr, false);
		}
	}

	// 如果有空余,把任务提交到任务队列中
	m_taskQue.emplace(taskPtr);
	m_taskSize++;

	// 因为新放了任务,任务队列肯定不为空了,在 m_taskQueNotEmptyCv 进行通知,赶快分配线程执行这个任务
	m_taskQueNotEmptyCv.notify_all();

	return Result(taskPtr);
}

void ThreadPool::Start(int initThreadSize)
{
	m_initThreadSize = initThreadSize;
	m_curThreadSize = initThreadSize;
    m_isPoolRuning = true;

	// 创建线程对象
	for (int i = 0; i < m_initThreadSize; i++) {
		auto ptr = std::make_unique<Thread>(std::bind(&ThreadPool::ThreadFunc, this, std::placeholders::_1));
		int threadId = ptr->GetId();
		m_threadMap.emplace(threadId, std::move(ptr));
	}

	// 启动所有线程
	for (auto iter = m_threadMap.cbegin(); iter != m_threadMap.end(); iter++) {
		iter->second->Start();
	}
}

void ThreadPool::ThreadFunc(int threadId)
{
	while (true) {

		// 获取锁
		std::unique_lock<std::mutex> lock(m_taskQueMtx);

		std::ostringstream ostr;
		ostr << std::this_thread::get_id();
		printf("thead %s, To Get task \n", ostr.str().c_str());

		// 判断任务队列是否为空
		while (m_taskQue.empty()) {
			if (!m_isPoolRuning) {
				m_threadMap.erase(threadId);
				m_exitCv.notify_all();

				printf("deconstructor thread exit, id = %d\n", threadId);
				return;
			}
            
			m_taskQueNotEmptyCv.wait(lock);

		}

		printf("thead %s, Getted task \n", ostr.str().c_str());
		// 不为空,获取任务
		auto taskPtr = m_taskQue.front();  // front()返回引用,auto 忽略引用属性,正好满足需要
		m_taskQue.pop();
		m_taskSize--;

		lock.unlock();  // 释放锁;

		// 如果任务队列还有任务,通知其他线程执行任务
		if (m_taskQue.size() > 0) {
			m_taskQueNotEmptyCv.notify_all();
		}

		// 通知队列已经不满
		m_taskQueNotFullCv.notify_all();

		taskPtr->exec();

		if (!m_isPoolRuning) {
			m_threadMap.erase(threadId);
			m_exitCv.notify_all();

			printf("deconstructor thread exit, id = %d\n", threadId);
			return;
		}

	}
}

bool ThreadPool::CheckRunningState() const
{
	if (m_isPoolRuning) {
		return true;
	}

	return false;
}

// 线程方法
int Thread::generateId = 0;

Thread::Thread(ThreadFunc func) : m_func(func),
								m_threadId(generateId++)
{
}

Thread::~Thread()
{
}

void Thread::Start()
{
	std::thread t(m_func, m_threadId);
	t.detach();
}

Result::Result(std::shared_ptr<Task> task, bool isValid) : m_task(task), m_isValid(isValid)
{
	m_task->setResult(this);
}

void Result::setVal(Any result)
{
	m_any = std::move(result);
	m_sem.post();  // 通知已经获得结果
}

Any Result::get()
{
	if (!m_isValid) {
		return "";
	}

	m_sem.wait();  // 等待结果
	return std::move(m_any);
}


void Task::exec()
{
	if (m_result != nullptr) {
		Any result = run();  // 这里发生多态调用

		m_result->setVal(std::move(result));
	}
}

void Task::setResult(Result* res)
{
	m_result = res;
}

main.cpp

#include "threadpool.h"

#include <chrono>
#include <iostream>

using std::cout;
using std::endl;


int main(int argc, char* argv[])
{
	{
		ThreadPool pool;
		pool.Start(4);

		Result res1 = pool.SubmitTask(std::make_shared<MyTask>(1, 100000000));
		Result res2 = pool.SubmitTask(std::make_shared<MyTask>(100000001, 200000000));
		Result res3 = pool.SubmitTask(std::make_shared<MyTask>(200000001, 300000000));

		//uint64_t sum1 = res1.get().cast_<uint64_t>();
		//uint64_t sum2 = res2.get().cast_<uint64_t>();
		//uint64_t sum3 = res3.get().cast_<uint64_t>();

		//cout << (sum1 + sum2 + sum3) << endl;
	}

	cout << "main over" << endl;

	getchar();
	return 0;
}

Featured Replies

No posts to show

创建帐户或登录来提出意见

Configure browser push notifications

Chrome (Android)
  1. Tap the lock icon next to the address bar.
  2. Tap Permissions → Notifications.
  3. Adjust your preference.
Chrome (Desktop)
  1. Click the padlock icon in the address bar.
  2. Select Site settings.
  3. Find Notifications and adjust your preference.