Stable Diffusion on Amazon SageMaker

Journey to learn about transformer models in SageMaker. Inspired by philschmid.
NLP
transfomers
stablediffusion
Published

November 4, 2022

** Using SageMaker Data Science Kernel**

Instaling necessary packages

!pip install "sagemaker==2.116.0" "huggingface_hub==0.10.1" --upgrade --quiet
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
sagemaker role arn: arn:aws:iam::332886139243:role/mod-6297809195fe4845-SageMakerExecutionRole-1BV8UTN3E51SV
sagemaker bucket: sagemaker-us-east-1-332886139243
sagemaker session region: us-east-1

Create Stable Diffusion inference script

Here we are going to create a an inference script by providing a inference.py file.

!mkdir code
%%writefile code/requirements.txt
diffusers==0.6.0
transformers==4.23.1
Writing code/requirements.txt
%%writefile code/inference.py
import base64
import torch
from io import BytesIO
from diffusers import StableDiffusionPipeline


def model_fn(model_dir):
    # Load stable diffusion and move it to the GPU
    pipe = StableDiffusionPipeline.from_pretrained(model_dir, torch_dtype=torch.float16)
    pipe = pipe.to("cuda")

    return pipe


def predict_fn(data, pipe):

    # get prompt & parameters
    prompt = data.pop("inputs", data)
    # set valid HP for stable diffusion
    num_inference_steps = data.pop("num_inference_steps", 50)
    guidance_scale = data.pop("guidance_scale", 7.5)
    num_images_per_prompt = data.pop("num_images_per_prompt", 4)

    # run generation with parameters
    generated_images = pipe(
        prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_images_per_prompt,
    )["images"]

    # create response
    encoded_images = []
    for image in generated_images:
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        encoded_images.append(base64.b64encode(buffered.getvalue()).decode())

    # create response
    return {"generated_images": encoded_images}
Writing code/inference.py

Create SageMaker model.tar.gz artifact

To use the inference.py we need to package it in a model.tzr.gz to be used by a container. We will also use the hugging_face SDK to download the CompVis/stable-diffusion-v1-4 model.

If planning to reproduce this work and use the mstable diffusion model please make sure you read the lisence information here.

from distutils.dir_util import copy_tree
from pathlib import Path
from huggingface_hub import snapshot_download
import random

HF_MODEL_ID="CompVis/stable-diffusion-v1-4"
HF_TOKEN="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" # your hf token: https://huggingface.co/settings/tokens
assert len(HF_TOKEN) > 0, "Please set HF_TOKEN to your huggingface token. You can find it here: https://huggingface.co/settings/tokens"

# download snapshot
snapshot_dir = snapshot_download(repo_id=HF_MODEL_ID,revision="fp16",use_auth_token=HF_TOKEN)

# create model dir
model_tar = Path(f"model-{random.getrandbits(16)}")
model_tar.mkdir(exist_ok=True)

# copy snapshot to model dir
copy_tree(snapshot_dir, str(model_tar))

['model-61770/model_index.json',
 'model-61770/safety_checker/pytorch_model.bin',
 'model-61770/safety_checker/config.json',
 'model-61770/text_encoder/pytorch_model.bin',
 'model-61770/text_encoder/config.json',
 'model-61770/vae/diffusion_pytorch_model.bin',
 'model-61770/vae/config.json',
 'model-61770/unet/diffusion_pytorch_model.bin',
 'model-61770/unet/config.json',
 'model-61770/.gitattributes',
 'model-61770/README.md',
 'model-61770/tokenizer/tokenizer_config.json',
 'model-61770/tokenizer/merges.txt',
 'model-61770/tokenizer/special_tokens_map.json',
 'model-61770/tokenizer/vocab.json',
 'model-61770/feature_extractor/preprocessor_config.json',
 'model-61770/v1-variants-scores.jpg',
 'model-61770/scheduler/scheduler_config.json',
 'model-61770/scheduler/.ipynb_checkpoints/scheduler_config-checkpoint.json']
# copy code/ to model dir
copy_tree("code/", str(model_tar.joinpath("code")))
['model-61770/code/requirements.txt', 'model-61770/code/inference.py']

#Lets create the model.tar.gz using the following structure:

model.tar.gz/ |- model_index.json |- unet/ |- code/

import tarfile
import os

# helper to create the model.tar.gz
def compress(tar_dir=None,output_file="model.tar.gz"):
    parent_dir=os.getcwd()
    os.chdir(tar_dir)
    with tarfile.open(os.path.join(parent_dir, output_file), "w:gz") as tar:
        for item in os.listdir('.'):
          print(item)
          tar.add(item, arcname=item)
    os.chdir(parent_dir)

compress(str(model_tar))
model_index.json
safety_checker
text_encoder
vae
unet
.gitattributes
README.md
code
tokenizer
feature_extractor
v1-variants-scores.jpg
scheduler

Let´s send the model to s3.

from sagemaker.s3 import S3Uploader

# upload model.tar.gz to s3
s3_model_uri=S3Uploader.upload(local_path="model.tar.gz", desired_s3_uri=f"s3://{sess.default_bucket()}/stable-diffusion-v1-4")

print(f"model uploaded to: {s3_model_uri}")
model uploaded to: s3://sagemaker-us-east-1-332886139243/stable-diffusion-v1-4/model.tar.gz

Deploy the model to Amazon SageMaker

In this step we are going to deploy the model in Amazon SageMaker using the HuggingFace estimator to create a real-time inference endpoint.

The endpoint will use a g4dn.xlarge instance.

from sagemaker.huggingface.model import HuggingFaceModel


# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
   model_data=s3_model_uri,      # path to your model and script
   role=role,                    # iam role with permissions to create an Endpoint
   transformers_version="4.17",  # transformers version used
   pytorch_version="1.10",       # pytorch version used
   py_version='py38',            # python version used
)

# deploy the endpoint endpoint
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge"
    )
----------!

Generate images using the deployed model

After running .deploy() we habe an estimator that we can use to request inferences. The endpoint expects a json with at least inputs key.

The predictor.predict() function returns a json with the generated_images key. The generated_images key contains the 4 generated images as a base64 encoded string. To decode our response we added a small helper function decode_base64_to_image which takes the base64 encoded string and returns a PIL.Image object and display_images, which takes a list of PIL.Image objects and displays them.

from PIL import Image
from io import BytesIO
from IPython.display import display
import base64
import matplotlib.pyplot as plt

# helper decoder
def decode_base64_image(image_string):
  base64_image = base64.b64decode(image_string)
  buffer = BytesIO(base64_image)
  return Image.open(buffer)

# display PIL images as grid
def display_images(images=None,columns=3, width=100, height=100):
    plt.figure(figsize=(width, height))
    for i, image in enumerate(images):
        plt.subplot(int(len(images) / columns + 1), columns, i + 1)
        plt.axis('off')
        plt.imshow(image)
    plt.show()

Now lets generte some images. I will try to generate 3 images for the prompt “Paradise, green, beauty, birds, rivers, trees, nature, serenity, eternal happiness”

num_images_per_prompt = 3
prompt = "Paradise, green, beauty, birds, rivers, trees, nature, serenity, eternal happiness"

# run prediction
response = predictor.predict(data={
  "inputs": prompt,
  "num_images_per_prompt" : num_images_per_prompt
  }
)

# decode images
decoded_images = [decode_base64_image(image) for image in response["generated_images"]]

# visualize generation
display_images(decoded_images)

Please Delete model and endpoint to avoid cost

#predictor.delete_model()
#predictor.delete_endpoint()