3  Baseline Characteristics Table

This article demonstrates how to create a baseline characteristics table for clinical study reports using rtflite.

3.1 Overview

Baseline characteristics tables summarize demographic and clinical characteristics of study participants at enrollment. These tables are essential for understanding the study population and assessing comparability between treatment groups.

3.2 Imports

import polars as pl
import rtflite as rtf
from importlib.resources import files

3.3 Data Preparation

data_path = files("rtflite.data").joinpath("adsl.parquet")

adsl_baseline = (
    pl.read_parquet(data_path)
    .filter(pl.col("SAFFL") == "Y")
    .select(["USUBJID", "TRT01P", "AGE", "SEX", "RACE"])
    .with_columns([
        pl.col("SEX").replace({"F": "Female", "M": "Male"}),
        pl.col("RACE").str.to_titlecase()
    ])
)

3.4 Statistics Function

def get_statistics(df, var, is_continuous=False):
    expr = [
            pl.col(var).mean().round(1).alias("mean"),
            pl.col(var).std().round(1).alias("sd"),
            pl.col(var).median().round(1).alias("median"),
            pl.col(var).min().alias("min"),
            pl.col(var).max().alias("max")
        ]

    if is_continuous:
        # Continuous statistics by treatment
        by_treatment = df.group_by("TRT01P").agg(expr)
        
        # Overall statistics
        overall = df.select(expr).row(0)
        
        return by_treatment, overall
    else:
        # Categorical counts and percentages
        total_n = df.height
        by_treatment = (
            df.group_by(["TRT01P", var])
            .len()
            .join(df.group_by("TRT01P").len().rename({"len": "total"}), on="TRT01P")
            .with_columns(
                pl.format("{} ({}%)", 
                    pl.col("len"), 
                    (100 * pl.col("len") / pl.col("total")).round(1)
                ).alias("formatted")
            )
        )
        
        # Overall counts
        overall = (
            df.group_by(var)
            .len()
            .with_columns(
                pl.format("{} ({}%)", 
                    pl.col("len"), 
                    (100 * pl.col("len") / total_n).round(1)
                ).alias("formatted")
            )
        )
        
        return by_treatment, overall

3.5 Build Table Data

# Treatment groups and counts
treatments = ["Placebo", "Xanomeline Low Dose", "Xanomeline High Dose"]
treatment_counts = dict(
    adsl_baseline.group_by("TRT01P").len().iter_rows()
)

def create_variable_rows(df, var_name, categories=None, is_continuous=False):
    rows = [[var_name, "", "", "", ""]]
    
    by_treatment, overall_stats = get_statistics(df, var_name, is_continuous)
    
    if is_continuous:
        # Mean (SD) row
        row = ["  Mean (SD)"]
        for trt in treatments:
            trt_stats = by_treatment.filter(pl.col("TRT01P") == trt)
            if trt_stats.height > 0:
                mean, sd = trt_stats.select(["mean", "sd"]).row(0)
                row.append(f"{mean} ({sd})")
            else:
                row.append("")
        row.append(f"{overall_stats[0]} ({overall_stats[1]})")
        rows.append(row)
        
        # Median [Min, Max] row  
        row = ["  Median [Min, Max]"]
        for trt in treatments:
            trt_stats = by_treatment.filter(pl.col("TRT01P") == trt)
            if trt_stats.height > 0:
                median, min_val, max_val = trt_stats.select(["median", "min", "max"]).row(0)
                row.append(f"{median} [{min_val}, {max_val}]")
            else:
                row.append("")
        row.append(f"{overall_stats[2]} [{overall_stats[3]}, {overall_stats[4]}]")
        rows.append(row)
    else:
        # Categorical variable rows
        for cat in categories:
            row = [f"  {cat}"]
            
            for trt in treatments:
                trt_data = by_treatment.filter(
                    (pl.col("TRT01P") == trt) & (pl.col(var_name) == cat)
                )
                if trt_data.height > 0:
                    row.append(trt_data["formatted"][0])
                else:
                    row.append("0 (0.0%)")
            
            # Overall column
            overall_data = overall_stats.filter(pl.col(var_name) == cat)
            if overall_data.height > 0:
                row.append(overall_data["formatted"][0])
            else:
                row.append("0 (0.0%)")
            rows.append(row)
    
    return rows

# Build complete table
table_data = []
table_data.extend(create_variable_rows(adsl_baseline, "SEX", ["Female", "Male"]))
table_data.extend(create_variable_rows(adsl_baseline, "AGE", is_continuous=True))
table_data.extend(create_variable_rows(
    adsl_baseline, "RACE", 
    ["Black Or African American", "White", "American Indian Or Alaska Native"]
))

df_baseline = pl.DataFrame(table_data, orient="row")

df_baseline

3.6 Create RTF Output

# Column headers with N counts
col_headers = [""] + [f"{trt}\n(N={treatment_counts[trt]})" for trt in treatments] + [f"Overall\n(N={adsl_baseline.height})"]

doc_baseline = rtf.RTFDocument(
    df=df_baseline,
    rtf_title=rtf.RTFTitle(
        text=["Baseline Characteristics of Participants", "(All Participants Randomized)"]
    ),
    rtf_column_header=rtf.RTFColumnHeader(
        text=col_headers,
        text_justification=["l"] + ["c"] * 4
    ),
    rtf_body=rtf.RTFBody(
        col_rel_width=[2] + [1] * 4,
        text_justification=["l"] + ["c"] * 4
    )
)

doc_baseline.write_rtf("../rtf/tlf_baseline.rtf")