Source code for dask_kubernetes.common.objects

"""
Convenience functions for creating pod templates.
"""

import copy
import json
from collections import namedtuple

from kubernetes import client
from kubernetes.client.configuration import Configuration

from dask_kubernetes.constants import (
    KUBECLUSTER_CONTAINER_NAME,
    MAX_CLUSTER_NAME_LEN,
    VALID_CLUSTER_NAME,
)
from dask_kubernetes.exceptions import ValidationError

_FakeResponse = namedtuple("_FakeResponse", ["data"])


class DummyApiClient(client.ApiClient):
    """A Dummy API client that is to be used solely for serialization/deserialization.

    This is to avoid starting a threadpool at initialization and for adapting the
    deserialize method to accept a python dictionary instead of a Response-like
    interface.
    """

    def __init__(self):
        self.configuration = Configuration.get_default_copy()

    def deserialize(self, dict_, klass):
        return super().deserialize(_FakeResponse(json.dumps(dict_)), klass)


SERIALIZATION_API_CLIENT = DummyApiClient()


def _set_k8s_attribute(obj, attribute, value):
    """
    Set a specific value on a kubernetes object's attribute

    obj
        an object from Kubernetes Python API client
    attribute
        Should be a Kubernetes API style attribute (with camelCase)
    value
        Can be anything (string, list, dict, k8s objects) that can be
        accepted by the k8s python client
    """
    current_value = None
    attribute_name = None
    # All k8s python client objects have an 'attribute_map' property
    # which has as keys python style attribute names (api_client)
    # and as values the kubernetes JSON API style attribute names
    # (apiClient). We want to allow users to use the JSON API style attribute
    # names only.
    for python_attribute, json_attribute in obj.attribute_map.items():
        if json_attribute == attribute:
            attribute_name = python_attribute
            break
    else:
        raise ValueError(
            "Attribute must be one of {}".format(obj.attribute_map.values())
        )

    if hasattr(obj, attribute_name):
        current_value = getattr(obj, attribute_name)

    if current_value is not None:
        # This will ensure that current_value is something JSONable,
        # so a dict, list, or scalar
        current_value = SERIALIZATION_API_CLIENT.sanitize_for_serialization(
            current_value
        )

    if isinstance(current_value, dict):
        # Deep merge our dictionaries!
        setattr(obj, attribute_name, merge_dictionaries(current_value, value))
    elif isinstance(current_value, list):
        # Just append lists
        setattr(obj, attribute_name, current_value + value)
    else:
        # Replace everything else
        setattr(obj, attribute_name, value)


def merge_dictionaries(a, b, path=None, update=True):
    """
    Merge two dictionaries recursively.

    From https://stackoverflow.com/a/25270947
    """
    if path is None:
        path = []
    for key in b:
        if key in a:
            if isinstance(a[key], dict) and isinstance(b[key], dict):
                merge_dictionaries(a[key], b[key], path + [str(key)])
            elif a[key] == b[key]:
                pass  # same leaf value
            elif isinstance(a[key], list) and isinstance(b[key], list):
                for idx, _ in enumerate(b[key]):
                    a[key][idx] = merge_dictionaries(
                        a[key][idx],
                        b[key][idx],
                        path + [str(key), str(idx)],
                        update=update,
                    )
            elif update:
                a[key] = b[key]
            else:
                raise Exception("Conflict at %s" % ".".join(path + [str(key)]))
        else:
            a[key] = b[key]
    return a


[docs]def make_pod_spec( image, labels={}, threads_per_worker=1, env={}, extra_container_config={}, extra_pod_config={}, resources=None, memory_limit=None, memory_request=None, cpu_limit=None, cpu_request=None, gpu_limit=None, annotations={}, ): """ Create generic pod template from input parameters Parameters ---------- image : str Docker image name labels : dict Dict of labels to pass to ``V1ObjectMeta`` threads_per_worker : int Number of threads per each worker env : dict Dict of environment variables to pass to ``V1Container`` extra_container_config : dict Extra config attributes to set on the container object extra_pod_config : dict Extra config attributes to set on the pod object resources : str Resources for task constraints like "GPU=2 MEM=10e9". Resources are applied separately to each worker process (only relevant when starting multiple worker processes. Passed to the `--resources` option in ``dask-worker``. memory_limit : int, float, or str Bytes of memory per process that the worker can use (applied to both ``dask-worker --memory-limit`` and ``spec.containers[].resources.limits.memory``). This can be: - an integer (bytes), note 0 is a special case for no memory management. - a float (bytes). Note: fraction of total system memory is not supported by k8s. - a string (like 5GiB or 5000M). Note: 'GB' is not supported by k8s. - 'auto' for automatically computing the memory limit. [default: auto] memory_request : int, float, or str Like ``memory_limit`` (applied only to ``spec.containers[].resources.requests.memory`` and ignored by ``dask-worker``). cpu_limit : float or str CPU resource limits (applied to ``spec.containers[].resources.limits.cpu``). cpu_request : float or str CPU resource requests (applied to ``spec.containers[].resources.requests.cpu``). gpu_limit : int GPU resource limits (applied to ``spec.containers[].resources.limits."nvidia.com/gpu"``). annotations : dict Dict of annotations passed to ``V1ObjectMeta`` Returns ------- pod : V1PodSpec Examples -------- >>> make_pod_spec(image='ghcr.io/dask/dask:latest', memory_limit='4G', memory_request='4G') """ args = [ "dask-worker", "$(DASK_SCHEDULER_ADDRESS)", "--nthreads", str(threads_per_worker), "--death-timeout", "60", ] if memory_limit: args.extend(["--memory-limit", str(memory_limit)]) if resources: args.extend(["--resources", str(resources)]) pod = client.V1Pod( metadata=client.V1ObjectMeta(labels=labels, annotations=annotations), spec=client.V1PodSpec( restart_policy="Never", containers=[ client.V1Container( name=KUBECLUSTER_CONTAINER_NAME, image=image, args=args, env=[client.V1EnvVar(name=k, value=v) for k, v in env.items()], ) ], ), ) resources = client.V1ResourceRequirements(limits={}, requests={}) if cpu_request: resources.requests["cpu"] = cpu_request if memory_request: resources.requests["memory"] = memory_request if cpu_limit: resources.limits["cpu"] = cpu_limit if gpu_limit: resources.limits["nvidia.com/gpu"] = gpu_limit if memory_limit: resources.limits["memory"] = memory_limit pod.spec.containers[0].resources = resources for key, value in extra_container_config.items(): _set_k8s_attribute(pod.spec.containers[0], key, value) for key, value in extra_pod_config.items(): _set_k8s_attribute(pod.spec, key, value) return pod
def make_pod_from_dict(dict_): containers = dict_.get("spec", {}).get("containers", []) for i, container in enumerate(containers): container.setdefault("name", f"dask-{i}") return SERIALIZATION_API_CLIENT.deserialize(dict_, client.V1Pod) def make_service_from_dict(dict_): return SERIALIZATION_API_CLIENT.deserialize(dict_, client.V1Service) def make_pdb_from_dict(dict_): return SERIALIZATION_API_CLIENT.deserialize(dict_, client.V1PodDisruptionBudget) def clean_pod_template( pod_template, apply_default_affinity="preferred", pod_type="worker" ): """Normalize pod template""" pod_template = copy.deepcopy(pod_template) # Make sure metadata / labels / env objects exist, so they can be modified # later without a lot of `is None` checks if pod_template.metadata is None: pod_template.metadata = client.V1ObjectMeta() if pod_template.metadata.labels is None: pod_template.metadata.labels = {} if pod_template.spec.containers[0].env is None: pod_template.spec.containers[0].env = [] # add default tolerations tolerations = [ client.V1Toleration( key="k8s.dask.org/dedicated", operator="Equal", value=pod_type, effect="NoSchedule", ), # GKE currently does not permit creating taints on a node pool # with a `/` in the key field client.V1Toleration( key="k8s.dask.org_dedicated", operator="Equal", value=pod_type, effect="NoSchedule", ), ] if pod_template.spec.tolerations is None: pod_template.spec.tolerations = tolerations else: pod_template.spec.tolerations.extend(tolerations) # add default node affinity to k8s.dask.org/node-purpose=worker if apply_default_affinity != "none": # for readability affinity = pod_template.spec.affinity if affinity is None: affinity = client.V1Affinity() if affinity.node_affinity is None: affinity.node_affinity = client.V1NodeAffinity() # a common object for both a preferred and a required node affinity node_selector_term = client.V1NodeSelectorTerm( match_expressions=[ client.V1NodeSelectorRequirement( key="k8s.dask.org/node-purpose", operator="In", values=[pod_type] ) ] ) if apply_default_affinity == "required": if ( affinity.node_affinity.required_during_scheduling_ignored_during_execution is None ): affinity.node_affinity.required_during_scheduling_ignored_during_execution = client.V1NodeSelector( node_selector_terms=[] ) affinity.node_affinity.required_during_scheduling_ignored_during_execution.node_selector_terms.append( node_selector_term ) elif apply_default_affinity == "preferred": if ( affinity.node_affinity.preferred_during_scheduling_ignored_during_execution is None ): affinity.node_affinity.preferred_during_scheduling_ignored_during_execution = ( [] ) preferred_scheduling_terms = [ client.V1PreferredSchedulingTerm( preference=node_selector_term, weight=100 ) ] affinity.node_affinity.preferred_during_scheduling_ignored_during_execution.extend( preferred_scheduling_terms ) else: raise ValueError( 'Attribute apply_default_affinity must be one of "none", "preferred", or "required".' ) pod_template.spec.affinity = affinity return pod_template def clean_service_template(service_template): """Normalize service template and check for type errors""" service_template = copy.deepcopy(service_template) # Make sure metadata / labels objects exist, so they can be modified # later without a lot of `is None` checks if service_template.metadata is None: service_template.metadata = client.V1ObjectMeta() if service_template.metadata.labels is None: service_template.metadata.labels = {} return service_template def clean_pdb_template(pdb_template): """Normalize pdb template and check for type errors""" pdb_template = copy.deepcopy(pdb_template) # Make sure metadata / labels objects exist, so they can be modified # later without a lot of `is None` checks if pdb_template.metadata is None: pdb_template.metadata = client.V1ObjectMeta() if pdb_template.metadata.labels is None: pdb_template.metadata.labels = {} if pdb_template.spec.selector is None: pdb_template.spec.selector = client.V1LabelSelector() return pdb_template def validate_cluster_name(cluster_name: str) -> None: """Raise exception if cluster name is too long and/or has invalid characters""" if not VALID_CLUSTER_NAME.match(cluster_name): raise ValidationError( message=( f"The DaskCluster {cluster_name} is invalid: a lowercase RFC 1123 subdomain must " "consist of lower case alphanumeric characters, '-' or '.', and must start " "and end with an alphanumeric character. DaskCluster name must also be under " f"{MAX_CLUSTER_NAME_LEN} characters." ) )