-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
145 lines (106 loc) · 4.93 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#%%
import numpy as np # linear algebra
import pydicom
import os
import scipy.ndimage
import matplotlib.pyplot as plt
from skimage import measure, morphology
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
INPUT_FOLDER = '/nfs3-p1/zsxm/dataset/aortic_dissection/CTA1(+)/+C/'
# Load the scans in given folder path
def load_scan(path):
slices = [pydicom.dcmread(path + '/' + s) for s in os.listdir(path)]
slices.sort(key = lambda x: float(x.InstanceNumber))
try:
slice_thickness = np.abs(slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2])
except:
slice_thickness = np.abs(slices[0].SliceLocation - slices[1].SliceLocation)
for s in slices:
s.SliceThickness = slice_thickness
return slices
def get_pixels_hu(slices):
image = np.stack([s.pixel_array for s in slices])
# Convert to int16 (from sometimes int16),
# should be possible as values should always be low enough (<32k)
image = image.astype(np.int16)
# Set outside-of-scan pixels to 0
# The intercept is usually -1024, so air is approximately 0
image[image == -2000] = 0
# Convert to Hounsfield units (HU)
for slice_number in range(len(slices)):
intercept = slices[slice_number].RescaleIntercept
slope = slices[slice_number].RescaleSlope
if slope != 1:
image[slice_number] = slope * image[slice_number].astype(np.float64)
image[slice_number] = image[slice_number].astype(np.int16)
image[slice_number] += np.int16(intercept)
return np.array(image, dtype=np.int16)
def set_window(image, slices, w_center, w_width):
image_copy = image.copy()
#image_copy = np.clip(image_copy, w_center-int(w_width/2), w_center+int(w_width/2))
for slice_number in range(len(slices)):
image_copy[slice_number][image_copy[slice_number]>w_center+int(w_width/2)] = np.int16(slices[slice_number].RescaleIntercept)
image_copy[slice_number][image_copy[slice_number]<w_center-int(w_width/2)] = np.int16(slices[slice_number].RescaleIntercept)
return image_copy
patient = load_scan(INPUT_FOLDER)
patient_pixels = get_pixels_hu(patient)
# plt.hist(patient_pixels.flatten(), bins=80, color='c')
# plt.xlabel("Hounsfield Units (HU)")
# plt.ylabel("Frequency")
# plt.show()
# # Show some slice in the middle
# plt.imshow(patient_pixels[80], cmap=plt.cm.gray)
# plt.show()
def resample(image, scan, new_spacing=[1,1,1]):
# Determine current pixel spacing
spacing = np.array([scan[0].SliceThickness] + list(scan[0].PixelSpacing), dtype=np.float32)
resize_factor = spacing / new_spacing
new_real_shape = image.shape * resize_factor
new_shape = np.round(new_real_shape)
real_resize_factor = new_shape / image.shape
new_spacing = spacing / real_resize_factor
image = scipy.ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest')
return image, new_spacing
pix_resampled, spacing = resample(patient_pixels, patient, [1,1,1])
print("Shape before resampling\t", patient_pixels.shape)
print("Shape after resampling\t", pix_resampled.shape)
def plot_3d(image, threshold=-300):
# Position the scan upright,
# so the head of the patient would be at the top facing the camera
p = image.transpose(2,1,0)
verts, faces = measure.marching_cubes_classic(p, threshold)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
# Fancy indexing: `verts[faces]` to generate a collection of triangles
mesh = Poly3DCollection(verts[faces], alpha=0.70)
face_color = [0.45, 0.45, 0.75]
mesh.set_facecolor(face_color)
ax.add_collection3d(mesh)
ax.set_xlim(0, p.shape[0])
ax.set_ylim(0, p.shape[1])
ax.set_zlim(0, p.shape[2])
plt.show()
plot_3d(pix_resampled, 400)
# %%
from train import create_net
from utils.eval import eval_net
import torch
import os
from PIL import Image
from torchvision.datasets import ImageFolder
from torchvision import transforms as T
from torch.utils.data import DataLoader
from utils.datasets import LabelSampler
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = create_net(device, load_model='checkpoints/10-18_13:44:44/Net_best.pth')
transform = T.Compose([
T.Resize(51), # 缩放图片(Image),保持长宽比不变,最短边为img_size像素
T.CenterCrop(51), # 从图片中间切出img_size*img_size的图片
T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
#T.Normalize(mean=[.5], std=[.5]) # 标准化至[-1, 1],规定均值和标准差
])
val = ImageFolder('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/val', transform=transform, loader=lambda path: Image.open(path))
val_loader = DataLoader(val, batch_size=128, sampler=LabelSampler(val), num_workers=8, pin_memory=True, drop_last=False)
eval_net(net, val_loader, len(val), device, True, PR_curve_save_dir='./')
# %%