Skip to content

Causal Inference

splita provides 19 causal inference classes for when standard A/B testing is not possible or sufficient. These methods handle observational data, interference, cluster randomization, and more.

DifferenceInDifferences (DiD)

Classic two-period DiD for when you cannot randomize at the user level.

from splita import DifferenceInDifferences
import numpy as np

rng = np.random.default_rng(42)

# Pre and post measurements for control and treatment groups
pre_ctrl = rng.normal(10, 2, 500)
post_ctrl = rng.normal(10.5, 2, 500)   # natural trend
pre_trt = rng.normal(10, 2, 500)
post_trt = rng.normal(12, 2, 500)      # trend + treatment effect

did = DifferenceInDifferences()
result = did.fit(pre_ctrl, post_ctrl, pre_trt, post_trt)
print(result.ate)        # ~1.0 (treatment effect net of trend)
print(result.ci)         # confidence interval
print(result.pvalue)
print(result.parallel_trends_pvalue)  # test the key assumption

SyntheticControl

Construct a synthetic counterfactual from a weighted combination of donor units.

from splita import SyntheticControl
import numpy as np

rng = np.random.default_rng(42)

# Time series: 20 pre-treatment periods, 10 post-treatment
n_pre, n_post = 20, 10
n_donors = 5

treated = np.concatenate([
    rng.normal(50, 2, n_pre),
    rng.normal(55, 2, n_post),  # treatment effect of ~5
])
donors = rng.normal(50, 2, (n_donors, n_pre + n_post))

sc = SyntheticControl()
result = sc.fit(treated, donors, n_pre=n_pre)
print(result.ate)          # ~5.0
print(result.weights)      # donor weights
print(result.ci)

PropensityScoreMatching

For observational data where treatment assignment is non-random:

from splita import PropensityScoreMatching
import numpy as np

rng = np.random.default_rng(42)
n = 2000

# Covariates
X = rng.normal(0, 1, (n, 3))

# Treatment assignment depends on covariates (confounding)
propensity = 1 / (1 + np.exp(-(X[:, 0] + 0.5 * X[:, 1])))
treatment = rng.binomial(1, propensity)

# Outcome depends on treatment AND covariates
y = 2.0 * treatment + X[:, 0] + 0.5 * X[:, 1] + rng.normal(0, 1, n)

psm = PropensityScoreMatching()
result = psm.fit(y, treatment, X)
print(result.ate)     # ~2.0
print(result.ci)
print(result.n_matched)

ClusterExperiment

For cluster-randomized experiments (e.g., by city, store, or school):

from splita import ClusterExperiment
import numpy as np

rng = np.random.default_rng(42)

# 50 clusters, ~100 users each
n_clusters = 50
cluster_sizes = rng.integers(80, 120, n_clusters)

ctrl_clusters = [rng.normal(10, 2, s) for s in cluster_sizes[:25]]
trt_clusters = [rng.normal(11, 2, s) for s in cluster_sizes[25:]]

ctrl = np.concatenate(ctrl_clusters)
trt = np.concatenate(trt_clusters)
ctrl_ids = np.concatenate([np.full(s, i) for i, s in enumerate(cluster_sizes[:25])])
trt_ids = np.concatenate([np.full(s, i + 25) for i, s in enumerate(cluster_sizes[25:])])

ce = ClusterExperiment()
result = ce.fit(ctrl, trt, ctrl_ids, trt_ids)
print(result.ate)
print(result.ci)
print(result.icc)  # intra-cluster correlation

SwitchbackExperiment

For time-based switchback designs where treatment alternates over time:

from splita import SwitchbackExperiment
import numpy as np

rng = np.random.default_rng(42)

# 100 time slots, alternating treatment
n_slots = 100
treatment_mask = np.array([i % 2 for i in range(n_slots)])
outcomes = rng.normal(10, 2, n_slots) + 1.5 * treatment_mask

sb = SwitchbackExperiment()
result = sb.fit(outcomes, treatment_mask)
print(result.ate)   # ~1.5
print(result.ci)

DoublyRobustEstimator

Augmented Inverse Propensity Weighting (AIPW) -- robust to misspecification of either the outcome model or the propensity model:

from splita import DoublyRobustEstimator
import numpy as np

rng = np.random.default_rng(42)
n = 2000
X = rng.normal(0, 1, (n, 3))
treatment = rng.binomial(1, 0.5, n)
y = 2.0 * treatment + X[:, 0] + rng.normal(0, 1, n)

dr = DoublyRobustEstimator()
result = dr.fit(y, treatment, X)
print(result.ate)
print(result.ci)

TMLE

Targeted Maximum Likelihood Estimation:

from splita import TMLE
import numpy as np

rng = np.random.default_rng(42)
n = 2000
X = rng.normal(0, 1, (n, 3))
treatment = rng.binomial(1, 0.5, n)
y = 2.0 * treatment + X[:, 0] + rng.normal(0, 1, n)

tmle = TMLE()
result = tmle.fit(y, treatment, X)
print(result.ate)
print(result.ci)

Other causal methods

Class Description Use case
SurrogateEstimator Short-term to long-term effect prediction When you can't wait for the long-term outcome
SurrogateIndex Multi-surrogate index with cross-fitting Multiple surrogate metrics
InterferenceExperiment Network interference with Horvitz-Thompson Social networks, marketplaces
BipartiteExperiment Cross-side exposure mapping Two-sided platforms
ContinuousTreatmentEffect Dose-response curves Variable treatment intensity
DynamicCausalEffect Time-varying treatment effects Treatment effects that change over time
EffectTransport Transport effects across populations Generalizing from one population to another
GeoExperiment Bayesian synthetic control for geo experiments Marketing geo-tests
InstrumentalVariables Two-stage least squares Non-compliance, encouragement designs
MarketplaceExperiment Buyer/seller randomization bias Two-sided marketplaces
MediationAnalysis Baron-Kenny mediation (ACME) Understanding why a treatment works
RegressionDiscontinuity Sharp RDD with local linear regression Policy cutoffs (age, score thresholds)

When to use which

Scenario Recommended method
Standard A/B test with clusters ClusterExperiment
Cannot randomize; have pre/post data DifferenceInDifferences
Cannot randomize; have covariates PropensityScoreMatching or DoublyRobustEstimator
Geo-level marketing test GeoExperiment or SyntheticControl
Time-based alternating design SwitchbackExperiment
Network effects between users InterferenceExperiment
Need robust estimation (either model can be wrong) DoublyRobustEstimator or TMLE