Causal inference dowhy – counterfactual analysis in medical cases

0x01. Background

In this example, we know the causal structure of the three observed variables, and we want to get counterfactual questions like “What would have happened if I had followed the doctor’s advice differently?”

More specifically, Alice, who suffers from severe dry eyes, decided to use the telemedicine platform because she was unable to see an ophthalmologist where she lived. She judged whether Alice had a rare allergy by reporting her medical history, and the platform finally recommended two possible eye drops for her with slightly different ingredients (“Option 1” and “Option 2”).

Alice does a quick search online and she finds that option 1 has a lot of positive reviews. Still, she decided to use the second method because her mother had used it in the past with good results. After a few days, Alice’s vision improved and her symptoms began to disappear. However, she was curious what would happen if she used the very popular option 1, or even did nothing at all.

The platform offers users the possibility of counterfactual questions, as long as they report the results of the options they follow.

0x02. Simulation data

We describe the SCM framework as follows,

f

p

1

,

p

2

f_{p1,p2}

fp1,p2? is the noise added to the model, expressed as:

V

i

the s

i

o

no

=

V

no

+

f

p

1

,

p

2

(

T

r

e

a

t

m

e

no

t

,

C

o

no

d

i

t

i

o

no

)

Vision=V_n + f_{p1,p2}(Treatment, Condition)

Vision=Vn? + fp1,p2?(Treatment,Condition). Our raw features for the three observed variables

N

T

,

N

C

,

N

V

N_T, N_C, N_V

NT?, NC?, NV? plus noise for sampling, the Vision of the target variable is

N

V

N_V

NV? plus the noise of its input node.

T

r

e

a

t

m

e

no

t

=

N

T

Treatment=N_T

The probability of Treatment=NT?~0, 1 or 2 is 33% respectively: 33% of users do nothing, 33% of users choose option 1, and 33% of users choose option 2. This is independent of whether a patient has a rare disease or not.

C

o

no

d

i

t

i

o

no

=

N

C

Condition=N_C

Condition=NC?~Bernoulli (0.01): Whether the patient has a rare disease.

V

i

the s

i

o

no

=

N

V

+

f

p

1

,

p

2

(

T

r

e

a

t

m

e

no

t

,

C

o

no

d

i

t

i

o

no

)

=

N

V

?

P

1

(

1

?

C

o

no

d

i

t

i

o

no

)

(

1

?

T

r

e

a

t

m

e

no

t

)

(

2

?

T

r

e

a

t

m

e

no

t

)

+

2

P

2

(

1

?

C

o

no

d

i

t

o

no

)

T

r

e

a

t

m

e

no

t

(

2

?

T

r

e

a

t

m

e

no

t

)

+

P

2

(

1

?

C

o

no

d

i

t

i

o

no

)

(

3

?

T

r

e

a

t

m

e

no

t

)

(

1

?

T

r

e

a

t

m

e

no

t

)

T

r

e

a

t

m

e

no

t

?

2

P

2

C

o

no

d

i

t

o

no

T

r

e

a

t

m

e

no

t

(

2

?

T

r

e

a

t

m

e

no

t

)

?

P

2

C

o

no

d

i

t

o

no

(

3

?

T

r

e

a

t

m

e

no

t

)

(

1

?

T

r

e

a

t

m

e

no

t

)

T

r

e

a

t

m

e

no

t

Vision=N_V + f_{p1,p2}(Treatment, Condition)=N_V-P_1(1-Condition)(1-Treatment)(2-Treatment) + 2P_2(1-Conditon)Treatment(2-Treatment) + P_2( 1-Condition)(3-Treatment)(1-Treatment)Treatment-2P_2ConditonTreatment(2-Treatment)-P_2Conditon(3-Treatment)(1-Treatment)Treatment

Vision=NV? + fp1,p2?(Treatment,Condition)=NVP1?(1?Condition)(1?Treatment)(2?Treatment) + 2P2?(1?Conditon)Treatment(2?Treatment) + P2?(1?Condition)(3?Treatment)(1?Treatment)Treatment?2P2?ConditonTreatment(2?Treatment)?P2?Conditon(3?Treatment)(1?Treatment)Treatment

P

1

P_1

P1? is a constant, and in the rare case that the patient does not have a decrease in raw vision, he is not taking any medication.

P

2

P_2

P2? is a constant that will increase or decrease raw vision accordingly depending on whether the patient has the disease and the type of drops they will be using. more specifically:

If Condition = 0 and Treatment = 1 then Vision = N_V + P_2

elIf Condition = 0 and Treatment = 2 then Vision = N_V - P_2

elIf Condition = 1 and Treatment = 1 then Vision = N_V - P_2

elIf Condition = 1 and Treatment = 2 then Vision = N_V + P_2

elIf Condition = 0 and Treatment = 0 then Vision = N_V - P_1

elif Condition = 1 and Treatment = 0 then Vision = N_V - P3

For such rare events, such as conditional (condition = 1, with a low probability of 1%), it is necessary to have a large number of samples to train the model in order to accurately reflect these rare events. That’s why we use 10000 samples here to generate the patient database.

Generate normal data:

from scipy.stats import bernoulli, norm, uniform
import numpy as np
from random import randint
import pandas as pd

n_unobserved = 10000
unobserved_data = {<!-- -->
   'N_T': np.array([randint(0, 2) for p in range(n_unobserved)]),
   'N_vision': np.random.uniform(0.4, 0.6, size=(n_unobserved,)),
   'N_C': bernoulli.rvs(0.01, size=n_unobserved)
}
P_1 = 0.2
P_2 = 0.15

def create_observed_medical_data(unobserved_data, name):
    observed_medical_data = {<!-- -->}
    observed_medical_data['Condition'] = unobserved_data['N_C']
    observed_medical_data['Treatment'] = unobserved_data['N_T']
    observed_medical_data['Vision'] = unobserved_data['N_vision'] + (-P_1)*(1 - observed_medical_data['Condition'])*(1 - observed_medical_data['Treatment'])*( 2 - observed_medical_data['Treatment']) + (2*P_2)*(1 - observed_medical_data['Condition'])*(observed_medical_data['Treatment'])*(2 - observed_medical_data['Treatment ']) + (P_2)*(1 - observed_medical_data['Condition'])*(observed_medical_data['Treatment'])*(1 - observed_medical_data['Treatment'])*(3 - observed_medical_data ['Treatment']) + 0*(observed_medical_data['Condition'])*(1 - observed_medical_data['Treatment'])*(2 - observed_medical_data['Treatment']) + (- 2*P_2)*(unobserved_data['N_C'])*(observed_medical_data['Treatment'])*(2 - observed_medical_data['Treatment']) + (-P_2)*(observed_medical_data[' Condition'])*(observed_medical_data['Treatment'])*(1 - observed_medical_data['Treatment'])*(3 - observed_medical_data['Treatment'])
    dfs = pd. DataFrame(observed_medical_data)
    dfs.to_csv(name, index=False)
    return pd. DataFrame(observed_medical_data)

medical_data = create_observed_medical_data(unobserved_data, 'patients_database.csv')

Generate exception data:

num_samples = 1
original_vision = np.random.uniform(0.4, 0.6, size=num_samples)
def generate_specific_patient_data(num_samples):
    return create_observed_medical_data({<!-- -->
    'N_T': np.full((num_samples,), 2),
    'N_C': bernoulli.rvs(1, size=num_samples),
    'N_vision': original_vision,
})

specific_patient_data = generate_specific_patient_data(num_samples, "newly_come_patients")

0x03. Read normal data

We have a database consisting of three observation variables: a continuous variable from 0 to 1 denoting vision quality (“Vision”), and a binary variable denoting whether a patient has a rare disease (i.e. allergy) (“condition” ), and a categorical variable (“Treatment”) that can take three values (0: “do nothing”, 1: “option 1” or 2: “option 2”). Data are as follows:

import pandas as pd

medical_data = pd.read_csv('patients_database.csv')
medical_data. head()

Data are as follows:

Condition Treatment Vision
0 0 2 0.223475
1 0 2 0.197306
2 0 0 0.101252
3 0 1 0.703056
4 0 0 0.020249
medical_data.iloc[0:100].plot(figsize=(15, 10))


The dataset reflects patients’ vision after taking one of three treatment options, based on whether they had the rare disease. Note that the dataset has no information about the patients’ raw vision before treatment (i.e. the noise of the vision variable). As we will see below, whenever we have post-nonlinear models (such as ANMs), the noisy parts of vision are recovered by counterfactual algorithms.

0x04. Modeling

We know that the “treatment” node and the “condition” node lead to vision, but we don’t know the structural causal model. However, we can learn about it from observational data. We assume that this graph correctly represents causality, and we assume that there are no hidden confounding factors (causal adequacy). Given a graph and data, we can fit a causal model and begin to answer counterfactual questions.

import networkx as nx
import dowhy.gcm as gcm

causal_model = gcm.InvertibleStructuralCausalModel(nx.DiGraph([('Treatment', 'Vision'), ('Condition', 'Vision')]))
gcm.auto.assign_causal_mechanisms(causal_model, medical_data)

gcm.util.plot(causal_model.graph)

gcm.fit(causal_model, medical_data)

0x05. Read abnormal data

specific_patient_data = pd.read_csv('newly_come_patients.csv')
specific_patient_data. head()

The output is as follows:

Condition Treatment Vision
0 1 2 0.857103

0x06. Answer Alice’s counterfactual question

If we want to check the outcome of a hypothesis, if an event didn’t happen, or happened in a different way, we use so-called counterfactual logic based on structural causal models. Considering — we know Alice’s treatment option is the second. Alice suffers from a rare allergy (condition=1). After treatment 2, Alice’s visual acuity was 0.78 (vision = 0.78). – We are able to recover noise based on the learned structural causal model.

We can now check the counterfactual results of her vision if the healing nodes are different. Below, we look at the counterfactual value of Alice’s vision if she does not receive any treatment (treatment = 0) and if she takes other eye drops (treatment = 1).

counterfactual_data1 = gcm.counterfactual_samples(causal_model,
                                                  {<!-- -->'Treatment': lambda x: 1},
                                                  observed_data = specific_patient_data)

counterfactual_data2 = gcm.counterfactual_samples(causal_model,
                                                  {<!-- -->'Treatment': lambda x: 0},
                                                  observed_data = specific_patient_data)


import matplotlib.pyplot as plt

df_plot2 = pd. DataFrame()
df_plot2['Vision after option 2'] = specific_patient_data['Vision']
df_plot2['Counterfactual vision (option 1)'] = counterfactual_data1['Vision']
df_plot2['Counterfactual vision (No treatment)'] = counterfactual_data2['Vision']

df_plot2.plot.bar(title="Counterfactual outputs")
plt.xlabel('Alice')
plt.ylabel('Eyesight quality')
plt. legend()

The effect is as follows:

What we see here is that if Alice chooses option 1, her eyesight will be worse than option 2. Therefore, she realized that the rare condition (condition = 1) she reported in her medical history could cause an allergic reaction to popular option 1. Alice can also see that if she does not choose any of the recommended options, her eyesight will be worse than if she chooses option 2 (the variable Vision results in a smaller relative value).