#include <iostream>
#include <type_traits>
// MyClass类定义
class MyClass {
private:
int data;
public:
MyClass(int value = 0) : data(value) {}
~MyClass() { std::cout << "MyClass destructor called." << std::endl; }
void printData() { std::cout << "Data: " << data << std::endl; }
};
// SmartPtr模板类定义
template <typename T>
class SmartPtr {
static_assert(std::is_base_of<MyClass, T>::value, "SmartPtr can only point to MyClass and its derived classes.");
private:
T* ptr;
int* refCount;
public:
SmartPtr(T* p = nullptr) : ptr(p) {
if (ptr) {
refCount = new int(1);
} else {
refCount = new int(0);
}
}
SmartPtr(const SmartPtr& other) : ptr(other.ptr), refCount(other.refCount) {
if (ptr) {
++(*refCount);
}
}
SmartPtr& operator=(const SmartPtr& other) {
if (this == &other) {
return *this;
}
if (ptr && --(*refCount) == 0) {
delete ptr;
delete refCount;
}
ptr = other.ptr;
refCount = other.refCount;
if (ptr) {
++(*refCount);
}
return *this;
}
~SmartPtr() {
if (ptr && --(*refCount) == 0) {
delete ptr;
delete refCount;
}
}
T& operator*() {
return *ptr;
}
T* operator->() {
return ptr;
}
bool operator==(const SmartPtr& other) const {
return ptr == other.ptr;
}
bool operator!=(const SmartPtr& other) const {
return ptr != other.ptr;
}
};
// 测试代码
int main() {
SmartPtr<MyClass> sp1(new MyClass(10));
SmartPtr<MyClass> sp2(sp1);
SmartPtr<MyClass> sp3;
sp3 = sp2;
sp1->printData();
(*sp2).printData();
sp3.printData();
return 0;
}