Super detailed! DALL · E Vincentian Diagram Model Practical Guide

Recently, I need to use the inference function of DALL·E. Based on the existing open source code, I found that there are still several issues that need attention. I would like to record them in this blog.

The source code I use is mainly the Inference pipeline.ipynb file in the https://github.com/borisdayma/dalle-mini warehouse.

Operating environment: Ubuntu server

Note: This blog only involves DALL · E inference, not the training process.

Directory

  • 1. Environment configuration
  • 2. Model download
  • 3. Program conversion
  • 4. Program operation
  • 5. BUG Clearing Guide

1. Environment configuration

It is recommended to use anaconda to create a new dalle environment, and then perform relevant configurations in this environment to avoid version conflicts with other libraries in the environment.

Use the following command to create a new environment named dalle:

conda create -n dalle python==3.8.0

Run the following commands in the terminal to install the required python libraries:

# Install the dependent libraries required for dalle to run (note that the version can only be 0.3.25) # Required only for colab environments + GPU
pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Install dalle specific libraries
pip install dalle-mini
# Install VQGAN
pip install -q git + https://github.com/patil-suraj/vqgan-jax.git

PS: If you cannot download VQGAN through the pip command due to network connection problems, take Plan-B: Download the warehouse https://github.com/patil-suraj/vqgan-jax to the server and decompress it. Then use the cd command to move the current directory to the corresponding warehouse download path, and run python setup.py install in the terminal to install VQGAN.

2. Model download

Due to network connection problems, I adopted the strategy of “downloading the model locally in advance” to directly call the model. The first thing to make clear is that in this project, DALL·E is used to encode the image, and VQGAN is used to decode the image, so We need to download the DALL·E and VQGAN models respectively.

DALL·E model download address:
mini version: https://huggingface.co/dalle-mini/dalle-mini/tree/main
Mega version: https://huggingface.co/dalle-mini/dalle-mega/tree/main

VQGAN model download address:
https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/tree/main

After downloading, deploy the model to the server and pay attention to the saving path.

3. Program conversion

Compared with ipynb files, I personally prefer to operate py files, so for a given ipynb file, first use the command jupyter nbconvert --to script Inference pipeline.ipynb to convert it into a py file with the same name, The main content of the file is as follows (excluding the CLIP sorting part). The model paths DALLE_MODEL and VQGAN_REPO have been changed to local paths (which are the saving paths of the two models in the second step). You can see that the comments on the file are also relatively detailed.

#dalle-mini
DALLE_MODEL = "/newdata/SD/dalle-mini/dalle-mini"
DALLE_COMMIT_ID = None
# VQGAN model
VQGAN_REPO = "/newdata/SD/dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

import jax
import jax.numpy as jnp

# check how many devices are available
jax.local_device_count()

# Load models & tokenizer
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
# Load dalle-mini
model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)
# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False)

# Model parameters are replicated on each device for faster inference.
from flax.jax_utils import replicate
params = replicate(params)
vqgan_params = replicate(vqgan_params)

# Model functions are compiled and parallelized to take advantage of multiple devices.
from functools import partial

# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        condition_scale=condition_scale,
    )

#decodeimage
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

# Keys are passed to the model on each device to generate unique inference per device.
import random

# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)

# ##  Text Prompt
# Our model requires processing prompts.

from dalle_mini import DalleBartProcessor
# from transformers import AutoProcessor
processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID) # force_download=True, , local_only=True
# Let's define some text prompts
prompts = [
    "sunset over a lake in the mountains",
    "the Eiffel tower landing on the moon",
]
# print(prompts)
# Note: we could use the same prompt multiple times for faster inference.
tokenized_prompts = processor(prompts)
# Finally we replicate the prompts onto each device.
tokenized_prompt = replicate(tokenized_prompts)

# ##  We generate images using dalle-mini model and decode them with the VQGAN.

# number of predictions per prompt
n_predictions = 8

# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature=None
cond_scale = 10.0 # The higher the value, the closer the generated image is to prompt

from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

print(f"Prompts: {<!-- -->prompts}\\
")
# generate images
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
    # get a new key
    key, subkey = jax.random.split(key) # jax.device_count()=1,returns the number of available jax devices
    # generate images
    encoded_images = p_generate(
        tokenized_prompt,
        shard_prng_key(subkey),
        params,
        gen_top_k,
        gen_top_p,
        temperature,
        cond_scale,
    )
    # remove BOS
    encoded_images = encoded_images.sequences[..., 1:]
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    
    for idx, decoded_img in enumerate(decoded_images):
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        images.append(img)
...

4. Program running

Use the command python /newdata/SD/inference_dalle-mini.py to run the program. Ideally, you can directly get the image generated by dalle!

5. BUG Clearing Guide

Due to external environmental factors and some improper operations, I still encountered some problems while running the program. There are three main problems. Here I will share the error information and solutions with everyone.

  • The download of specific files failed due to network problems, and the error message is as follows:
...
requests.exceptions.ConnectTimeout: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /dalle-mini/dalle-mini/resolve/main/enwiki-words -frequency.txt (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7faae4168460>, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: 61b7c191-3fb8-4dfa-9025-e9acd4ee4d28)')

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/newdata/SD/inference_dalle-mini.py", line 84, in <module>
    processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID) # force_download=True, , local_only=True
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/utils.py", line 25, in from_pretrained
    return super(PretrainedFromWandbMixin, cls).from_pretrained(
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 62, in from_pretrained
    return cls(tokenizer, config.normalize_text, config.max_text_length)
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 21, in __init__
    self.text_processor = TextNormalizer()
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 215, in __init__
    self._hashtag_processor = HashtagProcessor()
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 25, in __init__
    # wiki_word_frequency = hf_hub_download(
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn
    return fn(*args, **kwargs)
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/file_download.py", line 1363, in hf_hub_download
    raise LocalEntryNotFoundError(
huggingface_hub.utils._errors.LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.

Following the error message above, locate the following content of the /root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py file:

...
class HashtagProcessor:
    # Adapted from wordninja library
    # We use our wikipedia word count + a good heuristic to make it work
    def __init__(self):
wiki_word_frequency = hf_hub_download(
"dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt"
)
self._word_cost = (
l.split()[0]
for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
)
...

So the root of the problem is that when the program runs here, the local enwiki-words-frequency.txt file is not found (after checking, the file actually exists locally, but I don’t know why it is not found. It’s very confusing. ), so I tried to download from the huggingface official website through the Internet, but due to poor network conditions, the Internet failed and an error was reported. The solution is as follows:

...
class HashtagProcessor:
    # Adapted from wordninja library
    # We use our wikipedia word count + a good heuristic to make it work
    def __init__(self):
wiki_word_frequency = "/newdata/SD/dalle-mini/dalle-mini/enwiki-words-frequency.txt"
self._word_cost = (
l.split()[0]
for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
)
...

That is to say, the local path of the enwiki-words-frequency.txt file is directly assigned to the wiki_word_frequency variable, and the remaining parts remain unchanged, and the problem is solved.

  • Version conflicts caused by improper installation
FIx for "Couldn't invoke ptxas --version"

This error is caused by version conflicts caused by the installation of different python libraries. DALLE-mini requires that the jax and jaxlib versions must be 0.3.25, but the jaxlib version after installation through the pip imstall dalle-mini command is 0.4.13. However, using pip install jaxlib cannot find the 0.3.25 version of jaxlib, and it will cause incompatibility issues with other libraries such as flax and orbax-checkpoint… After trying various methods to reasonably reduce the After all versions of jaxlib failed, I found that the answer is in ipynb…that is: pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/ jax_cuda_releases.html

Enlightenment: Focus on official documentation and avoid many detours! ! !

  • Easter Egg: A very strange error:
The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/newdata/SD/inference_dalle-mini.py", line 130, in <module>
    decoded_images = p_decode(encoded_images, vqgan_params)
ValueError: pmap got inconsistent sizes for array axes to be mapped:
  * most axes (101 of them) had size 512, e.g. axis 0 of argument params['decoder']['conv_in']['bias'] of type float32[512];
  * some axes (71 of them) had size 3, e.g. axis 0 of argument params['decoder']['conv_in']['kernel'] of type float32[3,3,256,512];
  * some axes (69 of them) had size 256, e.g. axis 0 of argument params['decoder']['up_1']['block_0']['norm1']['bias '] of type float32[256];
  * some axes (67 of them) had size 128, e.g. axis 0 of argument params['decoder']['norm_out']['bias'] of type float32[128];
  * some axes (35 of them) had size 1, e.g. axis 0 of argument indices of type int32[1,2,256];
  * one axis had size 16384: axis 0 of argument params['quantize']['embedding']['embedding'] of type float32[16384,256]

Later I found out that it was because I accidentally commented out the following line of code during debugging… This bug was the hardest to sort out, and I was quite speechless

vqgan_params = replicate(vqgan_params)
  • The graphics card cannot be used properly due to version restrictions
    Since dalle-mini limits the versions of jax and jaxlib to only 0.3.25, these two packages cannot be updated to the latest versions. I wonder if the following error message will appear for this reason:
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: Couldn't get ptxas version string: INTERNAL: Couldn't invoke ptxas --version

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to load in-memory CUBIN: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid
  • There are some warnings during the running of the program. From the following warnings, we can also see that JAX belongs to the tensorflow faction. (My program did not recognize the existence of the graphics card, so it could only run on the CPU)
2023-11-07 11:30:35.139851: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.257514: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so .11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.258648: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so .11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.628768: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: system has unsupported display driver / cuda driver combination
2023-11-07 11:30:35.628915: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:313] kernel version 525.53.0 does not match DSO version 530.41.3 -- cannot find working devices in this configuration
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Prompts: ['sunset over a lake in the mountains', 'the Eiffel tower landing on the moon']

  0%| | 0/8 [00:00<?, ?it/s]
/root/anaconda3/envs/dalle/lib/python3.8/site-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float16 to dtype= float32. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "

Postscript: The first time I came into contact with a program written based on the JAX framework, it was quite new. It felt a little different from pytorch. Learned that jax is a lightweight version of tensorflow. If there is any inappropriate understanding in the above blog content, I hope you will criticize and correct me!

Reference links

  1. The use of Path in python pathlib (solve the path problem of different operating systems)_python pathlib.path-CSDN Blog
  2. python – vmap gives inconsistent shape error when trying to calculate gradient per sample – Stack Overflow
  3. https://github.com/google/jax/issues/9933