Recently, I need to use the inference function of DALL·E. Based on the existing open source code, I found that there are still several issues that need attention. I would like to record them in this blog.
The source code I use is mainly the Inference pipeline.ipynb file in the https://github.com/borisdayma/dalle-mini warehouse.
Operating environment: Ubuntu server
Note: This blog only involves DALL · E inference, not the training process.
Directory
- 1. Environment configuration
- 2. Model download
- 3. Program conversion
- 4. Program operation
- 5. BUG Clearing Guide
1. Environment configuration
It is recommended to use anaconda to create a new dalle environment, and then perform relevant configurations in this environment to avoid version conflicts with other libraries in the environment.
Use the following command to create a new environment named dalle:
conda create -n dalle python==3.8.0
Run the following commands in the terminal to install the required python libraries:
# Install the dependent libraries required for dalle to run (note that the version can only be 0.3.25) # Required only for colab environments + GPU pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # Install dalle specific libraries pip install dalle-mini # Install VQGAN pip install -q git + https://github.com/patil-suraj/vqgan-jax.git
PS: If you cannot download VQGAN through the pip
command due to network connection problems, take Plan-B: Download the warehouse https://github.com/patil-suraj/vqgan-jax to the server and decompress it. Then use the cd
command to move the current directory to the corresponding warehouse download path, and run python setup.py install
in the terminal to install VQGAN.
2. Model download
Due to network connection problems, I adopted the strategy of “downloading the model locally in advance” to directly call the model. The first thing to make clear is that in this project, DALL·E is used to encode the image, and VQGAN is used to decode the image, so We need to download the DALL·E and VQGAN models respectively.
DALL·E model download address:
mini version: https://huggingface.co/dalle-mini/dalle-mini/tree/main
Mega version: https://huggingface.co/dalle-mini/dalle-mega/tree/main
VQGAN model download address:
https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/tree/main
After downloading, deploy the model to the server and pay attention to the saving path.
3. Program conversion
Compared with ipynb files, I personally prefer to operate py files, so for a given ipynb file, first use the command jupyter nbconvert --to script Inference pipeline.ipynb
to convert it into a py file with the same name, The main content of the file is as follows (excluding the CLIP sorting part). The model paths DALLE_MODEL and VQGAN_REPO have been changed to local paths (which are the saving paths of the two models in the second step). You can see that the comments on the file are also relatively detailed.
#dalle-mini DALLE_MODEL = "/newdata/SD/dalle-mini/dalle-mini" DALLE_COMMIT_ID = None # VQGAN model VQGAN_REPO = "/newdata/SD/dalle-mini/vqgan_imagenet_f16_16384" VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9" import jax import jax.numpy as jnp # check how many devices are available jax.local_device_count() # Load models & tokenizer from dalle_mini import DalleBart, DalleBartProcessor from vqgan_jax.modeling_flax_vqgan import VQModel # Load dalle-mini model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False) # Load VQGAN vqgan, vqgan_params = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False) # Model parameters are replicated on each device for faster inference. from flax.jax_utils import replicate params = replicate(params) vqgan_params = replicate(vqgan_params) # Model functions are compiled and parallelized to take advantage of multiple devices. from functools import partial # model inference @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6)) def p_generate( tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale ): return model.generate( **tokenized_prompt, prng_key=key, params=params, top_k=top_k, top_p=top_p, temperature=temperature, condition_scale=condition_scale, ) #decodeimage @partial(jax.pmap, axis_name="batch") def p_decode(indices, params): return vqgan.decode_code(indices, params=params) # Keys are passed to the model on each device to generate unique inference per device. import random # create a random key seed = random.randint(0, 2**32 - 1) key = jax.random.PRNGKey(seed) # ## Text Prompt # Our model requires processing prompts. from dalle_mini import DalleBartProcessor # from transformers import AutoProcessor processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID) # force_download=True, , local_only=True # Let's define some text prompts prompts = [ "sunset over a lake in the mountains", "the Eiffel tower landing on the moon", ] # print(prompts) # Note: we could use the same prompt multiple times for faster inference. tokenized_prompts = processor(prompts) # Finally we replicate the prompts onto each device. tokenized_prompt = replicate(tokenized_prompts) # ## We generate images using dalle-mini model and decode them with the VQGAN. # number of predictions per prompt n_predictions = 8 # We can customize generation parameters (see https://huggingface.co/blog/how-to-generate) gen_top_k = None gen_top_p = None temperature=None cond_scale = 10.0 # The higher the value, the closer the generated image is to prompt from flax.training.common_utils import shard_prng_key import numpy as np from PIL import Image from tqdm.notebook import trange print(f"Prompts: {<!-- -->prompts}\\ ") # generate images images = [] for i in trange(max(n_predictions // jax.device_count(), 1)): # get a new key key, subkey = jax.random.split(key) # jax.device_count()=1,returns the number of available jax devices # generate images encoded_images = p_generate( tokenized_prompt, shard_prng_key(subkey), params, gen_top_k, gen_top_p, temperature, cond_scale, ) # remove BOS encoded_images = encoded_images.sequences[..., 1:] decoded_images = p_decode(encoded_images, vqgan_params) decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3)) for idx, decoded_img in enumerate(decoded_images): img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8)) images.append(img) ...
4. Program running
Use the command python /newdata/SD/inference_dalle-mini.py
to run the program. Ideally, you can directly get the image generated by dalle!
5. BUG Clearing Guide
Due to external environmental factors and some improper operations, I still encountered some problems while running the program. There are three main problems. Here I will share the error information and solutions with everyone.
- The download of specific files failed due to network problems, and the error message is as follows:
... requests.exceptions.ConnectTimeout: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /dalle-mini/dalle-mini/resolve/main/enwiki-words -frequency.txt (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7faae4168460>, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: 61b7c191-3fb8-4dfa-9025-e9acd4ee4d28)') The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/newdata/SD/inference_dalle-mini.py", line 84, in <module> processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID) # force_download=True, , local_only=True File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/utils.py", line 25, in from_pretrained return super(PretrainedFromWandbMixin, cls).from_pretrained( File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 62, in from_pretrained return cls(tokenizer, config.normalize_text, config.max_text_length) File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 21, in __init__ self.text_processor = TextNormalizer() File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 215, in __init__ self._hashtag_processor = HashtagProcessor() File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 25, in __init__ # wiki_word_frequency = hf_hub_download( File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn return fn(*args, **kwargs) File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/file_download.py", line 1363, in hf_hub_download raise LocalEntryNotFoundError( huggingface_hub.utils._errors.LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.
Following the error message above, locate the following content of the /root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py
file:
... class HashtagProcessor: # Adapted from wordninja library # We use our wikipedia word count + a good heuristic to make it work def __init__(self): wiki_word_frequency = hf_hub_download( "dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt" ) self._word_cost = ( l.split()[0] for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines() ) ...
So the root of the problem is that when the program runs here, the local enwiki-words-frequency.txt
file is not found (after checking, the file actually exists locally, but I don’t know why it is not found. It’s very confusing. ), so I tried to download from the huggingface official website through the Internet, but due to poor network conditions, the Internet failed and an error was reported. The solution is as follows:
... class HashtagProcessor: # Adapted from wordninja library # We use our wikipedia word count + a good heuristic to make it work def __init__(self): wiki_word_frequency = "/newdata/SD/dalle-mini/dalle-mini/enwiki-words-frequency.txt" self._word_cost = ( l.split()[0] for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines() ) ...
That is to say, the local path of the enwiki-words-frequency.txt
file is directly assigned to the wiki_word_frequency
variable, and the remaining parts remain unchanged, and the problem is solved.
- Version conflicts caused by improper installation
FIx for "Couldn't invoke ptxas --version"
This error is caused by version conflicts caused by the installation of different python libraries. DALLE-mini requires that the jax and jaxlib versions must be 0.3.25, but the jaxlib version after installation through the pip imstall dalle-mini command is 0.4.13. However, using pip install jaxlib
cannot find the 0.3.25 version of jaxlib, and it will cause incompatibility issues with other libraries such as flax and orbax-checkpoint… After trying various methods to reasonably reduce the After all versions of jaxlib failed, I found that the answer is in ipynb…that is: pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/ jax_cuda_releases.html
Enlightenment: Focus on official documentation and avoid many detours! ! !
- Easter Egg: A very strange error:
The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/newdata/SD/inference_dalle-mini.py", line 130, in <module> decoded_images = p_decode(encoded_images, vqgan_params) ValueError: pmap got inconsistent sizes for array axes to be mapped: * most axes (101 of them) had size 512, e.g. axis 0 of argument params['decoder']['conv_in']['bias'] of type float32[512]; * some axes (71 of them) had size 3, e.g. axis 0 of argument params['decoder']['conv_in']['kernel'] of type float32[3,3,256,512]; * some axes (69 of them) had size 256, e.g. axis 0 of argument params['decoder']['up_1']['block_0']['norm1']['bias '] of type float32[256]; * some axes (67 of them) had size 128, e.g. axis 0 of argument params['decoder']['norm_out']['bias'] of type float32[128]; * some axes (35 of them) had size 1, e.g. axis 0 of argument indices of type int32[1,2,256]; * one axis had size 16384: axis 0 of argument params['quantize']['embedding']['embedding'] of type float32[16384,256]
Later I found out that it was because I accidentally commented out the following line of code during debugging… This bug was the hardest to sort out, and I was quite speechless
vqgan_params = replicate(vqgan_params)
- The graphics card cannot be used properly due to version restrictions
Since dalle-mini limits the versions of jax and jaxlib to only 0.3.25, these two packages cannot be updated to the latest versions. I wonder if the following error message will appear for this reason:
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: Couldn't get ptxas version string: INTERNAL: Couldn't invoke ptxas --version jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to load in-memory CUBIN: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid
- There are some warnings during the running of the program. From the following warnings, we can also see that JAX belongs to the tensorflow faction. (My program did not recognize the existence of the graphics card, so it could only run on the CPU)
2023-11-07 11:30:35.139851: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory 2023-11-07 11:30:35.257514: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so .11.0: cannot open shared object file: No such file or directory 2023-11-07 11:30:35.258648: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so .11.0: cannot open shared object file: No such file or directory 2023-11-07 11:30:35.628768: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: system has unsupported display driver / cuda driver combination 2023-11-07 11:30:35.628915: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:313] kernel version 525.53.0 does not match DSO version 530.41.3 -- cannot find working devices in this configuration WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) Prompts: ['sunset over a lake in the mountains', 'the Eiffel tower landing on the moon'] 0%| | 0/8 [00:00<?, ?it/s] /root/anaconda3/envs/dalle/lib/python3.8/site-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float16 to dtype= float32. In future JAX releases this will result in an error. warnings.warn("scatter inputs have incompatible types: cannot safely cast "
Postscript: The first time I came into contact with a program written based on the JAX framework, it was quite new. It felt a little different from pytorch. Learned that jax is a lightweight version of tensorflow. If there is any inappropriate understanding in the above blog content, I hope you will criticize and correct me!
Reference links
- The use of Path in python pathlib (solve the path problem of different operating systems)_python pathlib.path-CSDN Blog
- python – vmap gives inconsistent shape error when trying to calculate gradient per sample – Stack Overflow
- https://github.com/google/jax/issues/9933