jobs_data = [  # task = (machine_id, processing_time).
    [(0, 3), (1, 2), (2, 2)],  # Job 0
    [(0, 2), (2, 1), (1, 4)],  # Job 1
    [(1, 4), (2, 3)]           # Job 2
]

from ortools.sat.python import cp_model
from collections import defaultdict

MACHINES = 3
TASK_TIMES_SUM = sum(d for job in jobs_data for (_, d) in job)
model = cp_model.CpModel()

# maps job_id -> list of tasks (intervals), in order
job_tasks = defaultdict(list)

# maps machine_no -> list of tasks (intervals)
machine_tasks = defaultdict(list)

# Define tasks (intervals)
for job_id, job in enumerate(jobs_data):
    for task_no, (machine_no, duration) in enumerate(job):
        start = model.NewIntVar(
            0, TASK_TIMES_SUM, f"start_job_{job_id}_task_{task_no}")
        end = model.NewIntVar(
            0, TASK_TIMES_SUM, f"end_job_{job_id}_task_{task_no}")
        task = model.NewIntervalVar(
            start, duration, end, f"job_{job_id}_task_{task_no}")
        
        task.start = start
        task.end = end 
        task.duration = duration 
        task.job_id = job_id 
        task.task_no = task_no 

        job_tasks[job_id].append(task)
        machine_tasks[machine_no].append(task)


# Add ordering constraint for each job
for job in job_tasks.values():
    for task_no in range(1, len(job)):
        model.Add(job[task_no-1].end <= job[task_no].start)


# Each machine can only run 1 task at a time
for m in range(MACHINES):
    model.AddNoOverlap(machine_tasks[m])


# Minimize makespan
makespan = model.NewIntVar(0, TASK_TIMES_SUM, "makespan")
model.AddMaxEquality(
    makespan,
    [job[-1].end for job in job_tasks.values()]
)
model.Minimize(makespan)


# Solve
solver = cp_model.CpSolver()
if solver.Solve(model) == cp_model.OPTIMAL:
    for m in range(MACHINES):
        print(f"Machine {m}")
        for task in sorted(
            machine_tasks[m], key=lambda t: solver.Value(t.start)):
            print(
                f"\t- Run job {task.job_id}, task {task.task_no}"
                f" for [{solver.Value(task.start)},"
                f"{solver.Value(task.end)}]"
            )
    print(f"Makespan: {solver.Value(makespan)}")
