DALL·E 2 Vincentian Diagram Model Practice Guide

Preface: This blog records the relevant information and DEBUG process used for inference using the dalle2 model.

Related blogs: Super detailed! DALL · E Vincentian Diagram Model Practical Guide

Directory

  • 1. Environment setup and pre-training model preparation
    • Environment setup
    • Pre-trained model download
  • 2. Code
  • 3. BUG & DEBUG
    • URLError
    • CUDA error
    • RuntimeError
    • PydanticUserError

1. Environment construction and pre-training model preparation

The code repository used in this article is: https://github.com/lucidrains/DALLE2-pytorch

Environment setup

pip install dalle2-pytorch

Pre-trained model download

Address: https://huggingface.co/laion/DALLE2-PyTorch

2. Code

The complete inference process of DALLE2 for inference is as follows (from @cest_andre in Issues#282):

import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig


prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
prior = prior_config.create().cuda()

prior_model_state = torch.load("weights/prior_latest.pth")
prior.load_state_dict(prior_model_state, strict=True)

decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
decoder = decoder_config.create().cuda()

decoder_model_state = torch.load("weights/decoder_latest.pth")["model"]

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

decoder.load_state_dict(decoder_model_state, strict=True)

dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()

images = dalle2(
    ['your prompt here'],
    cond_scale = 2.
).cpu()

print(images.shape)

for img in images:
    img = ToPILImage()(img)
    img.show()

3. BUG & amp;DEBUG

URLError

The error message is as follows:

Traceback (most recent call last):
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1350, in do_open
    h.request(req.get_method(), req.selector, req.data, headers,
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1255, in request
    self._send_request(method, url, body, headers, encode_chunked)
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1301, in _send_request
    self.endheaders(body, encode_chunked=encode_chunked)
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1250, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1010, in _send_output
    self.send(msg)
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 950, in send
    self.connect()
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1424, in connect
    self.sock = self._context.wrap_socket(self.sock,
  File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 500, in wrap_socket
    return self.sslsocket_class._create(
  File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 1040, in _create
    self.do_handshake()
  File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 1309, in do_handshake
    self._sslobj.do_handshake()
ConnectionResetError: [Errno 104] Connection reset by peer

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/newdata/SD/extra/dalle2_cest.py", line 11, in <module>
    prior = prior_config.create().cuda()
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 185, in create
    clip = self.clip.create()
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 122, in create
    return OpenAIClipAdapter(self.model)
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/dalle2_pytorch.py", line 313, in __init__
    openai_clip, preprocess = clip.load(name)
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/clip/clip.py", line 122, in load
    model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/clip/clip.py", line 59, in _download
    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 222, in urlopen
    return opener.open(url, data, timeout)
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 525, in open
    response = self._open(req, data)
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 542, in _open
    result = self._call_chain(self.handle_open, protocol, protocol +
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 502, in _call_chain
    result = func(*args)
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1393, in https_open
    return self.do_open(http.client.HTTPSConnection, req,
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1353, in do_open
    raise URLError(err)
urllib.error.URLError: <urlopen error [Errno 104] Connection reset by peer>

I use the URL https://github.com/lucidrains/DALLE2-pytorch.

Find the corresponding location in /root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py. Here is line 1349. The modification method is also given in the code below.

try:
    h.request(req.get_method(), req.selector, req.data, headers,
              encode_chunked=req.has_header('Transfer-encoding'))
    time.sleep(0.5) #Added line
except OSError as err: # timeout error
    raise URLError(err)

CUDA error

RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

Solution: Version mismatch, replace the pytorch version that matches the system cuda. For example, my cuda version is 12.0. You can use the following command to install pytorch:

pip install torch==2.0.0 + cu118 torchvision==0.15.1 + cu118 torchaudio==2.0.1 + cu118 -f https://download.pytorch.org/whl/torch_stable.html

RuntimeError

Traceback (most recent call last):
  File "/newdata/SD/extra/dalle2_cest.py", line 14, in <module>
    prior.load_state_dict(prior_model_state, strict=True)
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\
\t{}'.format(
RuntimeError: Error(s) in loading state_dict for DiffusionPrior:
        Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed".
        Unexpected key(s) in state_dict: "net.null_text_embed".

Solution 1: Change strict=True in the load_state_dict() function to strict=False, as follows:

...
prior.load_state_dict(prior_model_state, strict=False)

decoder.load_state_dict(decoder_model_state, strict=False)
...

However, this method may cause the performance of the model to degrade and generate mosaic images, which is obviously not the result we want.

Solution 2: Refer to cest-andre’s answer in Issues.

Step (1) Reduce the dalle2_pytorch version to 1.1.0:

pip install dalle2-pytorch==1.1.0

Step (2): After the version is downgraded, you need to fix a small bug in the dalle2_pytorch.py file: change line 2940 to the following code:

images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)

PydanticUserError

After lowering the version of dalle2_pytorch, the following error is reported when running the program:

Traceback (most recent call last):
  File "/newdata/SD/extra/dalle2_cest.py", line 8, in <module>
    from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 34, in <module>
    class TrainSplitConfig(BaseModel):
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 40, in TrainSplitConfig
    def validate_all(cls, fields):
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/pydantic/deprecated/class_validators.py", line 222, in root_validator
    return root_validator()(*__args) # type: ignore
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/pydantic/deprecated/class_validators.py", line 228, in root_validator
    raisePydanticUserError(
pydantic.errors.PydanticUserError: If you use `@root_validator` with pre=False (the default) you MUST specify `skip_on_failure=True`. Note that `@root_validator` is deprecated and should be replaced with `@model_validator`.

Solution: Refer to the answer of JasbirCodeSpace in Issues and lower the version of Pydantic:

pip install pydantic==1.10.6

At this point, the model can complete the inference process ~ Hehe! The following is the image generated when the prompt is a red car:

Postscript: Thanks to those who came before us for paving the way!

Reference links

  1. https://github.com/lucidrains/DALLE2-pytorch/issues/282
  2. python requests request error ConnectionError: (Connection aborted., error(104, Connection reset by peer))_Tidosti’s Blog-CSDN Blog
  3. GPU version pytorch (Cuda12.1) Tsinghua source quick installation step-by-step tutorial! Novice teaching~_Tsinghuayuan installation torch-CSDN blog