Source code for workgraph_collections.qe.bands

"""BandsWorkGraph."""

from aiida import orm
from aiida_workgraph import WorkGraph, task, build_task
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain
from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import (
    seekpath_structure_analysis,
)

# we build a SeekpathTask Node
# Add only two outputs port here, because we only use these outputs in the following.
SeekpathTask = build_task(
    seekpath_structure_analysis,
    outputs=[
        {"name": "primitive_structure"},
        {"name": "explicit_kpoints"},
    ],
)


[docs]@task() def inspect_relax(parameters): """Inspect relax calculation.""" return orm.Int(parameters.get_dict()["number_of_bands"])
[docs]@task.calcfunction() def update_scf_parameters(parameters, current_number_of_bands=None): """Update scf parameters.""" parameters = parameters.get_dict() parameters.setdefault("SYSTEM", {}).setdefault("nbnd", current_number_of_bands) return orm.Dict(parameters)
[docs]@task.calcfunction() def update_bands_parameters(parameters, scf_parameters, nbands_factor=None): """Update bands parameters.""" parameters = parameters.get_dict() parameters.setdefault("SYSTEM", {}) scf_parameters = scf_parameters.get_dict() if nbands_factor: factor = nbands_factor.value nbands = int(scf_parameters["number_of_bands"]) nelectron = int(scf_parameters["number_of_electrons"]) nbnd = max(int(0.5 * nelectron * factor), int(0.5 * nelectron) + 4, nbands) parameters["SYSTEM"]["nbnd"] = nbnd # Otherwise set the current number of bands, unless explicitly set in the inputs else: parameters["SYSTEM"].setdefault("nbnd", scf_parameters["number_of_bands"]) return orm.Dict(parameters)
[docs]@task.graph_builder() def bands_workgraph( structure: orm.StructureData = None, code: orm.Code = None, pseudo_family: str = None, pseudos: dict = None, inputs: dict = None, run_relax: bool = False, bands_kpoints_distance: float = None, nbands_factor: float = None, ) -> WorkGraph: """BandsWorkGraph.""" inputs = {} if inputs is None else inputs # Initialize some variables which can be overridden in the following bands_kpoints = None current_number_of_bands = None # Load the pseudopotential family. if pseudo_family is not None: pseudo_family = orm.load_group(pseudo_family) pseudos = pseudo_family.get_pseudos(structure=structure) # Initialize the workgraph wg = WorkGraph("BandsStructure") # ------- relax ----------- if run_relax: relax_task = wg.tasks.new(PwRelaxWorkChain, name="relax", structure=structure) # retrieve the relax inputs from the inputs, and set the relax inputs relax_inputs = inputs.get("relax", {}) relax_inputs.update( { "base.pw.code": code, "base.pw.pseudos": pseudos, } ) relax_task.set(relax_inputs) # override the input structure with the relaxed structure structure = relax_task.outputs["output_structure"] # -------- inspect_relax ----------- inspect_relax_task = wg.tasks.new( inspect_relax, name="inspect_relax", parameters=relax_task.outputs["output_parameters"], ) current_number_of_bands = inspect_relax_task.outputs["result"] # -------- seekpath ----------- if bands_kpoints_distance is not None: seekpath_task = wg.tasks.new( SeekpathTask, name="seekpath", structure=structure, kwargs={"reference_distance": orm.Float(bands_kpoints_distance)}, ) structure = seekpath_task.outputs["primitive_structure"] # override the bands_kpoints bands_kpoints = seekpath_task.outputs["explicit_kpoints"] # -------- scf ----------- # retrieve the scf inputs from the inputs, and update the scf parameters scf_inputs = inputs.get("scf", {"pw": {}}) scf_parameters = wg.tasks.new( update_scf_parameters, name="scf_parameters", parameters=scf_inputs["pw"].get("parameters", {}), current_number_of_bands=current_number_of_bands, ) scf_task = wg.tasks.new(PwBaseWorkChain, name="scf") # update inputs scf_inputs.update( { "pw.code": code, "pw.structure": structure, "pw.pseudos": pseudos, "pw.parameters": scf_parameters.outputs[0], } ) scf_task.set(scf_inputs) # -------- bands ----------- bands_inputs = inputs.get("bands", {"pw": {}}) bands_parameters = wg.tasks.new( update_bands_parameters, name="bands_parameters", parameters=bands_inputs["pw"].get("parameters", {}), nbands_factor=nbands_factor, scf_parameters=scf_task.outputs["output_parameters"], ) bands_task = wg.tasks.new(PwBaseWorkChain, name="bands", kpoints=bands_kpoints) bands_inputs.update( { "pw.code": code, "pw.structure": structure, "pw.pseudos": pseudos, "pw.parent_folder": scf_task.outputs["remote_folder"], "pw.parameters": bands_parameters.outputs[0], } ) bands_task.set(bands_inputs) return wg