19  SOLID in Python

19.1 Single Responsibility Principle (SRP)

19.1.1 Bad

class User:
    def __init__(self, user_id: str, username: str, email: str):
        self.user_id = user_id
        self.username = username
        self.email = email
        self.posts = []

    def create_post(self, content: str) -> dict:
        post = {"id": len(self.posts) + 1, "content": content, "likes": 0}
        self.posts.append(post)
        return post

    def get_timeline(self) -> list:
        # Fetch and return the user's timeline
        # This might involve complex logic to fetch and
        # sort posts from followed users
        pass

    def update_profile(self, new_username: str = None, new_email: str = None):
        if new_username:
            self.username = new_username
        if new_email:
            self.email = new_email

19.1.2 Good

┌─────────────────┐
│      User       │
├─────────────────┤
│ + user_id       │
│ + username      │
│ + email         │
└─────────────────┘

┌─────────────────┐     ┌─────────────────┐     ┌──────────────────┐
│  PostManager    │     │ TimelineService │     │ ProfileManager   │
├─────────────────┤     ├─────────────────┤     ├──────────────────┤
│ + create_post() │     │ + get_timeline()│     │+ update_profile()│
│ + generate_..() │     └─────────────────┘     └──────────────────┘
└─────────────────┘
        │
        │ uses
        ▼
   ┌─────────┐
   │  User   │
   └─────────┘
class User:
    def __init__(self, user_id: str, username: str, email: str):
        self.user_id = user_id
        self.username = username
        self.email = email


class PostManager:
    def create_post(self, user: User, content: str):
        post = {
            "id": self.generate_post_id(),
            "user_id": user.user_id,
            "content": content,
            "likes": 0,
        }
        # Logic to save the post
        return post

    def generate_post_id(self):
        # Logic to generate a unique post ID
        pass


class TimelineService:
    def get_timeline(self, user: User) -> list:
        # Fetch and return the user's timeline
        # This might involve complex logic to fetch and sort posts from followed users
        pass


class ProfileManager:
    def update_profile(
        self, user: User, new_username: str = None, new_email: str = None
    ):
        if new_username:
            user.username = new_username
        if new_email:
            user.email = new_email
        # Additional logic for profile updates, like triggering email verification

19.2 Open–Closed Principle (OCP)

19.2.1 Bad

class Rectangle:
    def __init__(self, width, height):
        self.width = width
        self.height = height


class Circle:
    def __init__(self, radius):
        self.radius = radius


class AreaCalculator:
    def calculate_area(self, shape):
        if isinstance(shape, Rectangle):
            return shape.width * shape.height
        elif isinstance(shape, Circle):
            return 3.14 * shape.radius**2
        else:
            raise ValueError("Unsupported shape")


# Usage
rectangle = Rectangle(5, 4)
circle = Circle(3)

calculator = AreaCalculator()
print(f"Rectangle area: {calculator.calculate_area(rectangle)}")
print(f"Circle area: {calculator.calculate_area(circle)}")

19.2.2 Good

                    ┌─────────────────┐
                    │ <<abstract>>    │
                    │     Shape       │
                    ├─────────────────┤
                    │ + area()        │
                    └─────────────────┘
                            △
                            │
           ┌────────────────┼────────────────┐
           │                │                │
   ┌───────────────┐ ┌─────────────┐ ┌─────────────┐
   │  Rectangle    │ │   Circle    │ │  Triangle   │
   ├───────────────┤ ├─────────────┤ ├─────────────┤
   │ + width       │ │ + radius    │ │ + base      │
   │ + height      │ ├─────────────┤ │ + height    │
   ├───────────────┤ │ + area()    │ ├─────────────┤
   │ + area()      │ └─────────────┘ │ + area()    │
   └───────────────┘                 └─────────────┘

   ┌──────────────────────┐
   │  AreaCalculator      │
   ├──────────────────────┤
   │ + calculate_area()   │───uses───> Shape
   └──────────────────────┘
import math
from abc import ABC, abstractmethod


class Shape(ABC):
    @abstractmethod
    def area(self):
        pass


class Rectangle(Shape):
    def __init__(self, width, height):
        self.width = width
        self.height = height

    def area(self):
        return self.width * self.height


class Circle(Shape):
    def __init__(self, radius):
        self.radius = radius

    def area(self):
        return math.pi * self.radius**2


class AreaCalculator:
    def calculate_area(self, shape: Shape):
        return shape.area()


# Usage
rectangle = Rectangle(5, 4)
circle = Circle(3)

calculator = AreaCalculator()
print(f"Rectangle area: {calculator.calculate_area(rectangle)}")
print(f"Circle area: {calculator.calculate_area(circle)}")


# Adding a new shape without modifying AreaCalculator
class Triangle(Shape):
    def __init__(self, base, height):
        self.base = base
        self.height = height

    def area(self):
        return 0.5 * self.base * self.height


triangle = Triangle(6, 4)
print(f"Triangle area: {calculator.calculate_area(triangle)}")

19.3 The Interface Segregation Principle (ISP)

19.3.1 Bad

from abc import ABC, abstractmethod


class MultimediaPlayer(ABC):
    @abstractmethod
    def play_media(self, file: str) -> None:
        pass

    @abstractmethod
    def stop_media(self) -> None:
        pass

    @abstractmethod
    def display_lyrics(self, file: str) -> None:
        pass

    @abstractmethod
    def apply_video_filter(self, filter: str) -> None:
        pass


class MusicPlayer(MultimediaPlayer):
    def play_media(self, file: str) -> None:
        # Implementation for playing music
        print(f"Playing music: {file}")

    def stop_media(self) -> None:
        # Implementation for stopping music
        print("Stopping music")

    def display_lyrics(self, file: str) -> None:
        # Implementation for displaying lyrics
        print(f"Displaying lyrics for: {file}")

    def apply_video_filter(self, filter: str) -> None:
        # This method doesn't make sense for a MusicPlayer
        raise NotImplementedError("MusicPlayer does not support video filters")


class VideoPlayer(MultimediaPlayer):
    # Implementation for video player
    ... 

19.3.2 Good

┌──────────────────────┐  ┌──────────────────────┐  ┌──────────────────────┐
│ <<abstract>>         │  │ <<abstract>>         │  │ <<abstract>>         │
│  MediaPlayable       │  │ LyricsDisplayable    │  │  VideoFilterable     │
├──────────────────────┤  ├──────────────────────┤  ├──────────────────────┤
│ + play_media()       │  │ + display_lyrics()   │  │ + apply_video_...()  │
│ + stop_media()       │  └──────────────────────┘  └──────────────────────┘
└──────────────────────┘            △                        △
         △                          │                        │
         │                          │                        │
         │         ┌────────────────┼────────────┐           │
         │         │                             │           │
         │    ┌────────────┐            ┌────────────────┐   │
         └────│MusicPlayer │            │  VideoPlayer   │───┘
              ├────────────┤            ├────────────────┤
              │implements  │            │ implements     │
              │both        │            │ both           │
              └────────────┘            └────────────────┘

         ┌─────────────────────┐
         │ BasicAudioPlayer    │
         ├─────────────────────┤
         │ implements only     │
         │ MediaPlayable       │
         └─────────────────────┘
from abc import ABC, abstractmethod


class MediaPlayable(ABC):
    @abstractmethod
    def play_media(self, file: str) -> None:
        pass

    @abstractmethod
    def stop_media(self) -> None:
        pass


class LyricsDisplayable(ABC):
    @abstractmethod
    def display_lyrics(self, file: str) -> None:
        pass


class VideoFilterable(ABC):
    @abstractmethod
    def apply_video_filter(self, filter: str) -> None:
        pass


class MusicPlayer(MediaPlayable, LyricsDisplayable):
    def play_media(self, file: str) -> None:
        print(f"Playing music: {file}")

    def stop_media(self) -> None:
        print("Stopping music")

    def display_lyrics(self, file: str) -> None:
        print(f"Displaying lyrics for: {file}")


class VideoPlayer(MediaPlayable, VideoFilterable):
    def play_media(self, file: str) -> None:
        print(f"Playing video: {file}")

    def stop_media(self) -> None:
        print("Stopping video")

    def apply_video_filter(self, filter: str) -> None:
        print(f"Applying video filter: {filter}")


class BasicAudioPlayer(MediaPlayable):
    def play_media(self, file: str) -> None:
        print(f"Playing audio: {file}")

    def stop_media(self) -> None:
        print("Stopping audio")

19.4 Liskov Substitution Principle (LSP)

19.4.1 Bad

class Vehicle:
    def __init__(self, fuel_capacity: float):
        self._fuel_capacity = fuel_capacity
        self._fuel_level = fuel_capacity

    def fuel_level(self) -> float:
        return self._fuel_level

    def consume_fuel(self, distance: float) -> None:
        fuel_consumed = distance / 10  # Assume 10 km per liter for simplicity
        if self._fuel_level - fuel_consumed < 0:
            raise ValueError("Not enough fuel to cover the distance")
        self._fuel_level -= fuel_consumed


class ElectricCar(Vehicle):
    def __init__(self, battery_capacity: float):
        super().__init__(battery_capacity)

    def consume_fuel(self, distance: float) -> None:
        energy_consumed = distance / 5  # Assume 5 km per kWh for simplicity
        if self._fuel_level - energy_consumed < 0:
            raise ValueError("Not enough charge to cover the distance")
        self._fuel_level -= energy_consumed


def drive_vehicle(vehicle: Vehicle, distance: float) -> None:
    initial_fuel = vehicle.fuel_level()
    vehicle.consume_fuel(distance)
    fuel_consumed = initial_fuel - vehicle.fuel_level()
    print(f"Fuel consumed: {fuel_consumed:.2f} liters")


# Usage
car = Vehicle(50)  # 50 liter tank
drive_vehicle(car, 100)  # Works fine

electric_car = ElectricCar(50)  # 50 kWh battery
drive_vehicle(electric_car, 100)  # This will print incorrect fuel consumption

19.4.2 Good

                    ┌─────────────────────┐
                    │ <<abstract>>        │
                    │   PowerSource       │
                    ├─────────────────────┤
                    │ + _capacity         │
                    │ + _level            │
                    ├─────────────────────┤
                    │ + level()           │
                    │ + consume()         │
                    └─────────────────────┘
                            △
                            │
                   ┌────────┴────────┐
                   │                 │
          ┌────────────────┐  ┌─────────────┐
          │   FuelTank     │  │   Battery   │
          ├────────────────┤  ├─────────────┤
          │ + consume()    │  │ + consume() │
          └────────────────┘  └─────────────┘
                   │                 │
                   │                 │
                   └────────┬────────┘
                            │
                      uses (composition)
                            │
                   ┌────────▼────────────┐
                   │     Vehicle         │
                   ├─────────────────────┤
                   │ - _power_source     │
                   ├─────────────────────┤
                   │ + power_level()     │
                   │ + drive()           │
                   └─────────────────────┘
from abc import ABC, abstractmethod


class PowerSource(ABC):
    def __init__(self, capacity: float):
        self._capacity = capacity
        self._level = capacity

    def level(self) -> float:
        return self._level

    @abstractmethod
    def consume(self, distance: float) -> float:
        pass


class FuelTank(PowerSource):
    def consume(self, distance: float) -> float:
        fuel_consumed = distance / 10  # Assume 10 km per liter for simplicity
        if self._level - fuel_consumed < 0:
            raise ValueError("Not enough fuel to cover the distance")
        self._level -= fuel_consumed
        return fuel_consumed


class Battery(PowerSource):
    def consume(self, distance: float) -> float:
        energy_consumed = distance / 5  # Assume 5 km per kWh for simplicity
        if self._level - energy_consumed < 0:
            raise ValueError("Not enough charge to cover the distance")
        self._level -= energy_consumed
        return energy_consumed


class Vehicle:
    def __init__(self, power_source: PowerSource):
        self._power_source = power_source

    def power_level(self) -> float:
        return self._power_source.level()

    def drive(self, distance: float) -> float:
        return self._power_source.consume(distance)


def drive_vehicle(vehicle: Vehicle, distance: float) -> None:
    try:
        energy_consumed = vehicle.drive(distance)
        print(f"Energy consumed: {energy_consumed:.2f} units")
    except ValueError as e:
        print(f"Unable to complete journey: {e}")


# Usage
fuel_car = Vehicle(FuelTank(50))  # 50 liter tank
drive_vehicle(fuel_car, 100)  # Prints: Energy consumed: 10.00 units

electric_car = Vehicle(Battery(50))  # 50 kWh battery
drive_vehicle(electric_car, 100)  # Prints: Energy consumed: 20.00 units

19.5 Dependency Inversion Principle (DIP)

19.5.1 Bad

class UserEntity:
    def __init__(self, user_id: str):
        self.user_id = user_id
        self.database = MySQLDatabase()  # Direct dependency on a low-level module

    def save(self):
        self.database.insert("users", {"id": self.user_id})


class MySQLDatabase:
    def insert(self, table: str, data: dict):
        print(f"Inserting {data} into {table} table in MySQL")

19.5.2 Good

                    ┌──────────────────────┐
                    │ <<abstract>>         │
                    │ DatabaseInterface    │
                    ├──────────────────────┤
                    │ + insert()           │
                    └──────────────────────┘
                            △
                            │ implements
           ┌────────────────┼─────────────────┐
           │                │                 │
   ┌───────────────┐ ┌──────────────┐ ┌──────────────┐
   │ MySQLDatabase │ │PostgreSQL... │ │MockDatabase  │
   ├───────────────┤ ├──────────────┤ ├──────────────┤
   │ + insert()    │ │ + insert()   │ │ + insert()   │
   └───────────────┘ └──────────────┘ └──────────────┘

   ┌─────────────────────┐
   │    UserEntity       │
   ├─────────────────────┤
   │ - user_id           │
   │ - database          │───depends on───> DatabaseInterface
   ├─────────────────────┤
   │ + save()            │
   └─────────────────────┘
from abc import ABC, abstractmethod


class DatabaseInterface(ABC):
    @abstractmethod
    def insert(self, table: str, data: dict):
        pass


class UserEntity:
    def __init__(self, user_id: str, database: DatabaseInterface):
        self.user_id = user_id
        self.database = database

    def save(self):
        self.database.insert("users", {"id": self.user_id})


class MySQLDatabase(DatabaseInterface):
    def insert(self, table: str, data: dict):
        print(f"Inserting {data} into {table} table in MySQL")


class PostgreSQLDatabase(DatabaseInterface):
    def insert(self, table: str, data: dict):
        print(f"Inserting {data} into {table} table in PostgreSQL")


# Usage
mysql_db = MySQLDatabase()
user = UserEntity("123", mysql_db)
user.save()
postgres_db = PostgreSQLDatabase()
another_user = UserEntity("456", postgres_db)
another_user.save()


class MockDatabase(DatabaseInterface):
    def __init__(self):
        self.inserted_data = []

    def insert(self, table: str, data: dict):
        self.inserted_data.append((table, data))


# In a test
mock_db = MockDatabase()
user = UserEntity("test_user", mock_db)
user.save()
assert mock_db.inserted_data == [("users", {"id": "test_user"})]