import
tensorflow as tf
import
os
import
time
from
matplotlib
import
pyplot as plt
from
IPython
import
display
URL
=
"
https:
/
/
people.eecs.berkeley.edu
/
~tinghuiz
/
projects
/
pix2pix
/
datasets
/
facades.tar.gz & quot
path_to_zip
=
tf.keras.utils.get_file(
'facades.tar.gz'
,
origin
=
URL,
extract
=
True
)
PATH
=
os.path.join(os.path.dirname(path_to_zip),
'facades/'
)
BUFFER_SIZE
=
400
BATCH_SIZE
=
1
IMG_WIDTH
=
256
IMG_HEIGHT
=
256
def
load(image_file):
image
=
tf.io.read_file(image_file)
image
=
tf.image.decode_jpeg(image)
w
=
tf.shape(image)[
1
]
w
=
w
/
/
2
real_image
=
image[:, :w, :]
input_image
=
image[:, w:, :]
input_image
=
tf.cast(input_image, tf.float32)
real_image
=
tf.cast(real_image, tf.float32)
return
input_image, real_image
def
resize(input_image, real_image, height, width):
input_image
=
tf.image.resize(input_image, [height, width],
method
=
tf.image.ResizeMethod.NEAREST_NEIGHBOR)
real_image
=
tf.image.resize(real_image, [height, width],
method
=
tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return
input_image, real_image
"
"
"
function to stack(
input
, real) images
and
apply
random crop on them to crop
to(
256
,
256
)
"
"
"
def
random_crop(input_image, real_image):
stacked_image
=
tf.stack([input_image, real_image], axis
=
0
)
cropped_image
=
tf.image.random_crop(
stacked_image, size
=
[
2
, IMG_HEIGHT, IMG_WIDTH,
3
])
return
cropped_image[
0
], cropped_image[
1
]
"
"
"
Before training, we need to perform random jittering on the dataset
According to the paper, this random jittering contains
3
steps
-
-
& gt
Resize the image to bigger size
-
-
& gt
Random crop the image to target size of model
-
-
& gt
Random Flip on the images
"
"
"
@tf
.function()
def
random_jitter(input_image, real_image):
input_image, real_image
=
resize(input_image, real_image,
286
,
286
)
input_image, real_image
=
random_crop(input_image, real_image)
if
tf.random.uniform(()) & gt
0.5
:
input_image
=
tf.image.flip_left_right(input_image)
real_image
=
tf.image.flip_left_right(real_image)
return
input_image, real_image