first commit
This commit is contained in:
		
						commit
						4f9296236a
					
				
							
								
								
									
										21
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | |||||||
|  | MIT License | ||||||
|  | 
 | ||||||
|  | Copyright (c) 2021 Or Patashnik, Zongze Wu | ||||||
|  | 
 | ||||||
|  | Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
|  | of this software and associated documentation files (the "Software"), to deal | ||||||
|  | in the Software without restriction, including without limitation the rights | ||||||
|  | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||||
|  | copies of the Software, and to permit persons to whom the Software is | ||||||
|  | furnished to do so, subject to the following conditions: | ||||||
|  | 
 | ||||||
|  | The above copyright notice and this permission notice shall be included in all | ||||||
|  | copies or substantial portions of the Software. | ||||||
|  | 
 | ||||||
|  | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||||
|  | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||||
|  | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||||
|  | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||||
|  | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||||
|  | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||||
|  | SOFTWARE. | ||||||
							
								
								
									
										287
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										287
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,287 @@ | |||||||
|  | # StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery (ICCV 2021 Oral) | ||||||
|  | 
 | ||||||
|  | [Run this model on Replicate](https://replicate.ai/orpatashnik/styleclip) | ||||||
|  | 
 | ||||||
|  | Optimization: [](http://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/optimization_playground.ipynb) | ||||||
|  | Mapper: [](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/mapper_playground.ipynb) | ||||||
|  | 
 | ||||||
|  | Global directions Torch: [](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/StyleCLIP_global_torch.ipynb) | ||||||
|  | Global directions TF1: [](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/StyleCLIP_global.ipynb) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | <p align="center"> | ||||||
|  |   <a href="https://www.youtube.com/watch?v=5icI0NgALnQ"><img src='https://github.com/orpatashnik/StyleCLIP/blob/main/img/StyleCLIP_gif.gif' width=600 ></a> | ||||||
|  |    | ||||||
|  | Full Demo Video: <a href="https://www.youtube.com/watch?v=5icI0NgALnQ"><img src="https://img.shields.io/badge/-YouTube-red?&style=for-the-badge&logo=youtube&logoColor=white" height=20></a>     ICCV Video <a href="https://www.youtube.com/watch?v=PhR1gpXDu0w"><img src="https://img.shields.io/badge/-YouTube-red?&style=for-the-badge&logo=youtube&logoColor=white" height=20></a> | ||||||
|  |    | ||||||
|  | </p> | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |  | ||||||
|  | 
 | ||||||
|  | > **StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery**<br> | ||||||
|  | > Or Patashnik*, Zongze Wu*, Eli Shechtman, Daniel Cohen-Or, Dani Lischinski <br> | ||||||
|  | > *Equal contribution, ordered alphabetically <br> | ||||||
|  | > https://arxiv.org/abs/2103.17249 <br> | ||||||
|  | > | ||||||
|  | >**Abstract:** Inspired by the ability of StyleGAN to generate highly realistic | ||||||
|  | images in a variety of domains, much recent work has | ||||||
|  | focused on understanding how to use the latent spaces of | ||||||
|  | StyleGAN to manipulate generated and real images. However, | ||||||
|  | discovering semantically meaningful latent manipulations | ||||||
|  | typically involves painstaking human examination of | ||||||
|  | the many degrees of freedom, or an annotated collection | ||||||
|  | of images for each desired manipulation. In this work, we | ||||||
|  | explore leveraging the power of recently introduced Contrastive | ||||||
|  | Language-Image Pre-training (CLIP) models in order | ||||||
|  | to develop a text-based interface for StyleGAN image | ||||||
|  | manipulation that does not require such manual effort. We | ||||||
|  | first introduce an optimization scheme that utilizes a CLIP-based | ||||||
|  | loss to modify an input latent vector in response to a | ||||||
|  | user-provided text prompt. Next, we describe a latent mapper | ||||||
|  | that infers a text-guided latent manipulation step for | ||||||
|  | a given input image, allowing faster and more stable textbased | ||||||
|  | manipulation. Finally, we present a method for mapping | ||||||
|  | a text prompts to input-agnostic directions in StyleGAN’s | ||||||
|  | style space, enabling interactive text-driven image | ||||||
|  | manipulation. Extensive results and comparisons demonstrate | ||||||
|  | the effectiveness of our approaches. | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ## Description | ||||||
|  | Official Implementation of StyleCLIP, a method to manipulate images using a driving text.  | ||||||
|  | Our method uses the generative power of a pretrained StyleGAN generator, and the visual-language power of CLIP. | ||||||
|  | In the paper we present three methods:  | ||||||
|  | - Latent vector optimization. | ||||||
|  | - Latent mapper, trained to manipulate latent vectors according to a specific text description. | ||||||
|  | - Global directions in the StyleSpace. | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ## Updates | ||||||
|  | **31/10/2022** Add support for global direction with torch implementation | ||||||
|  | 
 | ||||||
|  | **15/8/2021** Add support for StyleSpace in optimization and latent mapper methods | ||||||
|  | 
 | ||||||
|  | **6/4/2021** Add mapper training and inference (including a jupyter notebook) code | ||||||
|  | 
 | ||||||
|  | **6/4/2021** Add support for custom StyleGAN2 and StyleGAN2-ada models, and also custom images | ||||||
|  | 
 | ||||||
|  | **2/4/2021** Add the global directions code (a local GUI and a colab notebook) | ||||||
|  | 
 | ||||||
|  | **31/3/2021** Upload paper to arxiv, and video to YouTube | ||||||
|  | 
 | ||||||
|  | **14/2/2021** Initial version | ||||||
|  | 
 | ||||||
|  | ## Setup  (for all three methods) | ||||||
|  | For all the methods described in the paper, is it required to have: | ||||||
|  | - Anaconda | ||||||
|  | - [CLIP](https://github.com/openai/CLIP) | ||||||
|  | 
 | ||||||
|  | Specific requirements for each method are described in its section.  | ||||||
|  | To install CLIP please run the following commands: | ||||||
|  |   ```shell script | ||||||
|  | conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=<CUDA_VERSION> | ||||||
|  | pip install ftfy regex tqdm gdown | ||||||
|  | pip install git+https://github.com/openai/CLIP.git | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ## Editing via Latent Vector Optimization | ||||||
|  | 
 | ||||||
|  | ### Setup | ||||||
|  | 
 | ||||||
|  | Here, the code relies on the [Rosinality](https://github.com/rosinality/stylegan2-pytorch/) pytorch implementation of StyleGAN2. | ||||||
|  | Some parts of the StyleGAN implementation were modified, so that the whole implementation is native pytorch. | ||||||
|  | 
 | ||||||
|  | In addition to the requirements mentioned before, a pretrained StyleGAN2 generator will attempt to be downloaded, (or manually download from [here](https://drive.google.com/file/d/1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT/view?usp=sharing)). | ||||||
|  | 
 | ||||||
|  | ### Usage | ||||||
|  | 
 | ||||||
|  | Given a textual description, one can both edit a given image, or generate a random image that best fits to the description. | ||||||
|  | Both operations can be done through the `main.py` script, or the `optimization_playground.ipynb` notebook ([](http://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/optimization_playground.ipynb)). | ||||||
|  | 
 | ||||||
|  | #### Editing | ||||||
|  | To edit an image set `--mode=edit`. Editing can be done on both provided latent vector, and on a random latent vector from StyleGAN's latent space. | ||||||
|  | It is recommended to adjust the `--l2_lambda` according to the desired edit.  | ||||||
|  | 
 | ||||||
|  | #### Generating Free-style Images | ||||||
|  | To generate a free-style image set `--mode=free_generation`. | ||||||
|  | 
 | ||||||
|  | ## Editing via Latent Mapper | ||||||
|  | Here, we provide the code for the latent mapper. The mapper is trained to learn *residuals* from a given latent vector, according to the driving text. | ||||||
|  | The code for the mapper is in `mapper/`. | ||||||
|  | 
 | ||||||
|  | ### Setup | ||||||
|  | As in the optimization, the code relies on [Rosinality](https://github.com/rosinality/stylegan2-pytorch/) pytorch implementation of StyleGAN2. | ||||||
|  | In addition the the StyleGAN weights, it is neccessary to have weights for the facial recognition network used in the ID loss.  | ||||||
|  | The weights can be downloaded from [here](https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing). | ||||||
|  | 
 | ||||||
|  | The mapper is trained on latent vectors. It is recommended to train on *inverted real images*.  | ||||||
|  | To this end, we provide the CelebA-HQ that was inverted by e4e: | ||||||
|  | [train set](https://drive.google.com/file/d/1gof8kYc_gDLUT4wQlmUdAtPnQIlCO26q/view?usp=sharing), [test set](https://drive.google.com/file/d/1j7RIfmrCoisxx3t-r-KC02Qc8barBecr/view?usp=sharing). | ||||||
|  | 
 | ||||||
|  | ### Usage | ||||||
|  | 
 | ||||||
|  | #### Training | ||||||
|  | - The main training script is placed in `mapper/scripts/train.py`. | ||||||
|  | - Training arguments can be found at `mapper/options/train_options.py`. | ||||||
|  | - Intermediate training results are saved to opts.exp_dir. This includes checkpoints, train outputs, and test outputs. | ||||||
|  | Additionally, if you have tensorboard installed, you can visualize tensorboard logs in opts.exp_dir/logs. | ||||||
|  | Note that | ||||||
|  | - To resume a training, please provide `--checkpoint_path`. | ||||||
|  | - `--description` is where you provide the driving text. | ||||||
|  | - If you perform an edit that is not supposed to change "colors" in the image, it is recommended to use the flag `--no_fine_mapper`. | ||||||
|  | 
 | ||||||
|  | Example for training a mapper for the moahwk hairstyle: | ||||||
|  | ```bash | ||||||
|  | cd mapper | ||||||
|  | python train.py --exp_dir ../results/mohawk_hairstyle --no_fine_mapper --description "mohawk hairstyle" | ||||||
|  | ``` | ||||||
|  | All configurations for the examples shown in the paper are provided there. | ||||||
|  | 
 | ||||||
|  | #### Inference | ||||||
|  | - The main inferece script is placed in `mapper/scripts/inference.py`. | ||||||
|  | - Inference arguments can be found at `mapper/options/test_options.py`. | ||||||
|  | - Adding the flag `--couple_outputs` will save image containing the input and output images side-by-side. | ||||||
|  | 
 | ||||||
|  | Pretrained models for variuos edits are provided. Please refer to `utils.py` for the complete links list. | ||||||
|  | 
 | ||||||
|  | We also provide a notebook for performing inference with the mapper Mapper notebook: [](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/mapper_playground.ipynb) | ||||||
|  | 
 | ||||||
|  | ## Editing via Global Direction | ||||||
|  | 
 | ||||||
|  | Here we provide GUI for editing images with the global directions.  | ||||||
|  | We provide both a jupyter notebook [](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/StyleCLIP_global.ipynb), | ||||||
|  | and the GUI used in the [video](https://www.youtube.com/watch?v=5icI0NgALnQ). | ||||||
|  | For both, the linear direction are computed in **real time**. | ||||||
|  | The code is located at `global_directions/`. | ||||||
|  | 
 | ||||||
|  |    | ||||||
|  | ### Setup | ||||||
|  | Here, we rely on the [official](https://github.com/NVlabs/stylegan2) TensorFlow implementation of StyleGAN2. | ||||||
|  | 
 | ||||||
|  | It is required to have TensorFlow, version 1.14 or 1.15 (`conda install -c anaconda tensorflow-gpu==1.14`). | ||||||
|  | 
 | ||||||
|  | ### Usage | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | #### Local GUI | ||||||
|  | 
 | ||||||
|  | To start the local GUI please run the following commands: | ||||||
|  | 
 | ||||||
|  | ```shell script | ||||||
|  | cd global_directions | ||||||
|  | 
 | ||||||
|  | # input dataset name  | ||||||
|  | dataset_name='ffhq'  | ||||||
|  | 
 | ||||||
|  | # pretrained StyleGAN2 model from standard [NVlabs implementation](https://github.com/NVlabs/stylegan2) will be download automatically. | ||||||
|  | # pretrained StyleGAN2-ada model could be download from https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ . | ||||||
|  | # for custom StyleGAN2 or StyleGAN2-ada model, please place the model under ./StyleCLIP/global_directions/model/ folder. | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # input prepare data  | ||||||
|  | python GetCode.py --dataset_name $dataset_name --code_type 'w' | ||||||
|  | python GetCode.py --dataset_name $dataset_name --code_type 's' | ||||||
|  | python GetCode.py --dataset_name $dataset_name --code_type 's_mean_std' | ||||||
|  | 
 | ||||||
|  | # preprocess (this may take a few hours).  | ||||||
|  | # we precompute the results for StyleGAN2 on ffhq, StyleGAN2-ada on afhqdog, afhqcat. For these model, we can skip the preprocess step. | ||||||
|  | python SingleChannel.py --dataset_name $dataset_name | ||||||
|  | 
 | ||||||
|  | # generated image to be manipulated  | ||||||
|  | # this operation will generate and replace the w_plu.npy and .jpg images in './data/dataset_name/' folder.  | ||||||
|  | # if you you want to keep the original data, please rename the original folder. | ||||||
|  | # to use custom images, please use e4e encoder to generate latents.pt, and place it in './data/dataset_name/' folder, and add --real flag while running this function. | ||||||
|  | # you may skip this step if you want to manipulate the real human faces we prepare in ./data/ffhq/ folder.    | ||||||
|  | python GetGUIData.py --dataset_name $dataset_name | ||||||
|  | 
 | ||||||
|  | # interactively manipulation  | ||||||
|  | python PlayInteractively.py --dataset_name $dataset_name | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | As shown in the video, to edit an image it is requires to write a _neutral text_ and a _target text_.  | ||||||
|  | To operate the GUI, please do the following: | ||||||
|  | - Maximize the window size | ||||||
|  | - Double click on the left square to choose an image. The images are taken from  `global_directions/data/ffhq`, and the corresponding latent vectors are in `global_directions/data/ffhq/w_plus.npy`. | ||||||
|  | - Type a neutral text, then press enter | ||||||
|  | - Modify the target text so that it will contain the target edit, then press enter. | ||||||
|  | 
 | ||||||
|  | You can now play with: | ||||||
|  | - *Manipulation strength* - positive values correspond to moving along the target direction. | ||||||
|  | - *Disentanglement threshold* - large value means more disentangled edit, just a few channels will be manipulated so only the target attribute will change (for example, grey hair). Small value means less disentangled edit, a large number of channels will be manipulated, related attributes will also change (such as wrinkle, skin color, glasses). | ||||||
|  | 
 | ||||||
|  | ##### Examples: | ||||||
|  | 
 | ||||||
|  | | Edit  | Neutral Text | Target Text | | ||||||
|  | | --- | --- | --- | | ||||||
|  | | Smile  | face  | smiling face | | ||||||
|  | | Gender  | female face  | male face | | ||||||
|  | | Blonde hair | face with hair | face with blonde hair | | ||||||
|  | | Hi-top fade | face with hair | face with Hi-top fade hair | | ||||||
|  | | Blue eyes | face with eyes | face with blue eyes | | ||||||
|  | 
 | ||||||
|  | More examples could be found in the [video](https://www.youtube.com/watch?v=5icI0NgALnQ) and in the paper. | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ##### Pratice Tips:  | ||||||
|  | In the terminal, for every manipulation, the number of channels being manipulated is printed (the number is controlled by the attribute (neutral, target) and the disentanglement threshold). | ||||||
|  | 
 | ||||||
|  | 1. For color transformation, usually 10-20 channels is enough. For large structure change (for example, Hi-top fade), usually 100-200 channels are required. | ||||||
|  | 2. For an attribute (neutral, target), if you give a low disentanglement threshold, there are just few channels (<20) being manipulated, and usually it is not enough for performing the desired edit. | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | #### Notebook | ||||||
|  | Open the notebook in colab and run all the cells. In the last cell you can play with the image. | ||||||
|  | 
 | ||||||
|  | `beta` corresponds to the *disentanglement threshold*, and `alpha` to the *manipulation strength*.  | ||||||
|  | 
 | ||||||
|  | After you set the desired set of parameters, please run again the last cell to generate the image. | ||||||
|  | 
 | ||||||
|  | ## Editing Examples | ||||||
|  | 
 | ||||||
|  | In the following, we show some results obtained with our methods.  | ||||||
|  | All images are real, and were inverted into the StyleGAN's latent space using [e4e](https://github.com/omertov/encoder4editing).  | ||||||
|  | The driving text that was used for each edit appears below or above each image. | ||||||
|  | 
 | ||||||
|  | #### Latent Optimization | ||||||
|  | 
 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | 
 | ||||||
|  | #### Latent Mapper | ||||||
|  | 
 | ||||||
|  |  | ||||||
|  | 
 | ||||||
|  | #### Global Directions | ||||||
|  | 
 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | 
 | ||||||
|  | ## Related Works | ||||||
|  | 
 | ||||||
|  | The global directions we find for editing are direction in the _S Space_, which was introduced and analyzed in [StyleSpace](https://arxiv.org/abs/2011.12799) (Wu et al). | ||||||
|  | 
 | ||||||
|  | To edit real images, we inverted them to the StyleGAN's latent space using [e4e](https://arxiv.org/abs/2102.02766) (Tov et al.). | ||||||
|  | 
 | ||||||
|  | The code strcuture of the mapper is heavily based on [pSp](https://github.com/eladrich/pixel2style2pixel).  | ||||||
|  | 
 | ||||||
|  | ## Citation | ||||||
|  | 
 | ||||||
|  | If you use this code for your research, please cite our paper: | ||||||
|  | 
 | ||||||
|  | ``` | ||||||
|  | @InProceedings{Patashnik_2021_ICCV, | ||||||
|  |     author    = {Patashnik, Or and Wu, Zongze and Shechtman, Eli and Cohen-Or, Daniel and Lischinski, Dani}, | ||||||
|  |     title     = {StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery}, | ||||||
|  |     booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, | ||||||
|  |     month     = {October}, | ||||||
|  |     year      = {2021}, | ||||||
|  |     pages     = {2085-2094} | ||||||
|  | } | ||||||
|  | ``` | ||||||
							
								
								
									
										34
									
								
								cog.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								cog.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,34 @@ | |||||||
|  | build: | ||||||
|  |   gpu: true | ||||||
|  |   system_packages: | ||||||
|  |     - libgl1-mesa-glx | ||||||
|  |     - libglib2.0-0 | ||||||
|  |     - cmake | ||||||
|  |     - zip | ||||||
|  |   python_version: 3.7 | ||||||
|  |   python_packages: | ||||||
|  |     - torch==1.7.1 | ||||||
|  |     - tensorflow==1.15.0 | ||||||
|  |     - torchvision==0.8.2 | ||||||
|  |     - torchaudio==0.7.2 | ||||||
|  |     - ftfy==5.9 | ||||||
|  |     - regex==2021.4.4 | ||||||
|  |     - tqdm==4.59.0 | ||||||
|  |     - requests==2.25.1 | ||||||
|  |     - matplotlib==3.4.1 | ||||||
|  |     - opencv-python==4.3.0.38 | ||||||
|  |     - dlib==19.18.0 | ||||||
|  |     - scipy==1.6.3 | ||||||
|  |     - "git+git://github.com/openai/CLIP.git@8a665a683d791ed3491fedadcb3c91878f9eb78d" | ||||||
|  |   pre_install: | ||||||
|  |     - "mkdir /content" | ||||||
|  |     - "git clone https://github.com/omertov/encoder4editing.git /content/encoder4editing" | ||||||
|  |     - "wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip" | ||||||
|  |     - "unzip ninja-linux.zip -d /usr/local/bin/" | ||||||
|  |     - "update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force" | ||||||
|  |     - "wget -O /content/shape_predictor_68_face_landmarks.dat.bz2 http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2" | ||||||
|  |     - "cd /content && bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2" | ||||||
|  |     - "echo > /content/encoder4editing/__init__.py" | ||||||
|  |     - | | ||||||
|  |       sed -i 's/img = PIL.Image.open(filepath)/img = PIL.Image.open(filepath).convert(\"RGB\")/' /content/encoder4editing/utils/alignment.py | ||||||
|  | predict: cog_predict.py:Predictor | ||||||
							
								
								
									
										196
									
								
								cog_predict.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										196
									
								
								cog_predict.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,196 @@ | |||||||
|  | import copy | ||||||
|  | import os | ||||||
|  | import pickle | ||||||
|  | import sys | ||||||
|  | import tempfile | ||||||
|  | import time | ||||||
|  | from argparse import Namespace | ||||||
|  | from pathlib import Path | ||||||
|  | 
 | ||||||
|  | import clip | ||||||
|  | import cog | ||||||
|  | import dlib | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | import numpy as np | ||||||
|  | import tensorflow as tf | ||||||
|  | import torch | ||||||
|  | import torchvision.transforms as transforms | ||||||
|  | from PIL import Image | ||||||
|  | 
 | ||||||
|  | sys.path.insert(0, "/content") | ||||||
|  | sys.path.insert(0, "/content/encoder4editing") | ||||||
|  | 
 | ||||||
|  | from encoder4editing.models.psp import pSp | ||||||
|  | from encoder4editing.utils.alignment import align_face | ||||||
|  | from encoder4editing.utils.common import tensor2im | ||||||
|  | 
 | ||||||
|  | os.chdir("global_directions") | ||||||
|  | sys.path.insert(0, ".") | ||||||
|  | 
 | ||||||
|  | from dnnlib import tflib | ||||||
|  | from manipulate import Manipulator | ||||||
|  | from MapTS import GetBoundary, GetDt, GetFs | ||||||
|  | 
 | ||||||
|  | class Predictor(cog.Predictor): | ||||||
|  |     def setup(self): | ||||||
|  | 
 | ||||||
|  |         print("starting setup") | ||||||
|  | 
 | ||||||
|  |         self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||||||
|  |         self.model, self.preprocess = clip.load( | ||||||
|  |             "ViT-B/32", device=self.device, jit=False | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self.graph = tf.get_default_graph() | ||||||
|  |         gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333) | ||||||
|  |         self.sess = tf.Session( | ||||||
|  |             graph=self.graph, config=tf.ConfigProto(gpu_options=gpu_options) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self.experiment_args = {"model_path": "e4e_ffhq_encode.pt"} | ||||||
|  |         self.experiment_args["transform"] = transforms.Compose( | ||||||
|  |             [ | ||||||
|  |                 transforms.Resize((256, 256)), | ||||||
|  |                 transforms.ToTensor(), | ||||||
|  |                 transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |         self.resize_dims = (256, 256) | ||||||
|  | 
 | ||||||
|  |         model_path = self.experiment_args["model_path"] | ||||||
|  | 
 | ||||||
|  |         ckpt = torch.load(model_path, map_location="cpu") | ||||||
|  |         opts = ckpt["opts"] | ||||||
|  |         # pprint.pprint(opts)  # Display full options used | ||||||
|  |         # update the training options | ||||||
|  |         opts["checkpoint_path"] = model_path | ||||||
|  |         opts = Namespace(**opts) | ||||||
|  | 
 | ||||||
|  |         self.net = pSp(opts) | ||||||
|  |         self.net.eval() | ||||||
|  |         self.net.cuda() | ||||||
|  | 
 | ||||||
|  |         self.shape_predictor = dlib.shape_predictor( | ||||||
|  |             "/content/shape_predictor_68_face_landmarks.dat" | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         with self.graph.as_default(), self.sess.as_default(): | ||||||
|  |             #tflib.init_tf() | ||||||
|  | 
 | ||||||
|  |             self.M = Manipulator(dataset_name="ffhq", sess=self.sess) | ||||||
|  |             self.fs3 = np.load("npy/ffhq/fs3.npy") | ||||||
|  |             np.set_printoptions(suppress=True) | ||||||
|  | 
 | ||||||
|  |         print("setup complete") | ||||||
|  | 
 | ||||||
|  |     @cog.input("input", type=Path, help="Input image") | ||||||
|  |     @cog.input("neutral", type=str, help="Neutral image description") | ||||||
|  |     @cog.input("target", type=str, help="Target image description") | ||||||
|  |     @cog.input( | ||||||
|  |         "manipulation_strength", | ||||||
|  |         type=float, | ||||||
|  |         min=-10, | ||||||
|  |         max=10, | ||||||
|  |         default=4.1, | ||||||
|  |         help="The higher the manipulation strength, the closer the generated image becomes to the target description. Negative values moves the generated image further from the target description", | ||||||
|  |     ) | ||||||
|  |     @cog.input( | ||||||
|  |         "disentanglement_threshold", | ||||||
|  |         type=float, | ||||||
|  |         min=0.08, | ||||||
|  |         max=0.3, | ||||||
|  |         default=0.15, | ||||||
|  |         help="The higher the disentanglement threshold, the more specific the changes are to the target attribute. Lower values mean that broader changes are made to the input image", | ||||||
|  |     ) | ||||||
|  |     def predict( | ||||||
|  |         self, | ||||||
|  |         input, | ||||||
|  |         neutral, | ||||||
|  |         target, | ||||||
|  |         manipulation_strength, | ||||||
|  |         disentanglement_threshold, | ||||||
|  |     ): | ||||||
|  | 
 | ||||||
|  |         # @title Align image | ||||||
|  |         #original_image = Image.open(str(input)) | ||||||
|  |         #original_image = original_image.convert("RGB") | ||||||
|  |         input_image = self.run_alignment(str(input)) | ||||||
|  |         #input_image = original_image | ||||||
|  |         input_image = input_image.resize(self.resize_dims) | ||||||
|  | 
 | ||||||
|  |         img_transforms = self.experiment_args["transform"] | ||||||
|  |         transformed_image = img_transforms(input_image) | ||||||
|  | 
 | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             images, latents = self.run_on_batch(transformed_image.unsqueeze(0)) | ||||||
|  |             result_image, latent = images[0], latents[0] | ||||||
|  | 
 | ||||||
|  |             print("latents", latents) | ||||||
|  | 
 | ||||||
|  |         print(transformed_image.shape, result_image.shape) | ||||||
|  | 
 | ||||||
|  |         w_plus = latents.cpu().detach().numpy() | ||||||
|  |         with self.graph.as_default(), self.sess.as_default(): | ||||||
|  |             dlatents_loaded = self.M.W2S(w_plus) | ||||||
|  | 
 | ||||||
|  |         #print("w_plus, dlatents_loaded", w_plus, dlatents_loaded) | ||||||
|  | 
 | ||||||
|  |         img_index = 0 | ||||||
|  |         w_plus=latents.cpu().detach().numpy() | ||||||
|  |         with self.graph.as_default(), self.sess.as_default(): | ||||||
|  |             dlatents_loaded=self.M.W2S(w_plus) | ||||||
|  | 
 | ||||||
|  |         img_indexs=[img_index] | ||||||
|  |         dlatent_tmp=[tmp[img_indexs] for tmp in dlatents_loaded] | ||||||
|  |         with self.graph.as_default(), self.sess.as_default(): | ||||||
|  |             self.M.num_images = len(img_indexs) | ||||||
|  |             self.M.alpha = [0] | ||||||
|  |             self.M.manipulate_layers = [0] | ||||||
|  | 
 | ||||||
|  |         with self.graph.as_default(), self.sess.as_default(): | ||||||
|  |             codes, out = self.M.EditOneC(0, dlatent_tmp) | ||||||
|  | 
 | ||||||
|  |         original = Image.fromarray(out[0, 0]).resize((512, 512)) | ||||||
|  |         with self.graph.as_default(), self.sess.as_default(): | ||||||
|  |             self.M.manipulate_layers = None | ||||||
|  | 
 | ||||||
|  |         classnames = [target, neutral] | ||||||
|  |         dt = GetDt(classnames, self.model) | ||||||
|  | 
 | ||||||
|  |         with self.graph.as_default(), self.sess.as_default(): | ||||||
|  |             self.M.alpha = [manipulation_strength] | ||||||
|  |             boundary_tmp2, c = GetBoundary( | ||||||
|  |                 self.fs3, dt, self.M, threshold=disentanglement_threshold | ||||||
|  |             ) | ||||||
|  |             codes = self.M.MSCode(dlatent_tmp, boundary_tmp2) | ||||||
|  |             out = self.M.GenerateImg(codes) | ||||||
|  |         generated = Image.fromarray(out[0, 0])  # .resize((512,512)) | ||||||
|  | 
 | ||||||
|  |         out_path = Path(tempfile.mkdtemp()) / "out.jpg" | ||||||
|  |         generated.save(str(out_path)) | ||||||
|  | 
 | ||||||
|  |         return out_path | ||||||
|  | 
 | ||||||
|  |     def run_alignment(self, image_path): | ||||||
|  |         aligned_image = align_face(filepath=image_path, predictor=self.shape_predictor) | ||||||
|  |         print("Aligned image has shape: {}".format(aligned_image.size)) | ||||||
|  |         return aligned_image | ||||||
|  | 
 | ||||||
|  |     def run_on_batch(self, inputs): | ||||||
|  |         images, latents = self.net( | ||||||
|  |             inputs.to("cuda").float(), randomize_noise=False, return_latents=True | ||||||
|  |         ) | ||||||
|  |         return images, latents | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def concat_images(*images): | ||||||
|  |     width = 0 | ||||||
|  |     for im in images: | ||||||
|  |         width += im.width | ||||||
|  |     height = max([im.height for im in images]) | ||||||
|  |     concat = Image.new("RGB", (width, height)) | ||||||
|  |     offset = 0 | ||||||
|  |     for im in images: | ||||||
|  |         concat.paste(im, (offset, 0)) | ||||||
|  |         offset += im.width | ||||||
|  |     return concat | ||||||
							
								
								
									
										0
									
								
								criteria/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								criteria/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										17
									
								
								criteria/clip_loss.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								criteria/clip_loss.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,17 @@ | |||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | import clip | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class CLIPLoss(torch.nn.Module): | ||||||
|  | 
 | ||||||
|  |     def __init__(self, opts): | ||||||
|  |         super(CLIPLoss, self).__init__() | ||||||
|  |         self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") | ||||||
|  |         self.upsample = torch.nn.Upsample(scale_factor=7) | ||||||
|  |         self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32) | ||||||
|  | 
 | ||||||
|  |     def forward(self, image, text): | ||||||
|  |         image = self.avg_pool(self.upsample(image)) | ||||||
|  |         similarity = 1 - self.model(image, text)[0] / 100 | ||||||
|  |         return similarity | ||||||
							
								
								
									
										40
									
								
								criteria/id_loss.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								criteria/id_loss.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,40 @@ | |||||||
|  | import torch | ||||||
|  | from torch import nn | ||||||
|  | 
 | ||||||
|  | from models.facial_recognition.model_irse import Backbone | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class IDLoss(nn.Module): | ||||||
|  |     def __init__(self, opts): | ||||||
|  |         super(IDLoss, self).__init__() | ||||||
|  |         print('Loading ResNet ArcFace') | ||||||
|  |         self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') | ||||||
|  |         self.facenet.load_state_dict(torch.load(opts.ir_se50_weights)) | ||||||
|  |         self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | ||||||
|  |         self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) | ||||||
|  |         self.facenet.eval() | ||||||
|  |         self.facenet.cuda() | ||||||
|  |         self.opts = opts | ||||||
|  | 
 | ||||||
|  |     def extract_feats(self, x): | ||||||
|  |         if x.shape[2] != 256: | ||||||
|  |             x = self.pool(x) | ||||||
|  |         x = x[:, :, 35:223, 32:220]  # Crop interesting region | ||||||
|  |         x = self.face_pool(x) | ||||||
|  |         x_feats = self.facenet(x) | ||||||
|  |         return x_feats | ||||||
|  | 
 | ||||||
|  |     def forward(self, y_hat, y): | ||||||
|  |         n_samples = y.shape[0] | ||||||
|  |         y_feats = self.extract_feats(y)  # Otherwise use the feature from there | ||||||
|  |         y_hat_feats = self.extract_feats(y_hat) | ||||||
|  |         y_feats = y_feats.detach() | ||||||
|  |         loss = 0 | ||||||
|  |         sim_improvement = 0 | ||||||
|  |         count = 0 | ||||||
|  |         for i in range(n_samples): | ||||||
|  |             diff_target = y_hat_feats[i].dot(y_feats[i]) | ||||||
|  |             loss += 1 - diff_target | ||||||
|  |             count += 1 | ||||||
|  | 
 | ||||||
|  |         return loss / count, sim_improvement / count | ||||||
							
								
								
									
										127
									
								
								global_torch/SingleChannel.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								global_torch/SingleChannel.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,127 @@ | |||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | import numpy as np  | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | from PIL import Image | ||||||
|  | import copy | ||||||
|  | from manipulate import Manipulator | ||||||
|  | import argparse | ||||||
|  | 
 | ||||||
|  | import sys  | ||||||
|  | sys.path.append('/cs/labs/danix/wuzongze/Tansformer_Manipulation/CLIP/') | ||||||
|  | import clip | ||||||
|  | 
 | ||||||
|  | def GetImgF(out,model,preprocess): | ||||||
|  |     imgs=out | ||||||
|  |     imgs1=imgs.reshape([-1]+list(imgs.shape[2:])) | ||||||
|  |      | ||||||
|  |     tmp=[] | ||||||
|  |     for i in range(len(imgs1)): | ||||||
|  |          | ||||||
|  |         img=Image.fromarray(imgs1[i]) | ||||||
|  |         image = preprocess(img).unsqueeze(0).to(device) | ||||||
|  |         tmp.append(image) | ||||||
|  |      | ||||||
|  |     image=torch.cat(tmp) | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         image_features = model.encode_image(image) | ||||||
|  |      | ||||||
|  |     image_features1=image_features.cpu().numpy() | ||||||
|  |     image_features1=image_features1.reshape(list(imgs.shape[:2])+[512]) | ||||||
|  |      | ||||||
|  |     return image_features1 | ||||||
|  | 
 | ||||||
|  | def GetFs(fs): | ||||||
|  |     tmp=np.linalg.norm(fs,axis=-1) | ||||||
|  |     fs1=fs/tmp[:,:,:,None] | ||||||
|  |     fs2=fs1[:,:,1,:]-fs1[:,:,0,:]  # 5*sigma - (-5)* sigma | ||||||
|  |     fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None] | ||||||
|  |     fs3=fs3.mean(axis=1) | ||||||
|  |     fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None] | ||||||
|  |     return fs3 | ||||||
|  | 
 | ||||||
|  | #%% | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     ''' | ||||||
|  |     parser = argparse.ArgumentParser(description='Process some integers.') | ||||||
|  |      | ||||||
|  |     parser.add_argument('--dataset_name',type=str,default='cat', | ||||||
|  |                     help='name of dataset, for example, ffhq') | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     dataset_name=args.dataset_name | ||||||
|  |     ''' | ||||||
|  |     #%% | ||||||
|  |     device = "cuda" if torch.cuda.is_available() else "cpu" | ||||||
|  |     model, preprocess = clip.load("ViT-B/32", device=device,jit=False) | ||||||
|  |     #%% | ||||||
|  |      | ||||||
|  |     network_pkl='/cs/labs/danix/wuzongze/Gan_Manipulation/stylegan2/model/stylegan2-human-config-f.pkl' | ||||||
|  |     device = torch.device('cuda') | ||||||
|  |     M=Manipulator() | ||||||
|  |     M.device=device | ||||||
|  |     G=M.LoadModel(network_pkl,device) | ||||||
|  |     M.G=G | ||||||
|  |     M.SetGParameters() | ||||||
|  |     num_img=100_000 | ||||||
|  |     M.GenerateS(num_img=num_img) | ||||||
|  |     M.GetCodeMS() | ||||||
|  |      | ||||||
|  |     # M=Manipulator(dataset_name=dataset_name) | ||||||
|  |     np.set_printoptions(suppress=True) | ||||||
|  |     # print(M.dataset_name) | ||||||
|  |     #%% | ||||||
|  |     img_sindex=0 | ||||||
|  |     num_images=100 | ||||||
|  |     dlatents_o=[] | ||||||
|  |     tmp=img_sindex*num_images | ||||||
|  |     for i in range(len(M.dlatents)): | ||||||
|  |         tmp1=M.dlatents[i][tmp:(tmp+num_images)] | ||||||
|  |         dlatents_o.append(tmp1) | ||||||
|  |     #%% | ||||||
|  |      | ||||||
|  |     all_f=[] | ||||||
|  |     M.alpha=[-5,5] #ffhq 5 | ||||||
|  |     M.step=2 | ||||||
|  |     M.num_images=num_images | ||||||
|  |     select=np.array(M.mindexs)<=16 #below or equal to 128 resolution  | ||||||
|  |     mindexs2=np.array(M.mindexs)[select] | ||||||
|  |     for lindex in mindexs2: #ignore ToRGB layers | ||||||
|  |         print(lindex) | ||||||
|  |         num_c=M.dlatents[lindex].shape[1] | ||||||
|  |         for cindex in range(num_c): | ||||||
|  |              | ||||||
|  |             M.dlatents=copy.copy(dlatents_o) | ||||||
|  |             M.dlatents[lindex][:,cindex]=M.code_mean[lindex][cindex] | ||||||
|  |              | ||||||
|  |             M.manipulate_layers=[lindex] | ||||||
|  |             codes,out=M.EditOneC(cindex)  | ||||||
|  |             image_features1=GetImgF(out,model,preprocess) | ||||||
|  |             all_f.append(image_features1) | ||||||
|  |      | ||||||
|  |     all_f=np.array(all_f) | ||||||
|  |      | ||||||
|  |     fs3=GetFs(all_f) | ||||||
|  |      | ||||||
|  |     #%% | ||||||
|  |     # file_path='./npy/'+M.dataset_name+'/' | ||||||
|  |     file_path='/cs/labs/danix/wuzongze/Gan_Manipulation/stylegan2/results/npy/human/' | ||||||
|  |     np.save(file_path+'fs3',fs3) | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
							
								
								
									
										246
									
								
								global_torch/StyleCLIP.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										246
									
								
								global_torch/StyleCLIP.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,246 @@ | |||||||
|  | #!/usr/bin/env python3 | ||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | """ | ||||||
|  | Created on Tue Jun 14 09:40:28 2022 | ||||||
|  | 
 | ||||||
|  | @author: wuzongze | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | import os | ||||||
|  | 
 | ||||||
|  | import sys  | ||||||
|  | import numpy as np  | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | from PIL import Image | ||||||
|  | import pickle | ||||||
|  | import copy | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | from manipulate import Manipulator | ||||||
|  | 
 | ||||||
|  | import clip | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def SplitS(ds_p,M,if_std): | ||||||
|  |     all_ds=[] | ||||||
|  |     start=0 | ||||||
|  |     for i in M.mindexs: | ||||||
|  |         tmp=M.dlatents[i].shape[1] | ||||||
|  |         end=start+tmp | ||||||
|  |         tmp=ds_p[start:end] | ||||||
|  | #        tmp=tmp*M.code_std[i] | ||||||
|  |          | ||||||
|  |         all_ds.append(tmp) | ||||||
|  |         start=end | ||||||
|  |      | ||||||
|  |     all_ds2=[] | ||||||
|  |     tmp_index=0 | ||||||
|  |     for i in range(len(M.s_names)): | ||||||
|  |         if (not 'RGB' in M.s_names[i]) and (not len(all_ds[tmp_index])==0): | ||||||
|  |              | ||||||
|  |             if if_std: | ||||||
|  |                 tmp=all_ds[tmp_index]*M.code_std[i] | ||||||
|  |             else: | ||||||
|  |                 tmp=all_ds[tmp_index] | ||||||
|  |              | ||||||
|  |             all_ds2.append(tmp) | ||||||
|  |             tmp_index+=1 | ||||||
|  |         else: | ||||||
|  |             tmp=np.zeros(len(M.dlatents[i][0])) | ||||||
|  |             all_ds2.append(tmp) | ||||||
|  |     return all_ds2 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | imagenet_templates = [ | ||||||
|  |     'a bad photo of a {}.', | ||||||
|  | #    'a photo of many {}.', | ||||||
|  |     'a sculpture of a {}.', | ||||||
|  |     'a photo of the hard to see {}.', | ||||||
|  |     'a low resolution photo of the {}.', | ||||||
|  |     'a rendering of a {}.', | ||||||
|  |     'graffiti of a {}.', | ||||||
|  |     'a bad photo of the {}.', | ||||||
|  |     'a cropped photo of the {}.', | ||||||
|  |     'a tattoo of a {}.', | ||||||
|  |     'the embroidered {}.', | ||||||
|  |     'a photo of a hard to see {}.', | ||||||
|  |     'a bright photo of a {}.', | ||||||
|  |     'a photo of a clean {}.', | ||||||
|  |     'a photo of a dirty {}.', | ||||||
|  |     'a dark photo of the {}.', | ||||||
|  |     'a drawing of a {}.', | ||||||
|  |     'a photo of my {}.', | ||||||
|  |     'the plastic {}.', | ||||||
|  |     'a photo of the cool {}.', | ||||||
|  |     'a close-up photo of a {}.', | ||||||
|  |     'a black and white photo of the {}.', | ||||||
|  |     'a painting of the {}.', | ||||||
|  |     'a painting of a {}.', | ||||||
|  |     'a pixelated photo of the {}.', | ||||||
|  |     'a sculpture of the {}.', | ||||||
|  |     'a bright photo of the {}.', | ||||||
|  |     'a cropped photo of a {}.', | ||||||
|  |     'a plastic {}.', | ||||||
|  |     'a photo of the dirty {}.', | ||||||
|  |     'a jpeg corrupted photo of a {}.', | ||||||
|  |     'a blurry photo of the {}.', | ||||||
|  |     'a photo of the {}.', | ||||||
|  |     'a good photo of the {}.', | ||||||
|  |     'a rendering of the {}.', | ||||||
|  |     'a {} in a video game.', | ||||||
|  |     'a photo of one {}.', | ||||||
|  |     'a doodle of a {}.', | ||||||
|  |     'a close-up photo of the {}.', | ||||||
|  |     'a photo of a {}.', | ||||||
|  |     'the origami {}.', | ||||||
|  |     'the {} in a video game.', | ||||||
|  |     'a sketch of a {}.', | ||||||
|  |     'a doodle of the {}.', | ||||||
|  |     'a origami {}.', | ||||||
|  |     'a low resolution photo of a {}.', | ||||||
|  |     'the toy {}.', | ||||||
|  |     'a rendition of the {}.', | ||||||
|  |     'a photo of the clean {}.', | ||||||
|  |     'a photo of a large {}.', | ||||||
|  |     'a rendition of a {}.', | ||||||
|  |     'a photo of a nice {}.', | ||||||
|  |     'a photo of a weird {}.', | ||||||
|  |     'a blurry photo of a {}.', | ||||||
|  |     'a cartoon {}.', | ||||||
|  |     'art of a {}.', | ||||||
|  |     'a sketch of the {}.', | ||||||
|  |     'a embroidered {}.', | ||||||
|  |     'a pixelated photo of a {}.', | ||||||
|  |     'itap of the {}.', | ||||||
|  |     'a jpeg corrupted photo of the {}.', | ||||||
|  |     'a good photo of a {}.', | ||||||
|  |     'a plushie {}.', | ||||||
|  |     'a photo of the nice {}.', | ||||||
|  |     'a photo of the small {}.', | ||||||
|  |     'a photo of the weird {}.', | ||||||
|  |     'the cartoon {}.', | ||||||
|  |     'art of the {}.', | ||||||
|  |     'a drawing of the {}.', | ||||||
|  |     'a photo of the large {}.', | ||||||
|  |     'a black and white photo of a {}.', | ||||||
|  |     'the plushie {}.', | ||||||
|  |     'a dark photo of a {}.', | ||||||
|  |     'itap of a {}.', | ||||||
|  |     'graffiti of the {}.', | ||||||
|  |     'a toy {}.', | ||||||
|  |     'itap of my {}.', | ||||||
|  |     'a photo of a cool {}.', | ||||||
|  |     'a photo of a small {}.', | ||||||
|  |     'a tattoo of the {}.', | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def zeroshot_classifier(classnames, templates,model): | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         zeroshot_weights = [] | ||||||
|  |         for classname in classnames: | ||||||
|  |             texts = [template.format(classname) for template in templates] #format with class | ||||||
|  |             texts = clip.tokenize(texts).cuda() #tokenize | ||||||
|  |             class_embeddings = model.encode_text(texts) #embed with text encoder | ||||||
|  |             class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | ||||||
|  |             class_embedding = class_embeddings.mean(dim=0) | ||||||
|  |             class_embedding /= class_embedding.norm() | ||||||
|  |             zeroshot_weights.append(class_embedding) | ||||||
|  |         zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() | ||||||
|  |     return zeroshot_weights | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def GetDt(classnames,model): | ||||||
|  |     text_features=zeroshot_classifier(classnames, imagenet_templates,model).t() | ||||||
|  |      | ||||||
|  |     dt=text_features[0]-text_features[1] | ||||||
|  |     dt=dt.cpu().numpy() | ||||||
|  |      | ||||||
|  | 
 | ||||||
|  |     print(np.linalg.norm(dt)) | ||||||
|  |     dt=dt/np.linalg.norm(dt) | ||||||
|  |     return dt | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def GetBoundary(fs3,dt,M,threshold): | ||||||
|  |     tmp=np.dot(fs3,dt) | ||||||
|  |      | ||||||
|  |     ds_imp=copy.copy(tmp) | ||||||
|  |     select=np.abs(tmp)<threshold | ||||||
|  |     num_c=np.sum(~select) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     ds_imp[select]=0 | ||||||
|  |     tmp=np.abs(ds_imp).max() | ||||||
|  |     ds_imp/=tmp | ||||||
|  |      | ||||||
|  |     boundary_tmp2=SplitS(ds_imp,M,if_std=True) | ||||||
|  |     print('num of channels being manipulated:',num_c) | ||||||
|  |     return boundary_tmp2,num_c | ||||||
|  | 
 | ||||||
|  | #%% | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     device = "cuda" if torch.cuda.is_available() else "cpu" | ||||||
|  |     model, preprocess = clip.load("ViT-B/32", device=device,jit=False) | ||||||
|  |      | ||||||
|  |     # pls download the checkpoint from https://drive.google.com/file/d/1FlAb1rYa0r_--Zj_ML8e6shmaF28hQb5/view | ||||||
|  |     network_pkl='/cs/labs/danix/wuzongze/Gan_Manipulation/stylegan2/model/stylegan2-human-config-f.pkl' | ||||||
|  |     device = torch.device('cuda') | ||||||
|  |     M=Manipulator() | ||||||
|  |     M.device=device | ||||||
|  |     G=M.LoadModel(network_pkl,device) | ||||||
|  |     M.G=G | ||||||
|  |     M.SetGParameters() | ||||||
|  |     num_img=100_000 | ||||||
|  |     M.GenerateS(num_img=num_img) | ||||||
|  |     M.GetCodeMS() | ||||||
|  |     np.set_printoptions(suppress=True) | ||||||
|  |     #%% | ||||||
|  |     file_path='./npy/human/' | ||||||
|  |     fs3=np.load(file_path+'fs3.npy') | ||||||
|  |     #%% | ||||||
|  |     img_indexs=np.arange(20) | ||||||
|  |      | ||||||
|  |     dlatent_tmp=[tmp[img_indexs] for tmp in M.dlatents] | ||||||
|  |     M.num_images=len(img_indexs) | ||||||
|  |     #%% | ||||||
|  | 
 | ||||||
|  |     paras=[ | ||||||
|  |             ['person', 'original', 0, 0], | ||||||
|  |             ['woman', 'man', 0.2, 3], | ||||||
|  |             ['person', 'person with T-shirt', 0.15, 4], | ||||||
|  |             ['person', 'person with jeans', 0.15, 4], | ||||||
|  |             ['person', 'person with jacket', 0.15, 4], | ||||||
|  |     ] | ||||||
|  |     paras=np.array(paras) | ||||||
|  |     #%% | ||||||
|  |      | ||||||
|  |     M.step=1 | ||||||
|  |      | ||||||
|  |      | ||||||
|  |     imgs=[] | ||||||
|  |     all_b=[] | ||||||
|  |     for i in range(len(paras)): | ||||||
|  |          | ||||||
|  |         neutral,target,beta,alpha=paras[i] | ||||||
|  |         beta=np.float32(beta) | ||||||
|  |         alpha=np.float32(alpha) | ||||||
|  |         M.alpha=[alpha] | ||||||
|  |         print() | ||||||
|  |         print(target) | ||||||
|  |         classnames=[target,neutral] | ||||||
|  |         dt=GetDt(classnames,model) | ||||||
|  |         boundary_tmp2,num_c=GetBoundary(fs3,dt,M,threshold=beta) | ||||||
|  |         all_b.append(boundary_tmp2) | ||||||
|  |         codes=M.MSCode(dlatent_tmp,boundary_tmp2) | ||||||
|  |          | ||||||
|  |         out=M.GenerateImg(codes) | ||||||
|  |         imgs.append(out) | ||||||
|  |      | ||||||
|  | 
 | ||||||
|  |     imgs=np.concatenate(imgs,axis=1) | ||||||
|  |     M.step=imgs.shape[1] | ||||||
|  |     M.Vis('real','',imgs,colnames=list(paras[:,1]),rownames=img_indexs,viz_size=1024) | ||||||
|  |      | ||||||
|  |      | ||||||
|  |      | ||||||
							
								
								
									
										9
									
								
								global_torch/dnnlib/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								global_torch/dnnlib/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | from .util import EasyDict, make_cache_dir_path | ||||||
							
								
								
									
										477
									
								
								global_torch/dnnlib/util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										477
									
								
								global_torch/dnnlib/util.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,477 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Miscellaneous utility classes and functions.""" | ||||||
|  | 
 | ||||||
|  | import ctypes | ||||||
|  | import fnmatch | ||||||
|  | import importlib | ||||||
|  | import inspect | ||||||
|  | import numpy as np | ||||||
|  | import os | ||||||
|  | import shutil | ||||||
|  | import sys | ||||||
|  | import types | ||||||
|  | import io | ||||||
|  | import pickle | ||||||
|  | import re | ||||||
|  | import requests | ||||||
|  | import html | ||||||
|  | import hashlib | ||||||
|  | import glob | ||||||
|  | import tempfile | ||||||
|  | import urllib | ||||||
|  | import urllib.request | ||||||
|  | import uuid | ||||||
|  | 
 | ||||||
|  | from distutils.util import strtobool | ||||||
|  | from typing import Any, List, Tuple, Union | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Util classes | ||||||
|  | # ------------------------------------------------------------------------------------------ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class EasyDict(dict): | ||||||
|  |     """Convenience class that behaves like a dict but allows access with the attribute syntax.""" | ||||||
|  | 
 | ||||||
|  |     def __getattr__(self, name: str) -> Any: | ||||||
|  |         try: | ||||||
|  |             return self[name] | ||||||
|  |         except KeyError: | ||||||
|  |             raise AttributeError(name) | ||||||
|  | 
 | ||||||
|  |     def __setattr__(self, name: str, value: Any) -> None: | ||||||
|  |         self[name] = value | ||||||
|  | 
 | ||||||
|  |     def __delattr__(self, name: str) -> None: | ||||||
|  |         del self[name] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Logger(object): | ||||||
|  |     """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" | ||||||
|  | 
 | ||||||
|  |     def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): | ||||||
|  |         self.file = None | ||||||
|  | 
 | ||||||
|  |         if file_name is not None: | ||||||
|  |             self.file = open(file_name, file_mode) | ||||||
|  | 
 | ||||||
|  |         self.should_flush = should_flush | ||||||
|  |         self.stdout = sys.stdout | ||||||
|  |         self.stderr = sys.stderr | ||||||
|  | 
 | ||||||
|  |         sys.stdout = self | ||||||
|  |         sys.stderr = self | ||||||
|  | 
 | ||||||
|  |     def __enter__(self) -> "Logger": | ||||||
|  |         return self | ||||||
|  | 
 | ||||||
|  |     def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: | ||||||
|  |         self.close() | ||||||
|  | 
 | ||||||
|  |     def write(self, text: Union[str, bytes]) -> None: | ||||||
|  |         """Write text to stdout (and a file) and optionally flush.""" | ||||||
|  |         if isinstance(text, bytes): | ||||||
|  |             text = text.decode() | ||||||
|  |         if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         if self.file is not None: | ||||||
|  |             self.file.write(text) | ||||||
|  | 
 | ||||||
|  |         self.stdout.write(text) | ||||||
|  | 
 | ||||||
|  |         if self.should_flush: | ||||||
|  |             self.flush() | ||||||
|  | 
 | ||||||
|  |     def flush(self) -> None: | ||||||
|  |         """Flush written text to both stdout and a file, if open.""" | ||||||
|  |         if self.file is not None: | ||||||
|  |             self.file.flush() | ||||||
|  | 
 | ||||||
|  |         self.stdout.flush() | ||||||
|  | 
 | ||||||
|  |     def close(self) -> None: | ||||||
|  |         """Flush, close possible files, and remove stdout/stderr mirroring.""" | ||||||
|  |         self.flush() | ||||||
|  | 
 | ||||||
|  |         # if using multiple loggers, prevent closing in wrong order | ||||||
|  |         if sys.stdout is self: | ||||||
|  |             sys.stdout = self.stdout | ||||||
|  |         if sys.stderr is self: | ||||||
|  |             sys.stderr = self.stderr | ||||||
|  | 
 | ||||||
|  |         if self.file is not None: | ||||||
|  |             self.file.close() | ||||||
|  |             self.file = None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Cache directories | ||||||
|  | # ------------------------------------------------------------------------------------------ | ||||||
|  | 
 | ||||||
|  | _dnnlib_cache_dir = None | ||||||
|  | 
 | ||||||
|  | def set_cache_dir(path: str) -> None: | ||||||
|  |     global _dnnlib_cache_dir | ||||||
|  |     _dnnlib_cache_dir = path | ||||||
|  | 
 | ||||||
|  | def make_cache_dir_path(*paths: str) -> str: | ||||||
|  |     if _dnnlib_cache_dir is not None: | ||||||
|  |         return os.path.join(_dnnlib_cache_dir, *paths) | ||||||
|  |     if 'DNNLIB_CACHE_DIR' in os.environ: | ||||||
|  |         return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) | ||||||
|  |     if 'HOME' in os.environ: | ||||||
|  |         return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) | ||||||
|  |     if 'USERPROFILE' in os.environ: | ||||||
|  |         return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) | ||||||
|  |     return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) | ||||||
|  | 
 | ||||||
|  | # Small util functions | ||||||
|  | # ------------------------------------------------------------------------------------------ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def format_time(seconds: Union[int, float]) -> str: | ||||||
|  |     """Convert the seconds to human readable string with days, hours, minutes and seconds.""" | ||||||
|  |     s = int(np.rint(seconds)) | ||||||
|  | 
 | ||||||
|  |     if s < 60: | ||||||
|  |         return "{0}s".format(s) | ||||||
|  |     elif s < 60 * 60: | ||||||
|  |         return "{0}m {1:02}s".format(s // 60, s % 60) | ||||||
|  |     elif s < 24 * 60 * 60: | ||||||
|  |         return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) | ||||||
|  |     else: | ||||||
|  |         return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def ask_yes_no(question: str) -> bool: | ||||||
|  |     """Ask the user the question until the user inputs a valid answer.""" | ||||||
|  |     while True: | ||||||
|  |         try: | ||||||
|  |             print("{0} [y/n]".format(question)) | ||||||
|  |             return strtobool(input().lower()) | ||||||
|  |         except ValueError: | ||||||
|  |             pass | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def tuple_product(t: Tuple) -> Any: | ||||||
|  |     """Calculate the product of the tuple elements.""" | ||||||
|  |     result = 1 | ||||||
|  | 
 | ||||||
|  |     for v in t: | ||||||
|  |         result *= v | ||||||
|  | 
 | ||||||
|  |     return result | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | _str_to_ctype = { | ||||||
|  |     "uint8": ctypes.c_ubyte, | ||||||
|  |     "uint16": ctypes.c_uint16, | ||||||
|  |     "uint32": ctypes.c_uint32, | ||||||
|  |     "uint64": ctypes.c_uint64, | ||||||
|  |     "int8": ctypes.c_byte, | ||||||
|  |     "int16": ctypes.c_int16, | ||||||
|  |     "int32": ctypes.c_int32, | ||||||
|  |     "int64": ctypes.c_int64, | ||||||
|  |     "float32": ctypes.c_float, | ||||||
|  |     "float64": ctypes.c_double | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: | ||||||
|  |     """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" | ||||||
|  |     type_str = None | ||||||
|  | 
 | ||||||
|  |     if isinstance(type_obj, str): | ||||||
|  |         type_str = type_obj | ||||||
|  |     elif hasattr(type_obj, "__name__"): | ||||||
|  |         type_str = type_obj.__name__ | ||||||
|  |     elif hasattr(type_obj, "name"): | ||||||
|  |         type_str = type_obj.name | ||||||
|  |     else: | ||||||
|  |         raise RuntimeError("Cannot infer type name from input") | ||||||
|  | 
 | ||||||
|  |     assert type_str in _str_to_ctype.keys() | ||||||
|  | 
 | ||||||
|  |     my_dtype = np.dtype(type_str) | ||||||
|  |     my_ctype = _str_to_ctype[type_str] | ||||||
|  | 
 | ||||||
|  |     assert my_dtype.itemsize == ctypes.sizeof(my_ctype) | ||||||
|  | 
 | ||||||
|  |     return my_dtype, my_ctype | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def is_pickleable(obj: Any) -> bool: | ||||||
|  |     try: | ||||||
|  |         with io.BytesIO() as stream: | ||||||
|  |             pickle.dump(obj, stream) | ||||||
|  |         return True | ||||||
|  |     except: | ||||||
|  |         return False | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Functionality to import modules/objects by name, and call functions by name | ||||||
|  | # ------------------------------------------------------------------------------------------ | ||||||
|  | 
 | ||||||
|  | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: | ||||||
|  |     """Searches for the underlying module behind the name to some python object. | ||||||
|  |     Returns the module and the object name (original name with module part removed).""" | ||||||
|  | 
 | ||||||
|  |     # allow convenience shorthands, substitute them by full names | ||||||
|  |     obj_name = re.sub("^np.", "numpy.", obj_name) | ||||||
|  |     obj_name = re.sub("^tf.", "tensorflow.", obj_name) | ||||||
|  | 
 | ||||||
|  |     # list alternatives for (module_name, local_obj_name) | ||||||
|  |     parts = obj_name.split(".") | ||||||
|  |     name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] | ||||||
|  | 
 | ||||||
|  |     # try each alternative in turn | ||||||
|  |     for module_name, local_obj_name in name_pairs: | ||||||
|  |         try: | ||||||
|  |             module = importlib.import_module(module_name) # may raise ImportError | ||||||
|  |             get_obj_from_module(module, local_obj_name) # may raise AttributeError | ||||||
|  |             return module, local_obj_name | ||||||
|  |         except: | ||||||
|  |             pass | ||||||
|  | 
 | ||||||
|  |     # maybe some of the modules themselves contain errors? | ||||||
|  |     for module_name, _local_obj_name in name_pairs: | ||||||
|  |         try: | ||||||
|  |             importlib.import_module(module_name) # may raise ImportError | ||||||
|  |         except ImportError: | ||||||
|  |             if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): | ||||||
|  |                 raise | ||||||
|  | 
 | ||||||
|  |     # maybe the requested attribute is missing? | ||||||
|  |     for module_name, local_obj_name in name_pairs: | ||||||
|  |         try: | ||||||
|  |             module = importlib.import_module(module_name) # may raise ImportError | ||||||
|  |             get_obj_from_module(module, local_obj_name) # may raise AttributeError | ||||||
|  |         except ImportError: | ||||||
|  |             pass | ||||||
|  | 
 | ||||||
|  |     # we are out of luck, but we have no idea why | ||||||
|  |     raise ImportError(obj_name) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: | ||||||
|  |     """Traverses the object name and returns the last (rightmost) python object.""" | ||||||
|  |     if obj_name == '': | ||||||
|  |         return module | ||||||
|  |     obj = module | ||||||
|  |     for part in obj_name.split("."): | ||||||
|  |         obj = getattr(obj, part) | ||||||
|  |     return obj | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_obj_by_name(name: str) -> Any: | ||||||
|  |     """Finds the python object with the given name.""" | ||||||
|  |     module, obj_name = get_module_from_obj_name(name) | ||||||
|  |     return get_obj_from_module(module, obj_name) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: | ||||||
|  |     """Finds the python object with the given name and calls it as a function.""" | ||||||
|  |     assert func_name is not None | ||||||
|  |     func_obj = get_obj_by_name(func_name) | ||||||
|  |     assert callable(func_obj) | ||||||
|  |     return func_obj(*args, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: | ||||||
|  |     """Finds the python class with the given name and constructs it with the given arguments.""" | ||||||
|  |     return call_func_by_name(*args, func_name=class_name, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_module_dir_by_obj_name(obj_name: str) -> str: | ||||||
|  |     """Get the directory path of the module containing the given object name.""" | ||||||
|  |     module, _ = get_module_from_obj_name(obj_name) | ||||||
|  |     return os.path.dirname(inspect.getfile(module)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def is_top_level_function(obj: Any) -> bool: | ||||||
|  |     """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" | ||||||
|  |     return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_top_level_function_name(obj: Any) -> str: | ||||||
|  |     """Return the fully-qualified name of a top-level function.""" | ||||||
|  |     assert is_top_level_function(obj) | ||||||
|  |     module = obj.__module__ | ||||||
|  |     if module == '__main__': | ||||||
|  |         module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] | ||||||
|  |     return module + "." + obj.__name__ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # File system helpers | ||||||
|  | # ------------------------------------------------------------------------------------------ | ||||||
|  | 
 | ||||||
|  | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: | ||||||
|  |     """List all files recursively in a given directory while ignoring given file and directory names. | ||||||
|  |     Returns list of tuples containing both absolute and relative paths.""" | ||||||
|  |     assert os.path.isdir(dir_path) | ||||||
|  |     base_name = os.path.basename(os.path.normpath(dir_path)) | ||||||
|  | 
 | ||||||
|  |     if ignores is None: | ||||||
|  |         ignores = [] | ||||||
|  | 
 | ||||||
|  |     result = [] | ||||||
|  | 
 | ||||||
|  |     for root, dirs, files in os.walk(dir_path, topdown=True): | ||||||
|  |         for ignore_ in ignores: | ||||||
|  |             dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] | ||||||
|  | 
 | ||||||
|  |             # dirs need to be edited in-place | ||||||
|  |             for d in dirs_to_remove: | ||||||
|  |                 dirs.remove(d) | ||||||
|  | 
 | ||||||
|  |             files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] | ||||||
|  | 
 | ||||||
|  |         absolute_paths = [os.path.join(root, f) for f in files] | ||||||
|  |         relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] | ||||||
|  | 
 | ||||||
|  |         if add_base_to_relative: | ||||||
|  |             relative_paths = [os.path.join(base_name, p) for p in relative_paths] | ||||||
|  | 
 | ||||||
|  |         assert len(absolute_paths) == len(relative_paths) | ||||||
|  |         result += zip(absolute_paths, relative_paths) | ||||||
|  | 
 | ||||||
|  |     return result | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: | ||||||
|  |     """Takes in a list of tuples of (src, dst) paths and copies files. | ||||||
|  |     Will create all necessary directories.""" | ||||||
|  |     for file in files: | ||||||
|  |         target_dir_name = os.path.dirname(file[1]) | ||||||
|  | 
 | ||||||
|  |         # will create all intermediate-level directories | ||||||
|  |         if not os.path.exists(target_dir_name): | ||||||
|  |             os.makedirs(target_dir_name) | ||||||
|  | 
 | ||||||
|  |         shutil.copyfile(file[0], file[1]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # URL helpers | ||||||
|  | # ------------------------------------------------------------------------------------------ | ||||||
|  | 
 | ||||||
|  | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: | ||||||
|  |     """Determine whether the given object is a valid URL string.""" | ||||||
|  |     if not isinstance(obj, str) or not "://" in obj: | ||||||
|  |         return False | ||||||
|  |     if allow_file_urls and obj.startswith('file://'): | ||||||
|  |         return True | ||||||
|  |     try: | ||||||
|  |         res = requests.compat.urlparse(obj) | ||||||
|  |         if not res.scheme or not res.netloc or not "." in res.netloc: | ||||||
|  |             return False | ||||||
|  |         res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) | ||||||
|  |         if not res.scheme or not res.netloc or not "." in res.netloc: | ||||||
|  |             return False | ||||||
|  |     except: | ||||||
|  |         return False | ||||||
|  |     return True | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: | ||||||
|  |     """Download the given URL and return a binary-mode file object to access the data.""" | ||||||
|  |     assert num_attempts >= 1 | ||||||
|  |     assert not (return_filename and (not cache)) | ||||||
|  | 
 | ||||||
|  |     # Doesn't look like an URL scheme so interpret it as a local filename. | ||||||
|  |     if not re.match('^[a-z]+://', url): | ||||||
|  |         return url if return_filename else open(url, "rb") | ||||||
|  | 
 | ||||||
|  |     # Handle file URLs.  This code handles unusual file:// patterns that | ||||||
|  |     # arise on Windows: | ||||||
|  |     # | ||||||
|  |     # file:///c:/foo.txt | ||||||
|  |     # | ||||||
|  |     # which would translate to a local '/c:/foo.txt' filename that's | ||||||
|  |     # invalid.  Drop the forward slash for such pathnames. | ||||||
|  |     # | ||||||
|  |     # If you touch this code path, you should test it on both Linux and | ||||||
|  |     # Windows. | ||||||
|  |     # | ||||||
|  |     # Some internet resources suggest using urllib.request.url2pathname() but | ||||||
|  |     # but that converts forward slashes to backslashes and this causes | ||||||
|  |     # its own set of problems. | ||||||
|  |     if url.startswith('file://'): | ||||||
|  |         filename = urllib.parse.urlparse(url).path | ||||||
|  |         if re.match(r'^/[a-zA-Z]:', filename): | ||||||
|  |             filename = filename[1:] | ||||||
|  |         return filename if return_filename else open(filename, "rb") | ||||||
|  | 
 | ||||||
|  |     assert is_url(url) | ||||||
|  | 
 | ||||||
|  |     # Lookup from cache. | ||||||
|  |     if cache_dir is None: | ||||||
|  |         cache_dir = make_cache_dir_path('downloads') | ||||||
|  | 
 | ||||||
|  |     url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() | ||||||
|  |     if cache: | ||||||
|  |         cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) | ||||||
|  |         if len(cache_files) == 1: | ||||||
|  |             filename = cache_files[0] | ||||||
|  |             return filename if return_filename else open(filename, "rb") | ||||||
|  | 
 | ||||||
|  |     # Download. | ||||||
|  |     url_name = None | ||||||
|  |     url_data = None | ||||||
|  |     with requests.Session() as session: | ||||||
|  |         if verbose: | ||||||
|  |             print("Downloading %s ..." % url, end="", flush=True) | ||||||
|  |         for attempts_left in reversed(range(num_attempts)): | ||||||
|  |             try: | ||||||
|  |                 with session.get(url) as res: | ||||||
|  |                     res.raise_for_status() | ||||||
|  |                     if len(res.content) == 0: | ||||||
|  |                         raise IOError("No data received") | ||||||
|  | 
 | ||||||
|  |                     if len(res.content) < 8192: | ||||||
|  |                         content_str = res.content.decode("utf-8") | ||||||
|  |                         if "download_warning" in res.headers.get("Set-Cookie", ""): | ||||||
|  |                             links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] | ||||||
|  |                             if len(links) == 1: | ||||||
|  |                                 url = requests.compat.urljoin(url, links[0]) | ||||||
|  |                                 raise IOError("Google Drive virus checker nag") | ||||||
|  |                         if "Google Drive - Quota exceeded" in content_str: | ||||||
|  |                             raise IOError("Google Drive download quota exceeded -- please try again later") | ||||||
|  | 
 | ||||||
|  |                     match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) | ||||||
|  |                     url_name = match[1] if match else url | ||||||
|  |                     url_data = res.content | ||||||
|  |                     if verbose: | ||||||
|  |                         print(" done") | ||||||
|  |                     break | ||||||
|  |             except KeyboardInterrupt: | ||||||
|  |                 raise | ||||||
|  |             except: | ||||||
|  |                 if not attempts_left: | ||||||
|  |                     if verbose: | ||||||
|  |                         print(" failed") | ||||||
|  |                     raise | ||||||
|  |                 if verbose: | ||||||
|  |                     print(".", end="", flush=True) | ||||||
|  | 
 | ||||||
|  |     # Save to cache. | ||||||
|  |     if cache: | ||||||
|  |         safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) | ||||||
|  |         cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) | ||||||
|  |         temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) | ||||||
|  |         os.makedirs(cache_dir, exist_ok=True) | ||||||
|  |         with open(temp_file, "wb") as f: | ||||||
|  |             f.write(url_data) | ||||||
|  |         os.replace(temp_file, cache_file) # atomic | ||||||
|  |         if return_filename: | ||||||
|  |             return cache_file | ||||||
|  | 
 | ||||||
|  |     # Return data as file object. | ||||||
|  |     assert not return_filename | ||||||
|  |     return io.BytesIO(url_data) | ||||||
							
								
								
									
										99
									
								
								global_torch/html/[6]_501_c.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								global_torch/html/[6]_501_c.html
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										223
									
								
								global_torch/html/real_.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										223
									
								
								global_torch/html/real_.html
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										326
									
								
								global_torch/legacy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										326
									
								
								global_torch/legacy.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,326 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | import click | ||||||
|  | import pickle | ||||||
|  | import re | ||||||
|  | import copy | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import dnnlib | ||||||
|  | from torch_utils import misc | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def load_network_pkl(f, force_fp16=False): | ||||||
|  |     data = _LegacyUnpickler(f).load() | ||||||
|  | 
 | ||||||
|  |     # Legacy TensorFlow pickle => convert. | ||||||
|  |     if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): | ||||||
|  |         tf_G, tf_D, tf_Gs = data | ||||||
|  |         G = convert_tf_generator(tf_G) | ||||||
|  |         D = convert_tf_discriminator(tf_D) | ||||||
|  |         G_ema = convert_tf_generator(tf_Gs) | ||||||
|  |         data = dict(G=G, D=D, G_ema=G_ema) | ||||||
|  | 
 | ||||||
|  |     # Add missing fields. | ||||||
|  |     if 'training_set_kwargs' not in data: | ||||||
|  |         data['training_set_kwargs'] = None | ||||||
|  |     if 'augment_pipe' not in data: | ||||||
|  |         data['augment_pipe'] = None | ||||||
|  | 
 | ||||||
|  |     # Validate contents. | ||||||
|  |     assert isinstance(data['G'], torch.nn.Module) | ||||||
|  |     assert isinstance(data['D'], torch.nn.Module) | ||||||
|  |     assert isinstance(data['G_ema'], torch.nn.Module) | ||||||
|  |     assert isinstance(data['training_set_kwargs'], (dict, type(None))) | ||||||
|  |     assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) | ||||||
|  | 
 | ||||||
|  |     # Force FP16. | ||||||
|  |     if force_fp16: | ||||||
|  |         for key in ['G', 'D', 'G_ema']: | ||||||
|  |             old = data[key] | ||||||
|  |             kwargs = copy.deepcopy(old.init_kwargs) | ||||||
|  |             if key.startswith('G'): | ||||||
|  |                 kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {})) | ||||||
|  |                 kwargs.synthesis_kwargs.num_fp16_res = 4 | ||||||
|  |                 kwargs.synthesis_kwargs.conv_clamp = 256 | ||||||
|  |             if key.startswith('D'): | ||||||
|  |                 kwargs.num_fp16_res = 4 | ||||||
|  |                 kwargs.conv_clamp = 256 | ||||||
|  |             if kwargs != old.init_kwargs: | ||||||
|  |                 new = type(old)(**kwargs).eval().requires_grad_(False) | ||||||
|  |                 misc.copy_params_and_buffers(old, new, require_all=True) | ||||||
|  |                 data[key] = new | ||||||
|  |     return data | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | class _TFNetworkStub(dnnlib.EasyDict): | ||||||
|  |     pass | ||||||
|  | 
 | ||||||
|  | class _LegacyUnpickler(pickle.Unpickler): | ||||||
|  |     def find_class(self, module, name): | ||||||
|  |         if module == 'dnnlib.tflib.network' and name == 'Network': | ||||||
|  |             return _TFNetworkStub | ||||||
|  |         return super().find_class(module, name) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _collect_tf_params(tf_net): | ||||||
|  |     # pylint: disable=protected-access | ||||||
|  |     tf_params = dict() | ||||||
|  |     def recurse(prefix, tf_net): | ||||||
|  |         for name, value in tf_net.variables: | ||||||
|  |             tf_params[prefix + name] = value | ||||||
|  |         for name, comp in tf_net.components.items(): | ||||||
|  |             recurse(prefix + name + '/', comp) | ||||||
|  |     recurse('', tf_net) | ||||||
|  |     return tf_params | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _populate_module_params(module, *patterns): | ||||||
|  |     for name, tensor in misc.named_params_and_buffers(module): | ||||||
|  |         found = False | ||||||
|  |         value = None | ||||||
|  |         for pattern, value_fn in zip(patterns[0::2], patterns[1::2]): | ||||||
|  |             match = re.fullmatch(pattern, name) | ||||||
|  |             if match: | ||||||
|  |                 found = True | ||||||
|  |                 if value_fn is not None: | ||||||
|  |                     value = value_fn(*match.groups()) | ||||||
|  |                 break | ||||||
|  |         try: | ||||||
|  |             assert found | ||||||
|  |             if value is not None: | ||||||
|  |                 tensor.copy_(torch.from_numpy(np.array(value))) | ||||||
|  |         except: | ||||||
|  |             print(name, list(tensor.shape)) | ||||||
|  |             raise | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def convert_tf_generator(tf_G): | ||||||
|  |     if tf_G.version < 4: | ||||||
|  |         raise ValueError('TensorFlow pickle version too low') | ||||||
|  | 
 | ||||||
|  |     # Collect kwargs. | ||||||
|  |     tf_kwargs = tf_G.static_kwargs | ||||||
|  |     known_kwargs = set() | ||||||
|  |     def kwarg(tf_name, default=None, none=None): | ||||||
|  |         known_kwargs.add(tf_name) | ||||||
|  |         val = tf_kwargs.get(tf_name, default) | ||||||
|  |         return val if val is not None else none | ||||||
|  | 
 | ||||||
|  |     # Convert kwargs. | ||||||
|  |     kwargs = dnnlib.EasyDict( | ||||||
|  |         z_dim                   = kwarg('latent_size',          512), | ||||||
|  |         c_dim                   = kwarg('label_size',           0), | ||||||
|  |         w_dim                   = kwarg('dlatent_size',         512), | ||||||
|  |         img_resolution          = kwarg('resolution',           1024), | ||||||
|  |         img_channels            = kwarg('num_channels',         3), | ||||||
|  |         mapping_kwargs = dnnlib.EasyDict( | ||||||
|  |             num_layers          = kwarg('mapping_layers',       8), | ||||||
|  |             embed_features      = kwarg('label_fmaps',          None), | ||||||
|  |             layer_features      = kwarg('mapping_fmaps',        None), | ||||||
|  |             activation          = kwarg('mapping_nonlinearity', 'lrelu'), | ||||||
|  |             lr_multiplier       = kwarg('mapping_lrmul',        0.01), | ||||||
|  |             w_avg_beta          = kwarg('w_avg_beta',           0.995,  none=1), | ||||||
|  |         ), | ||||||
|  |         synthesis_kwargs = dnnlib.EasyDict( | ||||||
|  |             channel_base        = kwarg('fmap_base',            16384) * 2, | ||||||
|  |             channel_max         = kwarg('fmap_max',             512), | ||||||
|  |             num_fp16_res        = kwarg('num_fp16_res',         0), | ||||||
|  |             conv_clamp          = kwarg('conv_clamp',           None), | ||||||
|  |             architecture        = kwarg('architecture',         'skip'), | ||||||
|  |             resample_filter     = kwarg('resample_kernel',      [1,3,3,1]), | ||||||
|  |             use_noise           = kwarg('use_noise',            True), | ||||||
|  |             activation          = kwarg('nonlinearity',         'lrelu'), | ||||||
|  |         ), | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # Check for unknown kwargs. | ||||||
|  |     kwarg('truncation_psi') | ||||||
|  |     kwarg('truncation_cutoff') | ||||||
|  |     kwarg('style_mixing_prob') | ||||||
|  |     kwarg('structure') | ||||||
|  |     if 'resolution_w' in tf_kwargs: | ||||||
|  |         tf_kwargs.pop('resolution_w', None) | ||||||
|  |         tf_kwargs.pop('resolution_h', None) | ||||||
|  |     unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) | ||||||
|  |     if len(unknown_kwargs) > 0: | ||||||
|  |         raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) | ||||||
|  | 
 | ||||||
|  |     # Collect params. | ||||||
|  |     tf_params = _collect_tf_params(tf_G) | ||||||
|  |     for name, value in list(tf_params.items()): | ||||||
|  |         match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name) | ||||||
|  |         if match: | ||||||
|  |             r = kwargs.img_resolution // (2 ** int(match.group(1))) | ||||||
|  |             tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value | ||||||
|  |             kwargs.synthesis.kwargs.architecture = 'orig' | ||||||
|  |     #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') | ||||||
|  | 
 | ||||||
|  |     # Convert params. | ||||||
|  |     from training import networks | ||||||
|  |     G = networks.Generator(**kwargs).eval().requires_grad_(False) | ||||||
|  |     # pylint: disable=unnecessary-lambda | ||||||
|  |     _populate_module_params(G, | ||||||
|  |         r'mapping\.w_avg',                                  lambda:     tf_params[f'dlatent_avg'], | ||||||
|  |         r'mapping\.embed\.weight',                          lambda:     tf_params[f'mapping/LabelEmbed/weight'].transpose(), | ||||||
|  |         r'mapping\.embed\.bias',                            lambda:     tf_params[f'mapping/LabelEmbed/bias'], | ||||||
|  |         r'mapping\.fc(\d+)\.weight',                        lambda i:   tf_params[f'mapping/Dense{i}/weight'].transpose(), | ||||||
|  |         r'mapping\.fc(\d+)\.bias',                          lambda i:   tf_params[f'mapping/Dense{i}/bias'], | ||||||
|  |         r'synthesis\.b4\.const',                            lambda:     tf_params[f'synthesis/4x4/Const/const'][0], | ||||||
|  |         r'synthesis\.b4\.conv1\.weight',                    lambda:     tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1), | ||||||
|  |         r'synthesis\.b4\.conv1\.bias',                      lambda:     tf_params[f'synthesis/4x4/Conv/bias'], | ||||||
|  |         r'synthesis\.b4\.conv1\.noise_const',               lambda:     tf_params[f'synthesis/noise0'][0, 0], | ||||||
|  |         r'synthesis\.b4\.conv1\.noise_strength',            lambda:     tf_params[f'synthesis/4x4/Conv/noise_strength'], | ||||||
|  |         r'synthesis\.b4\.conv1\.affine\.weight',            lambda:     tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(), | ||||||
|  |         r'synthesis\.b4\.conv1\.affine\.bias',              lambda:     tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1, | ||||||
|  |         r'synthesis\.b(\d+)\.conv0\.weight',                lambda r:   tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1), | ||||||
|  |         r'synthesis\.b(\d+)\.conv0\.bias',                  lambda r:   tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'], | ||||||
|  |         r'synthesis\.b(\d+)\.conv0\.noise_const',           lambda r:   tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0], | ||||||
|  |         r'synthesis\.b(\d+)\.conv0\.noise_strength',        lambda r:   tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'], | ||||||
|  |         r'synthesis\.b(\d+)\.conv0\.affine\.weight',        lambda r:   tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(), | ||||||
|  |         r'synthesis\.b(\d+)\.conv0\.affine\.bias',          lambda r:   tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1, | ||||||
|  |         r'synthesis\.b(\d+)\.conv1\.weight',                lambda r:   tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1), | ||||||
|  |         r'synthesis\.b(\d+)\.conv1\.bias',                  lambda r:   tf_params[f'synthesis/{r}x{r}/Conv1/bias'], | ||||||
|  |         r'synthesis\.b(\d+)\.conv1\.noise_const',           lambda r:   tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0], | ||||||
|  |         r'synthesis\.b(\d+)\.conv1\.noise_strength',        lambda r:   tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'], | ||||||
|  |         r'synthesis\.b(\d+)\.conv1\.affine\.weight',        lambda r:   tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(), | ||||||
|  |         r'synthesis\.b(\d+)\.conv1\.affine\.bias',          lambda r:   tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1, | ||||||
|  |         r'synthesis\.b(\d+)\.torgb\.weight',                lambda r:   tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1), | ||||||
|  |         r'synthesis\.b(\d+)\.torgb\.bias',                  lambda r:   tf_params[f'synthesis/{r}x{r}/ToRGB/bias'], | ||||||
|  |         r'synthesis\.b(\d+)\.torgb\.affine\.weight',        lambda r:   tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(), | ||||||
|  |         r'synthesis\.b(\d+)\.torgb\.affine\.bias',          lambda r:   tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1, | ||||||
|  |         r'synthesis\.b(\d+)\.skip\.weight',                 lambda r:   tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1), | ||||||
|  |         r'.*\.resample_filter',                             None, | ||||||
|  |     ) | ||||||
|  |     return G | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def convert_tf_discriminator(tf_D): | ||||||
|  |     if tf_D.version < 4: | ||||||
|  |         raise ValueError('TensorFlow pickle version too low') | ||||||
|  | 
 | ||||||
|  |     # Collect kwargs. | ||||||
|  |     tf_kwargs = tf_D.static_kwargs | ||||||
|  |     known_kwargs = set() | ||||||
|  |     def kwarg(tf_name, default=None): | ||||||
|  |         known_kwargs.add(tf_name) | ||||||
|  |         return tf_kwargs.get(tf_name, default) | ||||||
|  | 
 | ||||||
|  |     # Convert kwargs. | ||||||
|  |     kwargs = dnnlib.EasyDict( | ||||||
|  |         c_dim                   = kwarg('label_size',           0), | ||||||
|  |         img_resolution          = kwarg('resolution',           1024), | ||||||
|  |         img_channels            = kwarg('num_channels',         3), | ||||||
|  |         architecture            = kwarg('architecture',         'resnet'), | ||||||
|  |         channel_base            = kwarg('fmap_base',            16384) * 2, | ||||||
|  |         channel_max             = kwarg('fmap_max',             512), | ||||||
|  |         num_fp16_res            = kwarg('num_fp16_res',         0), | ||||||
|  |         conv_clamp              = kwarg('conv_clamp',           None), | ||||||
|  |         cmap_dim                = kwarg('mapping_fmaps',        None), | ||||||
|  |         block_kwargs = dnnlib.EasyDict( | ||||||
|  |             activation          = kwarg('nonlinearity',         'lrelu'), | ||||||
|  |             resample_filter     = kwarg('resample_kernel',      [1,3,3,1]), | ||||||
|  |             freeze_layers       = kwarg('freeze_layers',        0), | ||||||
|  |         ), | ||||||
|  |         mapping_kwargs = dnnlib.EasyDict( | ||||||
|  |             num_layers          = kwarg('mapping_layers',       0), | ||||||
|  |             embed_features      = kwarg('mapping_fmaps',        None), | ||||||
|  |             layer_features      = kwarg('mapping_fmaps',        None), | ||||||
|  |             activation          = kwarg('nonlinearity',         'lrelu'), | ||||||
|  |             lr_multiplier       = kwarg('mapping_lrmul',        0.1), | ||||||
|  |         ), | ||||||
|  |         epilogue_kwargs = dnnlib.EasyDict( | ||||||
|  |             mbstd_group_size    = kwarg('mbstd_group_size',     None), | ||||||
|  |             mbstd_num_channels  = kwarg('mbstd_num_features',   1), | ||||||
|  |             activation          = kwarg('nonlinearity',         'lrelu'), | ||||||
|  |         ), | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # Check for unknown kwargs. | ||||||
|  |     kwarg('structure') | ||||||
|  |     if 'resolution_w' in tf_kwargs: | ||||||
|  |         tf_kwargs.pop('resolution_w', None) | ||||||
|  |         tf_kwargs.pop('resolution_h', None) | ||||||
|  |     unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) | ||||||
|  |     if len(unknown_kwargs) > 0: | ||||||
|  |         raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) | ||||||
|  | 
 | ||||||
|  |     # Collect params. | ||||||
|  |     tf_params = _collect_tf_params(tf_D) | ||||||
|  |     for name, value in list(tf_params.items()): | ||||||
|  |         match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name) | ||||||
|  |         if match: | ||||||
|  |             r = kwargs.img_resolution // (2 ** int(match.group(1))) | ||||||
|  |             tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value | ||||||
|  |             kwargs.architecture = 'orig' | ||||||
|  |     #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') | ||||||
|  | 
 | ||||||
|  |     # Convert params. | ||||||
|  |     from training import networks | ||||||
|  |     D = networks.Discriminator(**kwargs).eval().requires_grad_(False) | ||||||
|  |     # pylint: disable=unnecessary-lambda | ||||||
|  |     _populate_module_params(D, | ||||||
|  |         r'b(\d+)\.fromrgb\.weight',     lambda r:       tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1), | ||||||
|  |         r'b(\d+)\.fromrgb\.bias',       lambda r:       tf_params[f'{r}x{r}/FromRGB/bias'], | ||||||
|  |         r'b(\d+)\.conv(\d+)\.weight',   lambda r, i:    tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1), | ||||||
|  |         r'b(\d+)\.conv(\d+)\.bias',     lambda r, i:    tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'], | ||||||
|  |         r'b(\d+)\.skip\.weight',        lambda r:       tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1), | ||||||
|  |         r'mapping\.embed\.weight',      lambda:         tf_params[f'LabelEmbed/weight'].transpose(), | ||||||
|  |         r'mapping\.embed\.bias',        lambda:         tf_params[f'LabelEmbed/bias'], | ||||||
|  |         r'mapping\.fc(\d+)\.weight',    lambda i:       tf_params[f'Mapping{i}/weight'].transpose(), | ||||||
|  |         r'mapping\.fc(\d+)\.bias',      lambda i:       tf_params[f'Mapping{i}/bias'], | ||||||
|  |         r'b4\.conv\.weight',            lambda:         tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1), | ||||||
|  |         r'b4\.conv\.bias',              lambda:         tf_params[f'4x4/Conv/bias'], | ||||||
|  |         r'b4\.fc\.weight',              lambda:         tf_params[f'4x4/Dense0/weight'].transpose(), | ||||||
|  |         r'b4\.fc\.bias',                lambda:         tf_params[f'4x4/Dense0/bias'], | ||||||
|  |         r'b4\.out\.weight',             lambda:         tf_params[f'Output/weight'].transpose(), | ||||||
|  |         r'b4\.out\.bias',               lambda:         tf_params[f'Output/bias'], | ||||||
|  |         r'.*\.resample_filter',         None, | ||||||
|  |     ) | ||||||
|  |     return D | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @click.command() | ||||||
|  | @click.option('--source', help='Input pickle', required=True, metavar='PATH') | ||||||
|  | @click.option('--dest', help='Output pickle', required=True, metavar='PATH') | ||||||
|  | @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True) | ||||||
|  | def convert_network_pickle(source, dest, force_fp16): | ||||||
|  |     """Convert legacy network pickle into the native PyTorch format. | ||||||
|  | 
 | ||||||
|  |     The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA. | ||||||
|  |     It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks. | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  | 
 | ||||||
|  |     \b | ||||||
|  |     python legacy.py \\ | ||||||
|  |         --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\ | ||||||
|  |         --dest=stylegan2-cat-config-f.pkl | ||||||
|  |     """ | ||||||
|  |     print(f'Loading "{source}"...') | ||||||
|  |     with dnnlib.util.open_url(source) as f: | ||||||
|  |         data = load_network_pkl(f, force_fp16=force_fp16) | ||||||
|  |     print(f'Saving "{dest}"...') | ||||||
|  |     with open(dest, 'wb') as f: | ||||||
|  |         pickle.dump(data, f) | ||||||
|  |     print('Done.') | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     convert_network_pickle() # pylint: disable=no-value-for-parameter | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										383
									
								
								global_torch/manipulate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										383
									
								
								global_torch/manipulate.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,383 @@ | |||||||
|  | #!/usr/bin/env python3 | ||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | """ | ||||||
|  | Created on Mon Jul 19 21:03:58 2021 | ||||||
|  | 
 | ||||||
|  | @author: wuzongze | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | import sys | ||||||
|  | 
 | ||||||
|  | import copy | ||||||
|  | import os | ||||||
|  | from time import perf_counter | ||||||
|  | 
 | ||||||
|  | import click | ||||||
|  | import imageio | ||||||
|  | import numpy as np | ||||||
|  | import PIL.Image | ||||||
|  | import torch | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from PIL import Image | ||||||
|  | 
 | ||||||
|  | import dnnlib | ||||||
|  | import legacy | ||||||
|  | import pickle | ||||||
|  | from visualizer import HtmlPageVisualizer | ||||||
|  | 
 | ||||||
|  | from torch_utils import misc | ||||||
|  | import types | ||||||
|  | from training.networks import SynthesisNetwork,SynthesisBlock,SynthesisLayer,ToRGBLayer | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def change_style_code(codes, layer, channel, step): | ||||||
|  |     codes[layer][:, channel] += step | ||||||
|  |     return codes | ||||||
|  | 
 | ||||||
|  | def Vis(bname,suffix,out,rownames=None,colnames=None,save_path=None,viz_size=256): | ||||||
|  |      | ||||||
|  |     if save_path is None: | ||||||
|  |         save_path='./html/' | ||||||
|  |      | ||||||
|  |      | ||||||
|  |     num_images=out.shape[0] | ||||||
|  |     step=out.shape[1] | ||||||
|  |      | ||||||
|  |     if colnames is None: | ||||||
|  |         colnames=[f'Step {i:02d}' for i in range(1, step + 1)] | ||||||
|  |     if rownames is None: | ||||||
|  |         rownames=[str(i) for i in range(num_images)] | ||||||
|  |      | ||||||
|  |      | ||||||
|  |     visualizer = HtmlPageVisualizer( | ||||||
|  |       num_rows=num_images, num_cols=step + 1, viz_size=viz_size) | ||||||
|  |     visualizer.set_headers( | ||||||
|  |       ['Name'] +colnames) | ||||||
|  |      | ||||||
|  |     for i in range(num_images): | ||||||
|  |         visualizer.set_cell(i, 0, text=rownames[i]) | ||||||
|  |      | ||||||
|  |     for i in range(num_images): | ||||||
|  |         for k in range(step): | ||||||
|  |             image=out[i,k,:,:,:] | ||||||
|  |             visualizer.set_cell(i, 1+k, image=image) | ||||||
|  |      | ||||||
|  |     visualizer.save(save_path+bname+'_'+suffix+'.html') | ||||||
|  | 
 | ||||||
|  | def LoadModel(network_pkl,device): | ||||||
|  |     with dnnlib.util.open_url(network_pkl) as fp: | ||||||
|  |         G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore | ||||||
|  |      | ||||||
|  |     G.synthesis.forward=types.MethodType(SynthesisNetwork.forward,G.synthesis) | ||||||
|  |     G.synthesis.W2S=types.MethodType(SynthesisNetwork.W2S,G.synthesis) | ||||||
|  |      | ||||||
|  |     for res in G.synthesis.block_resolutions: | ||||||
|  |         block = getattr(G.synthesis, f'b{res}') | ||||||
|  |         # print(block) | ||||||
|  |         block.forward=types.MethodType(SynthesisBlock.forward,block) | ||||||
|  |          | ||||||
|  |         if res!=4: | ||||||
|  |             layer=block.conv0 | ||||||
|  |             layer.forward=types.MethodType(SynthesisLayer.forward,layer) | ||||||
|  |             layer.name='conv0_resolution_'+str(res) | ||||||
|  |              | ||||||
|  |         layer=block.conv1 | ||||||
|  |         layer.forward=types.MethodType(SynthesisLayer.forward,layer) | ||||||
|  |         layer.name='conv1_resolution_'+str(res) | ||||||
|  |          | ||||||
|  |         layer=block.torgb | ||||||
|  |         layer.forward=types.MethodType(ToRGBLayer.forward,layer) | ||||||
|  |         layer.name='toRGB_resolution_'+str(res) | ||||||
|  |          | ||||||
|  |      | ||||||
|  |     return G | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def S2List(encoded_styles): | ||||||
|  |     all_s=[] | ||||||
|  |     for name in encoded_styles.keys(): | ||||||
|  |         tmp=encoded_styles[name].cpu().numpy() | ||||||
|  |         all_s.append(tmp) | ||||||
|  |     return all_s | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Manipulator(): | ||||||
|  |     def __init__(self,dataset_name='ffhq'): | ||||||
|  |          | ||||||
|  |         self.alpha=[0] #manipulation strength  | ||||||
|  |         self.num_images=10 | ||||||
|  |         self.img_index=0  #which image to start  | ||||||
|  |         # self.viz_size=256 | ||||||
|  |         self.manipulate_layers=None #which layer to manipulate, list | ||||||
|  |         self.truncation_psi=0.7 | ||||||
|  |         self.truncation_cutoff=8 | ||||||
|  |          | ||||||
|  | #        self.G=LoadModel(self.model_path,self.model_name) | ||||||
|  |          | ||||||
|  |         self.LoadModel=LoadModel | ||||||
|  |         self.Vis=Vis | ||||||
|  |         self.S2List=S2List | ||||||
|  |          | ||||||
|  |         fmaps=[512, 512, 512, 512, 512, 256, 128,  64, 32] | ||||||
|  |         self.fmaps=np.repeat(fmaps,3) | ||||||
|  |          | ||||||
|  |      | ||||||
|  |     def GetSName(self): | ||||||
|  |         s_names=[] | ||||||
|  |         for res in self.G.synthesis.block_resolutions: | ||||||
|  |             if res==4:  | ||||||
|  |                 tmp=f'conv1_resolution_{res}' | ||||||
|  |                 s_names.append(tmp) | ||||||
|  |                  | ||||||
|  |                 tmp=f'toRGB_resolution_{res}' | ||||||
|  |                 s_names.append(tmp) | ||||||
|  |             else: | ||||||
|  |                 tmp=f'conv0_resolution_{res}' | ||||||
|  |                 s_names.append(tmp) | ||||||
|  |                  | ||||||
|  |                 tmp=f'conv1_resolution_{res}' | ||||||
|  |                 s_names.append(tmp) | ||||||
|  |                  | ||||||
|  |                 tmp=f'toRGB_resolution_{res}' | ||||||
|  |                 s_names.append(tmp) | ||||||
|  |                  | ||||||
|  |         return s_names | ||||||
|  |      | ||||||
|  |     def SL2D(self,tmp_code): | ||||||
|  |         encoded_styles={} | ||||||
|  |         for i in range(len(self.s_names)): | ||||||
|  |             encoded_styles[self.s_names[i]]=torch.from_numpy(tmp_code[i]).to(self.device) | ||||||
|  |          | ||||||
|  |         return encoded_styles | ||||||
|  |          | ||||||
|  |      | ||||||
|  |      | ||||||
|  |     def GenerateS(self,num_img=100): | ||||||
|  |         seed=5 | ||||||
|  |         with torch.no_grad():  | ||||||
|  |             z = torch.from_numpy(np.random.RandomState(seed).randn(num_img, self.G.z_dim)).to(self.device) | ||||||
|  |             ws = self.G.mapping(z=z,c=None,truncation_psi=self.truncation_psi,truncation_cutoff=self.truncation_cutoff) | ||||||
|  |             encoded_styles=self.G.synthesis.W2S(ws) | ||||||
|  | #            encoded_styles=encoded_styles.cpu().numpy() | ||||||
|  |              | ||||||
|  |         self.dlatents=S2List(encoded_styles) | ||||||
|  |      | ||||||
|  |     def GenerateImg(self,codes): | ||||||
|  |          | ||||||
|  |         num_images,step=codes[0].shape[:2] | ||||||
|  |         out=np.zeros((num_images,step,self.img_size,self.img_size,3),dtype='uint8') | ||||||
|  |         for i in range(num_images): | ||||||
|  |             for k in range(step): | ||||||
|  |                  | ||||||
|  |                 tmp_code=[] | ||||||
|  |                 for m in range(len(self.s_names)): | ||||||
|  |                     tmp=codes[m][i,k][None,:] | ||||||
|  |                     tmp_code.append(tmp) | ||||||
|  |                      | ||||||
|  |                 encoded_styles=self.SL2D(tmp_code) | ||||||
|  |                  | ||||||
|  |                 with torch.no_grad():  | ||||||
|  |                     img = self.G.synthesis(None, encoded_styles=encoded_styles,noise_mode='const') | ||||||
|  |                     img = (img + 1) * (255/2) | ||||||
|  |                     img = img.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() | ||||||
|  |                  | ||||||
|  |                  | ||||||
|  |                  | ||||||
|  |                 if img.shape[1]==img.shape[0]: | ||||||
|  |                     out[i,k,:,:,:]=img | ||||||
|  |                 else: | ||||||
|  |                     tmp=img.shape[1] | ||||||
|  |                     tmp1=int((img.shape[0]-tmp)/2) | ||||||
|  |                     out[i,k,:,tmp1:tmp1+tmp,:]=img | ||||||
|  |         return out | ||||||
|  |      | ||||||
|  |     def ShowImg(self,num_img=10): | ||||||
|  |          | ||||||
|  |         codes=[] | ||||||
|  |         for i in range(len(self.dlatents)): | ||||||
|  |             # print(i) | ||||||
|  |             tmp=self.dlatents[i][:num_img,None,:] | ||||||
|  |             codes.append(tmp) | ||||||
|  |         out=self.GenerateImg(codes) | ||||||
|  |         return out | ||||||
|  |      | ||||||
|  |     def SetGParameters(self): | ||||||
|  |         self.num_layers=self.G.synthesis.num_ws | ||||||
|  |         self.img_size=self.G.synthesis.img_resolution | ||||||
|  |         self.s_names=self.GetSName() | ||||||
|  |          | ||||||
|  |         self.img_size=self.G.synthesis.block_resolutions[-1] | ||||||
|  |          | ||||||
|  |         self.mindexs=[0, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 21,23,24] | ||||||
|  |          | ||||||
|  |          | ||||||
|  |      | ||||||
|  |     def MSCode(self,dlatent_tmp,boundary_tmp): | ||||||
|  |          | ||||||
|  |         step=len(self.alpha) | ||||||
|  |         dlatent_tmp1=[tmp.reshape((self.num_images,-1)) for tmp in dlatent_tmp] | ||||||
|  |         dlatent_tmp2=[np.tile(tmp[:,None],(1,step,1)) for tmp in dlatent_tmp1] # (10, 7, 512) | ||||||
|  | 
 | ||||||
|  |         l=np.array(self.alpha) | ||||||
|  |         l=l.reshape( | ||||||
|  |                     [step if axis == 1 else 1 for axis in range(dlatent_tmp2[0].ndim)]) | ||||||
|  |          | ||||||
|  |         if type(self.manipulate_layers)==int: | ||||||
|  |             tmp=[self.manipulate_layers] | ||||||
|  |         elif type(self.manipulate_layers)==list: | ||||||
|  |             tmp=self.manipulate_layers | ||||||
|  |         elif self.manipulate_layers is None: | ||||||
|  |             tmp=np.arange(len(boundary_tmp)) | ||||||
|  |         else: | ||||||
|  |             raise ValueError('manipulate_layers is wrong') | ||||||
|  |              | ||||||
|  |         for i in tmp: | ||||||
|  |             dlatent_tmp2[i]+=l*boundary_tmp[i] | ||||||
|  |          | ||||||
|  |         codes=[] | ||||||
|  |         for i in range(len(dlatent_tmp2)): | ||||||
|  |             tmp=list(dlatent_tmp[i].shape) | ||||||
|  |             tmp.insert(1,step) | ||||||
|  |             codes.append(dlatent_tmp2[i].reshape(tmp)) | ||||||
|  |         return codes | ||||||
|  |      | ||||||
|  |      | ||||||
|  |     def EditOne(self,bname,dlatent_tmp=None): | ||||||
|  |         if dlatent_tmp==None: | ||||||
|  |             dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents] | ||||||
|  |          | ||||||
|  |         boundary_tmp=[] | ||||||
|  |         for i in range(len(self.boundary)): | ||||||
|  |             tmp=self.boundary[i] | ||||||
|  |             if len(tmp)<=bname: | ||||||
|  |                 boundary_tmp.append([]) | ||||||
|  |             else: | ||||||
|  |                 boundary_tmp.append(tmp[bname]) | ||||||
|  |          | ||||||
|  |         codes=self.MSCode(dlatent_tmp,boundary_tmp) | ||||||
|  |              | ||||||
|  |         out=self.GenerateImg(codes) | ||||||
|  |         return codes,out | ||||||
|  |      | ||||||
|  |     def EditOneC(self,cindex,dlatent_tmp=None):  | ||||||
|  |         if dlatent_tmp==None: | ||||||
|  |             dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents] | ||||||
|  |          | ||||||
|  |         boundary_tmp=[[] for i in range(len(self.dlatents))] | ||||||
|  |          | ||||||
|  |         #'only manipulate 1 layer and one channel' | ||||||
|  |         assert len(self.manipulate_layers)==1  | ||||||
|  |          | ||||||
|  |         ml=self.manipulate_layers[0] | ||||||
|  |         tmp=dlatent_tmp[ml].shape[1] #ada | ||||||
|  |         tmp1=np.zeros(tmp) | ||||||
|  |         tmp1[cindex]=self.code_std[ml][cindex]  #1 | ||||||
|  |         boundary_tmp[ml]=tmp1 | ||||||
|  |          | ||||||
|  |         codes=self.MSCode(dlatent_tmp,boundary_tmp) | ||||||
|  |         out=self.GenerateImg(codes) | ||||||
|  |         return codes,out | ||||||
|  |      | ||||||
|  |     def GetFindex(self,lindex,cindex,ignore_RGB=False): | ||||||
|  |          | ||||||
|  |         if ignore_RGB: | ||||||
|  |             tmp=np.array(self.mindexs)<lindex | ||||||
|  |             tmp=np.sum(tmp) | ||||||
|  |         else: | ||||||
|  |             tmp=lindex | ||||||
|  |         findex=np.sum(self.fmaps[:tmp])+cindex | ||||||
|  |         return findex  | ||||||
|  |      | ||||||
|  |     def GetLCIndex(self,findex): | ||||||
|  |         l_p=[] | ||||||
|  |         cfmaps=np.cumsum(self.fmaps) | ||||||
|  |         for i in range(len(findex)): | ||||||
|  |             #    i=-2 | ||||||
|  |             tmp_index=findex[i] | ||||||
|  |         #    importance_matrix.max(axis=0) | ||||||
|  |         #    self.attrib_indices2 | ||||||
|  |             tmp=tmp_index-cfmaps | ||||||
|  |             tmp=tmp[tmp>0] | ||||||
|  |             lindex=len(tmp) | ||||||
|  |             if lindex==0: | ||||||
|  |                 cindex=tmp_index | ||||||
|  |             else: | ||||||
|  |                 cindex=tmp[-1] | ||||||
|  |              | ||||||
|  |             if cindex ==self.fmaps[lindex]: | ||||||
|  |                 cindex=0 | ||||||
|  |                 lindex+=1 | ||||||
|  |     #        print(completeness.index[i],completeness.iloc[i,:].values,lindex,cindex) | ||||||
|  |             l_p.append([lindex,cindex]) | ||||||
|  |         l_p=np.array(l_p) | ||||||
|  |         return l_p | ||||||
|  |     def GetLCIndex2(self,findex): #input findex without ToRGB | ||||||
|  |         fmaps_o=copy.copy(self.fmaps) | ||||||
|  |         mindexs=[0, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 21,23,24] | ||||||
|  |         self.fmaps=fmaps_o[mindexs] | ||||||
|  |          | ||||||
|  |         l_p=self.GetLCIndex(findex) | ||||||
|  |          | ||||||
|  |         l=l_p[:,0] | ||||||
|  |         l2=np.array(mindexs)[l] | ||||||
|  |         l_p[:,0]=l2 | ||||||
|  |         self.fmaps=fmaps_o | ||||||
|  |         return l_p | ||||||
|  |      | ||||||
|  |     def GetCodeMS(self): | ||||||
|  |         m=[] | ||||||
|  |         std=[] | ||||||
|  |         for i in range(len(self.dlatents)): | ||||||
|  |             tmp= self.dlatents[i]  | ||||||
|  |             tmp_mean=tmp.mean(axis=0) | ||||||
|  |             tmp_std=tmp.std(axis=0) | ||||||
|  |             m.append(tmp_mean) | ||||||
|  |             std.append(tmp_std) | ||||||
|  |          | ||||||
|  |         self.code_mean=m | ||||||
|  |         self.code_std=std | ||||||
|  |         # return m,std | ||||||
|  |      | ||||||
|  |      | ||||||
|  | #%% | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     network_pkl='/cs/labs/danix/wuzongze/Gan_Manipulation/stylegan2/model/stylegan2-ffhq-config-f.pkl' | ||||||
|  |     device = torch.device('cuda') | ||||||
|  |     M=Manipulator() | ||||||
|  |     M.device=device | ||||||
|  |     G=M.LoadModel(network_pkl,device) | ||||||
|  |     M.G=G | ||||||
|  |     M.SetGParameters() | ||||||
|  |     num_img=100_000 | ||||||
|  |     M.GenerateS(num_img=num_img) | ||||||
|  |     M.GetCodeMS() | ||||||
|  |     np.set_printoptions(suppress=True) | ||||||
|  |      | ||||||
|  |     #%% | ||||||
|  |     M.alpha=[24,16,8,0,-8,-16,-24] | ||||||
|  |     M.step=len(M.alpha) | ||||||
|  |     M.img_index=0 | ||||||
|  |     M.num_images=10 | ||||||
|  |     lindex,bname=6,501 | ||||||
|  | #    M. | ||||||
|  |     M.manipulate_layers=[lindex] | ||||||
|  |     codes,out=M.EditOneC(bname) #dlatent_tmp | ||||||
|  |     tmp=str(M.manipulate_layers)+'_'+str(bname) | ||||||
|  |     M.Vis(tmp,'c',out) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
							
								
								
									
										
											BIN
										
									
								
								global_torch/npy/ffhq/fs3.npy
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								global_torch/npy/ffhq/fs3.npy
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								global_torch/npy/human/fs3.npy
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								global_torch/npy/human/fs3.npy
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										9
									
								
								global_torch/torch_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								global_torch/torch_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | # empty | ||||||
							
								
								
									
										126
									
								
								global_torch/torch_utils/custom_ops.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								global_torch/torch_utils/custom_ops.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,126 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | import os | ||||||
|  | import glob | ||||||
|  | import torch | ||||||
|  | import torch.utils.cpp_extension | ||||||
|  | import importlib | ||||||
|  | import hashlib | ||||||
|  | import shutil | ||||||
|  | from pathlib import Path | ||||||
|  | 
 | ||||||
|  | from torch.utils.file_baton import FileBaton | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Global options. | ||||||
|  | 
 | ||||||
|  | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Internal helper funcs. | ||||||
|  | 
 | ||||||
|  | def _find_compiler_bindir(): | ||||||
|  |     patterns = [ | ||||||
|  |         'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', | ||||||
|  |         'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', | ||||||
|  |         'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', | ||||||
|  |         'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', | ||||||
|  |     ] | ||||||
|  |     for pattern in patterns: | ||||||
|  |         matches = sorted(glob.glob(pattern)) | ||||||
|  |         if len(matches): | ||||||
|  |             return matches[-1] | ||||||
|  |     return None | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Main entry point for compiling and loading C++/CUDA plugins. | ||||||
|  | 
 | ||||||
|  | _cached_plugins = dict() | ||||||
|  | 
 | ||||||
|  | def get_plugin(module_name, sources, **build_kwargs): | ||||||
|  |     assert verbosity in ['none', 'brief', 'full'] | ||||||
|  | 
 | ||||||
|  |     # Already cached? | ||||||
|  |     if module_name in _cached_plugins: | ||||||
|  |         return _cached_plugins[module_name] | ||||||
|  | 
 | ||||||
|  |     # Print status. | ||||||
|  |     if verbosity == 'full': | ||||||
|  |         print(f'Setting up PyTorch plugin "{module_name}"...') | ||||||
|  |     elif verbosity == 'brief': | ||||||
|  |         print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) | ||||||
|  | 
 | ||||||
|  |     try: # pylint: disable=too-many-nested-blocks | ||||||
|  |         # Make sure we can find the necessary compiler binaries. | ||||||
|  |         if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: | ||||||
|  |             compiler_bindir = _find_compiler_bindir() | ||||||
|  |             if compiler_bindir is None: | ||||||
|  |                 raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') | ||||||
|  |             os.environ['PATH'] += ';' + compiler_bindir | ||||||
|  | 
 | ||||||
|  |         # Compile and load. | ||||||
|  |         verbose_build = (verbosity == 'full') | ||||||
|  | 
 | ||||||
|  |         # Incremental build md5sum trickery.  Copies all the input source files | ||||||
|  |         # into a cached build directory under a combined md5 digest of the input | ||||||
|  |         # source files.  Copying is done only if the combined digest has changed. | ||||||
|  |         # This keeps input file timestamps and filenames the same as in previous | ||||||
|  |         # extension builds, allowing for fast incremental rebuilds. | ||||||
|  |         # | ||||||
|  |         # This optimization is done only in case all the source files reside in | ||||||
|  |         # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR | ||||||
|  |         # environment variable is set (we take this as a signal that the user | ||||||
|  |         # actually cares about this.) | ||||||
|  |         source_dirs_set = set(os.path.dirname(source) for source in sources) | ||||||
|  |         if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): | ||||||
|  |             all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) | ||||||
|  | 
 | ||||||
|  |             # Compute a combined hash digest for all source files in the same | ||||||
|  |             # custom op directory (usually .cu, .cpp, .py and .h files). | ||||||
|  |             hash_md5 = hashlib.md5() | ||||||
|  |             for src in all_source_files: | ||||||
|  |                 with open(src, 'rb') as f: | ||||||
|  |                     hash_md5.update(f.read()) | ||||||
|  |             build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access | ||||||
|  |             digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) | ||||||
|  | 
 | ||||||
|  |             if not os.path.isdir(digest_build_dir): | ||||||
|  |                 os.makedirs(digest_build_dir, exist_ok=True) | ||||||
|  |                 baton = FileBaton(os.path.join(digest_build_dir, 'lock')) | ||||||
|  |                 if baton.try_acquire(): | ||||||
|  |                     try: | ||||||
|  |                         for src in all_source_files: | ||||||
|  |                             shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) | ||||||
|  |                     finally: | ||||||
|  |                         baton.release() | ||||||
|  |                 else: | ||||||
|  |                     # Someone else is copying source files under the digest dir, | ||||||
|  |                     # wait until done and continue. | ||||||
|  |                     baton.wait() | ||||||
|  |             digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] | ||||||
|  |             torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, | ||||||
|  |                 verbose=verbose_build, sources=digest_sources, **build_kwargs) | ||||||
|  |         else: | ||||||
|  |             torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) | ||||||
|  |         module = importlib.import_module(module_name) | ||||||
|  | 
 | ||||||
|  |     except: | ||||||
|  |         if verbosity == 'brief': | ||||||
|  |             print('Failed!') | ||||||
|  |         raise | ||||||
|  | 
 | ||||||
|  |     # Print status and add to cache. | ||||||
|  |     if verbosity == 'full': | ||||||
|  |         print(f'Done setting up PyTorch plugin "{module_name}".') | ||||||
|  |     elif verbosity == 'brief': | ||||||
|  |         print('Done.') | ||||||
|  |     _cached_plugins[module_name] = module | ||||||
|  |     return module | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										262
									
								
								global_torch/torch_utils/misc.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										262
									
								
								global_torch/torch_utils/misc.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,262 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | import re | ||||||
|  | import contextlib | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import warnings | ||||||
|  | import dnnlib | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the | ||||||
|  | # same constant is used multiple times. | ||||||
|  | 
 | ||||||
|  | _constant_cache = dict() | ||||||
|  | 
 | ||||||
|  | def constant(value, shape=None, dtype=None, device=None, memory_format=None): | ||||||
|  |     value = np.asarray(value) | ||||||
|  |     if shape is not None: | ||||||
|  |         shape = tuple(shape) | ||||||
|  |     if dtype is None: | ||||||
|  |         dtype = torch.get_default_dtype() | ||||||
|  |     if device is None: | ||||||
|  |         device = torch.device('cpu') | ||||||
|  |     if memory_format is None: | ||||||
|  |         memory_format = torch.contiguous_format | ||||||
|  | 
 | ||||||
|  |     key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) | ||||||
|  |     tensor = _constant_cache.get(key, None) | ||||||
|  |     if tensor is None: | ||||||
|  |         tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) | ||||||
|  |         if shape is not None: | ||||||
|  |             tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) | ||||||
|  |         tensor = tensor.contiguous(memory_format=memory_format) | ||||||
|  |         _constant_cache[key] = tensor | ||||||
|  |     return tensor | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Replace NaN/Inf with specified numerical values. | ||||||
|  | 
 | ||||||
|  | try: | ||||||
|  |     nan_to_num = torch.nan_to_num # 1.8.0a0 | ||||||
|  | except AttributeError: | ||||||
|  |     def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin | ||||||
|  |         assert isinstance(input, torch.Tensor) | ||||||
|  |         if posinf is None: | ||||||
|  |             posinf = torch.finfo(input.dtype).max | ||||||
|  |         if neginf is None: | ||||||
|  |             neginf = torch.finfo(input.dtype).min | ||||||
|  |         assert nan == 0 | ||||||
|  |         return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Symbolic assert. | ||||||
|  | 
 | ||||||
|  | try: | ||||||
|  |     symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access | ||||||
|  | except AttributeError: | ||||||
|  |     symbolic_assert = torch.Assert # 1.7.0 | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Context manager to suppress known warnings in torch.jit.trace(). | ||||||
|  | 
 | ||||||
|  | class suppress_tracer_warnings(warnings.catch_warnings): | ||||||
|  |     def __enter__(self): | ||||||
|  |         super().__enter__() | ||||||
|  |         warnings.simplefilter('ignore', category=torch.jit.TracerWarning) | ||||||
|  |         return self | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Assert that the shape of a tensor matches the given list of integers. | ||||||
|  | # None indicates that the size of a dimension is allowed to vary. | ||||||
|  | # Performs symbolic assertion when used in torch.jit.trace(). | ||||||
|  | 
 | ||||||
|  | def assert_shape(tensor, ref_shape): | ||||||
|  |     if tensor.ndim != len(ref_shape): | ||||||
|  |         raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') | ||||||
|  |     for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): | ||||||
|  |         if ref_size is None: | ||||||
|  |             pass | ||||||
|  |         elif isinstance(ref_size, torch.Tensor): | ||||||
|  |             with suppress_tracer_warnings(): # as_tensor results are registered as constants | ||||||
|  |                 symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') | ||||||
|  |         elif isinstance(size, torch.Tensor): | ||||||
|  |             with suppress_tracer_warnings(): # as_tensor results are registered as constants | ||||||
|  |                 symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') | ||||||
|  |         elif size != ref_size: | ||||||
|  |             raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Function decorator that calls torch.autograd.profiler.record_function(). | ||||||
|  | 
 | ||||||
|  | def profiled_function(fn): | ||||||
|  |     def decorator(*args, **kwargs): | ||||||
|  |         with torch.autograd.profiler.record_function(fn.__name__): | ||||||
|  |             return fn(*args, **kwargs) | ||||||
|  |     decorator.__name__ = fn.__name__ | ||||||
|  |     return decorator | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Sampler for torch.utils.data.DataLoader that loops over the dataset | ||||||
|  | # indefinitely, shuffling items as it goes. | ||||||
|  | 
 | ||||||
|  | class InfiniteSampler(torch.utils.data.Sampler): | ||||||
|  |     def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): | ||||||
|  |         assert len(dataset) > 0 | ||||||
|  |         assert num_replicas > 0 | ||||||
|  |         assert 0 <= rank < num_replicas | ||||||
|  |         assert 0 <= window_size <= 1 | ||||||
|  |         super().__init__(dataset) | ||||||
|  |         self.dataset = dataset | ||||||
|  |         self.rank = rank | ||||||
|  |         self.num_replicas = num_replicas | ||||||
|  |         self.shuffle = shuffle | ||||||
|  |         self.seed = seed | ||||||
|  |         self.window_size = window_size | ||||||
|  | 
 | ||||||
|  |     def __iter__(self): | ||||||
|  |         order = np.arange(len(self.dataset)) | ||||||
|  |         rnd = None | ||||||
|  |         window = 0 | ||||||
|  |         if self.shuffle: | ||||||
|  |             rnd = np.random.RandomState(self.seed) | ||||||
|  |             rnd.shuffle(order) | ||||||
|  |             window = int(np.rint(order.size * self.window_size)) | ||||||
|  | 
 | ||||||
|  |         idx = 0 | ||||||
|  |         while True: | ||||||
|  |             i = idx % order.size | ||||||
|  |             if idx % self.num_replicas == self.rank: | ||||||
|  |                 yield order[i] | ||||||
|  |             if window >= 2: | ||||||
|  |                 j = (i - rnd.randint(window)) % order.size | ||||||
|  |                 order[i], order[j] = order[j], order[i] | ||||||
|  |             idx += 1 | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Utilities for operating with torch.nn.Module parameters and buffers. | ||||||
|  | 
 | ||||||
|  | def params_and_buffers(module): | ||||||
|  |     assert isinstance(module, torch.nn.Module) | ||||||
|  |     return list(module.parameters()) + list(module.buffers()) | ||||||
|  | 
 | ||||||
|  | def named_params_and_buffers(module): | ||||||
|  |     assert isinstance(module, torch.nn.Module) | ||||||
|  |     return list(module.named_parameters()) + list(module.named_buffers()) | ||||||
|  | 
 | ||||||
|  | def copy_params_and_buffers(src_module, dst_module, require_all=False): | ||||||
|  |     assert isinstance(src_module, torch.nn.Module) | ||||||
|  |     assert isinstance(dst_module, torch.nn.Module) | ||||||
|  |     src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} | ||||||
|  |     for name, tensor in named_params_and_buffers(dst_module): | ||||||
|  |         assert (name in src_tensors) or (not require_all) | ||||||
|  |         if name in src_tensors: | ||||||
|  |             tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Context manager for easily enabling/disabling DistributedDataParallel | ||||||
|  | # synchronization. | ||||||
|  | 
 | ||||||
|  | @contextlib.contextmanager | ||||||
|  | def ddp_sync(module, sync): | ||||||
|  |     assert isinstance(module, torch.nn.Module) | ||||||
|  |     if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): | ||||||
|  |         yield | ||||||
|  |     else: | ||||||
|  |         with module.no_sync(): | ||||||
|  |             yield | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Check DistributedDataParallel consistency across processes. | ||||||
|  | 
 | ||||||
|  | def check_ddp_consistency(module, ignore_regex=None): | ||||||
|  |     assert isinstance(module, torch.nn.Module) | ||||||
|  |     for name, tensor in named_params_and_buffers(module): | ||||||
|  |         fullname = type(module).__name__ + '.' + name | ||||||
|  |         if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): | ||||||
|  |             continue | ||||||
|  |         tensor = tensor.detach() | ||||||
|  |         other = tensor.clone() | ||||||
|  |         torch.distributed.broadcast(tensor=other, src=0) | ||||||
|  |         assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Print summary table of module hierarchy. | ||||||
|  | 
 | ||||||
|  | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): | ||||||
|  |     assert isinstance(module, torch.nn.Module) | ||||||
|  |     assert not isinstance(module, torch.jit.ScriptModule) | ||||||
|  |     assert isinstance(inputs, (tuple, list)) | ||||||
|  | 
 | ||||||
|  |     # Register hooks. | ||||||
|  |     entries = [] | ||||||
|  |     nesting = [0] | ||||||
|  |     def pre_hook(_mod, _inputs): | ||||||
|  |         nesting[0] += 1 | ||||||
|  |     def post_hook(mod, _inputs, outputs): | ||||||
|  |         nesting[0] -= 1 | ||||||
|  |         if nesting[0] <= max_nesting: | ||||||
|  |             outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] | ||||||
|  |             outputs = [t for t in outputs if isinstance(t, torch.Tensor)] | ||||||
|  |             entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) | ||||||
|  |     hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] | ||||||
|  |     hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] | ||||||
|  | 
 | ||||||
|  |     # Run module. | ||||||
|  |     outputs = module(*inputs) | ||||||
|  |     for hook in hooks: | ||||||
|  |         hook.remove() | ||||||
|  | 
 | ||||||
|  |     # Identify unique outputs, parameters, and buffers. | ||||||
|  |     tensors_seen = set() | ||||||
|  |     for e in entries: | ||||||
|  |         e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] | ||||||
|  |         e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] | ||||||
|  |         e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] | ||||||
|  |         tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} | ||||||
|  | 
 | ||||||
|  |     # Filter out redundant entries. | ||||||
|  |     if skip_redundant: | ||||||
|  |         entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] | ||||||
|  | 
 | ||||||
|  |     # Construct table. | ||||||
|  |     rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] | ||||||
|  |     rows += [['---'] * len(rows[0])] | ||||||
|  |     param_total = 0 | ||||||
|  |     buffer_total = 0 | ||||||
|  |     submodule_names = {mod: name for name, mod in module.named_modules()} | ||||||
|  |     for e in entries: | ||||||
|  |         name = '<top-level>' if e.mod is module else submodule_names[e.mod] | ||||||
|  |         param_size = sum(t.numel() for t in e.unique_params) | ||||||
|  |         buffer_size = sum(t.numel() for t in e.unique_buffers) | ||||||
|  |         output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] | ||||||
|  |         output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] | ||||||
|  |         rows += [[ | ||||||
|  |             name + (':0' if len(e.outputs) >= 2 else ''), | ||||||
|  |             str(param_size) if param_size else '-', | ||||||
|  |             str(buffer_size) if buffer_size else '-', | ||||||
|  |             (output_shapes + ['-'])[0], | ||||||
|  |             (output_dtypes + ['-'])[0], | ||||||
|  |         ]] | ||||||
|  |         for idx in range(1, len(e.outputs)): | ||||||
|  |             rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] | ||||||
|  |         param_total += param_size | ||||||
|  |         buffer_total += buffer_size | ||||||
|  |     rows += [['---'] * len(rows[0])] | ||||||
|  |     rows += [['Total', str(param_total), str(buffer_total), '-', '-']] | ||||||
|  | 
 | ||||||
|  |     # Print table. | ||||||
|  |     widths = [max(len(cell) for cell in column) for column in zip(*rows)] | ||||||
|  |     print() | ||||||
|  |     for row in rows: | ||||||
|  |         print('  '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) | ||||||
|  |     print() | ||||||
|  |     return outputs | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										9
									
								
								global_torch/torch_utils/ops/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								global_torch/torch_utils/ops/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | # empty | ||||||
							
								
								
									
										99
									
								
								global_torch/torch_utils/ops/bias_act.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								global_torch/torch_utils/ops/bias_act.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,99 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property
 | ||||||
|  | // and proprietary rights in and to this software, related documentation
 | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or
 | ||||||
|  | // distribution of this software and related documentation without an express
 | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited.
 | ||||||
|  | 
 | ||||||
|  | #include <torch/extension.h> | ||||||
|  | #include <ATen/cuda/CUDAContext.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  | #include "bias_act.h" | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | static bool has_same_layout(torch::Tensor x, torch::Tensor y) | ||||||
|  | { | ||||||
|  |     if (x.dim() != y.dim()) | ||||||
|  |         return false; | ||||||
|  |     for (int64_t i = 0; i < x.dim(); i++) | ||||||
|  |     { | ||||||
|  |         if (x.size(i) != y.size(i)) | ||||||
|  |             return false; | ||||||
|  |         if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) | ||||||
|  |             return false; | ||||||
|  |     } | ||||||
|  |     return true; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) | ||||||
|  | { | ||||||
|  |     // Validate arguments.
 | ||||||
|  |     TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); | ||||||
|  |     TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); | ||||||
|  |     TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); | ||||||
|  |     TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); | ||||||
|  |     TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); | ||||||
|  |     TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); | ||||||
|  |     TORCH_CHECK(b.dim() == 1, "b must have rank 1"); | ||||||
|  |     TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); | ||||||
|  |     TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); | ||||||
|  |     TORCH_CHECK(grad >= 0, "grad must be non-negative"); | ||||||
|  | 
 | ||||||
|  |     // Validate layout.
 | ||||||
|  |     TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); | ||||||
|  |     TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); | ||||||
|  |     TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); | ||||||
|  |     TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); | ||||||
|  |     TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); | ||||||
|  | 
 | ||||||
|  |     // Create output tensor.
 | ||||||
|  |     const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | ||||||
|  |     torch::Tensor y = torch::empty_like(x); | ||||||
|  |     TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); | ||||||
|  | 
 | ||||||
|  |     // Initialize CUDA kernel parameters.
 | ||||||
|  |     bias_act_kernel_params p; | ||||||
|  |     p.x     = x.data_ptr(); | ||||||
|  |     p.b     = (b.numel()) ? b.data_ptr() : NULL; | ||||||
|  |     p.xref  = (xref.numel()) ? xref.data_ptr() : NULL; | ||||||
|  |     p.yref  = (yref.numel()) ? yref.data_ptr() : NULL; | ||||||
|  |     p.dy    = (dy.numel()) ? dy.data_ptr() : NULL; | ||||||
|  |     p.y     = y.data_ptr(); | ||||||
|  |     p.grad  = grad; | ||||||
|  |     p.act   = act; | ||||||
|  |     p.alpha = alpha; | ||||||
|  |     p.gain  = gain; | ||||||
|  |     p.clamp = clamp; | ||||||
|  |     p.sizeX = (int)x.numel(); | ||||||
|  |     p.sizeB = (int)b.numel(); | ||||||
|  |     p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; | ||||||
|  | 
 | ||||||
|  |     // Choose CUDA kernel.
 | ||||||
|  |     void* kernel; | ||||||
|  |     AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] | ||||||
|  |     { | ||||||
|  |         kernel = choose_bias_act_kernel<scalar_t>(p); | ||||||
|  |     }); | ||||||
|  |     TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); | ||||||
|  | 
 | ||||||
|  |     // Launch CUDA kernel.
 | ||||||
|  |     p.loopX = 4; | ||||||
|  |     int blockSize = 4 * 32; | ||||||
|  |     int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; | ||||||
|  |     void* args[] = {&p}; | ||||||
|  |     AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); | ||||||
|  |     return y; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||||||
|  | { | ||||||
|  |     m.def("bias_act", &bias_act); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
							
								
								
									
										173
									
								
								global_torch/torch_utils/ops/bias_act.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								global_torch/torch_utils/ops/bias_act.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,173 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | // | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | // and proprietary rights in and to this software, related documentation | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | // distribution of this software and related documentation without an express | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | #include <c10/util/Half.h> | ||||||
|  | #include "bias_act.h" | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // Helpers. | ||||||
|  | 
 | ||||||
|  | template <class T> struct InternalType; | ||||||
|  | template <> struct InternalType<double>     { typedef double scalar_t; }; | ||||||
|  | template <> struct InternalType<float>      { typedef float  scalar_t; }; | ||||||
|  | template <> struct InternalType<c10::Half>  { typedef float  scalar_t; }; | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // CUDA kernel. | ||||||
|  | 
 | ||||||
|  | template <class T, int A> | ||||||
|  | __global__ void bias_act_kernel(bias_act_kernel_params p) | ||||||
|  | { | ||||||
|  |     typedef typename InternalType<T>::scalar_t scalar_t; | ||||||
|  |     int G                 = p.grad; | ||||||
|  |     scalar_t alpha        = (scalar_t)p.alpha; | ||||||
|  |     scalar_t gain         = (scalar_t)p.gain; | ||||||
|  |     scalar_t clamp        = (scalar_t)p.clamp; | ||||||
|  |     scalar_t one          = (scalar_t)1; | ||||||
|  |     scalar_t two          = (scalar_t)2; | ||||||
|  |     scalar_t expRange     = (scalar_t)80; | ||||||
|  |     scalar_t halfExpRange = (scalar_t)40; | ||||||
|  |     scalar_t seluScale    = (scalar_t)1.0507009873554804934193349852946; | ||||||
|  |     scalar_t seluAlpha    = (scalar_t)1.6732632423543772848170429916717; | ||||||
|  | 
 | ||||||
|  |     // Loop over elements. | ||||||
|  |     int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; | ||||||
|  |     for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) | ||||||
|  |     { | ||||||
|  |         // Load. | ||||||
|  |         scalar_t x = (scalar_t)((const T*)p.x)[xi]; | ||||||
|  |         scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; | ||||||
|  |         scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; | ||||||
|  |         scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; | ||||||
|  |         scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; | ||||||
|  |         scalar_t yy = (gain != 0) ? yref / gain : 0; | ||||||
|  |         scalar_t y = 0; | ||||||
|  | 
 | ||||||
|  |         // Apply bias. | ||||||
|  |         ((G == 0) ? x : xref) += b; | ||||||
|  | 
 | ||||||
|  |         // linear | ||||||
|  |         if (A == 1) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = x; | ||||||
|  |             if (G == 1) y = x; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // relu | ||||||
|  |         if (A == 2) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = (x > 0) ? x : 0; | ||||||
|  |             if (G == 1) y = (yy > 0) ? x : 0; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // lrelu | ||||||
|  |         if (A == 3) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = (x > 0) ? x : x * alpha; | ||||||
|  |             if (G == 1) y = (yy > 0) ? x : x * alpha; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // tanh | ||||||
|  |         if (A == 4) | ||||||
|  |         { | ||||||
|  |             if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } | ||||||
|  |             if (G == 1) y = x * (one - yy * yy); | ||||||
|  |             if (G == 2) y = x * (one - yy * yy) * (-two * yy); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // sigmoid | ||||||
|  |         if (A == 5) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); | ||||||
|  |             if (G == 1) y = x * yy * (one - yy); | ||||||
|  |             if (G == 2) y = x * yy * (one - yy) * (one - two * yy); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // elu | ||||||
|  |         if (A == 6) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = (x >= 0) ? x : exp(x) - one; | ||||||
|  |             if (G == 1) y = (yy >= 0) ? x : x * (yy + one); | ||||||
|  |             if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // selu | ||||||
|  |         if (A == 7) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); | ||||||
|  |             if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); | ||||||
|  |             if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // softplus | ||||||
|  |         if (A == 8) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); | ||||||
|  |             if (G == 1) y = x * (one - exp(-yy)); | ||||||
|  |             if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // swish | ||||||
|  |         if (A == 9) | ||||||
|  |         { | ||||||
|  |             if (G == 0) | ||||||
|  |                 y = (x < -expRange) ? 0 : x / (exp(-x) + one); | ||||||
|  |             else | ||||||
|  |             { | ||||||
|  |                 scalar_t c = exp(xref); | ||||||
|  |                 scalar_t d = c + one; | ||||||
|  |                 if (G == 1) | ||||||
|  |                     y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); | ||||||
|  |                 else | ||||||
|  |                     y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); | ||||||
|  |                 yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Apply gain. | ||||||
|  |         y *= gain * dy; | ||||||
|  | 
 | ||||||
|  |         // Clamp. | ||||||
|  |         if (clamp >= 0) | ||||||
|  |         { | ||||||
|  |             if (G == 0) | ||||||
|  |                 y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; | ||||||
|  |             else | ||||||
|  |                 y = (yref > -clamp & yref < clamp) ? y : 0; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Store. | ||||||
|  |         ((T*)p.y)[xi] = (T)y; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // CUDA kernel selection. | ||||||
|  | 
 | ||||||
|  | template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p) | ||||||
|  | { | ||||||
|  |     if (p.act == 1) return (void*)bias_act_kernel<T, 1>; | ||||||
|  |     if (p.act == 2) return (void*)bias_act_kernel<T, 2>; | ||||||
|  |     if (p.act == 3) return (void*)bias_act_kernel<T, 3>; | ||||||
|  |     if (p.act == 4) return (void*)bias_act_kernel<T, 4>; | ||||||
|  |     if (p.act == 5) return (void*)bias_act_kernel<T, 5>; | ||||||
|  |     if (p.act == 6) return (void*)bias_act_kernel<T, 6>; | ||||||
|  |     if (p.act == 7) return (void*)bias_act_kernel<T, 7>; | ||||||
|  |     if (p.act == 8) return (void*)bias_act_kernel<T, 8>; | ||||||
|  |     if (p.act == 9) return (void*)bias_act_kernel<T, 9>; | ||||||
|  |     return NULL; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // Template specializations. | ||||||
|  | 
 | ||||||
|  | template void* choose_bias_act_kernel<double>       (const bias_act_kernel_params& p); | ||||||
|  | template void* choose_bias_act_kernel<float>        (const bias_act_kernel_params& p); | ||||||
|  | template void* choose_bias_act_kernel<c10::Half>    (const bias_act_kernel_params& p); | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
							
								
								
									
										38
									
								
								global_torch/torch_utils/ops/bias_act.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								global_torch/torch_utils/ops/bias_act.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,38 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property
 | ||||||
|  | // and proprietary rights in and to this software, related documentation
 | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or
 | ||||||
|  | // distribution of this software and related documentation without an express
 | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited.
 | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | // CUDA kernel parameters.
 | ||||||
|  | 
 | ||||||
|  | struct bias_act_kernel_params | ||||||
|  | { | ||||||
|  |     const void* x;      // [sizeX]
 | ||||||
|  |     const void* b;      // [sizeB] or NULL
 | ||||||
|  |     const void* xref;   // [sizeX] or NULL
 | ||||||
|  |     const void* yref;   // [sizeX] or NULL
 | ||||||
|  |     const void* dy;     // [sizeX] or NULL
 | ||||||
|  |     void*       y;      // [sizeX]
 | ||||||
|  | 
 | ||||||
|  |     int         grad; | ||||||
|  |     int         act; | ||||||
|  |     float       alpha; | ||||||
|  |     float       gain; | ||||||
|  |     float       clamp; | ||||||
|  | 
 | ||||||
|  |     int         sizeX; | ||||||
|  |     int         sizeB; | ||||||
|  |     int         stepB; | ||||||
|  |     int         loopX; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | // CUDA kernel selection.
 | ||||||
|  | 
 | ||||||
|  | template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p); | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
							
								
								
									
										212
									
								
								global_torch/torch_utils/ops/bias_act.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										212
									
								
								global_torch/torch_utils/ops/bias_act.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,212 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Custom PyTorch ops for efficient bias and activation.""" | ||||||
|  | 
 | ||||||
|  | import os | ||||||
|  | import warnings | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import dnnlib | ||||||
|  | import traceback | ||||||
|  | 
 | ||||||
|  | from .. import custom_ops | ||||||
|  | from .. import misc | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | activation_funcs = { | ||||||
|  |     'linear':   dnnlib.EasyDict(func=lambda x, **_:         x,                                          def_alpha=0,    def_gain=1,             cuda_idx=1, ref='',  has_2nd_grad=False), | ||||||
|  |     'relu':     dnnlib.EasyDict(func=lambda x, **_:         torch.nn.functional.relu(x),                def_alpha=0,    def_gain=np.sqrt(2),    cuda_idx=2, ref='y', has_2nd_grad=False), | ||||||
|  |     'lrelu':    dnnlib.EasyDict(func=lambda x, alpha, **_:  torch.nn.functional.leaky_relu(x, alpha),   def_alpha=0.2,  def_gain=np.sqrt(2),    cuda_idx=3, ref='y', has_2nd_grad=False), | ||||||
|  |     'tanh':     dnnlib.EasyDict(func=lambda x, **_:         torch.tanh(x),                              def_alpha=0,    def_gain=1,             cuda_idx=4, ref='y', has_2nd_grad=True), | ||||||
|  |     'sigmoid':  dnnlib.EasyDict(func=lambda x, **_:         torch.sigmoid(x),                           def_alpha=0,    def_gain=1,             cuda_idx=5, ref='y', has_2nd_grad=True), | ||||||
|  |     'elu':      dnnlib.EasyDict(func=lambda x, **_:         torch.nn.functional.elu(x),                 def_alpha=0,    def_gain=1,             cuda_idx=6, ref='y', has_2nd_grad=True), | ||||||
|  |     'selu':     dnnlib.EasyDict(func=lambda x, **_:         torch.nn.functional.selu(x),                def_alpha=0,    def_gain=1,             cuda_idx=7, ref='y', has_2nd_grad=True), | ||||||
|  |     'softplus': dnnlib.EasyDict(func=lambda x, **_:         torch.nn.functional.softplus(x),            def_alpha=0,    def_gain=1,             cuda_idx=8, ref='y', has_2nd_grad=True), | ||||||
|  |     'swish':    dnnlib.EasyDict(func=lambda x, **_:         torch.sigmoid(x) * x,                       def_alpha=0,    def_gain=np.sqrt(2),    cuda_idx=9, ref='x', has_2nd_grad=True), | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _inited = False | ||||||
|  | _plugin = None | ||||||
|  | _null_tensor = torch.empty([0]) | ||||||
|  | 
 | ||||||
|  | def _init(): | ||||||
|  |     global _inited, _plugin | ||||||
|  |     if not _inited: | ||||||
|  |         _inited = True | ||||||
|  |         sources = ['bias_act.cpp', 'bias_act.cu'] | ||||||
|  |         sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] | ||||||
|  |         try: | ||||||
|  |             _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) | ||||||
|  |         except: | ||||||
|  |             warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) | ||||||
|  |     return _plugin is not None | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): | ||||||
|  |     r"""Fused bias and activation function. | ||||||
|  | 
 | ||||||
|  |     Adds bias `b` to activation tensor `x`, evaluates activation function `act`, | ||||||
|  |     and scales the result by `gain`. Each of the steps is optional. In most cases, | ||||||
|  |     the fused op is considerably more efficient than performing the same calculation | ||||||
|  |     using standard PyTorch ops. It supports first and second order gradients, | ||||||
|  |     but not third order gradients. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x:      Input activation tensor. Can be of any shape. | ||||||
|  |         b:      Bias vector, or `None` to disable. Must be a 1D tensor of the same type | ||||||
|  |                 as `x`. The shape must be known, and it must match the dimension of `x` | ||||||
|  |                 corresponding to `dim`. | ||||||
|  |         dim:    The dimension in `x` corresponding to the elements of `b`. | ||||||
|  |                 The value of `dim` is ignored if `b` is not specified. | ||||||
|  |         act:    Name of the activation function to evaluate, or `"linear"` to disable. | ||||||
|  |                 Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. | ||||||
|  |                 See `activation_funcs` for a full list. `None` is not allowed. | ||||||
|  |         alpha:  Shape parameter for the activation function, or `None` to use the default. | ||||||
|  |         gain:   Scaling factor for the output tensor, or `None` to use default. | ||||||
|  |                 See `activation_funcs` for the default scaling of each activation function. | ||||||
|  |                 If unsure, consider specifying 1. | ||||||
|  |         clamp:  Clamp the output values to `[-clamp, +clamp]`, or `None` to disable | ||||||
|  |                 the clamping (default). | ||||||
|  |         impl:   Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tensor of the same shape and datatype as `x`. | ||||||
|  |     """ | ||||||
|  |     assert isinstance(x, torch.Tensor) | ||||||
|  |     assert impl in ['ref', 'cuda'] | ||||||
|  |     if impl == 'cuda' and x.device.type == 'cuda' and _init(): | ||||||
|  |         return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) | ||||||
|  |     return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @misc.profiled_function | ||||||
|  | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): | ||||||
|  |     """Slow reference implementation of `bias_act()` using standard TensorFlow ops. | ||||||
|  |     """ | ||||||
|  |     assert isinstance(x, torch.Tensor) | ||||||
|  |     assert clamp is None or clamp >= 0 | ||||||
|  |     spec = activation_funcs[act] | ||||||
|  |     alpha = float(alpha if alpha is not None else spec.def_alpha) | ||||||
|  |     gain = float(gain if gain is not None else spec.def_gain) | ||||||
|  |     clamp = float(clamp if clamp is not None else -1) | ||||||
|  | 
 | ||||||
|  |     # Add bias. | ||||||
|  |     if b is not None: | ||||||
|  |         assert isinstance(b, torch.Tensor) and b.ndim == 1 | ||||||
|  |         assert 0 <= dim < x.ndim | ||||||
|  |         assert b.shape[0] == x.shape[dim] | ||||||
|  |         x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) | ||||||
|  | 
 | ||||||
|  |     # Evaluate activation function. | ||||||
|  |     alpha = float(alpha) | ||||||
|  |     x = spec.func(x, alpha=alpha) | ||||||
|  | 
 | ||||||
|  |     # Scale by gain. | ||||||
|  |     gain = float(gain) | ||||||
|  |     if gain != 1: | ||||||
|  |         x = x * gain | ||||||
|  | 
 | ||||||
|  |     # Clamp. | ||||||
|  |     if clamp >= 0: | ||||||
|  |         x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type | ||||||
|  |     return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _bias_act_cuda_cache = dict() | ||||||
|  | 
 | ||||||
|  | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): | ||||||
|  |     """Fast CUDA implementation of `bias_act()` using custom ops. | ||||||
|  |     """ | ||||||
|  |     # Parse arguments. | ||||||
|  |     assert clamp is None or clamp >= 0 | ||||||
|  |     spec = activation_funcs[act] | ||||||
|  |     alpha = float(alpha if alpha is not None else spec.def_alpha) | ||||||
|  |     gain = float(gain if gain is not None else spec.def_gain) | ||||||
|  |     clamp = float(clamp if clamp is not None else -1) | ||||||
|  | 
 | ||||||
|  |     # Lookup from cache. | ||||||
|  |     key = (dim, act, alpha, gain, clamp) | ||||||
|  |     if key in _bias_act_cuda_cache: | ||||||
|  |         return _bias_act_cuda_cache[key] | ||||||
|  | 
 | ||||||
|  |     # Forward op. | ||||||
|  |     class BiasActCuda(torch.autograd.Function): | ||||||
|  |         @staticmethod | ||||||
|  |         def forward(ctx, x, b): # pylint: disable=arguments-differ | ||||||
|  |             ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format | ||||||
|  |             x = x.contiguous(memory_format=ctx.memory_format) | ||||||
|  |             b = b.contiguous() if b is not None else _null_tensor | ||||||
|  |             y = x | ||||||
|  |             if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: | ||||||
|  |                 y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) | ||||||
|  |             ctx.save_for_backward( | ||||||
|  |                 x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, | ||||||
|  |                 b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, | ||||||
|  |                 y if 'y' in spec.ref else _null_tensor) | ||||||
|  |             return y | ||||||
|  | 
 | ||||||
|  |         @staticmethod | ||||||
|  |         def backward(ctx, dy): # pylint: disable=arguments-differ | ||||||
|  |             dy = dy.contiguous(memory_format=ctx.memory_format) | ||||||
|  |             x, b, y = ctx.saved_tensors | ||||||
|  |             dx = None | ||||||
|  |             db = None | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: | ||||||
|  |                 dx = dy | ||||||
|  |                 if act != 'linear' or gain != 1 or clamp >= 0: | ||||||
|  |                     dx = BiasActCudaGrad.apply(dy, x, b, y) | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[1]: | ||||||
|  |                 db = dx.sum([i for i in range(dx.ndim) if i != dim]) | ||||||
|  | 
 | ||||||
|  |             return dx, db | ||||||
|  | 
 | ||||||
|  |     # Backward op. | ||||||
|  |     class BiasActCudaGrad(torch.autograd.Function): | ||||||
|  |         @staticmethod | ||||||
|  |         def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ | ||||||
|  |             ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format | ||||||
|  |             dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) | ||||||
|  |             ctx.save_for_backward( | ||||||
|  |                 dy if spec.has_2nd_grad else _null_tensor, | ||||||
|  |                 x, b, y) | ||||||
|  |             return dx | ||||||
|  | 
 | ||||||
|  |         @staticmethod | ||||||
|  |         def backward(ctx, d_dx): # pylint: disable=arguments-differ | ||||||
|  |             d_dx = d_dx.contiguous(memory_format=ctx.memory_format) | ||||||
|  |             dy, x, b, y = ctx.saved_tensors | ||||||
|  |             d_dy = None | ||||||
|  |             d_x = None | ||||||
|  |             d_b = None | ||||||
|  |             d_y = None | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[0]: | ||||||
|  |                 d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) | ||||||
|  | 
 | ||||||
|  |             if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): | ||||||
|  |                 d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) | ||||||
|  | 
 | ||||||
|  |             if spec.has_2nd_grad and ctx.needs_input_grad[2]: | ||||||
|  |                 d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) | ||||||
|  | 
 | ||||||
|  |             return d_dy, d_x, d_b, d_y | ||||||
|  | 
 | ||||||
|  |     # Add to cache. | ||||||
|  |     _bias_act_cuda_cache[key] = BiasActCuda | ||||||
|  |     return BiasActCuda | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										170
									
								
								global_torch/torch_utils/ops/conv2d_gradfix.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										170
									
								
								global_torch/torch_utils/ops/conv2d_gradfix.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,170 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Custom replacement for `torch.nn.functional.conv2d` that supports | ||||||
|  | arbitrarily high order gradients with zero performance penalty.""" | ||||||
|  | 
 | ||||||
|  | import warnings | ||||||
|  | import contextlib | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | # pylint: disable=redefined-builtin | ||||||
|  | # pylint: disable=arguments-differ | ||||||
|  | # pylint: disable=protected-access | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | enabled = False                     # Enable the custom op by setting this to true. | ||||||
|  | weight_gradients_disabled = False   # Forcefully disable computation of gradients with respect to the weights. | ||||||
|  | 
 | ||||||
|  | @contextlib.contextmanager | ||||||
|  | def no_weight_gradients(): | ||||||
|  |     global weight_gradients_disabled | ||||||
|  |     old = weight_gradients_disabled | ||||||
|  |     weight_gradients_disabled = True | ||||||
|  |     yield | ||||||
|  |     weight_gradients_disabled = old | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): | ||||||
|  |     if _should_use_custom_op(input): | ||||||
|  |         return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) | ||||||
|  |     return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) | ||||||
|  | 
 | ||||||
|  | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): | ||||||
|  |     if _should_use_custom_op(input): | ||||||
|  |         return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) | ||||||
|  |     return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _should_use_custom_op(input): | ||||||
|  |     assert isinstance(input, torch.Tensor) | ||||||
|  |     if (not enabled) or (not torch.backends.cudnn.enabled): | ||||||
|  |         return False | ||||||
|  |     if input.device.type != 'cuda': | ||||||
|  |         return False | ||||||
|  |     if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): | ||||||
|  |         return True | ||||||
|  |     warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') | ||||||
|  |     return False | ||||||
|  | 
 | ||||||
|  | def _tuple_of_ints(xs, ndim): | ||||||
|  |     xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim | ||||||
|  |     assert len(xs) == ndim | ||||||
|  |     assert all(isinstance(x, int) for x in xs) | ||||||
|  |     return xs | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _conv2d_gradfix_cache = dict() | ||||||
|  | 
 | ||||||
|  | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): | ||||||
|  |     # Parse arguments. | ||||||
|  |     ndim = 2 | ||||||
|  |     weight_shape = tuple(weight_shape) | ||||||
|  |     stride = _tuple_of_ints(stride, ndim) | ||||||
|  |     padding = _tuple_of_ints(padding, ndim) | ||||||
|  |     output_padding = _tuple_of_ints(output_padding, ndim) | ||||||
|  |     dilation = _tuple_of_ints(dilation, ndim) | ||||||
|  | 
 | ||||||
|  |     # Lookup from cache. | ||||||
|  |     key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) | ||||||
|  |     if key in _conv2d_gradfix_cache: | ||||||
|  |         return _conv2d_gradfix_cache[key] | ||||||
|  | 
 | ||||||
|  |     # Validate arguments. | ||||||
|  |     assert groups >= 1 | ||||||
|  |     assert len(weight_shape) == ndim + 2 | ||||||
|  |     assert all(stride[i] >= 1 for i in range(ndim)) | ||||||
|  |     assert all(padding[i] >= 0 for i in range(ndim)) | ||||||
|  |     assert all(dilation[i] >= 0 for i in range(ndim)) | ||||||
|  |     if not transpose: | ||||||
|  |         assert all(output_padding[i] == 0 for i in range(ndim)) | ||||||
|  |     else: # transpose | ||||||
|  |         assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) | ||||||
|  | 
 | ||||||
|  |     # Helpers. | ||||||
|  |     common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) | ||||||
|  |     def calc_output_padding(input_shape, output_shape): | ||||||
|  |         if transpose: | ||||||
|  |             return [0, 0] | ||||||
|  |         return [ | ||||||
|  |             input_shape[i + 2] | ||||||
|  |             - (output_shape[i + 2] - 1) * stride[i] | ||||||
|  |             - (1 - 2 * padding[i]) | ||||||
|  |             - dilation[i] * (weight_shape[i + 2] - 1) | ||||||
|  |             for i in range(ndim) | ||||||
|  |         ] | ||||||
|  | 
 | ||||||
|  |     # Forward & backward. | ||||||
|  |     class Conv2d(torch.autograd.Function): | ||||||
|  |         @staticmethod | ||||||
|  |         def forward(ctx, input, weight, bias): | ||||||
|  |             assert weight.shape == weight_shape | ||||||
|  |             if not transpose: | ||||||
|  |                 output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) | ||||||
|  |             else: # transpose | ||||||
|  |                 output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) | ||||||
|  |             ctx.save_for_backward(input, weight) | ||||||
|  |             return output | ||||||
|  | 
 | ||||||
|  |         @staticmethod | ||||||
|  |         def backward(ctx, grad_output): | ||||||
|  |             input, weight = ctx.saved_tensors | ||||||
|  |             grad_input = None | ||||||
|  |             grad_weight = None | ||||||
|  |             grad_bias = None | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[0]: | ||||||
|  |                 p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) | ||||||
|  |                 grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) | ||||||
|  |                 assert grad_input.shape == input.shape | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[1] and not weight_gradients_disabled: | ||||||
|  |                 grad_weight = Conv2dGradWeight.apply(grad_output, input) | ||||||
|  |                 assert grad_weight.shape == weight_shape | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[2]: | ||||||
|  |                 grad_bias = grad_output.sum([0, 2, 3]) | ||||||
|  | 
 | ||||||
|  |             return grad_input, grad_weight, grad_bias | ||||||
|  | 
 | ||||||
|  |     # Gradient with respect to the weights. | ||||||
|  |     class Conv2dGradWeight(torch.autograd.Function): | ||||||
|  |         @staticmethod | ||||||
|  |         def forward(ctx, grad_output, input): | ||||||
|  |             op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') | ||||||
|  |             flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] | ||||||
|  |             grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) | ||||||
|  |             assert grad_weight.shape == weight_shape | ||||||
|  |             ctx.save_for_backward(grad_output, input) | ||||||
|  |             return grad_weight | ||||||
|  | 
 | ||||||
|  |         @staticmethod | ||||||
|  |         def backward(ctx, grad2_grad_weight): | ||||||
|  |             grad_output, input = ctx.saved_tensors | ||||||
|  |             grad2_grad_output = None | ||||||
|  |             grad2_input = None | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[0]: | ||||||
|  |                 grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) | ||||||
|  |                 assert grad2_grad_output.shape == grad_output.shape | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[1]: | ||||||
|  |                 p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) | ||||||
|  |                 grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) | ||||||
|  |                 assert grad2_input.shape == input.shape | ||||||
|  | 
 | ||||||
|  |             return grad2_grad_output, grad2_input | ||||||
|  | 
 | ||||||
|  |     _conv2d_gradfix_cache[key] = Conv2d | ||||||
|  |     return Conv2d | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										156
									
								
								global_torch/torch_utils/ops/conv2d_resample.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										156
									
								
								global_torch/torch_utils/ops/conv2d_resample.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,156 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """2D convolution with optional up/downsampling.""" | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | from .. import misc | ||||||
|  | from . import conv2d_gradfix | ||||||
|  | from . import upfirdn2d | ||||||
|  | from .upfirdn2d import _parse_padding | ||||||
|  | from .upfirdn2d import _get_filter_size | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _get_weight_shape(w): | ||||||
|  |     with misc.suppress_tracer_warnings(): # this value will be treated as a constant | ||||||
|  |         shape = [int(sz) for sz in w.shape] | ||||||
|  |     misc.assert_shape(w, shape) | ||||||
|  |     return shape | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): | ||||||
|  |     """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. | ||||||
|  |     """ | ||||||
|  |     out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) | ||||||
|  | 
 | ||||||
|  |     # Flip weight if requested. | ||||||
|  |     if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). | ||||||
|  |         w = w.flip([2, 3]) | ||||||
|  | 
 | ||||||
|  |     # Workaround performance pitfall in cuDNN 8.0.5, triggered when using | ||||||
|  |     # 1x1 kernel + memory_format=channels_last + less than 64 channels. | ||||||
|  |     if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: | ||||||
|  |         if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: | ||||||
|  |             if out_channels <= 4 and groups == 1: | ||||||
|  |                 in_shape = x.shape | ||||||
|  |                 x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) | ||||||
|  |                 x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) | ||||||
|  |             else: | ||||||
|  |                 x = x.to(memory_format=torch.contiguous_format) | ||||||
|  |                 w = w.to(memory_format=torch.contiguous_format) | ||||||
|  |                 x = conv2d_gradfix.conv2d(x, w, groups=groups) | ||||||
|  |             return x.to(memory_format=torch.channels_last) | ||||||
|  | 
 | ||||||
|  |     # Otherwise => execute using conv2d_gradfix. | ||||||
|  |     op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d | ||||||
|  |     return op(x, w, stride=stride, padding=padding, groups=groups) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @misc.profiled_function | ||||||
|  | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): | ||||||
|  |     r"""2D convolution with optional up/downsampling. | ||||||
|  | 
 | ||||||
|  |     Padding is performed only once at the beginning, not between the operations. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x:              Input tensor of shape | ||||||
|  |                         `[batch_size, in_channels, in_height, in_width]`. | ||||||
|  |         w:              Weight tensor of shape | ||||||
|  |                         `[out_channels, in_channels//groups, kernel_height, kernel_width]`. | ||||||
|  |         f:              Low-pass filter for up/downsampling. Must be prepared beforehand by | ||||||
|  |                         calling upfirdn2d.setup_filter(). None = identity (default). | ||||||
|  |         up:             Integer upsampling factor (default: 1). | ||||||
|  |         down:           Integer downsampling factor (default: 1). | ||||||
|  |         padding:        Padding with respect to the upsampled image. Can be a single number | ||||||
|  |                         or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | ||||||
|  |                         (default: 0). | ||||||
|  |         groups:         Split input channels into N groups (default: 1). | ||||||
|  |         flip_weight:    False = convolution, True = correlation (default: True). | ||||||
|  |         flip_filter:    False = convolution, True = correlation (default: False). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | ||||||
|  |     """ | ||||||
|  |     # Validate arguments. | ||||||
|  |     assert isinstance(x, torch.Tensor) and (x.ndim == 4) | ||||||
|  |     assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) | ||||||
|  |     assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) | ||||||
|  |     assert isinstance(up, int) and (up >= 1) | ||||||
|  |     assert isinstance(down, int) and (down >= 1) | ||||||
|  |     assert isinstance(groups, int) and (groups >= 1) | ||||||
|  |     out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) | ||||||
|  |     fw, fh = _get_filter_size(f) | ||||||
|  |     px0, px1, py0, py1 = _parse_padding(padding) | ||||||
|  | 
 | ||||||
|  |     # Adjust padding to account for up/downsampling. | ||||||
|  |     if up > 1: | ||||||
|  |         px0 += (fw + up - 1) // 2 | ||||||
|  |         px1 += (fw - up) // 2 | ||||||
|  |         py0 += (fh + up - 1) // 2 | ||||||
|  |         py1 += (fh - up) // 2 | ||||||
|  |     if down > 1: | ||||||
|  |         px0 += (fw - down + 1) // 2 | ||||||
|  |         px1 += (fw - down) // 2 | ||||||
|  |         py0 += (fh - down + 1) // 2 | ||||||
|  |         py1 += (fh - down) // 2 | ||||||
|  | 
 | ||||||
|  |     # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. | ||||||
|  |     if kw == 1 and kh == 1 and (down > 1 and up == 1): | ||||||
|  |         x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) | ||||||
|  |         x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. | ||||||
|  |     if kw == 1 and kh == 1 and (up > 1 and down == 1): | ||||||
|  |         x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) | ||||||
|  |         x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     # Fast path: downsampling only => use strided convolution. | ||||||
|  |     if down > 1 and up == 1: | ||||||
|  |         x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) | ||||||
|  |         x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     # Fast path: upsampling with optional downsampling => use transpose strided convolution. | ||||||
|  |     if up > 1: | ||||||
|  |         if groups == 1: | ||||||
|  |             w = w.transpose(0, 1) | ||||||
|  |         else: | ||||||
|  |             w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) | ||||||
|  |             w = w.transpose(1, 2) | ||||||
|  |             w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) | ||||||
|  |         px0 -= kw - 1 | ||||||
|  |         px1 -= kw - up | ||||||
|  |         py0 -= kh - 1 | ||||||
|  |         py1 -= kh - up | ||||||
|  |         pxt = max(min(-px0, -px1), 0) | ||||||
|  |         pyt = max(min(-py0, -py1), 0) | ||||||
|  |         x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) | ||||||
|  |         x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) | ||||||
|  |         if down > 1: | ||||||
|  |             x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. | ||||||
|  |     if up == 1 and down == 1: | ||||||
|  |         if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: | ||||||
|  |             return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) | ||||||
|  | 
 | ||||||
|  |     # Fallback: Generic reference implementation. | ||||||
|  |     x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) | ||||||
|  |     x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) | ||||||
|  |     if down > 1: | ||||||
|  |         x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) | ||||||
|  |     return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										60
									
								
								global_torch/torch_utils/ops/fma.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								global_torch/torch_utils/ops/fma.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,60 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def fma(a, b, c): # => a * b + c | ||||||
|  |     return _FusedMultiplyAdd.apply(a, b, c) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c | ||||||
|  |     @staticmethod | ||||||
|  |     def forward(ctx, a, b, c): # pylint: disable=arguments-differ | ||||||
|  |         out = torch.addcmul(c, a, b) | ||||||
|  |         ctx.save_for_backward(a, b) | ||||||
|  |         ctx.c_shape = c.shape | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def backward(ctx, dout): # pylint: disable=arguments-differ | ||||||
|  |         a, b = ctx.saved_tensors | ||||||
|  |         c_shape = ctx.c_shape | ||||||
|  |         da = None | ||||||
|  |         db = None | ||||||
|  |         dc = None | ||||||
|  | 
 | ||||||
|  |         if ctx.needs_input_grad[0]: | ||||||
|  |             da = _unbroadcast(dout * b, a.shape) | ||||||
|  | 
 | ||||||
|  |         if ctx.needs_input_grad[1]: | ||||||
|  |             db = _unbroadcast(dout * a, b.shape) | ||||||
|  | 
 | ||||||
|  |         if ctx.needs_input_grad[2]: | ||||||
|  |             dc = _unbroadcast(dout, c_shape) | ||||||
|  | 
 | ||||||
|  |         return da, db, dc | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _unbroadcast(x, shape): | ||||||
|  |     extra_dims = x.ndim - len(shape) | ||||||
|  |     assert extra_dims >= 0 | ||||||
|  |     dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] | ||||||
|  |     if len(dim): | ||||||
|  |         x = x.sum(dim=dim, keepdim=True) | ||||||
|  |     if extra_dims: | ||||||
|  |         x = x.reshape(-1, *x.shape[extra_dims+1:]) | ||||||
|  |     assert x.shape == shape | ||||||
|  |     return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										34
									
								
								global_torch/torch_utils/ops/fused_act.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								global_torch/torch_utils/ops/fused_act.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,34 @@ | |||||||
|  | import os | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | from torch import nn | ||||||
|  | from torch.nn import functional as F | ||||||
|  | from torch.autograd import Function | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | module_path = os.path.dirname(__file__) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class FusedLeakyReLU(nn.Module): | ||||||
|  |     def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.bias = nn.Parameter(torch.zeros(channel)) | ||||||
|  |         self.negative_slope = negative_slope | ||||||
|  |         self.scale = scale | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): | ||||||
|  |     rest_dim = [1] * (input.ndim - bias.ndim - 1) | ||||||
|  |     input = input.cuda() | ||||||
|  |     return ( | ||||||
|  |         F.leaky_relu( | ||||||
|  |             input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope | ||||||
|  |         ) | ||||||
|  |         * scale | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
							
								
								
									
										83
									
								
								global_torch/torch_utils/ops/grid_sample_gradfix.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								global_torch/torch_utils/ops/grid_sample_gradfix.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,83 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Custom replacement for `torch.nn.functional.grid_sample` that | ||||||
|  | supports arbitrarily high order gradients between the input and output. | ||||||
|  | Only works on 2D images and assumes | ||||||
|  | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" | ||||||
|  | 
 | ||||||
|  | import warnings | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | # pylint: disable=redefined-builtin | ||||||
|  | # pylint: disable=arguments-differ | ||||||
|  | # pylint: disable=protected-access | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | enabled = False  # Enable the custom op by setting this to true. | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def grid_sample(input, grid): | ||||||
|  |     if _should_use_custom_op(): | ||||||
|  |         return _GridSample2dForward.apply(input, grid) | ||||||
|  |     return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _should_use_custom_op(): | ||||||
|  |     if not enabled: | ||||||
|  |         return False | ||||||
|  |     if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): | ||||||
|  |         return True | ||||||
|  |     warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') | ||||||
|  |     return False | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | class _GridSample2dForward(torch.autograd.Function): | ||||||
|  |     @staticmethod | ||||||
|  |     def forward(ctx, input, grid): | ||||||
|  |         assert input.ndim == 4 | ||||||
|  |         assert grid.ndim == 4 | ||||||
|  |         output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) | ||||||
|  |         ctx.save_for_backward(input, grid) | ||||||
|  |         return output | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def backward(ctx, grad_output): | ||||||
|  |         input, grid = ctx.saved_tensors | ||||||
|  |         grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) | ||||||
|  |         return grad_input, grad_grid | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | class _GridSample2dBackward(torch.autograd.Function): | ||||||
|  |     @staticmethod | ||||||
|  |     def forward(ctx, grad_output, input, grid): | ||||||
|  |         op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') | ||||||
|  |         grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) | ||||||
|  |         ctx.save_for_backward(grid) | ||||||
|  |         return grad_input, grad_grid | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def backward(ctx, grad2_grad_input, grad2_grad_grid): | ||||||
|  |         _ = grad2_grad_grid # unused | ||||||
|  |         grid, = ctx.saved_tensors | ||||||
|  |         grad2_grad_output = None | ||||||
|  |         grad2_input = None | ||||||
|  |         grad2_grid = None | ||||||
|  | 
 | ||||||
|  |         if ctx.needs_input_grad[0]: | ||||||
|  |             grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) | ||||||
|  | 
 | ||||||
|  |         assert not ctx.needs_input_grad[2] | ||||||
|  |         return grad2_grad_output, grad2_input, grad2_grid | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										103
									
								
								global_torch/torch_utils/ops/upfirdn2d.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								global_torch/torch_utils/ops/upfirdn2d.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,103 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property
 | ||||||
|  | // and proprietary rights in and to this software, related documentation
 | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or
 | ||||||
|  | // distribution of this software and related documentation without an express
 | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited.
 | ||||||
|  | 
 | ||||||
|  | #include <torch/extension.h> | ||||||
|  | #include <ATen/cuda/CUDAContext.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  | #include "upfirdn2d.h" | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) | ||||||
|  | { | ||||||
|  |     // Validate arguments.
 | ||||||
|  |     TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); | ||||||
|  |     TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); | ||||||
|  |     TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); | ||||||
|  |     TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); | ||||||
|  |     TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); | ||||||
|  |     TORCH_CHECK(x.dim() == 4, "x must be rank 4"); | ||||||
|  |     TORCH_CHECK(f.dim() == 2, "f must be rank 2"); | ||||||
|  |     TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); | ||||||
|  |     TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); | ||||||
|  |     TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); | ||||||
|  | 
 | ||||||
|  |     // Create output tensor.
 | ||||||
|  |     const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | ||||||
|  |     int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; | ||||||
|  |     int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; | ||||||
|  |     TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); | ||||||
|  |     torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); | ||||||
|  |     TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); | ||||||
|  | 
 | ||||||
|  |     // Initialize CUDA kernel parameters.
 | ||||||
|  |     upfirdn2d_kernel_params p; | ||||||
|  |     p.x             = x.data_ptr(); | ||||||
|  |     p.f             = f.data_ptr<float>(); | ||||||
|  |     p.y             = y.data_ptr(); | ||||||
|  |     p.up            = make_int2(upx, upy); | ||||||
|  |     p.down          = make_int2(downx, downy); | ||||||
|  |     p.pad0          = make_int2(padx0, pady0); | ||||||
|  |     p.flip          = (flip) ? 1 : 0; | ||||||
|  |     p.gain          = gain; | ||||||
|  |     p.inSize        = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); | ||||||
|  |     p.inStride      = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); | ||||||
|  |     p.filterSize    = make_int2((int)f.size(1), (int)f.size(0)); | ||||||
|  |     p.filterStride  = make_int2((int)f.stride(1), (int)f.stride(0)); | ||||||
|  |     p.outSize       = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); | ||||||
|  |     p.outStride     = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); | ||||||
|  |     p.sizeMajor     = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; | ||||||
|  |     p.sizeMinor     = (p.inStride.z == 1) ? p.inSize.z : 1; | ||||||
|  | 
 | ||||||
|  |     // Choose CUDA kernel.
 | ||||||
|  |     upfirdn2d_kernel_spec spec; | ||||||
|  |     AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] | ||||||
|  |     { | ||||||
|  |         spec = choose_upfirdn2d_kernel<scalar_t>(p); | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     // Set looping options.
 | ||||||
|  |     p.loopMajor     = (p.sizeMajor - 1) / 16384 + 1; | ||||||
|  |     p.loopMinor     = spec.loopMinor; | ||||||
|  |     p.loopX         = spec.loopX; | ||||||
|  |     p.launchMinor   = (p.sizeMinor - 1) / p.loopMinor + 1; | ||||||
|  |     p.launchMajor   = (p.sizeMajor - 1) / p.loopMajor + 1; | ||||||
|  | 
 | ||||||
|  |     // Compute grid size.
 | ||||||
|  |     dim3 blockSize, gridSize; | ||||||
|  |     if (spec.tileOutW < 0) // large
 | ||||||
|  |     { | ||||||
|  |         blockSize = dim3(4, 32, 1); | ||||||
|  |         gridSize = dim3( | ||||||
|  |             ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, | ||||||
|  |             (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, | ||||||
|  |             p.launchMajor); | ||||||
|  |     } | ||||||
|  |     else // small
 | ||||||
|  |     { | ||||||
|  |         blockSize = dim3(256, 1, 1); | ||||||
|  |         gridSize = dim3( | ||||||
|  |             ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, | ||||||
|  |             (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, | ||||||
|  |             p.launchMajor); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Launch CUDA kernel.
 | ||||||
|  |     void* args[] = {&p}; | ||||||
|  |     AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); | ||||||
|  |     return y; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||||||
|  | { | ||||||
|  |     m.def("upfirdn2d", &upfirdn2d); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
							
								
								
									
										350
									
								
								global_torch/torch_utils/ops/upfirdn2d.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										350
									
								
								global_torch/torch_utils/ops/upfirdn2d.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,350 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | // | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | // and proprietary rights in and to this software, related documentation | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | // distribution of this software and related documentation without an express | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | #include <c10/util/Half.h> | ||||||
|  | #include "upfirdn2d.h" | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // Helpers. | ||||||
|  | 
 | ||||||
|  | template <class T> struct InternalType; | ||||||
|  | template <> struct InternalType<double>     { typedef double scalar_t; }; | ||||||
|  | template <> struct InternalType<float>      { typedef float  scalar_t; }; | ||||||
|  | template <> struct InternalType<c10::Half>  { typedef float  scalar_t; }; | ||||||
|  | 
 | ||||||
|  | static __device__ __forceinline__ int floor_div(int a, int b) | ||||||
|  | { | ||||||
|  |     int t = 1 - a / b; | ||||||
|  |     return (a + t * b) / b - t; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // Generic CUDA implementation for large filters. | ||||||
|  | 
 | ||||||
|  | template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) | ||||||
|  | { | ||||||
|  |     typedef typename InternalType<T>::scalar_t scalar_t; | ||||||
|  | 
 | ||||||
|  |     // Calculate thread index. | ||||||
|  |     int minorBase = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |     int outY = minorBase / p.launchMinor; | ||||||
|  |     minorBase -= outY * p.launchMinor; | ||||||
|  |     int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; | ||||||
|  |     int majorBase = blockIdx.z * p.loopMajor; | ||||||
|  |     if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) | ||||||
|  |         return; | ||||||
|  | 
 | ||||||
|  |     // Setup Y receptive field. | ||||||
|  |     int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; | ||||||
|  |     int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); | ||||||
|  |     int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; | ||||||
|  |     int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; | ||||||
|  |     if (p.flip) | ||||||
|  |         filterY = p.filterSize.y - 1 - filterY; | ||||||
|  | 
 | ||||||
|  |     // Loop over major, minor, and X. | ||||||
|  |     for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) | ||||||
|  |     for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) | ||||||
|  |     { | ||||||
|  |         int nc = major * p.sizeMinor + minor; | ||||||
|  |         int n = nc / p.inSize.z; | ||||||
|  |         int c = nc - n * p.inSize.z; | ||||||
|  |         for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) | ||||||
|  |         { | ||||||
|  |             // Setup X receptive field. | ||||||
|  |             int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; | ||||||
|  |             int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); | ||||||
|  |             int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; | ||||||
|  |             int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; | ||||||
|  |             if (p.flip) | ||||||
|  |                 filterX = p.filterSize.x - 1 - filterX; | ||||||
|  | 
 | ||||||
|  |             // Initialize pointers. | ||||||
|  |             const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; | ||||||
|  |             const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; | ||||||
|  |             int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; | ||||||
|  |             int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; | ||||||
|  | 
 | ||||||
|  |             // Inner loop. | ||||||
|  |             scalar_t v = 0; | ||||||
|  |             for (int y = 0; y < h; y++) | ||||||
|  |             { | ||||||
|  |                 for (int x = 0; x < w; x++) | ||||||
|  |                 { | ||||||
|  |                     v += (scalar_t)(*xp) * (scalar_t)(*fp); | ||||||
|  |                     xp += p.inStride.x; | ||||||
|  |                     fp += filterStepX; | ||||||
|  |                 } | ||||||
|  |                 xp += p.inStride.y - w * p.inStride.x; | ||||||
|  |                 fp += filterStepY - w * filterStepX; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Store result. | ||||||
|  |             v *= p.gain; | ||||||
|  |             ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // Specialized CUDA implementation for small filters. | ||||||
|  | 
 | ||||||
|  | template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor> | ||||||
|  | static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) | ||||||
|  | { | ||||||
|  |     typedef typename InternalType<T>::scalar_t scalar_t; | ||||||
|  |     const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; | ||||||
|  |     const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; | ||||||
|  |     __shared__ volatile scalar_t sf[filterH][filterW]; | ||||||
|  |     __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; | ||||||
|  | 
 | ||||||
|  |     // Calculate tile index. | ||||||
|  |     int minorBase = blockIdx.x; | ||||||
|  |     int tileOutY = minorBase / p.launchMinor; | ||||||
|  |     minorBase -= tileOutY * p.launchMinor; | ||||||
|  |     minorBase *= loopMinor; | ||||||
|  |     tileOutY *= tileOutH; | ||||||
|  |     int tileOutXBase = blockIdx.y * p.loopX * tileOutW; | ||||||
|  |     int majorBase = blockIdx.z * p.loopMajor; | ||||||
|  |     if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) | ||||||
|  |         return; | ||||||
|  | 
 | ||||||
|  |     // Load filter (flipped). | ||||||
|  |     for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) | ||||||
|  |     { | ||||||
|  |         int fy = tapIdx / filterW; | ||||||
|  |         int fx = tapIdx - fy * filterW; | ||||||
|  |         scalar_t v = 0; | ||||||
|  |         if (fx < p.filterSize.x & fy < p.filterSize.y) | ||||||
|  |         { | ||||||
|  |             int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; | ||||||
|  |             int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; | ||||||
|  |             v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; | ||||||
|  |         } | ||||||
|  |         sf[fy][fx] = v; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Loop over major and X. | ||||||
|  |     for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) | ||||||
|  |     { | ||||||
|  |         int baseNC = major * p.sizeMinor + minorBase; | ||||||
|  |         int n = baseNC / p.inSize.z; | ||||||
|  |         int baseC = baseNC - n * p.inSize.z; | ||||||
|  |         for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) | ||||||
|  |         { | ||||||
|  |             // Load input pixels. | ||||||
|  |             int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; | ||||||
|  |             int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; | ||||||
|  |             int tileInX = floor_div(tileMidX, upx); | ||||||
|  |             int tileInY = floor_div(tileMidY, upy); | ||||||
|  |             __syncthreads(); | ||||||
|  |             for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) | ||||||
|  |             { | ||||||
|  |                 int relC = inIdx; | ||||||
|  |                 int relInX = relC / loopMinor; | ||||||
|  |                 int relInY = relInX / tileInW; | ||||||
|  |                 relC -= relInX * loopMinor; | ||||||
|  |                 relInX -= relInY * tileInW; | ||||||
|  |                 int c = baseC + relC; | ||||||
|  |                 int inX = tileInX + relInX; | ||||||
|  |                 int inY = tileInY + relInY; | ||||||
|  |                 scalar_t v = 0; | ||||||
|  |                 if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) | ||||||
|  |                     v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; | ||||||
|  |                 sx[relInY][relInX][relC] = v; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Loop over output pixels. | ||||||
|  |             __syncthreads(); | ||||||
|  |             for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) | ||||||
|  |             { | ||||||
|  |                 int relC = outIdx; | ||||||
|  |                 int relOutX = relC / loopMinor; | ||||||
|  |                 int relOutY = relOutX / tileOutW; | ||||||
|  |                 relC -= relOutX * loopMinor; | ||||||
|  |                 relOutX -= relOutY * tileOutW; | ||||||
|  |                 int c = baseC + relC; | ||||||
|  |                 int outX = tileOutX + relOutX; | ||||||
|  |                 int outY = tileOutY + relOutY; | ||||||
|  | 
 | ||||||
|  |                 // Setup receptive field. | ||||||
|  |                 int midX = tileMidX + relOutX * downx; | ||||||
|  |                 int midY = tileMidY + relOutY * downy; | ||||||
|  |                 int inX = floor_div(midX, upx); | ||||||
|  |                 int inY = floor_div(midY, upy); | ||||||
|  |                 int relInX = inX - tileInX; | ||||||
|  |                 int relInY = inY - tileInY; | ||||||
|  |                 int filterX = (inX + 1) * upx - midX - 1; // flipped | ||||||
|  |                 int filterY = (inY + 1) * upy - midY - 1; // flipped | ||||||
|  | 
 | ||||||
|  |                 // Inner loop. | ||||||
|  |                 if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) | ||||||
|  |                 { | ||||||
|  |                     scalar_t v = 0; | ||||||
|  |                     #pragma unroll | ||||||
|  |                     for (int y = 0; y < filterH / upy; y++) | ||||||
|  |                         #pragma unroll | ||||||
|  |                         for (int x = 0; x < filterW / upx; x++) | ||||||
|  |                             v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; | ||||||
|  |                     v *= p.gain; | ||||||
|  |                     ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // CUDA kernel selection. | ||||||
|  | 
 | ||||||
|  | template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) | ||||||
|  | { | ||||||
|  |     int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; | ||||||
|  | 
 | ||||||
|  |     upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous | ||||||
|  |     if (s == 1)           spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last | ||||||
|  | 
 | ||||||
|  |     if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous | ||||||
|  |     { | ||||||
|  |         if (fx <= 7  && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7,  64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6,  64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (fx <= 5  && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5,  64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,  64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (fx <= 3  && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3,  64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1,  128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8,  32,32,1>, 32,32,1, 1}; | ||||||
|  |     } | ||||||
|  |     if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last | ||||||
|  |     { | ||||||
|  |         if (fx <= 7  && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7,  16,16,8>,  16,16,8,  1}; | ||||||
|  |         if (fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,  16,16,8>,  16,16,8,  1}; | ||||||
|  |         if (fx <= 5  && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,  16,16,8>,  16,16,8,  1}; | ||||||
|  |         if (fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,  16,16,8>,  16,16,8,  1}; | ||||||
|  |         if (fx <= 3  && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,  16,16,8>,  16,16,8,  1}; | ||||||
|  |         if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1,  128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8,  1,128,16>, 1,128,16, 1}; | ||||||
|  |     } | ||||||
|  |     if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous | ||||||
|  |     { | ||||||
|  |         if (fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1}; | ||||||
|  |     } | ||||||
|  |     if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last | ||||||
|  |     { | ||||||
|  |         if (fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1}; | ||||||
|  |         if (fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1}; | ||||||
|  |         if (fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1}; | ||||||
|  |         if (fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1}; | ||||||
|  |     } | ||||||
|  |     if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous | ||||||
|  |     { | ||||||
|  |         if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1,  128,8,1>, 128,8,1, 1}; | ||||||
|  |     } | ||||||
|  |     if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last | ||||||
|  |     { | ||||||
|  |         if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1,  128,1,16>, 128,1,16, 1}; | ||||||
|  |     } | ||||||
|  |     if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous | ||||||
|  |     { | ||||||
|  |         if (fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8,  32,32,1>, 32,32,1, 1}; | ||||||
|  |     } | ||||||
|  |     if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last | ||||||
|  |     { | ||||||
|  |         if (fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8,  1,128,16>, 1,128,16, 1}; | ||||||
|  |     } | ||||||
|  |     if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous | ||||||
|  |     { | ||||||
|  |         if (fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8,  32,8,1>, 32,8,1, 1}; | ||||||
|  |         if (fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6,  32,8,1>, 32,8,1, 1}; | ||||||
|  |         if (fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4,  32,8,1>, 32,8,1, 1}; | ||||||
|  |         if (fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2,  32,8,1>, 32,8,1, 1}; | ||||||
|  |     } | ||||||
|  |     if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last | ||||||
|  |     { | ||||||
|  |         if (fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8,  8,8,8>, 8,8,8, 1}; | ||||||
|  |         if (fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6,  8,8,8>, 8,8,8, 1}; | ||||||
|  |         if (fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4,  8,8,8>, 8,8,8, 1}; | ||||||
|  |         if (fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2,  8,8,8>, 8,8,8, 1}; | ||||||
|  |     } | ||||||
|  |     if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous | ||||||
|  |     { | ||||||
|  |         if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1}; | ||||||
|  |         if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1}; | ||||||
|  |         if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1}; | ||||||
|  |         if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1}; | ||||||
|  |         if (fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1,  64,8,1>, 64,8,1, 1}; | ||||||
|  |     } | ||||||
|  |     if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last | ||||||
|  |     { | ||||||
|  |         if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1}; | ||||||
|  |         if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1}; | ||||||
|  |         if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1}; | ||||||
|  |         if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1}; | ||||||
|  |         if (fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1,  64,1,8>, 64,1,8, 1}; | ||||||
|  |     } | ||||||
|  |     if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous | ||||||
|  |     { | ||||||
|  |         if (fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8,  32,16,1>, 32,16,1, 1}; | ||||||
|  |     } | ||||||
|  |     if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last | ||||||
|  |     { | ||||||
|  |         if (fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1}; | ||||||
|  |         if (fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8,  1,64,8>, 1,64,8, 1}; | ||||||
|  |     } | ||||||
|  |     return spec; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // Template specializations. | ||||||
|  | 
 | ||||||
|  | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double>   (const upfirdn2d_kernel_params& p); | ||||||
|  | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float>    (const upfirdn2d_kernel_params& p); | ||||||
|  | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p); | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
							
								
								
									
										59
									
								
								global_torch/torch_utils/ops/upfirdn2d.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								global_torch/torch_utils/ops/upfirdn2d.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,59 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property
 | ||||||
|  | // and proprietary rights in and to this software, related documentation
 | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or
 | ||||||
|  | // distribution of this software and related documentation without an express
 | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited.
 | ||||||
|  | 
 | ||||||
|  | #include <cuda_runtime.h> | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | // CUDA kernel parameters.
 | ||||||
|  | 
 | ||||||
|  | struct upfirdn2d_kernel_params | ||||||
|  | { | ||||||
|  |     const void*     x; | ||||||
|  |     const float*    f; | ||||||
|  |     void*           y; | ||||||
|  | 
 | ||||||
|  |     int2            up; | ||||||
|  |     int2            down; | ||||||
|  |     int2            pad0; | ||||||
|  |     int             flip; | ||||||
|  |     float           gain; | ||||||
|  | 
 | ||||||
|  |     int4            inSize;         // [width, height, channel, batch]
 | ||||||
|  |     int4            inStride; | ||||||
|  |     int2            filterSize;     // [width, height]
 | ||||||
|  |     int2            filterStride; | ||||||
|  |     int4            outSize;        // [width, height, channel, batch]
 | ||||||
|  |     int4            outStride; | ||||||
|  |     int             sizeMinor; | ||||||
|  |     int             sizeMajor; | ||||||
|  | 
 | ||||||
|  |     int             loopMinor; | ||||||
|  |     int             loopMajor; | ||||||
|  |     int             loopX; | ||||||
|  |     int             launchMinor; | ||||||
|  |     int             launchMajor; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | // CUDA kernel specialization.
 | ||||||
|  | 
 | ||||||
|  | struct upfirdn2d_kernel_spec | ||||||
|  | { | ||||||
|  |     void*   kernel; | ||||||
|  |     int     tileOutW; | ||||||
|  |     int     tileOutH; | ||||||
|  |     int     loopMinor; | ||||||
|  |     int     loopX; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | // CUDA kernel selection.
 | ||||||
|  | 
 | ||||||
|  | template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
							
								
								
									
										384
									
								
								global_torch/torch_utils/ops/upfirdn2d.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										384
									
								
								global_torch/torch_utils/ops/upfirdn2d.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,384 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Custom PyTorch ops for efficient resampling of 2D images.""" | ||||||
|  | 
 | ||||||
|  | import os | ||||||
|  | import warnings | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import traceback | ||||||
|  | 
 | ||||||
|  | from .. import custom_ops | ||||||
|  | from .. import misc | ||||||
|  | from . import conv2d_gradfix | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _inited = False | ||||||
|  | _plugin = None | ||||||
|  | 
 | ||||||
|  | def _init(): | ||||||
|  |     global _inited, _plugin | ||||||
|  |     if not _inited: | ||||||
|  |         sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] | ||||||
|  |         sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] | ||||||
|  |         try: | ||||||
|  |             _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) | ||||||
|  |         except: | ||||||
|  |             warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) | ||||||
|  |     return _plugin is not None | ||||||
|  | 
 | ||||||
|  | def _parse_scaling(scaling): | ||||||
|  |     if isinstance(scaling, int): | ||||||
|  |         scaling = [scaling, scaling] | ||||||
|  |     assert isinstance(scaling, (list, tuple)) | ||||||
|  |     assert all(isinstance(x, int) for x in scaling) | ||||||
|  |     sx, sy = scaling | ||||||
|  |     assert sx >= 1 and sy >= 1 | ||||||
|  |     return sx, sy | ||||||
|  | 
 | ||||||
|  | def _parse_padding(padding): | ||||||
|  |     if isinstance(padding, int): | ||||||
|  |         padding = [padding, padding] | ||||||
|  |     assert isinstance(padding, (list, tuple)) | ||||||
|  |     assert all(isinstance(x, int) for x in padding) | ||||||
|  |     if len(padding) == 2: | ||||||
|  |         padx, pady = padding | ||||||
|  |         padding = [padx, padx, pady, pady] | ||||||
|  |     padx0, padx1, pady0, pady1 = padding | ||||||
|  |     return padx0, padx1, pady0, pady1 | ||||||
|  | 
 | ||||||
|  | def _get_filter_size(f): | ||||||
|  |     if f is None: | ||||||
|  |         return 1, 1 | ||||||
|  |     assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] | ||||||
|  |     fw = f.shape[-1] | ||||||
|  |     fh = f.shape[0] | ||||||
|  |     with misc.suppress_tracer_warnings(): | ||||||
|  |         fw = int(fw) | ||||||
|  |         fh = int(fh) | ||||||
|  |     misc.assert_shape(f, [fh, fw][:f.ndim]) | ||||||
|  |     assert fw >= 1 and fh >= 1 | ||||||
|  |     return fw, fh | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): | ||||||
|  |     r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         f:           Torch tensor, numpy array, or python list of the shape | ||||||
|  |                      `[filter_height, filter_width]` (non-separable), | ||||||
|  |                      `[filter_taps]` (separable), | ||||||
|  |                      `[]` (impulse), or | ||||||
|  |                      `None` (identity). | ||||||
|  |         device:      Result device (default: cpu). | ||||||
|  |         normalize:   Normalize the filter so that it retains the magnitude | ||||||
|  |                      for constant input signal (DC)? (default: True). | ||||||
|  |         flip_filter: Flip the filter? (default: False). | ||||||
|  |         gain:        Overall scaling factor for signal magnitude (default: 1). | ||||||
|  |         separable:   Return a separable filter? (default: select automatically). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Float32 tensor of the shape | ||||||
|  |         `[filter_height, filter_width]` (non-separable) or | ||||||
|  |         `[filter_taps]` (separable). | ||||||
|  |     """ | ||||||
|  |     # Validate. | ||||||
|  |     if f is None: | ||||||
|  |         f = 1 | ||||||
|  |     f = torch.as_tensor(f, dtype=torch.float32) | ||||||
|  |     assert f.ndim in [0, 1, 2] | ||||||
|  |     assert f.numel() > 0 | ||||||
|  |     if f.ndim == 0: | ||||||
|  |         f = f[np.newaxis] | ||||||
|  | 
 | ||||||
|  |     # Separable? | ||||||
|  |     if separable is None: | ||||||
|  |         separable = (f.ndim == 1 and f.numel() >= 8) | ||||||
|  |     if f.ndim == 1 and not separable: | ||||||
|  |         f = f.ger(f) | ||||||
|  |     assert f.ndim == (1 if separable else 2) | ||||||
|  | 
 | ||||||
|  |     # Apply normalize, flip, gain, and device. | ||||||
|  |     if normalize: | ||||||
|  |         f /= f.sum() | ||||||
|  |     if flip_filter: | ||||||
|  |         f = f.flip(list(range(f.ndim))) | ||||||
|  |     f = f * (gain ** (f.ndim / 2)) | ||||||
|  |     f = f.to(device=device) | ||||||
|  |     return f | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): | ||||||
|  |     r"""Pad, upsample, filter, and downsample a batch of 2D images. | ||||||
|  | 
 | ||||||
|  |     Performs the following sequence of operations for each channel: | ||||||
|  | 
 | ||||||
|  |     1. Upsample the image by inserting N-1 zeros after each pixel (`up`). | ||||||
|  | 
 | ||||||
|  |     2. Pad the image with the specified number of zeros on each side (`padding`). | ||||||
|  |        Negative padding corresponds to cropping the image. | ||||||
|  | 
 | ||||||
|  |     3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it | ||||||
|  |        so that the footprint of all output pixels lies within the input image. | ||||||
|  | 
 | ||||||
|  |     4. Downsample the image by keeping every Nth pixel (`down`). | ||||||
|  | 
 | ||||||
|  |     This sequence of operations bears close resemblance to scipy.signal.upfirdn(). | ||||||
|  |     The fused op is considerably more efficient than performing the same calculation | ||||||
|  |     using standard PyTorch ops. It supports gradients of arbitrary order. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x:           Float32/float64/float16 input tensor of the shape | ||||||
|  |                      `[batch_size, num_channels, in_height, in_width]`. | ||||||
|  |         f:           Float32 FIR filter of the shape | ||||||
|  |                      `[filter_height, filter_width]` (non-separable), | ||||||
|  |                      `[filter_taps]` (separable), or | ||||||
|  |                      `None` (identity). | ||||||
|  |         up:          Integer upsampling factor. Can be a single int or a list/tuple | ||||||
|  |                      `[x, y]` (default: 1). | ||||||
|  |         down:        Integer downsampling factor. Can be a single int or a list/tuple | ||||||
|  |                      `[x, y]` (default: 1). | ||||||
|  |         padding:     Padding with respect to the upsampled image. Can be a single number | ||||||
|  |                      or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | ||||||
|  |                      (default: 0). | ||||||
|  |         flip_filter: False = convolution, True = correlation (default: False). | ||||||
|  |         gain:        Overall scaling factor for signal magnitude (default: 1). | ||||||
|  |         impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | ||||||
|  |     """ | ||||||
|  |     assert isinstance(x, torch.Tensor) | ||||||
|  |     assert impl in ['ref', 'cuda'] | ||||||
|  |     if impl == 'cuda' and x.device.type == 'cuda' and _init(): | ||||||
|  |         return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) | ||||||
|  |     return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @misc.profiled_function | ||||||
|  | def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): | ||||||
|  |     """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. | ||||||
|  |     """ | ||||||
|  |     # Validate arguments. | ||||||
|  |     assert isinstance(x, torch.Tensor) and x.ndim == 4 | ||||||
|  |     if f is None: | ||||||
|  |         f = torch.ones([1, 1], dtype=torch.float32, device=x.device) | ||||||
|  |     assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] | ||||||
|  |     assert f.dtype == torch.float32 and not f.requires_grad | ||||||
|  |     batch_size, num_channels, in_height, in_width = x.shape | ||||||
|  |     upx, upy = _parse_scaling(up) | ||||||
|  |     downx, downy = _parse_scaling(down) | ||||||
|  |     padx0, padx1, pady0, pady1 = _parse_padding(padding) | ||||||
|  | 
 | ||||||
|  |     # Upsample by inserting zeros. | ||||||
|  |     x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) | ||||||
|  |     x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) | ||||||
|  |     x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) | ||||||
|  | 
 | ||||||
|  |     # Pad or crop. | ||||||
|  |     x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) | ||||||
|  |     x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] | ||||||
|  | 
 | ||||||
|  |     # Setup filter. | ||||||
|  |     f = f * (gain ** (f.ndim / 2)) | ||||||
|  |     f = f.to(x.dtype) | ||||||
|  |     if not flip_filter: | ||||||
|  |         f = f.flip(list(range(f.ndim))) | ||||||
|  | 
 | ||||||
|  |     # Convolve with the filter. | ||||||
|  |     f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) | ||||||
|  |     if f.ndim == 4: | ||||||
|  |         x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) | ||||||
|  |     else: | ||||||
|  |         x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) | ||||||
|  |         x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) | ||||||
|  | 
 | ||||||
|  |     # Downsample by throwing away pixels. | ||||||
|  |     x = x[:, :, ::downy, ::downx] | ||||||
|  |     return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _upfirdn2d_cuda_cache = dict() | ||||||
|  | 
 | ||||||
|  | def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): | ||||||
|  |     """Fast CUDA implementation of `upfirdn2d()` using custom ops. | ||||||
|  |     """ | ||||||
|  |     # Parse arguments. | ||||||
|  |     upx, upy = _parse_scaling(up) | ||||||
|  |     downx, downy = _parse_scaling(down) | ||||||
|  |     padx0, padx1, pady0, pady1 = _parse_padding(padding) | ||||||
|  | 
 | ||||||
|  |     # Lookup from cache. | ||||||
|  |     key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) | ||||||
|  |     if key in _upfirdn2d_cuda_cache: | ||||||
|  |         return _upfirdn2d_cuda_cache[key] | ||||||
|  | 
 | ||||||
|  |     # Forward op. | ||||||
|  |     class Upfirdn2dCuda(torch.autograd.Function): | ||||||
|  |         @staticmethod | ||||||
|  |         def forward(ctx, x, f): # pylint: disable=arguments-differ | ||||||
|  |             assert isinstance(x, torch.Tensor) and x.ndim == 4 | ||||||
|  |             if f is None: | ||||||
|  |                 f = torch.ones([1, 1], dtype=torch.float32, device=x.device) | ||||||
|  |             assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] | ||||||
|  |             y = x | ||||||
|  |             if f.ndim == 2: | ||||||
|  |                 y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) | ||||||
|  |             else: | ||||||
|  |                 y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) | ||||||
|  |                 y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) | ||||||
|  |             ctx.save_for_backward(f) | ||||||
|  |             ctx.x_shape = x.shape | ||||||
|  |             return y | ||||||
|  | 
 | ||||||
|  |         @staticmethod | ||||||
|  |         def backward(ctx, dy): # pylint: disable=arguments-differ | ||||||
|  |             f, = ctx.saved_tensors | ||||||
|  |             _, _, ih, iw = ctx.x_shape | ||||||
|  |             _, _, oh, ow = dy.shape | ||||||
|  |             fw, fh = _get_filter_size(f) | ||||||
|  |             p = [ | ||||||
|  |                 fw - padx0 - 1, | ||||||
|  |                 iw * upx - ow * downx + padx0 - upx + 1, | ||||||
|  |                 fh - pady0 - 1, | ||||||
|  |                 ih * upy - oh * downy + pady0 - upy + 1, | ||||||
|  |             ] | ||||||
|  |             dx = None | ||||||
|  |             df = None | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[0]: | ||||||
|  |                 dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) | ||||||
|  | 
 | ||||||
|  |             assert not ctx.needs_input_grad[1] | ||||||
|  |             return dx, df | ||||||
|  | 
 | ||||||
|  |     # Add to cache. | ||||||
|  |     _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda | ||||||
|  |     return Upfirdn2dCuda | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): | ||||||
|  |     r"""Filter a batch of 2D images using the given 2D FIR filter. | ||||||
|  | 
 | ||||||
|  |     By default, the result is padded so that its shape matches the input. | ||||||
|  |     User-specified padding is applied on top of that, with negative values | ||||||
|  |     indicating cropping. Pixels outside the image are assumed to be zero. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x:           Float32/float64/float16 input tensor of the shape | ||||||
|  |                      `[batch_size, num_channels, in_height, in_width]`. | ||||||
|  |         f:           Float32 FIR filter of the shape | ||||||
|  |                      `[filter_height, filter_width]` (non-separable), | ||||||
|  |                      `[filter_taps]` (separable), or | ||||||
|  |                      `None` (identity). | ||||||
|  |         padding:     Padding with respect to the output. Can be a single number or a | ||||||
|  |                      list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | ||||||
|  |                      (default: 0). | ||||||
|  |         flip_filter: False = convolution, True = correlation (default: False). | ||||||
|  |         gain:        Overall scaling factor for signal magnitude (default: 1). | ||||||
|  |         impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | ||||||
|  |     """ | ||||||
|  |     padx0, padx1, pady0, pady1 = _parse_padding(padding) | ||||||
|  |     fw, fh = _get_filter_size(f) | ||||||
|  |     p = [ | ||||||
|  |         padx0 + fw // 2, | ||||||
|  |         padx1 + (fw - 1) // 2, | ||||||
|  |         pady0 + fh // 2, | ||||||
|  |         pady1 + (fh - 1) // 2, | ||||||
|  |     ] | ||||||
|  |     return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): | ||||||
|  |     r"""Upsample a batch of 2D images using the given 2D FIR filter. | ||||||
|  | 
 | ||||||
|  |     By default, the result is padded so that its shape is a multiple of the input. | ||||||
|  |     User-specified padding is applied on top of that, with negative values | ||||||
|  |     indicating cropping. Pixels outside the image are assumed to be zero. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x:           Float32/float64/float16 input tensor of the shape | ||||||
|  |                      `[batch_size, num_channels, in_height, in_width]`. | ||||||
|  |         f:           Float32 FIR filter of the shape | ||||||
|  |                      `[filter_height, filter_width]` (non-separable), | ||||||
|  |                      `[filter_taps]` (separable), or | ||||||
|  |                      `None` (identity). | ||||||
|  |         up:          Integer upsampling factor. Can be a single int or a list/tuple | ||||||
|  |                      `[x, y]` (default: 1). | ||||||
|  |         padding:     Padding with respect to the output. Can be a single number or a | ||||||
|  |                      list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | ||||||
|  |                      (default: 0). | ||||||
|  |         flip_filter: False = convolution, True = correlation (default: False). | ||||||
|  |         gain:        Overall scaling factor for signal magnitude (default: 1). | ||||||
|  |         impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | ||||||
|  |     """ | ||||||
|  |     upx, upy = _parse_scaling(up) | ||||||
|  |     padx0, padx1, pady0, pady1 = _parse_padding(padding) | ||||||
|  |     fw, fh = _get_filter_size(f) | ||||||
|  |     p = [ | ||||||
|  |         padx0 + (fw + upx - 1) // 2, | ||||||
|  |         padx1 + (fw - upx) // 2, | ||||||
|  |         pady0 + (fh + upy - 1) // 2, | ||||||
|  |         pady1 + (fh - upy) // 2, | ||||||
|  |     ] | ||||||
|  |     return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): | ||||||
|  |     r"""Downsample a batch of 2D images using the given 2D FIR filter. | ||||||
|  | 
 | ||||||
|  |     By default, the result is padded so that its shape is a fraction of the input. | ||||||
|  |     User-specified padding is applied on top of that, with negative values | ||||||
|  |     indicating cropping. Pixels outside the image are assumed to be zero. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x:           Float32/float64/float16 input tensor of the shape | ||||||
|  |                      `[batch_size, num_channels, in_height, in_width]`. | ||||||
|  |         f:           Float32 FIR filter of the shape | ||||||
|  |                      `[filter_height, filter_width]` (non-separable), | ||||||
|  |                      `[filter_taps]` (separable), or | ||||||
|  |                      `None` (identity). | ||||||
|  |         down:        Integer downsampling factor. Can be a single int or a list/tuple | ||||||
|  |                      `[x, y]` (default: 1). | ||||||
|  |         padding:     Padding with respect to the input. Can be a single number or a | ||||||
|  |                      list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | ||||||
|  |                      (default: 0). | ||||||
|  |         flip_filter: False = convolution, True = correlation (default: False). | ||||||
|  |         gain:        Overall scaling factor for signal magnitude (default: 1). | ||||||
|  |         impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | ||||||
|  |     """ | ||||||
|  |     downx, downy = _parse_scaling(down) | ||||||
|  |     padx0, padx1, pady0, pady1 = _parse_padding(padding) | ||||||
|  |     fw, fh = _get_filter_size(f) | ||||||
|  |     p = [ | ||||||
|  |         padx0 + (fw - downx + 1) // 2, | ||||||
|  |         padx1 + (fw - downx) // 2, | ||||||
|  |         pady0 + (fh - downy + 1) // 2, | ||||||
|  |         pady1 + (fh - downy) // 2, | ||||||
|  |     ] | ||||||
|  |     return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										251
									
								
								global_torch/torch_utils/persistence.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										251
									
								
								global_torch/torch_utils/persistence.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,251 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Facilities for pickling Python code alongside other data. | ||||||
|  | 
 | ||||||
|  | The pickled code is automatically imported into a separate Python module | ||||||
|  | during unpickling. This way, any previously exported pickles will remain | ||||||
|  | usable even if the original code is no longer available, or if the current | ||||||
|  | version of the code is not consistent with what was originally pickled.""" | ||||||
|  | 
 | ||||||
|  | import sys | ||||||
|  | import pickle | ||||||
|  | import io | ||||||
|  | import inspect | ||||||
|  | import copy | ||||||
|  | import uuid | ||||||
|  | import types | ||||||
|  | import dnnlib | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _version            = 6         # internal version number | ||||||
|  | _decorators         = set()     # {decorator_class, ...} | ||||||
|  | _import_hooks       = []        # [hook_function, ...] | ||||||
|  | _module_to_src_dict = dict()    # {module: src, ...} | ||||||
|  | _src_to_module_dict = dict()    # {src: module, ...} | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def persistent_class(orig_class): | ||||||
|  |     r"""Class decorator that extends a given class to save its source code | ||||||
|  |     when pickled. | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  | 
 | ||||||
|  |         from torch_utils import persistence | ||||||
|  | 
 | ||||||
|  |         @persistence.persistent_class | ||||||
|  |         class MyNetwork(torch.nn.Module): | ||||||
|  |             def __init__(self, num_inputs, num_outputs): | ||||||
|  |                 super().__init__() | ||||||
|  |                 self.fc = MyLayer(num_inputs, num_outputs) | ||||||
|  |                 ... | ||||||
|  | 
 | ||||||
|  |         @persistence.persistent_class | ||||||
|  |         class MyLayer(torch.nn.Module): | ||||||
|  |             ... | ||||||
|  | 
 | ||||||
|  |     When pickled, any instance of `MyNetwork` and `MyLayer` will save its | ||||||
|  |     source code alongside other internal state (e.g., parameters, buffers, | ||||||
|  |     and submodules). This way, any previously exported pickle will remain | ||||||
|  |     usable even if the class definitions have been modified or are no | ||||||
|  |     longer available. | ||||||
|  | 
 | ||||||
|  |     The decorator saves the source code of the entire Python module | ||||||
|  |     containing the decorated class. It does *not* save the source code of | ||||||
|  |     any imported modules. Thus, the imported modules must be available | ||||||
|  |     during unpickling, also including `torch_utils.persistence` itself. | ||||||
|  | 
 | ||||||
|  |     It is ok to call functions defined in the same module from the | ||||||
|  |     decorated class. However, if the decorated class depends on other | ||||||
|  |     classes defined in the same module, they must be decorated as well. | ||||||
|  |     This is illustrated in the above example in the case of `MyLayer`. | ||||||
|  | 
 | ||||||
|  |     It is also possible to employ the decorator just-in-time before | ||||||
|  |     calling the constructor. For example: | ||||||
|  | 
 | ||||||
|  |         cls = MyLayer | ||||||
|  |         if want_to_make_it_persistent: | ||||||
|  |             cls = persistence.persistent_class(cls) | ||||||
|  |         layer = cls(num_inputs, num_outputs) | ||||||
|  | 
 | ||||||
|  |     As an additional feature, the decorator also keeps track of the | ||||||
|  |     arguments that were used to construct each instance of the decorated | ||||||
|  |     class. The arguments can be queried via `obj.init_args` and | ||||||
|  |     `obj.init_kwargs`, and they are automatically pickled alongside other | ||||||
|  |     object state. A typical use case is to first unpickle a previous | ||||||
|  |     instance of a persistent class, and then upgrade it to use the latest | ||||||
|  |     version of the source code: | ||||||
|  | 
 | ||||||
|  |         with open('old_pickle.pkl', 'rb') as f: | ||||||
|  |             old_net = pickle.load(f) | ||||||
|  |         new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) | ||||||
|  |         misc.copy_params_and_buffers(old_net, new_net, require_all=True) | ||||||
|  |     """ | ||||||
|  |     assert isinstance(orig_class, type) | ||||||
|  |     if is_persistent(orig_class): | ||||||
|  |         return orig_class | ||||||
|  | 
 | ||||||
|  |     assert orig_class.__module__ in sys.modules | ||||||
|  |     orig_module = sys.modules[orig_class.__module__] | ||||||
|  |     orig_module_src = _module_to_src(orig_module) | ||||||
|  | 
 | ||||||
|  |     class Decorator(orig_class): | ||||||
|  |         _orig_module_src = orig_module_src | ||||||
|  |         _orig_class_name = orig_class.__name__ | ||||||
|  | 
 | ||||||
|  |         def __init__(self, *args, **kwargs): | ||||||
|  |             super().__init__(*args, **kwargs) | ||||||
|  |             self._init_args = copy.deepcopy(args) | ||||||
|  |             self._init_kwargs = copy.deepcopy(kwargs) | ||||||
|  |             assert orig_class.__name__ in orig_module.__dict__ | ||||||
|  |             _check_pickleable(self.__reduce__()) | ||||||
|  | 
 | ||||||
|  |         @property | ||||||
|  |         def init_args(self): | ||||||
|  |             return copy.deepcopy(self._init_args) | ||||||
|  | 
 | ||||||
|  |         @property | ||||||
|  |         def init_kwargs(self): | ||||||
|  |             return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) | ||||||
|  | 
 | ||||||
|  |         def __reduce__(self): | ||||||
|  |             fields = list(super().__reduce__()) | ||||||
|  |             fields += [None] * max(3 - len(fields), 0) | ||||||
|  |             if fields[0] is not _reconstruct_persistent_obj: | ||||||
|  |                 meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) | ||||||
|  |                 fields[0] = _reconstruct_persistent_obj # reconstruct func | ||||||
|  |                 fields[1] = (meta,) # reconstruct args | ||||||
|  |                 fields[2] = None # state dict | ||||||
|  |             return tuple(fields) | ||||||
|  | 
 | ||||||
|  |     Decorator.__name__ = orig_class.__name__ | ||||||
|  |     _decorators.add(Decorator) | ||||||
|  |     return Decorator | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def is_persistent(obj): | ||||||
|  |     r"""Test whether the given object or class is persistent, i.e., | ||||||
|  |     whether it will save its source code when pickled. | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         if obj in _decorators: | ||||||
|  |             return True | ||||||
|  |     except TypeError: | ||||||
|  |         pass | ||||||
|  |     return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def import_hook(hook): | ||||||
|  |     r"""Register an import hook that is called whenever a persistent object | ||||||
|  |     is being unpickled. A typical use case is to patch the pickled source | ||||||
|  |     code to avoid errors and inconsistencies when the API of some imported | ||||||
|  |     module has changed. | ||||||
|  | 
 | ||||||
|  |     The hook should have the following signature: | ||||||
|  | 
 | ||||||
|  |         hook(meta) -> modified meta | ||||||
|  | 
 | ||||||
|  |     `meta` is an instance of `dnnlib.EasyDict` with the following fields: | ||||||
|  | 
 | ||||||
|  |         type:       Type of the persistent object, e.g. `'class'`. | ||||||
|  |         version:    Internal version number of `torch_utils.persistence`. | ||||||
|  |         module_src  Original source code of the Python module. | ||||||
|  |         class_name: Class name in the original Python module. | ||||||
|  |         state:      Internal state of the object. | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  | 
 | ||||||
|  |         @persistence.import_hook | ||||||
|  |         def wreck_my_network(meta): | ||||||
|  |             if meta.class_name == 'MyNetwork': | ||||||
|  |                 print('MyNetwork is being imported. I will wreck it!') | ||||||
|  |                 meta.module_src = meta.module_src.replace("True", "False") | ||||||
|  |             return meta | ||||||
|  |     """ | ||||||
|  |     assert callable(hook) | ||||||
|  |     _import_hooks.append(hook) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _reconstruct_persistent_obj(meta): | ||||||
|  |     r"""Hook that is called internally by the `pickle` module to unpickle | ||||||
|  |     a persistent object. | ||||||
|  |     """ | ||||||
|  |     meta = dnnlib.EasyDict(meta) | ||||||
|  |     meta.state = dnnlib.EasyDict(meta.state) | ||||||
|  |     for hook in _import_hooks: | ||||||
|  |         meta = hook(meta) | ||||||
|  |         assert meta is not None | ||||||
|  | 
 | ||||||
|  |     assert meta.version == _version | ||||||
|  |     module = _src_to_module(meta.module_src) | ||||||
|  | 
 | ||||||
|  |     assert meta.type == 'class' | ||||||
|  |     orig_class = module.__dict__[meta.class_name] | ||||||
|  |     decorator_class = persistent_class(orig_class) | ||||||
|  |     obj = decorator_class.__new__(decorator_class) | ||||||
|  | 
 | ||||||
|  |     setstate = getattr(obj, '__setstate__', None) | ||||||
|  |     if callable(setstate): | ||||||
|  |         setstate(meta.state) # pylint: disable=not-callable | ||||||
|  |     else: | ||||||
|  |         obj.__dict__.update(meta.state) | ||||||
|  |     return obj | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _module_to_src(module): | ||||||
|  |     r"""Query the source code of a given Python module. | ||||||
|  |     """ | ||||||
|  |     src = _module_to_src_dict.get(module, None) | ||||||
|  |     if src is None: | ||||||
|  |         src = inspect.getsource(module) | ||||||
|  |         _module_to_src_dict[module] = src | ||||||
|  |         _src_to_module_dict[src] = module | ||||||
|  |     return src | ||||||
|  | 
 | ||||||
|  | def _src_to_module(src): | ||||||
|  |     r"""Get or create a Python module for the given source code. | ||||||
|  |     """ | ||||||
|  |     module = _src_to_module_dict.get(src, None) | ||||||
|  |     if module is None: | ||||||
|  |         module_name = "_imported_module_" + uuid.uuid4().hex | ||||||
|  |         module = types.ModuleType(module_name) | ||||||
|  |         sys.modules[module_name] = module | ||||||
|  |         _module_to_src_dict[module] = src | ||||||
|  |         _src_to_module_dict[src] = module | ||||||
|  |         exec(src, module.__dict__) # pylint: disable=exec-used | ||||||
|  |     return module | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _check_pickleable(obj): | ||||||
|  |     r"""Check that the given object is pickleable, raising an exception if | ||||||
|  |     it is not. This function is expected to be considerably more efficient | ||||||
|  |     than actually pickling the object. | ||||||
|  |     """ | ||||||
|  |     def recurse(obj): | ||||||
|  |         if isinstance(obj, (list, tuple, set)): | ||||||
|  |             return [recurse(x) for x in obj] | ||||||
|  |         if isinstance(obj, dict): | ||||||
|  |             return [[recurse(x), recurse(y)] for x, y in obj.items()] | ||||||
|  |         if isinstance(obj, (str, int, float, bool, bytes, bytearray)): | ||||||
|  |             return None # Python primitive types are pickleable. | ||||||
|  |         if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: | ||||||
|  |             return None # NumPy arrays and PyTorch tensors are pickleable. | ||||||
|  |         if is_persistent(obj): | ||||||
|  |             return None # Persistent objects are pickleable, by virtue of the constructor check. | ||||||
|  |         return obj | ||||||
|  |     with io.BytesIO() as f: | ||||||
|  |         pickle.dump(recurse(obj), f) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										268
									
								
								global_torch/torch_utils/training_stats.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										268
									
								
								global_torch/torch_utils/training_stats.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,268 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Facilities for reporting and collecting training statistics across | ||||||
|  | multiple processes and devices. The interface is designed to minimize | ||||||
|  | synchronization overhead as well as the amount of boilerplate in user | ||||||
|  | code.""" | ||||||
|  | 
 | ||||||
|  | import re | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import dnnlib | ||||||
|  | 
 | ||||||
|  | from . import misc | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _num_moments    = 3             # [num_scalars, sum_of_scalars, sum_of_squares] | ||||||
|  | _reduce_dtype   = torch.float32 # Data type to use for initial per-tensor reduction. | ||||||
|  | _counter_dtype  = torch.float64 # Data type to use for the internal counters. | ||||||
|  | _rank           = 0             # Rank of the current process. | ||||||
|  | _sync_device    = None          # Device to use for multiprocess communication. None = single-process. | ||||||
|  | _sync_called    = False         # Has _sync() been called yet? | ||||||
|  | _counters       = dict()        # Running counters on each device, updated by report(): name => device => torch.Tensor | ||||||
|  | _cumulative     = dict()        # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def init_multiprocessing(rank, sync_device): | ||||||
|  |     r"""Initializes `torch_utils.training_stats` for collecting statistics | ||||||
|  |     across multiple processes. | ||||||
|  | 
 | ||||||
|  |     This function must be called after | ||||||
|  |     `torch.distributed.init_process_group()` and before `Collector.update()`. | ||||||
|  |     The call is not necessary if multi-process collection is not needed. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         rank:           Rank of the current process. | ||||||
|  |         sync_device:    PyTorch device to use for inter-process | ||||||
|  |                         communication, or None to disable multi-process | ||||||
|  |                         collection. Typically `torch.device('cuda', rank)`. | ||||||
|  |     """ | ||||||
|  |     global _rank, _sync_device | ||||||
|  |     assert not _sync_called | ||||||
|  |     _rank = rank | ||||||
|  |     _sync_device = sync_device | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @misc.profiled_function | ||||||
|  | def report(name, value): | ||||||
|  |     r"""Broadcasts the given set of scalars to all interested instances of | ||||||
|  |     `Collector`, across device and process boundaries. | ||||||
|  | 
 | ||||||
|  |     This function is expected to be extremely cheap and can be safely | ||||||
|  |     called from anywhere in the training loop, loss function, or inside a | ||||||
|  |     `torch.nn.Module`. | ||||||
|  | 
 | ||||||
|  |     Warning: The current implementation expects the set of unique names to | ||||||
|  |     be consistent across processes. Please make sure that `report()` is | ||||||
|  |     called at least once for each unique name by each process, and in the | ||||||
|  |     same order. If a given process has no scalars to broadcast, it can do | ||||||
|  |     `report(name, [])` (empty list). | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         name:   Arbitrary string specifying the name of the statistic. | ||||||
|  |                 Averages are accumulated separately for each unique name. | ||||||
|  |         value:  Arbitrary set of scalars. Can be a list, tuple, | ||||||
|  |                 NumPy array, PyTorch tensor, or Python scalar. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         The same `value` that was passed in. | ||||||
|  |     """ | ||||||
|  |     if name not in _counters: | ||||||
|  |         _counters[name] = dict() | ||||||
|  | 
 | ||||||
|  |     elems = torch.as_tensor(value) | ||||||
|  |     if elems.numel() == 0: | ||||||
|  |         return value | ||||||
|  | 
 | ||||||
|  |     elems = elems.detach().flatten().to(_reduce_dtype) | ||||||
|  |     moments = torch.stack([ | ||||||
|  |         torch.ones_like(elems).sum(), | ||||||
|  |         elems.sum(), | ||||||
|  |         elems.square().sum(), | ||||||
|  |     ]) | ||||||
|  |     assert moments.ndim == 1 and moments.shape[0] == _num_moments | ||||||
|  |     moments = moments.to(_counter_dtype) | ||||||
|  | 
 | ||||||
|  |     device = moments.device | ||||||
|  |     if device not in _counters[name]: | ||||||
|  |         _counters[name][device] = torch.zeros_like(moments) | ||||||
|  |     _counters[name][device].add_(moments) | ||||||
|  |     return value | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def report0(name, value): | ||||||
|  |     r"""Broadcasts the given set of scalars by the first process (`rank = 0`), | ||||||
|  |     but ignores any scalars provided by the other processes. | ||||||
|  |     See `report()` for further details. | ||||||
|  |     """ | ||||||
|  |     report(name, value if _rank == 0 else []) | ||||||
|  |     return value | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | class Collector: | ||||||
|  |     r"""Collects the scalars broadcasted by `report()` and `report0()` and | ||||||
|  |     computes their long-term averages (mean and standard deviation) over | ||||||
|  |     user-defined periods of time. | ||||||
|  | 
 | ||||||
|  |     The averages are first collected into internal counters that are not | ||||||
|  |     directly visible to the user. They are then copied to the user-visible | ||||||
|  |     state as a result of calling `update()` and can then be queried using | ||||||
|  |     `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the | ||||||
|  |     internal counters for the next round, so that the user-visible state | ||||||
|  |     effectively reflects averages collected between the last two calls to | ||||||
|  |     `update()`. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         regex:          Regular expression defining which statistics to | ||||||
|  |                         collect. The default is to collect everything. | ||||||
|  |         keep_previous:  Whether to retain the previous averages if no | ||||||
|  |                         scalars were collected on a given round | ||||||
|  |                         (default: True). | ||||||
|  |     """ | ||||||
|  |     def __init__(self, regex='.*', keep_previous=True): | ||||||
|  |         self._regex = re.compile(regex) | ||||||
|  |         self._keep_previous = keep_previous | ||||||
|  |         self._cumulative = dict() | ||||||
|  |         self._moments = dict() | ||||||
|  |         self.update() | ||||||
|  |         self._moments.clear() | ||||||
|  | 
 | ||||||
|  |     def names(self): | ||||||
|  |         r"""Returns the names of all statistics broadcasted so far that | ||||||
|  |         match the regular expression specified at construction time. | ||||||
|  |         """ | ||||||
|  |         return [name for name in _counters if self._regex.fullmatch(name)] | ||||||
|  | 
 | ||||||
|  |     def update(self): | ||||||
|  |         r"""Copies current values of the internal counters to the | ||||||
|  |         user-visible state and resets them for the next round. | ||||||
|  | 
 | ||||||
|  |         If `keep_previous=True` was specified at construction time, the | ||||||
|  |         operation is skipped for statistics that have received no scalars | ||||||
|  |         since the last update, retaining their previous averages. | ||||||
|  | 
 | ||||||
|  |         This method performs a number of GPU-to-CPU transfers and one | ||||||
|  |         `torch.distributed.all_reduce()`. It is intended to be called | ||||||
|  |         periodically in the main training loop, typically once every | ||||||
|  |         N training steps. | ||||||
|  |         """ | ||||||
|  |         if not self._keep_previous: | ||||||
|  |             self._moments.clear() | ||||||
|  |         for name, cumulative in _sync(self.names()): | ||||||
|  |             if name not in self._cumulative: | ||||||
|  |                 self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) | ||||||
|  |             delta = cumulative - self._cumulative[name] | ||||||
|  |             self._cumulative[name].copy_(cumulative) | ||||||
|  |             if float(delta[0]) != 0: | ||||||
|  |                 self._moments[name] = delta | ||||||
|  | 
 | ||||||
|  |     def _get_delta(self, name): | ||||||
|  |         r"""Returns the raw moments that were accumulated for the given | ||||||
|  |         statistic between the last two calls to `update()`, or zero if | ||||||
|  |         no scalars were collected. | ||||||
|  |         """ | ||||||
|  |         assert self._regex.fullmatch(name) | ||||||
|  |         if name not in self._moments: | ||||||
|  |             self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) | ||||||
|  |         return self._moments[name] | ||||||
|  | 
 | ||||||
|  |     def num(self, name): | ||||||
|  |         r"""Returns the number of scalars that were accumulated for the given | ||||||
|  |         statistic between the last two calls to `update()`, or zero if | ||||||
|  |         no scalars were collected. | ||||||
|  |         """ | ||||||
|  |         delta = self._get_delta(name) | ||||||
|  |         return int(delta[0]) | ||||||
|  | 
 | ||||||
|  |     def mean(self, name): | ||||||
|  |         r"""Returns the mean of the scalars that were accumulated for the | ||||||
|  |         given statistic between the last two calls to `update()`, or NaN if | ||||||
|  |         no scalars were collected. | ||||||
|  |         """ | ||||||
|  |         delta = self._get_delta(name) | ||||||
|  |         if int(delta[0]) == 0: | ||||||
|  |             return float('nan') | ||||||
|  |         return float(delta[1] / delta[0]) | ||||||
|  | 
 | ||||||
|  |     def std(self, name): | ||||||
|  |         r"""Returns the standard deviation of the scalars that were | ||||||
|  |         accumulated for the given statistic between the last two calls to | ||||||
|  |         `update()`, or NaN if no scalars were collected. | ||||||
|  |         """ | ||||||
|  |         delta = self._get_delta(name) | ||||||
|  |         if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): | ||||||
|  |             return float('nan') | ||||||
|  |         if int(delta[0]) == 1: | ||||||
|  |             return float(0) | ||||||
|  |         mean = float(delta[1] / delta[0]) | ||||||
|  |         raw_var = float(delta[2] / delta[0]) | ||||||
|  |         return np.sqrt(max(raw_var - np.square(mean), 0)) | ||||||
|  | 
 | ||||||
|  |     def as_dict(self): | ||||||
|  |         r"""Returns the averages accumulated between the last two calls to | ||||||
|  |         `update()` as an `dnnlib.EasyDict`. The contents are as follows: | ||||||
|  | 
 | ||||||
|  |             dnnlib.EasyDict( | ||||||
|  |                 NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), | ||||||
|  |                 ... | ||||||
|  |             ) | ||||||
|  |         """ | ||||||
|  |         stats = dnnlib.EasyDict() | ||||||
|  |         for name in self.names(): | ||||||
|  |             stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) | ||||||
|  |         return stats | ||||||
|  | 
 | ||||||
|  |     def __getitem__(self, name): | ||||||
|  |         r"""Convenience getter. | ||||||
|  |         `collector[name]` is a synonym for `collector.mean(name)`. | ||||||
|  |         """ | ||||||
|  |         return self.mean(name) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _sync(names): | ||||||
|  |     r"""Synchronize the global cumulative counters across devices and | ||||||
|  |     processes. Called internally by `Collector.update()`. | ||||||
|  |     """ | ||||||
|  |     if len(names) == 0: | ||||||
|  |         return [] | ||||||
|  |     global _sync_called | ||||||
|  |     _sync_called = True | ||||||
|  | 
 | ||||||
|  |     # Collect deltas within current rank. | ||||||
|  |     deltas = [] | ||||||
|  |     device = _sync_device if _sync_device is not None else torch.device('cpu') | ||||||
|  |     for name in names: | ||||||
|  |         delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) | ||||||
|  |         for counter in _counters[name].values(): | ||||||
|  |             delta.add_(counter.to(device)) | ||||||
|  |             counter.copy_(torch.zeros_like(counter)) | ||||||
|  |         deltas.append(delta) | ||||||
|  |     deltas = torch.stack(deltas) | ||||||
|  | 
 | ||||||
|  |     # Sum deltas across ranks. | ||||||
|  |     if _sync_device is not None: | ||||||
|  |         torch.distributed.all_reduce(deltas) | ||||||
|  | 
 | ||||||
|  |     # Update cumulative values. | ||||||
|  |     deltas = deltas.cpu() | ||||||
|  |     for idx, name in enumerate(names): | ||||||
|  |         if name not in _cumulative: | ||||||
|  |             _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) | ||||||
|  |         _cumulative[name].add_(deltas[idx]) | ||||||
|  | 
 | ||||||
|  |     # Return name-value pairs. | ||||||
|  |     return [(name, _cumulative[name]) for name in names] | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										9
									
								
								global_torch/training/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								global_torch/training/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | # empty | ||||||
							
								
								
									
										809
									
								
								global_torch/training/networks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										809
									
								
								global_torch/training/networks.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,809 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | from torch_utils import misc | ||||||
|  | from torch_utils import persistence | ||||||
|  | from torch_utils.ops import conv2d_resample | ||||||
|  | from torch_utils.ops import upfirdn2d | ||||||
|  | from torch_utils.ops import bias_act | ||||||
|  | from torch_utils.ops import fma | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @misc.profiled_function | ||||||
|  | def normalize_2nd_moment(x, dim=1, eps=1e-8): | ||||||
|  |     return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @misc.profiled_function | ||||||
|  | def modulated_conv2d( | ||||||
|  |     x,                          # Input tensor of shape [batch_size, in_channels, in_height, in_width]. | ||||||
|  |     weight,                     # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. | ||||||
|  |     styles,                     # Modulation coefficients of shape [batch_size, in_channels]. | ||||||
|  |     noise           = None,     # Optional noise tensor to add to the output activations. | ||||||
|  |     up              = 1,        # Integer upsampling factor. | ||||||
|  |     down            = 1,        # Integer downsampling factor. | ||||||
|  |     padding         = 0,        # Padding with respect to the upsampled image. | ||||||
|  |     resample_filter = None,     # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). | ||||||
|  |     demodulate      = True,     # Apply weight demodulation? | ||||||
|  |     flip_weight     = True,     # False = convolution, True = correlation (matches torch.nn.functional.conv2d). | ||||||
|  |     fused_modconv   = True,     # Perform modulation, convolution, and demodulation as a single fused operation? | ||||||
|  | ): | ||||||
|  |     batch_size = x.shape[0] | ||||||
|  |     out_channels, in_channels, kh, kw = weight.shape | ||||||
|  |     misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] | ||||||
|  |     misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] | ||||||
|  |     misc.assert_shape(styles, [batch_size, in_channels]) # [NI] | ||||||
|  | 
 | ||||||
|  |     # Pre-normalize inputs to avoid FP16 overflow. | ||||||
|  |     if x.dtype == torch.float16 and demodulate: | ||||||
|  |         weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk | ||||||
|  |         styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I | ||||||
|  | 
 | ||||||
|  |     # Calculate per-sample weights and demodulation coefficients. | ||||||
|  |     w = None | ||||||
|  |     dcoefs = None | ||||||
|  |     if demodulate or fused_modconv: | ||||||
|  |         w = weight.unsqueeze(0) # [NOIkk] | ||||||
|  |         w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] | ||||||
|  |     if demodulate: | ||||||
|  |         dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] | ||||||
|  |     if demodulate and fused_modconv: | ||||||
|  |         w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] | ||||||
|  | 
 | ||||||
|  |     # Execute by scaling the activations before and after the convolution. | ||||||
|  |     if not fused_modconv: | ||||||
|  |         x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) | ||||||
|  |         x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) | ||||||
|  |         if demodulate and noise is not None: | ||||||
|  |             x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) | ||||||
|  |         elif demodulate: | ||||||
|  |             x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) | ||||||
|  |         elif noise is not None: | ||||||
|  |             x = x.add_(noise.to(x.dtype)) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     # Execute as one fused op using grouped convolution. | ||||||
|  |     with misc.suppress_tracer_warnings(): # this value will be treated as a constant | ||||||
|  |         batch_size = int(batch_size) | ||||||
|  |     misc.assert_shape(x, [batch_size, in_channels, None, None]) | ||||||
|  |     x = x.reshape(1, -1, *x.shape[2:]) | ||||||
|  |     w = w.reshape(-1, in_channels, kh, kw) | ||||||
|  |     x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) | ||||||
|  |     x = x.reshape(batch_size, -1, *x.shape[2:]) | ||||||
|  |     if noise is not None: | ||||||
|  |         x = x.add_(noise) | ||||||
|  |     return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class FullyConnectedLayer(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         in_features,                # Number of input features. | ||||||
|  |         out_features,               # Number of output features. | ||||||
|  |         bias            = True,     # Apply additive bias before the activation function? | ||||||
|  |         activation      = 'linear', # Activation function: 'relu', 'lrelu', etc. | ||||||
|  |         lr_multiplier   = 1,        # Learning rate multiplier. | ||||||
|  |         bias_init       = 0,        # Initial value for the additive bias. | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.activation = activation | ||||||
|  |         self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) | ||||||
|  |         self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None | ||||||
|  |         self.weight_gain = lr_multiplier / np.sqrt(in_features) | ||||||
|  |         self.bias_gain = lr_multiplier | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         w = self.weight.to(x.dtype) * self.weight_gain | ||||||
|  |         b = self.bias | ||||||
|  |         if b is not None: | ||||||
|  |             b = b.to(x.dtype) | ||||||
|  |             if self.bias_gain != 1: | ||||||
|  |                 b = b * self.bias_gain | ||||||
|  | 
 | ||||||
|  |         if self.activation == 'linear' and b is not None: | ||||||
|  |             x = torch.addmm(b.unsqueeze(0), x, w.t()) | ||||||
|  |         else: | ||||||
|  |             x = x.matmul(w.t()) | ||||||
|  |             x = bias_act.bias_act(x, b, act=self.activation) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class Conv2dLayer(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         in_channels,                    # Number of input channels. | ||||||
|  |         out_channels,                   # Number of output channels. | ||||||
|  |         kernel_size,                    # Width and height of the convolution kernel. | ||||||
|  |         bias            = True,         # Apply additive bias before the activation function? | ||||||
|  |         activation      = 'linear',     # Activation function: 'relu', 'lrelu', etc. | ||||||
|  |         up              = 1,            # Integer upsampling factor. | ||||||
|  |         down            = 1,            # Integer downsampling factor. | ||||||
|  |         resample_filter = [1,3,3,1],    # Low-pass filter to apply when resampling activations. | ||||||
|  |         conv_clamp      = None,         # Clamp the output to +-X, None = disable clamping. | ||||||
|  |         channels_last   = False,        # Expect the input to have memory_format=channels_last? | ||||||
|  |         trainable       = True,         # Update the weights of this layer during training? | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.activation = activation | ||||||
|  |         self.up = up | ||||||
|  |         self.down = down | ||||||
|  |         self.conv_clamp = conv_clamp | ||||||
|  |         self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) | ||||||
|  |         self.padding = kernel_size // 2 | ||||||
|  |         self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) | ||||||
|  |         self.act_gain = bias_act.activation_funcs[activation].def_gain | ||||||
|  | 
 | ||||||
|  |         memory_format = torch.channels_last if channels_last else torch.contiguous_format | ||||||
|  |         weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) | ||||||
|  |         bias = torch.zeros([out_channels]) if bias else None | ||||||
|  |         if trainable: | ||||||
|  |             self.weight = torch.nn.Parameter(weight) | ||||||
|  |             self.bias = torch.nn.Parameter(bias) if bias is not None else None | ||||||
|  |         else: | ||||||
|  |             self.register_buffer('weight', weight) | ||||||
|  |             if bias is not None: | ||||||
|  |                 self.register_buffer('bias', bias) | ||||||
|  |             else: | ||||||
|  |                 self.bias = None | ||||||
|  | 
 | ||||||
|  |     def forward(self, x, gain=1): | ||||||
|  |         w = self.weight * self.weight_gain | ||||||
|  |         b = self.bias.to(x.dtype) if self.bias is not None else None | ||||||
|  |         flip_weight = (self.up == 1) # slightly faster | ||||||
|  |         x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight) | ||||||
|  | 
 | ||||||
|  |         act_gain = self.act_gain * gain | ||||||
|  |         act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None | ||||||
|  |         x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class MappingNetwork(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         z_dim,                      # Input latent (Z) dimensionality, 0 = no latent. | ||||||
|  |         c_dim,                      # Conditioning label (C) dimensionality, 0 = no label. | ||||||
|  |         w_dim,                      # Intermediate latent (W) dimensionality. | ||||||
|  |         num_ws,                     # Number of intermediate latents to output, None = do not broadcast. | ||||||
|  |         num_layers      = 8,        # Number of mapping layers. | ||||||
|  |         embed_features  = None,     # Label embedding dimensionality, None = same as w_dim. | ||||||
|  |         layer_features  = None,     # Number of intermediate features in the mapping layers, None = same as w_dim. | ||||||
|  |         activation      = 'lrelu',  # Activation function: 'relu', 'lrelu', etc. | ||||||
|  |         lr_multiplier   = 0.01,     # Learning rate multiplier for the mapping layers. | ||||||
|  |         w_avg_beta      = 0.995,    # Decay for tracking the moving average of W during training, None = do not track. | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.z_dim = z_dim | ||||||
|  |         self.c_dim = c_dim | ||||||
|  |         self.w_dim = w_dim | ||||||
|  |         self.num_ws = num_ws | ||||||
|  |         self.num_layers = num_layers | ||||||
|  |         self.w_avg_beta = w_avg_beta | ||||||
|  | 
 | ||||||
|  |         if embed_features is None: | ||||||
|  |             embed_features = w_dim | ||||||
|  |         if c_dim == 0: | ||||||
|  |             embed_features = 0 | ||||||
|  |         if layer_features is None: | ||||||
|  |             layer_features = w_dim | ||||||
|  |         features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] | ||||||
|  | 
 | ||||||
|  |         if c_dim > 0: | ||||||
|  |             self.embed = FullyConnectedLayer(c_dim, embed_features) | ||||||
|  |         for idx in range(num_layers): | ||||||
|  |             in_features = features_list[idx] | ||||||
|  |             out_features = features_list[idx + 1] | ||||||
|  |             layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) | ||||||
|  |             setattr(self, f'fc{idx}', layer) | ||||||
|  | 
 | ||||||
|  |         if num_ws is not None and w_avg_beta is not None: | ||||||
|  |             self.register_buffer('w_avg', torch.zeros([w_dim])) | ||||||
|  | 
 | ||||||
|  |     def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): | ||||||
|  |         # Embed, normalize, and concat inputs. | ||||||
|  |         x = None | ||||||
|  |         with torch.autograd.profiler.record_function('input'): | ||||||
|  |             if self.z_dim > 0: | ||||||
|  |                 misc.assert_shape(z, [None, self.z_dim]) | ||||||
|  |                 x = normalize_2nd_moment(z.to(torch.float32)) | ||||||
|  |             if self.c_dim > 0: | ||||||
|  |                 misc.assert_shape(c, [None, self.c_dim]) | ||||||
|  |                 y = normalize_2nd_moment(self.embed(c.to(torch.float32))) | ||||||
|  |                 x = torch.cat([x, y], dim=1) if x is not None else y | ||||||
|  | 
 | ||||||
|  |         # Main layers. | ||||||
|  |         for idx in range(self.num_layers): | ||||||
|  |             layer = getattr(self, f'fc{idx}') | ||||||
|  |             x = layer(x) | ||||||
|  | 
 | ||||||
|  |         # Update moving average of W. | ||||||
|  |         if self.w_avg_beta is not None and self.training and not skip_w_avg_update: | ||||||
|  |             with torch.autograd.profiler.record_function('update_w_avg'): | ||||||
|  |                 self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) | ||||||
|  | 
 | ||||||
|  |         # Broadcast. | ||||||
|  |         if self.num_ws is not None: | ||||||
|  |             with torch.autograd.profiler.record_function('broadcast'): | ||||||
|  |                 x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) | ||||||
|  | 
 | ||||||
|  |         # Apply truncation. | ||||||
|  |         if truncation_psi != 1: | ||||||
|  |             with torch.autograd.profiler.record_function('truncate'): | ||||||
|  |                 assert self.w_avg_beta is not None | ||||||
|  |                 if self.num_ws is None or truncation_cutoff is None: | ||||||
|  |                     x = self.w_avg.lerp(x, truncation_psi) | ||||||
|  |                 else: | ||||||
|  |                     x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class SynthesisLayer(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         in_channels,                    # Number of input channels. | ||||||
|  |         out_channels,                   # Number of output channels. | ||||||
|  |         w_dim,                          # Intermediate latent (W) dimensionality. | ||||||
|  |         resolution,                     # Resolution of this layer. | ||||||
|  |         kernel_size     = 3,            # Convolution kernel size. | ||||||
|  |         up              = 1,            # Integer upsampling factor. | ||||||
|  |         use_noise       = True,         # Enable noise input? | ||||||
|  |         activation      = 'lrelu',      # Activation function: 'relu', 'lrelu', etc. | ||||||
|  |         resample_filter = [1,3,3,1],    # Low-pass filter to apply when resampling activations. | ||||||
|  |         conv_clamp      = None,         # Clamp the output of convolution layers to +-X, None = disable clamping. | ||||||
|  |         channels_last   = False,        # Use channels_last format for the weights? | ||||||
|  |         name = '' | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.resolution = resolution | ||||||
|  |         self.up = up | ||||||
|  |         self.use_noise = use_noise | ||||||
|  |         self.activation = activation | ||||||
|  |         self.conv_clamp = conv_clamp | ||||||
|  |         self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) | ||||||
|  |         self.padding = kernel_size // 2 | ||||||
|  |         self.act_gain = bias_act.activation_funcs[activation].def_gain | ||||||
|  |         self.name = name | ||||||
|  |         self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) | ||||||
|  |         memory_format = torch.channels_last if channels_last else torch.contiguous_format | ||||||
|  |         self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) | ||||||
|  |         if use_noise: | ||||||
|  |             self.register_buffer('noise_const', torch.randn([resolution, resolution])) | ||||||
|  |             self.noise_strength = torch.nn.Parameter(torch.zeros([])) | ||||||
|  |         self.bias = torch.nn.Parameter(torch.zeros([out_channels])) | ||||||
|  |         print(f"name:{name} Resolution: {resolution}, InC: {in_channels}, OutC:{out_channels}, w_dim: {w_dim}") | ||||||
|  | 
 | ||||||
|  |     def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1, encoded_styles=None): | ||||||
|  |         assert noise_mode in ['random', 'const', 'none'] | ||||||
|  |         in_resolution = self.resolution // self.up | ||||||
|  |         # misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution]) # not need to be squre  | ||||||
|  |         if encoded_styles is None: | ||||||
|  |             styles = self.affine(w) | ||||||
|  |         else: | ||||||
|  |             styles = encoded_styles[self.name] | ||||||
|  | 
 | ||||||
|  |         noise = None | ||||||
|  |         if self.use_noise and noise_mode == 'random': | ||||||
|  |             noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength | ||||||
|  |         if self.use_noise and noise_mode == 'const': | ||||||
|  |             noise = self.noise_const * self.noise_strength | ||||||
|  | 
 | ||||||
|  |         flip_weight = (self.up == 1) # slightly faster | ||||||
|  |         x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, | ||||||
|  |             padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv) | ||||||
|  | 
 | ||||||
|  |         act_gain = self.act_gain * gain | ||||||
|  |         act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None | ||||||
|  |         x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class ToRGBLayer(torch.nn.Module): | ||||||
|  |     def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False, name=''): | ||||||
|  |         super().__init__() | ||||||
|  |         self.conv_clamp = conv_clamp | ||||||
|  |         self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) | ||||||
|  |         memory_format = torch.channels_last if channels_last else torch.contiguous_format | ||||||
|  |         self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) | ||||||
|  |         self.bias = torch.nn.Parameter(torch.zeros([out_channels])) | ||||||
|  |         self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) | ||||||
|  |         self.name = name | ||||||
|  |         print(f"name:{name} InC: {in_channels}, OutC:{out_channels}, w_dim: {w_dim}") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     def forward(self, x, w, fused_modconv=True, encoded_styles=None): | ||||||
|  |         if encoded_styles is None: | ||||||
|  |             styles = self.affine(w) #* self.weight_gain | ||||||
|  | 
 | ||||||
|  |         else: | ||||||
|  |             styles = encoded_styles[self.name] | ||||||
|  |         tmp_s=styles* self.weight_gain   | ||||||
|  |          | ||||||
|  |         x = modulated_conv2d(x=x, weight=self.weight, styles=tmp_s, demodulate=False, fused_modconv=fused_modconv) | ||||||
|  |         x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class SynthesisBlock(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         in_channels,                        # Number of input channels, 0 = first block. | ||||||
|  |         out_channels,                       # Number of output channels. | ||||||
|  |         w_dim,                              # Intermediate latent (W) dimensionality. | ||||||
|  |         resolution,                         # Resolution of this block. | ||||||
|  |         img_channels,                       # Number of output color channels. | ||||||
|  |         is_last,                            # Is this the last block? | ||||||
|  |         architecture        = 'skip',       # Architecture: 'orig', 'skip', 'resnet'. | ||||||
|  |         resample_filter     = [1,3,3,1],    # Low-pass filter to apply when resampling activations. | ||||||
|  |         conv_clamp          = None,         # Clamp the output of convolution layers to +-X, None = disable clamping. | ||||||
|  |         use_fp16            = False,        # Use FP16 for this block? | ||||||
|  |         fp16_channels_last  = False,        # Use channels-last memory format with FP16? | ||||||
|  |         **layer_kwargs,                     # Arguments for SynthesisLayer. | ||||||
|  |     ): | ||||||
|  |         assert architecture in ['orig', 'skip', 'resnet'] | ||||||
|  |         super().__init__() | ||||||
|  |         self.in_channels = in_channels | ||||||
|  |         self.w_dim = w_dim | ||||||
|  |         self.resolution = resolution | ||||||
|  |         self.img_channels = img_channels | ||||||
|  |         self.is_last = is_last | ||||||
|  |         self.architecture = architecture | ||||||
|  |         self.use_fp16 = use_fp16 | ||||||
|  |         self.channels_last = (use_fp16 and fp16_channels_last) | ||||||
|  |         self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) | ||||||
|  |         self.num_conv = 0 | ||||||
|  |         self.num_torgb = 0 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         if in_channels == 0: | ||||||
|  |             self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) | ||||||
|  | 
 | ||||||
|  |         if in_channels != 0: | ||||||
|  |             self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2, | ||||||
|  |                 resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, name=f'conv0_resolution_{resolution}', **layer_kwargs) | ||||||
|  |             self.num_conv += 1 | ||||||
|  | 
 | ||||||
|  |         self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, | ||||||
|  |             conv_clamp=conv_clamp, channels_last=self.channels_last, name=f'conv1_resolution_{resolution}', **layer_kwargs) | ||||||
|  |         self.num_conv += 1 | ||||||
|  | 
 | ||||||
|  |         if is_last or architecture == 'skip': | ||||||
|  |             self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, | ||||||
|  |                 conv_clamp=conv_clamp, channels_last=self.channels_last, name=f'toRGB_resolution_{resolution}') | ||||||
|  |             self.num_torgb += 1 | ||||||
|  | 
 | ||||||
|  |         if in_channels != 0 and architecture == 'resnet': | ||||||
|  |             self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, | ||||||
|  |                 resample_filter=resample_filter, channels_last=self.channels_last) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, encoded_styles=None, **layer_kwargs): | ||||||
|  | 
 | ||||||
|  |         class NoneIter: | ||||||
|  |             def __init__(self): | ||||||
|  |                 pass | ||||||
|  |             def __iter__(self): | ||||||
|  |                 return self | ||||||
|  |             def __next__(self): | ||||||
|  |                 return None | ||||||
|  | 
 | ||||||
|  |         if encoded_styles is None: | ||||||
|  |             misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) | ||||||
|  |             w_iter = iter(ws.unbind(dim=1)) | ||||||
|  |         else: | ||||||
|  |             w_iter = iter(NoneIter()) | ||||||
|  | 
 | ||||||
|  |         dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 | ||||||
|  |         memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format | ||||||
|  |         if fused_modconv is None: | ||||||
|  |             with misc.suppress_tracer_warnings(): # this value will be treated as a constant | ||||||
|  |                 fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) | ||||||
|  | 
 | ||||||
|  |         # Input. | ||||||
|  |         if self.in_channels == 0: | ||||||
|  |             x = self.const.to(dtype=dtype, memory_format=memory_format) | ||||||
|  |             if encoded_styles is None: | ||||||
|  |                 x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) | ||||||
|  |             else: | ||||||
|  |                 x = x.unsqueeze(0).repeat([encoded_styles['conv1_resolution_4'].shape[0], 1, 1, 1]) | ||||||
|  |         else: | ||||||
|  |             # misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) # not need to be squre  | ||||||
|  |             x = x.to(dtype=dtype, memory_format=memory_format) | ||||||
|  | 
 | ||||||
|  |         # Main layers. | ||||||
|  |         if self.in_channels == 0: | ||||||
|  |             x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, encoded_styles=encoded_styles, **layer_kwargs) | ||||||
|  |         elif self.architecture == 'resnet': | ||||||
|  |             y = self.skip(x, gain=np.sqrt(0.5)) | ||||||
|  |             x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, encoded_styles=encoded_styles,  **layer_kwargs) | ||||||
|  |             x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, encoded_styles=encoded_styles,  gain=np.sqrt(0.5), **layer_kwargs) | ||||||
|  |             x = y.add_(x) | ||||||
|  |         else: | ||||||
|  |             x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, encoded_styles=encoded_styles,  **layer_kwargs) | ||||||
|  |             x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, encoded_styles=encoded_styles,  **layer_kwargs) | ||||||
|  | 
 | ||||||
|  |         # ToRGB. | ||||||
|  |         if img is not None: | ||||||
|  |             # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) ## not need to be squre  | ||||||
|  |             img = upfirdn2d.upsample2d(img, self.resample_filter) | ||||||
|  |         if self.is_last or self.architecture == 'skip': | ||||||
|  |             y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv, encoded_styles=encoded_styles,  ) | ||||||
|  |             y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) | ||||||
|  |             img = img.add_(y) if img is not None else y | ||||||
|  | 
 | ||||||
|  |         assert x.dtype == dtype | ||||||
|  |         assert img is None or img.dtype == torch.float32 | ||||||
|  |         return x, img | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class SynthesisNetwork(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         w_dim,                      # Intermediate latent (W) dimensionality. | ||||||
|  |         img_resolution,             # Output image resolution. | ||||||
|  |         img_channels,               # Number of color channels. | ||||||
|  |         channel_base    = 32768,    # Overall multiplier for the number of channels. | ||||||
|  |         channel_max     = 512,      # Maximum number of channels in any layer. | ||||||
|  |         num_fp16_res    = 0,        # Use FP16 for the N highest resolutions. | ||||||
|  |         **block_kwargs,             # Arguments for SynthesisBlock. | ||||||
|  |     ): | ||||||
|  |         assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 | ||||||
|  |         super().__init__() | ||||||
|  |         self.w_dim = w_dim | ||||||
|  |         self.img_resolution = img_resolution | ||||||
|  |         self.img_resolution_log2 = int(np.log2(img_resolution)) | ||||||
|  |         self.img_channels = img_channels | ||||||
|  |         self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)] | ||||||
|  |         channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} | ||||||
|  |         fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | ||||||
|  | 
 | ||||||
|  |         self.num_ws = 0 | ||||||
|  |         for res in self.block_resolutions: | ||||||
|  |             in_channels = channels_dict[res // 2] if res > 4 else 0 | ||||||
|  |             out_channels = channels_dict[res] | ||||||
|  |             use_fp16 = (res >= fp16_resolution) | ||||||
|  |             is_last = (res == self.img_resolution) | ||||||
|  |             block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, | ||||||
|  |                 img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) | ||||||
|  |             self.num_ws += block.num_conv | ||||||
|  |             if is_last: | ||||||
|  |                 self.num_ws += block.num_torgb | ||||||
|  |             setattr(self, f'b{res}', block) | ||||||
|  | 
 | ||||||
|  |     def forward(self, ws, encoded_styles=None, **block_kwargs): | ||||||
|  |         if encoded_styles is None: | ||||||
|  |             block_ws = [] | ||||||
|  |             with torch.autograd.profiler.record_function('split_ws'): | ||||||
|  |                 misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) | ||||||
|  |                 ws = ws.to(torch.float32) | ||||||
|  |                 w_idx = 0 | ||||||
|  |                 for res in self.block_resolutions: | ||||||
|  |                     block = getattr(self, f'b{res}') | ||||||
|  |                     block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) | ||||||
|  |                     w_idx += block.num_conv | ||||||
|  | 
 | ||||||
|  |             x = img = None | ||||||
|  |             for res, cur_ws in zip(self.block_resolutions, block_ws): | ||||||
|  |                 block = getattr(self, f'b{res}') | ||||||
|  |                 x, img = block(x, img, cur_ws, encoded_styles=encoded_styles, **block_kwargs) | ||||||
|  |         else: | ||||||
|  |             x = img = None | ||||||
|  |             for res in self.block_resolutions: | ||||||
|  |                 block = getattr(self, f'b{res}') | ||||||
|  |                 x, img = block(x, img, None, encoded_styles=encoded_styles, **block_kwargs) | ||||||
|  |         return img | ||||||
|  |      | ||||||
|  |     def W2S(self,ws): | ||||||
|  |          | ||||||
|  |         i=0 | ||||||
|  |         encoded_styles={} | ||||||
|  |         for res in self.block_resolutions: | ||||||
|  |             block = getattr(self, f'b{res}') | ||||||
|  |             if res==4:  | ||||||
|  |                 s=block.conv1.affine(ws[:,i]) | ||||||
|  |                 encoded_styles[f'conv1_resolution_{res}'] =s | ||||||
|  |                 i+=1 | ||||||
|  |                 s=block.torgb.affine(ws[:,i]) #* block.torgb.weight_gain | ||||||
|  |                 encoded_styles[f'toRGB_resolution_{res}'] =s | ||||||
|  | #                i+=1 | ||||||
|  |             else: | ||||||
|  | #                print(res,i) | ||||||
|  |                 s=block.conv0.affine(ws[:,i]) | ||||||
|  |                 encoded_styles[f'conv0_resolution_{res}'] =s | ||||||
|  |                 i+=1 | ||||||
|  | #                print(res,i) | ||||||
|  |                 s=block.conv1.affine(ws[:,i]) | ||||||
|  |                 encoded_styles[f'conv1_resolution_{res}'] =s | ||||||
|  |                 i+=1 | ||||||
|  |                 # toRGB and next layer conv0 use the same w | ||||||
|  |                 s=block.torgb.affine(ws[:,i])#* block.torgb.weight_gain | ||||||
|  |                 encoded_styles[f'toRGB_resolution_{res}'] =s | ||||||
|  | #                i+=1 | ||||||
|  | #        print(i) | ||||||
|  |          | ||||||
|  | 
 | ||||||
|  |              | ||||||
|  |          | ||||||
|  |         return encoded_styles | ||||||
|  |      | ||||||
|  | 
 | ||||||
|  |      | ||||||
|  |          | ||||||
|  |      | ||||||
|  |      | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class Generator(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         z_dim,                      # Input latent (Z) dimensionality. | ||||||
|  |         c_dim,                      # Conditioning label (C) dimensionality. | ||||||
|  |         w_dim,                      # Intermediate latent (W) dimensionality. | ||||||
|  |         img_resolution,             # Output resolution. | ||||||
|  |         img_channels,               # Number of output color channels. | ||||||
|  |         mapping_kwargs      = {},   # Arguments for MappingNetwork. | ||||||
|  |         synthesis_kwargs    = {},   # Arguments for SynthesisNetwork. | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.z_dim = z_dim | ||||||
|  |         self.c_dim = c_dim | ||||||
|  |         self.w_dim = w_dim | ||||||
|  |         self.img_resolution = img_resolution | ||||||
|  |         self.img_channels = img_channels | ||||||
|  |         self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) | ||||||
|  |         self.num_ws = self.synthesis.num_ws | ||||||
|  |         self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) | ||||||
|  | 
 | ||||||
|  |     def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, encoded_styles=None, **synthesis_kwargs): | ||||||
|  |         if encoded_styles is None: | ||||||
|  |             ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) | ||||||
|  |         else: | ||||||
|  |             ws = None | ||||||
|  |         img = self.synthesis(ws, encoded_styles=encoded_styles, **synthesis_kwargs) | ||||||
|  |         return img | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class DiscriminatorBlock(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         in_channels,                        # Number of input channels, 0 = first block. | ||||||
|  |         tmp_channels,                       # Number of intermediate channels. | ||||||
|  |         out_channels,                       # Number of output channels. | ||||||
|  |         resolution,                         # Resolution of this block. | ||||||
|  |         img_channels,                       # Number of input color channels. | ||||||
|  |         first_layer_idx,                    # Index of the first layer. | ||||||
|  |         architecture        = 'resnet',     # Architecture: 'orig', 'skip', 'resnet'. | ||||||
|  |         activation          = 'lrelu',      # Activation function: 'relu', 'lrelu', etc. | ||||||
|  |         resample_filter     = [1,3,3,1],    # Low-pass filter to apply when resampling activations. | ||||||
|  |         conv_clamp          = None,         # Clamp the output of convolution layers to +-X, None = disable clamping. | ||||||
|  |         use_fp16            = False,        # Use FP16 for this block? | ||||||
|  |         fp16_channels_last  = False,        # Use channels-last memory format with FP16? | ||||||
|  |         freeze_layers       = 0,            # Freeze-D: Number of layers to freeze. | ||||||
|  |     ): | ||||||
|  |         assert in_channels in [0, tmp_channels] | ||||||
|  |         assert architecture in ['orig', 'skip', 'resnet'] | ||||||
|  |         super().__init__() | ||||||
|  |         self.in_channels = in_channels | ||||||
|  |         self.resolution = resolution | ||||||
|  |         self.img_channels = img_channels | ||||||
|  |         self.first_layer_idx = first_layer_idx | ||||||
|  |         self.architecture = architecture | ||||||
|  |         self.use_fp16 = use_fp16 | ||||||
|  |         self.channels_last = (use_fp16 and fp16_channels_last) | ||||||
|  |         self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) | ||||||
|  | 
 | ||||||
|  |         self.num_layers = 0 | ||||||
|  |         def trainable_gen(): | ||||||
|  |             while True: | ||||||
|  |                 layer_idx = self.first_layer_idx + self.num_layers | ||||||
|  |                 trainable = (layer_idx >= freeze_layers) | ||||||
|  |                 self.num_layers += 1 | ||||||
|  |                 yield trainable | ||||||
|  |         trainable_iter = trainable_gen() | ||||||
|  | 
 | ||||||
|  |         if in_channels == 0 or architecture == 'skip': | ||||||
|  |             self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation, | ||||||
|  |                 trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) | ||||||
|  | 
 | ||||||
|  |         self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation, | ||||||
|  |             trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last) | ||||||
|  | 
 | ||||||
|  |         self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2, | ||||||
|  |             trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last) | ||||||
|  | 
 | ||||||
|  |         if architecture == 'resnet': | ||||||
|  |             self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2, | ||||||
|  |                 trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x, img, force_fp32=False): | ||||||
|  |         dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 | ||||||
|  |         memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format | ||||||
|  | 
 | ||||||
|  |         # Input. | ||||||
|  |         if x is not None: | ||||||
|  |             misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) | ||||||
|  |             x = x.to(dtype=dtype, memory_format=memory_format) | ||||||
|  | 
 | ||||||
|  |         # FromRGB. | ||||||
|  |         if self.in_channels == 0 or self.architecture == 'skip': | ||||||
|  |             misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) | ||||||
|  |             img = img.to(dtype=dtype, memory_format=memory_format) | ||||||
|  |             y = self.fromrgb(img) | ||||||
|  |             x = x + y if x is not None else y | ||||||
|  |             img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None | ||||||
|  | 
 | ||||||
|  |         # Main layers. | ||||||
|  |         if self.architecture == 'resnet': | ||||||
|  |             y = self.skip(x, gain=np.sqrt(0.5)) | ||||||
|  |             x = self.conv0(x) | ||||||
|  |             x = self.conv1(x, gain=np.sqrt(0.5)) | ||||||
|  |             x = y.add_(x) | ||||||
|  |         else: | ||||||
|  |             x = self.conv0(x) | ||||||
|  |             x = self.conv1(x) | ||||||
|  | 
 | ||||||
|  |         assert x.dtype == dtype | ||||||
|  |         return x, img | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class MinibatchStdLayer(torch.nn.Module): | ||||||
|  |     def __init__(self, group_size, num_channels=1): | ||||||
|  |         super().__init__() | ||||||
|  |         self.group_size = group_size | ||||||
|  |         self.num_channels = num_channels | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         N, C, H, W = x.shape | ||||||
|  |         with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants | ||||||
|  |             G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N | ||||||
|  |         F = self.num_channels | ||||||
|  |         c = C // F | ||||||
|  | 
 | ||||||
|  |         y = x.reshape(G, -1, F, c, H, W)    # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. | ||||||
|  |         y = y - y.mean(dim=0)               # [GnFcHW] Subtract mean over group. | ||||||
|  |         y = y.square().mean(dim=0)          # [nFcHW]  Calc variance over group. | ||||||
|  |         y = (y + 1e-8).sqrt()               # [nFcHW]  Calc stddev over group. | ||||||
|  |         y = y.mean(dim=[2,3,4])             # [nF]     Take average over channels and pixels. | ||||||
|  |         y = y.reshape(-1, F, 1, 1)          # [nF11]   Add missing dimensions. | ||||||
|  |         y = y.repeat(G, 1, H, W)            # [NFHW]   Replicate over group and pixels. | ||||||
|  |         x = torch.cat([x, y], dim=1)        # [NCHW]   Append to input as new channels. | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class DiscriminatorEpilogue(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         in_channels,                    # Number of input channels. | ||||||
|  |         cmap_dim,                       # Dimensionality of mapped conditioning label, 0 = no label. | ||||||
|  |         resolution,                     # Resolution of this block. | ||||||
|  |         img_channels,                   # Number of input color channels. | ||||||
|  |         architecture        = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. | ||||||
|  |         mbstd_group_size    = 4,        # Group size for the minibatch standard deviation layer, None = entire minibatch. | ||||||
|  |         mbstd_num_channels  = 1,        # Number of features for the minibatch standard deviation layer, 0 = disable. | ||||||
|  |         activation          = 'lrelu',  # Activation function: 'relu', 'lrelu', etc. | ||||||
|  |         conv_clamp          = None,     # Clamp the output of convolution layers to +-X, None = disable clamping. | ||||||
|  |     ): | ||||||
|  |         assert architecture in ['orig', 'skip', 'resnet'] | ||||||
|  |         super().__init__() | ||||||
|  |         self.in_channels = in_channels | ||||||
|  |         self.cmap_dim = cmap_dim | ||||||
|  |         self.resolution = resolution | ||||||
|  |         self.img_channels = img_channels | ||||||
|  |         self.architecture = architecture | ||||||
|  | 
 | ||||||
|  |         if architecture == 'skip': | ||||||
|  |             self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation) | ||||||
|  |         self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None | ||||||
|  |         self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp) | ||||||
|  |         self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation) | ||||||
|  |         self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x, img, cmap, force_fp32=False): | ||||||
|  |         misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW] | ||||||
|  |         _ = force_fp32 # unused | ||||||
|  |         dtype = torch.float32 | ||||||
|  |         memory_format = torch.contiguous_format | ||||||
|  | 
 | ||||||
|  |         # FromRGB. | ||||||
|  |         x = x.to(dtype=dtype, memory_format=memory_format) | ||||||
|  |         if self.architecture == 'skip': | ||||||
|  |             misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution]) | ||||||
|  |             img = img.to(dtype=dtype, memory_format=memory_format) | ||||||
|  |             x = x + self.fromrgb(img) | ||||||
|  | 
 | ||||||
|  |         # Main layers. | ||||||
|  |         if self.mbstd is not None: | ||||||
|  |             x = self.mbstd(x) | ||||||
|  |         x = self.conv(x) | ||||||
|  |         x = self.fc(x.flatten(1)) | ||||||
|  |         x = self.out(x) | ||||||
|  | 
 | ||||||
|  |         # Conditioning. | ||||||
|  |         if self.cmap_dim > 0: | ||||||
|  |             misc.assert_shape(cmap, [None, self.cmap_dim]) | ||||||
|  |             x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) | ||||||
|  | 
 | ||||||
|  |         assert x.dtype == dtype | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class Discriminator(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         c_dim,                          # Conditioning label (C) dimensionality. | ||||||
|  |         img_resolution,                 # Input resolution. | ||||||
|  |         img_channels,                   # Number of input color channels. | ||||||
|  |         architecture        = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. | ||||||
|  |         channel_base        = 32768,    # Overall multiplier for the number of channels. | ||||||
|  |         channel_max         = 512,      # Maximum number of channels in any layer. | ||||||
|  |         num_fp16_res        = 0,        # Use FP16 for the N highest resolutions. | ||||||
|  |         conv_clamp          = None,     # Clamp the output of convolution layers to +-X, None = disable clamping. | ||||||
|  |         cmap_dim            = None,     # Dimensionality of mapped conditioning label, None = default. | ||||||
|  |         block_kwargs        = {},       # Arguments for DiscriminatorBlock. | ||||||
|  |         mapping_kwargs      = {},       # Arguments for MappingNetwork. | ||||||
|  |         epilogue_kwargs     = {},       # Arguments for DiscriminatorEpilogue. | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.c_dim = c_dim | ||||||
|  |         self.img_resolution = img_resolution | ||||||
|  |         self.img_resolution_log2 = int(np.log2(img_resolution)) | ||||||
|  |         self.img_channels = img_channels | ||||||
|  |         self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] | ||||||
|  |         channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} | ||||||
|  |         fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | ||||||
|  | 
 | ||||||
|  |         if cmap_dim is None: | ||||||
|  |             cmap_dim = channels_dict[4] | ||||||
|  |         if c_dim == 0: | ||||||
|  |             cmap_dim = 0 | ||||||
|  | 
 | ||||||
|  |         common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) | ||||||
|  |         cur_layer_idx = 0 | ||||||
|  |         for res in self.block_resolutions: | ||||||
|  |             in_channels = channels_dict[res] if res < img_resolution else 0 | ||||||
|  |             tmp_channels = channels_dict[res] | ||||||
|  |             out_channels = channels_dict[res // 2] | ||||||
|  |             use_fp16 = (res >= fp16_resolution) | ||||||
|  |             block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, | ||||||
|  |                 first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) | ||||||
|  |             setattr(self, f'b{res}', block) | ||||||
|  |             cur_layer_idx += block.num_layers | ||||||
|  |         if c_dim > 0: | ||||||
|  |             self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) | ||||||
|  |         self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) | ||||||
|  | 
 | ||||||
|  |     def forward(self, img, c, **block_kwargs): | ||||||
|  |         x = None | ||||||
|  |         for res in self.block_resolutions: | ||||||
|  |             block = getattr(self, f'b{res}') | ||||||
|  |             x, img = block(x, img, **block_kwargs) | ||||||
|  | 
 | ||||||
|  |         cmap = None | ||||||
|  |         if self.c_dim > 0: | ||||||
|  |             cmap = self.mapping(None, c) | ||||||
|  |         x = self.b4(x, img, cmap) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										605
									
								
								global_torch/visualizer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										605
									
								
								global_torch/visualizer.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,605 @@ | |||||||
|  | # python 3.7 | ||||||
|  | """Utility functions for visualizing results on html page.""" | ||||||
|  | 
 | ||||||
|  | import base64 | ||||||
|  | import os.path | ||||||
|  | import cv2 | ||||||
|  | import numpy as np | ||||||
|  | 
 | ||||||
|  | __all__ = [ | ||||||
|  |     'get_grid_shape', 'get_blank_image', 'load_image', 'save_image', | ||||||
|  |     'resize_image', 'add_text_to_image', 'fuse_images', 'HtmlPageVisualizer', | ||||||
|  |     'VideoReader', 'VideoWriter', 'adjust_pixel_range' | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def adjust_pixel_range(images, min_val=-1.0, max_val=1.0, channel_order='NCHW'): | ||||||
|  |   """Adjusts the pixel range of the input images. | ||||||
|  | 
 | ||||||
|  |   This function assumes the input array (image batch) is with shape [batch_size, | ||||||
|  |   channel, height, width] if `channel_order = NCHW`, or with shape [batch_size, | ||||||
|  |   height, width] if `channel_order = NHWC`. The returned images are with shape | ||||||
|  |   [batch_size, height, width, channel] and pixel range [0, 255]. | ||||||
|  | 
 | ||||||
|  |   NOTE: The channel order of output images will remain the same as the input. | ||||||
|  | 
 | ||||||
|  |   Args: | ||||||
|  |     images: Input images to adjust pixel range. | ||||||
|  |     min_val: Min value of the input images. (default: -1.0) | ||||||
|  |     max_val: Max value of the input images. (default: 1.0) | ||||||
|  |     channel_order: Channel order of the input array. (default: NCHW) | ||||||
|  | 
 | ||||||
|  |   Returns: | ||||||
|  |     The postprocessed images with dtype `numpy.uint8` and range [0, 255]. | ||||||
|  | 
 | ||||||
|  |   Raises: | ||||||
|  |     ValueError: If the input `images` are not with type `numpy.ndarray` or the | ||||||
|  |       shape is invalid according to `channel_order`. | ||||||
|  |   """ | ||||||
|  |   if not isinstance(images, np.ndarray): | ||||||
|  |     raise ValueError(f'Images should be with type `numpy.ndarray`!') | ||||||
|  | 
 | ||||||
|  |   channel_order = channel_order.upper() | ||||||
|  |   if channel_order not in ['NCHW', 'NHWC']: | ||||||
|  |     raise ValueError(f'Invalid channel order `{channel_order}`!') | ||||||
|  | 
 | ||||||
|  |   if images.ndim != 4: | ||||||
|  |     raise ValueError(f'Input images are expected to be with shape `NCHW` or ' | ||||||
|  |                      f'`NHWC`, but `{images.shape}` is received!') | ||||||
|  |   if channel_order == 'NCHW' and images.shape[1] not in [1, 3]: | ||||||
|  |     raise ValueError(f'Input images should have 1 or 3 channels under `NCHW` ' | ||||||
|  |                      f'channel order!') | ||||||
|  |   if channel_order == 'NHWC' and images.shape[3] not in [1, 3]: | ||||||
|  |     raise ValueError(f'Input images should have 1 or 3 channels under `NHWC` ' | ||||||
|  |                      f'channel order!') | ||||||
|  | 
 | ||||||
|  |   images = images.astype(np.float32) | ||||||
|  |   images = (images - min_val) * 255 / (max_val - min_val) | ||||||
|  |   images = np.clip(images + 0.5, 0, 255).astype(np.uint8) | ||||||
|  |   if channel_order == 'NCHW': | ||||||
|  |     images = images.transpose(0, 2, 3, 1) | ||||||
|  | 
 | ||||||
|  |   return images | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_grid_shape(size, row=0, col=0, is_portrait=False): | ||||||
|  |   """Gets the shape of a grid based on the size. | ||||||
|  | 
 | ||||||
|  |   This function makes greatest effort on making the output grid square if | ||||||
|  |   neither `row` nor `col` is set. If `is_portrait` is set as `False`, the height | ||||||
|  |   will always be equal to or smaller than the width. For example, if input | ||||||
|  |   `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, output shape | ||||||
|  |   will be (3, 5). Otherwise, the height will always be equal to or larger than | ||||||
|  |   the width. | ||||||
|  | 
 | ||||||
|  |   Args: | ||||||
|  |     size: Size (height * width) of the target grid. | ||||||
|  |     is_portrait: Whether to return a portrait size of a landscape size. | ||||||
|  |       (default: False) | ||||||
|  | 
 | ||||||
|  |   Returns: | ||||||
|  |     A two-element tuple, representing height and width respectively. | ||||||
|  |   """ | ||||||
|  |   assert isinstance(size, int) | ||||||
|  |   assert isinstance(row, int) | ||||||
|  |   assert isinstance(col, int) | ||||||
|  |   if size == 0: | ||||||
|  |     return (0, 0) | ||||||
|  | 
 | ||||||
|  |   if row > 0 and col > 0 and row * col != size: | ||||||
|  |     row = 0 | ||||||
|  |     col = 0 | ||||||
|  | 
 | ||||||
|  |   if row > 0 and size % row == 0: | ||||||
|  |     return (row, size // row) | ||||||
|  |   if col > 0 and size % col == 0: | ||||||
|  |     return (size // col, col) | ||||||
|  | 
 | ||||||
|  |   row = int(np.sqrt(size)) | ||||||
|  |   while row > 0: | ||||||
|  |     if size % row == 0: | ||||||
|  |       col = size // row | ||||||
|  |       break | ||||||
|  |     row = row - 1 | ||||||
|  | 
 | ||||||
|  |   return (col, row) if is_portrait else (row, col) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_blank_image(height, width, channels=3, is_black=True): | ||||||
|  |   """Gets a blank image, either white of black. | ||||||
|  | 
 | ||||||
|  |   NOTE: This function will always return an image with `RGB` channel order for | ||||||
|  |   color image and pixel range [0, 255]. | ||||||
|  | 
 | ||||||
|  |   Args: | ||||||
|  |     height: Height of the returned image. | ||||||
|  |     width: Width of the returned image. | ||||||
|  |     channels: Number of channels. (default: 3) | ||||||
|  |     is_black: Whether to return a black image or white image. (default: True) | ||||||
|  |   """ | ||||||
|  |   shape = (height, width, channels) | ||||||
|  |   if is_black: | ||||||
|  |     return np.zeros(shape, dtype=np.uint8) | ||||||
|  |   return np.ones(shape, dtype=np.uint8) * 255 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def load_image(path): | ||||||
|  |   """Loads an image from disk. | ||||||
|  | 
 | ||||||
|  |   NOTE: This function will always return an image with `RGB` channel order for | ||||||
|  |   color image and pixel range [0, 255]. | ||||||
|  | 
 | ||||||
|  |   Args: | ||||||
|  |     path: Path to load the image from. | ||||||
|  | 
 | ||||||
|  |   Returns: | ||||||
|  |     An image with dtype `np.ndarray` or `None` if input `path` does not exist. | ||||||
|  |   """ | ||||||
|  |   if not os.path.isfile(path): | ||||||
|  |     return None | ||||||
|  | 
 | ||||||
|  |   image = cv2.imread(path) | ||||||
|  |   return image[:, :, ::-1] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def save_image(path, image): | ||||||
|  |   """Saves an image to disk. | ||||||
|  | 
 | ||||||
|  |   NOTE: The input image (if colorful) is assumed to be with `RGB` channel order | ||||||
|  |   and pixel range [0, 255]. | ||||||
|  | 
 | ||||||
|  |   Args: | ||||||
|  |     path: Path to save the image to. | ||||||
|  |     image: Image to save. | ||||||
|  |   """ | ||||||
|  |   if image is None: | ||||||
|  |     return | ||||||
|  | 
 | ||||||
|  |   assert len(image.shape) == 3 and image.shape[2] in [1, 3] | ||||||
|  |   cv2.imwrite(path, image[:, :, ::-1]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resize_image(image, *args, **kwargs): | ||||||
|  |   """Resizes image. | ||||||
|  | 
 | ||||||
|  |   This is a wrap of `cv2.resize()`. | ||||||
|  | 
 | ||||||
|  |   NOTE: THe channel order of the input image will not be changed. | ||||||
|  | 
 | ||||||
|  |   Args: | ||||||
|  |     image: Image to resize. | ||||||
|  |   """ | ||||||
|  |   if image is None: | ||||||
|  |     return None | ||||||
|  | 
 | ||||||
|  |   assert image.ndim == 3 and image.shape[2] in [1, 3] | ||||||
|  |   image = cv2.resize(image, *args, **kwargs) | ||||||
|  |   if image.ndim == 2: | ||||||
|  |     return image[:, :, np.newaxis] | ||||||
|  |   return image | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def add_text_to_image(image, | ||||||
|  |                       text='', | ||||||
|  |                       position=None, | ||||||
|  |                       font=cv2.FONT_HERSHEY_TRIPLEX, | ||||||
|  |                       font_size=1.0, | ||||||
|  |                       line_type=cv2.LINE_8, | ||||||
|  |                       line_width=1, | ||||||
|  |                       color=(255, 255, 255)): | ||||||
|  |   """Overlays text on given image. | ||||||
|  | 
 | ||||||
|  |   NOTE: The input image is assumed to be with `RGB` channel order. | ||||||
|  | 
 | ||||||
|  |   Args: | ||||||
|  |     image: The image to overlay text on. | ||||||
|  |     text: Text content to overlay on the image. (default: '') | ||||||
|  |     position: Target position (bottom-left corner) to add text. If not set, | ||||||
|  |       center of the image will be used by default. (default: None) | ||||||
|  |     font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX) | ||||||
|  |     font_size: Font size of the text added. (default: 1.0) | ||||||
|  |     line_type: Line type used to depict the text. (default: cv2.LINE_8) | ||||||
|  |     line_width: Line width used to depict the text. (default: 1) | ||||||
|  |     color: Color of the text added in `RGB` channel order. (default: | ||||||
|  |       (255, 255, 255)) | ||||||
|  | 
 | ||||||
|  |   Returns: | ||||||
|  |     An image with target text overlayed on. | ||||||
|  |   """ | ||||||
|  |   if image is None or not text: | ||||||
|  |     return image | ||||||
|  | 
 | ||||||
|  |   cv2.putText(img=image, | ||||||
|  |               text=text, | ||||||
|  |               org=position, | ||||||
|  |               fontFace=font, | ||||||
|  |               fontScale=font_size, | ||||||
|  |               color=color, | ||||||
|  |               thickness=line_width, | ||||||
|  |               lineType=line_type, | ||||||
|  |               bottomLeftOrigin=False) | ||||||
|  | 
 | ||||||
|  |   return image | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def fuse_images(images, | ||||||
|  |                 image_size=None, | ||||||
|  |                 row=0, | ||||||
|  |                 col=0, | ||||||
|  |                 is_row_major=True, | ||||||
|  |                 is_portrait=False, | ||||||
|  |                 row_spacing=0, | ||||||
|  |                 col_spacing=0, | ||||||
|  |                 border_left=0, | ||||||
|  |                 border_right=0, | ||||||
|  |                 border_top=0, | ||||||
|  |                 border_bottom=0, | ||||||
|  |                 black_background=True): | ||||||
|  |   """Fuses a collection of images into an entire image. | ||||||
|  | 
 | ||||||
|  |   Args: | ||||||
|  |     images: A collection of images to fuse. Should be with shape [num, height, | ||||||
|  |       width, channels]. | ||||||
|  |     image_size: Int or two-element tuple. This field is used to resize the image | ||||||
|  |       before fusing. `None` disables resizing. (default: None) | ||||||
|  |     row: Number of rows used for image fusion. If not set, this field will be | ||||||
|  |       automatically assigned based on `col` and total number of images. | ||||||
|  |       (default: None) | ||||||
|  |     col: Number of columns used for image fusion. If not set, this field will be | ||||||
|  |       automatically assigned based on `row` and total number of images. | ||||||
|  |       (default: None) | ||||||
|  |     is_row_major: Whether the input images should be arranged row-major or | ||||||
|  |       column-major. (default: True) | ||||||
|  |     is_portrait: Only active when both `row` and `col` should be assigned | ||||||
|  |       automatically. (default: False) | ||||||
|  |     row_spacing: Space between rows. (default: 0) | ||||||
|  |     col_spacing: Space between columns. (default: 0) | ||||||
|  |     border_left: Width of left border. (default: 0) | ||||||
|  |     border_right: Width of right border. (default: 0) | ||||||
|  |     border_top: Width of top border. (default: 0) | ||||||
|  |     border_bottom: Width of bottom border. (default: 0) | ||||||
|  | 
 | ||||||
|  |   Returns: | ||||||
|  |     The fused image. | ||||||
|  | 
 | ||||||
|  |   Raises: | ||||||
|  |     ValueError: If the input `images` is not with shape [num, height, width, | ||||||
|  |       width]. | ||||||
|  |   """ | ||||||
|  |   if images is None: | ||||||
|  |     return images | ||||||
|  | 
 | ||||||
|  |   if not images.ndim == 4: | ||||||
|  |     raise ValueError(f'Input `images` should be with shape [num, height, ' | ||||||
|  |                      f'width, channels], but {images.shape} is received!') | ||||||
|  | 
 | ||||||
|  |   num, image_height, image_width, channels = images.shape | ||||||
|  |   if image_size is not None: | ||||||
|  |     if isinstance(image_size, int): | ||||||
|  |       image_size = (image_size, image_size) | ||||||
|  |     assert isinstance(image_size, (list, tuple)) and len(image_size) == 2 | ||||||
|  |     width, height = image_size | ||||||
|  |   else: | ||||||
|  |     height, width = image_height, image_width | ||||||
|  |   row, col = get_grid_shape(num, row=row, col=col, is_portrait=is_portrait) | ||||||
|  |   fused_height = ( | ||||||
|  |       height * row + row_spacing * (row - 1) + border_top + border_bottom) | ||||||
|  |   fused_width = ( | ||||||
|  |       width * col + col_spacing * (col - 1) + border_left + border_right) | ||||||
|  |   fused_image = get_blank_image( | ||||||
|  |       fused_height, fused_width, channels=channels, is_black=black_background) | ||||||
|  |   images = images.reshape(row, col, image_height, image_width, channels) | ||||||
|  |   if not is_row_major: | ||||||
|  |     images = images.transpose(1, 0, 2, 3, 4) | ||||||
|  | 
 | ||||||
|  |   for i in range(row): | ||||||
|  |     y = border_top + i * (height + row_spacing) | ||||||
|  |     for j in range(col): | ||||||
|  |       x = border_left + j * (width + col_spacing) | ||||||
|  |       if image_size is not None: | ||||||
|  |         image = cv2.resize(images[i, j], image_size) | ||||||
|  |       else: | ||||||
|  |         image = images[i, j] | ||||||
|  |       fused_image[y:y + height, x:x + width] = image | ||||||
|  | 
 | ||||||
|  |   return fused_image | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_sortable_html_header(column_name_list, sort_by_ascending=False): | ||||||
|  |   """Gets header for sortable html page. | ||||||
|  | 
 | ||||||
|  |   Basically, the html page contains a sortable table, where user can sort the | ||||||
|  |   rows by a particular column by clicking the column head. | ||||||
|  | 
 | ||||||
|  |   Example: | ||||||
|  | 
 | ||||||
|  |   column_name_list = [name_1, name_2, name_3] | ||||||
|  |   header = get_sortable_html_header(column_name_list) | ||||||
|  |   footer = get_sortable_html_footer() | ||||||
|  |   sortable_table = ... | ||||||
|  |   html_page = header + sortable_table + footer | ||||||
|  | 
 | ||||||
|  |   Args: | ||||||
|  |     column_name_list: List of column header names. | ||||||
|  |     sort_by_ascending: Default sorting order. If set as `True`, the html page | ||||||
|  |       will be sorted by ascending order when the header is clicked for the first | ||||||
|  |       time. | ||||||
|  | 
 | ||||||
|  |   Returns: | ||||||
|  |     A string, which represents for the header for a sortable html page. | ||||||
|  |   """ | ||||||
|  |   header = '\n'.join([ | ||||||
|  |       '<script type="text/javascript">', | ||||||
|  |       'var column_idx;', | ||||||
|  |       'var sort_by_ascending = ' + str(sort_by_ascending).lower() + ';', | ||||||
|  |       '', | ||||||
|  |       'function sorting(tbody, column_idx){', | ||||||
|  |       '  this.column_idx = column_idx;', | ||||||
|  |       '  Array.from(tbody.rows)', | ||||||
|  |       '       .sort(compareCells)', | ||||||
|  |       '       .forEach(function(row) { tbody.appendChild(row); })', | ||||||
|  |       '  sort_by_ascending = !sort_by_ascending;', | ||||||
|  |       '}', | ||||||
|  |       '', | ||||||
|  |       'function compareCells(row_a, row_b) {', | ||||||
|  |       '  var val_a = row_a.cells[column_idx].innerText;', | ||||||
|  |       '  var val_b = row_b.cells[column_idx].innerText;', | ||||||
|  |       '  var flag = sort_by_ascending ? 1 : -1;', | ||||||
|  |       '  return flag * (val_a > val_b ? 1 : -1);', | ||||||
|  |       '}', | ||||||
|  |       '</script>', | ||||||
|  |       '', | ||||||
|  |       '<html>', | ||||||
|  |       '', | ||||||
|  |       '<head>', | ||||||
|  |       '<style>', | ||||||
|  |       '  table {', | ||||||
|  |       '    border-spacing: 0;', | ||||||
|  |       '    border: 1px solid black;', | ||||||
|  |       '  }', | ||||||
|  |       '  th {', | ||||||
|  |       '    cursor: pointer;', | ||||||
|  |       '  }', | ||||||
|  |       '  th, td {', | ||||||
|  |       '    text-align: left;', | ||||||
|  |       '    vertical-align: middle;', | ||||||
|  |       '    border-collapse: collapse;', | ||||||
|  |       '    border: 0.5px solid black;', | ||||||
|  |       '    padding: 8px;', | ||||||
|  |       '  }', | ||||||
|  |       '  tr:nth-child(even) {', | ||||||
|  |       '    background-color: #d2d2d2;', | ||||||
|  |       '  }', | ||||||
|  |       '</style>', | ||||||
|  |       '</head>', | ||||||
|  |       '', | ||||||
|  |       '<body>', | ||||||
|  |       '', | ||||||
|  |       '<table>', | ||||||
|  |       '<thead>', | ||||||
|  |       '<tr>', | ||||||
|  |       '']) | ||||||
|  |   for idx, column_name in enumerate(column_name_list): | ||||||
|  |     header += f'  <th onclick="sorting(tbody, {idx})">{column_name}</th>\n' | ||||||
|  |   header += '</tr>\n' | ||||||
|  |   header += '</thead>\n' | ||||||
|  |   header += '<tbody id="tbody">\n' | ||||||
|  | 
 | ||||||
|  |   return header | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_sortable_html_footer(): | ||||||
|  |   """Gets footer for sortable html page. | ||||||
|  | 
 | ||||||
|  |   Check function `get_sortable_html_header()` for more details. | ||||||
|  |   """ | ||||||
|  |   return '</tbody>\n</table>\n\n</body>\n</html>\n' | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def encode_image_to_html_str(image, image_size=None): | ||||||
|  |   """Encodes an image to html language. | ||||||
|  | 
 | ||||||
|  |   Args: | ||||||
|  |     image: The input image to encode. Should be with `RGB` channel order. | ||||||
|  |     image_size: Int or two-element tuple. This field is used to resize the image | ||||||
|  |       before encoding. `None` disables resizing. (default: None) | ||||||
|  | 
 | ||||||
|  |   Returns: | ||||||
|  |     A string which represents the encoded image. | ||||||
|  |   """ | ||||||
|  |   if image is None: | ||||||
|  |     return '' | ||||||
|  | 
 | ||||||
|  |   assert len(image.shape) == 3 and image.shape[2] in [1, 3] | ||||||
|  | 
 | ||||||
|  |   # Change channel order to `BGR`, which is opencv-friendly. | ||||||
|  |   image = image[:, :, ::-1] | ||||||
|  | 
 | ||||||
|  |   # Resize the image if needed. | ||||||
|  |   if image_size is not None: | ||||||
|  |     if isinstance(image_size, int): | ||||||
|  |       image_size = (image_size, image_size) | ||||||
|  |     assert isinstance(image_size, (list, tuple)) and len(image_size) == 2 | ||||||
|  |     image = cv2.resize(image, image_size) | ||||||
|  | 
 | ||||||
|  |   # Encode the image to html-format string. | ||||||
|  |   encoded_image = cv2.imencode(".jpg", image)[1].tostring() | ||||||
|  |   encoded_image_base64 = base64.b64encode(encoded_image).decode('utf-8') | ||||||
|  |   html_str = f'<img src="data:image/jpeg;base64, {encoded_image_base64}"/>' | ||||||
|  | 
 | ||||||
|  |   return html_str | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class HtmlPageVisualizer(object): | ||||||
|  |   """Defines the html page visualizer. | ||||||
|  | 
 | ||||||
|  |   This class can be used to visualize image results as html page. Basically, it | ||||||
|  |   is based on an html-format sorted table with helper functions | ||||||
|  |   `get_sortable_html_header()`, `get_sortable_html_footer()`, and | ||||||
|  |   `encode_image_to_html_str()`. To simplify the usage, specifying the following | ||||||
|  |   fields is enough to create a visualization page: | ||||||
|  | 
 | ||||||
|  |   (1) num_rows: Number of rows of the table (header-row exclusive). | ||||||
|  |   (2) num_cols: Number of columns of the table. | ||||||
|  |   (3) header contents (optional): Title of each column. | ||||||
|  | 
 | ||||||
|  |   NOTE: `grid_size` can be used to assign `num_rows` and `num_cols` | ||||||
|  |   automatically. | ||||||
|  | 
 | ||||||
|  |   Example: | ||||||
|  | 
 | ||||||
|  |   html = HtmlPageVisualizer(num_rows, num_cols) | ||||||
|  |   html.set_headers([...]) | ||||||
|  |   for i in range(num_rows): | ||||||
|  |     for j in range(num_cols): | ||||||
|  |       html.set_cell(i, j, text=..., image=...) | ||||||
|  |   html.save('visualize.html') | ||||||
|  |   """ | ||||||
|  | 
 | ||||||
|  |   def __init__(self, | ||||||
|  |                num_rows=0, | ||||||
|  |                num_cols=0, | ||||||
|  |                grid_size=0, | ||||||
|  |                is_portrait=False, | ||||||
|  |                viz_size=None): | ||||||
|  |     if grid_size > 0: | ||||||
|  |       num_rows, num_cols = get_grid_shape( | ||||||
|  |           grid_size, row=num_rows, col=num_cols, is_portrait=is_portrait) | ||||||
|  |     assert num_rows > 0 and num_cols > 0 | ||||||
|  | 
 | ||||||
|  |     self.num_rows = num_rows | ||||||
|  |     self.num_cols = num_cols | ||||||
|  |     self.viz_size = viz_size | ||||||
|  |     self.headers = ['' for _ in range(self.num_cols)] | ||||||
|  |     self.cells = [[{ | ||||||
|  |         'text': '', | ||||||
|  |         'image': '', | ||||||
|  |     } for _ in range(self.num_cols)] for _ in range(self.num_rows)] | ||||||
|  | 
 | ||||||
|  |   def set_header(self, column_idx, content): | ||||||
|  |     """Sets the content of a particular header by column index.""" | ||||||
|  |     self.headers[column_idx] = content | ||||||
|  | 
 | ||||||
|  |   def set_headers(self, contents): | ||||||
|  |     """Sets the contents of all headers.""" | ||||||
|  |     if isinstance(contents, str): | ||||||
|  |       contents = [contents] | ||||||
|  |     assert isinstance(contents, (list, tuple)) | ||||||
|  |     assert len(contents) == self.num_cols | ||||||
|  |     for column_idx, content in enumerate(contents): | ||||||
|  |       self.set_header(column_idx, content) | ||||||
|  | 
 | ||||||
|  |   def set_cell(self, row_idx, column_idx, text='', image=None): | ||||||
|  |     """Sets the content of a particular cell. | ||||||
|  | 
 | ||||||
|  |     Basically, a cell contains some text as well as an image. Both text and | ||||||
|  |     image can be empty. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |       row_idx: Row index of the cell to edit. | ||||||
|  |       column_idx: Column index of the cell to edit. | ||||||
|  |       text: Text to add into the target cell. | ||||||
|  |       image: Image to show in the target cell. Should be with `RGB` channel | ||||||
|  |         order. | ||||||
|  |     """ | ||||||
|  |     self.cells[row_idx][column_idx]['text'] = text | ||||||
|  |     self.cells[row_idx][column_idx]['image'] = encode_image_to_html_str( | ||||||
|  |         image, self.viz_size) | ||||||
|  | 
 | ||||||
|  |   def save(self, save_path): | ||||||
|  |     """Saves the html page.""" | ||||||
|  |     html = '' | ||||||
|  |     for i in range(self.num_rows): | ||||||
|  |       html += f'<tr>\n' | ||||||
|  |       for j in range(self.num_cols): | ||||||
|  |         text = self.cells[i][j]['text'] | ||||||
|  |         image = self.cells[i][j]['image'] | ||||||
|  |         if text: | ||||||
|  |           html += f'  <td>{text}<br><br>{image}</td>\n' | ||||||
|  |         else: | ||||||
|  |           html += f'  <td>{image}</td>\n' | ||||||
|  |       html += f'</tr>\n' | ||||||
|  | 
 | ||||||
|  |     header = get_sortable_html_header(self.headers) | ||||||
|  |     footer = get_sortable_html_footer() | ||||||
|  | 
 | ||||||
|  |     with open(save_path, 'w') as f: | ||||||
|  |       f.write(header + html + footer) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class VideoReader(object): | ||||||
|  |   """Defines the video reader. | ||||||
|  | 
 | ||||||
|  |   This class can be used to read frames from a given video. | ||||||
|  |   """ | ||||||
|  | 
 | ||||||
|  |   def __init__(self, path): | ||||||
|  |     """Initializes the video reader by loading the video from disk.""" | ||||||
|  |     if not os.path.isfile(path): | ||||||
|  |       raise ValueError(f'Video `{path}` does not exist!') | ||||||
|  | 
 | ||||||
|  |     self.path = path | ||||||
|  |     self.video = cv2.VideoCapture(path) | ||||||
|  |     assert self.video.isOpened() | ||||||
|  |     self.position = 0 | ||||||
|  | 
 | ||||||
|  |     self.length = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT)) | ||||||
|  |     self.frame_height = int(self.video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | ||||||
|  |     self.frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH)) | ||||||
|  |     self.fps = self.video.get(cv2.CAP_PROP_FPS) | ||||||
|  | 
 | ||||||
|  |   def __del__(self): | ||||||
|  |     """Releases the opened video.""" | ||||||
|  |     self.video.release() | ||||||
|  | 
 | ||||||
|  |   def read(self, position=None): | ||||||
|  |     """Reads a certain frame. | ||||||
|  | 
 | ||||||
|  |     NOTE: The returned frame is assumed to be with `RGB` channel order. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |       position: Optional. If set, the reader will read frames from the exact | ||||||
|  |         position. Otherwise, the reader will read next frames. (default: None) | ||||||
|  |     """ | ||||||
|  |     if position is not None and position < self.length: | ||||||
|  |       self.video.set(cv2.CAP_PROP_POS_FRAMES, position) | ||||||
|  |       self.position = position | ||||||
|  | 
 | ||||||
|  |     success, frame = self.video.read() | ||||||
|  |     self.position = self.position + 1 | ||||||
|  | 
 | ||||||
|  |     return frame[:, :, ::-1] if success else None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class VideoWriter(object): | ||||||
|  |   """Defines the video writer. | ||||||
|  | 
 | ||||||
|  |   This class can be used to create a video. | ||||||
|  | 
 | ||||||
|  |   NOTE: `.avi` and `DIVX` is the most recommended codec format since it does not | ||||||
|  |   rely on other dependencies. | ||||||
|  |   """ | ||||||
|  | 
 | ||||||
|  |   def __init__(self, path, frame_height, frame_width, fps=24, codec='DIVX'): | ||||||
|  |     """Creates the video writer.""" | ||||||
|  |     self.path = path | ||||||
|  |     self.frame_height = frame_height | ||||||
|  |     self.frame_width = frame_width | ||||||
|  |     self.fps = fps | ||||||
|  |     self.codec = codec | ||||||
|  | 
 | ||||||
|  |     self.video = cv2.VideoWriter(filename=path, | ||||||
|  |                                  fourcc=cv2.VideoWriter_fourcc(*codec), | ||||||
|  |                                  fps=fps, | ||||||
|  |                                  frameSize=(frame_width, frame_height)) | ||||||
|  | 
 | ||||||
|  |   def __del__(self): | ||||||
|  |     """Releases the opened video.""" | ||||||
|  |     self.video.release() | ||||||
|  | 
 | ||||||
|  |   def write(self, frame): | ||||||
|  |     """Writes a target frame. | ||||||
|  | 
 | ||||||
|  |     NOTE: The input frame is assumed to be with `RGB` channel order. | ||||||
|  |     """ | ||||||
|  |     self.video.write(frame[:, :, ::-1]) | ||||||
							
								
								
									
										
											BIN
										
									
								
								latents_test/example_celebs.pt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								latents_test/example_celebs.pt
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										21
									
								
								licenses/LICENSE-CLIP
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								licenses/LICENSE-CLIP
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | |||||||
|  | MIT License | ||||||
|  | 
 | ||||||
|  | Copyright (c) 2021 OpenAI | ||||||
|  | 
 | ||||||
|  | Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
|  | of this software and associated documentation files (the "Software"), to deal | ||||||
|  | in the Software without restriction, including without limitation the rights | ||||||
|  | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||||
|  | copies of the Software, and to permit persons to whom the Software is | ||||||
|  | furnished to do so, subject to the following conditions: | ||||||
|  | 
 | ||||||
|  | The above copyright notice and this permission notice shall be included in all | ||||||
|  | copies or substantial portions of the Software. | ||||||
|  | 
 | ||||||
|  | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||||
|  | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||||
|  | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||||
|  | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||||
|  | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||||
|  | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||||
|  | SOFTWARE. | ||||||
							
								
								
									
										21
									
								
								licenses/LICENSE-stylegan2-pytorch
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								licenses/LICENSE-stylegan2-pytorch
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | |||||||
|  | MIT License | ||||||
|  | 
 | ||||||
|  | Copyright (c) 2019 Kim Seonghyeon | ||||||
|  | 
 | ||||||
|  | Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
|  | of this software and associated documentation files (the "Software"), to deal | ||||||
|  | in the Software without restriction, including without limitation the rights | ||||||
|  | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||||
|  | copies of the Software, and to permit persons to whom the Software is | ||||||
|  | furnished to do so, subject to the following conditions: | ||||||
|  | 
 | ||||||
|  | The above copyright notice and this permission notice shall be included in all | ||||||
|  | copies or substantial portions of the Software. | ||||||
|  | 
 | ||||||
|  | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||||
|  | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||||
|  | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||||
|  | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||||
|  | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||||
|  | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||||
|  | SOFTWARE. | ||||||
							
								
								
									
										0
									
								
								models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								models/facial_recognition/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								models/facial_recognition/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										119
									
								
								models/facial_recognition/helpers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								models/facial_recognition/helpers.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,119 @@ | |||||||
|  | from collections import namedtuple | ||||||
|  | import torch | ||||||
|  | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module | ||||||
|  | 
 | ||||||
|  | """ | ||||||
|  | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Flatten(Module): | ||||||
|  | 	def forward(self, input): | ||||||
|  | 		return input.view(input.size(0), -1) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def l2_norm(input, axis=1): | ||||||
|  | 	norm = torch.norm(input, 2, axis, True) | ||||||
|  | 	output = torch.div(input, norm) | ||||||
|  | 	return output | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): | ||||||
|  | 	""" A named tuple describing a ResNet block. """ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_block(in_channel, depth, num_units, stride=2): | ||||||
|  | 	return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_blocks(num_layers): | ||||||
|  | 	if num_layers == 50: | ||||||
|  | 		blocks = [ | ||||||
|  | 			get_block(in_channel=64, depth=64, num_units=3), | ||||||
|  | 			get_block(in_channel=64, depth=128, num_units=4), | ||||||
|  | 			get_block(in_channel=128, depth=256, num_units=14), | ||||||
|  | 			get_block(in_channel=256, depth=512, num_units=3) | ||||||
|  | 		] | ||||||
|  | 	elif num_layers == 100: | ||||||
|  | 		blocks = [ | ||||||
|  | 			get_block(in_channel=64, depth=64, num_units=3), | ||||||
|  | 			get_block(in_channel=64, depth=128, num_units=13), | ||||||
|  | 			get_block(in_channel=128, depth=256, num_units=30), | ||||||
|  | 			get_block(in_channel=256, depth=512, num_units=3) | ||||||
|  | 		] | ||||||
|  | 	elif num_layers == 152: | ||||||
|  | 		blocks = [ | ||||||
|  | 			get_block(in_channel=64, depth=64, num_units=3), | ||||||
|  | 			get_block(in_channel=64, depth=128, num_units=8), | ||||||
|  | 			get_block(in_channel=128, depth=256, num_units=36), | ||||||
|  | 			get_block(in_channel=256, depth=512, num_units=3) | ||||||
|  | 		] | ||||||
|  | 	else: | ||||||
|  | 		raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) | ||||||
|  | 	return blocks | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class SEModule(Module): | ||||||
|  | 	def __init__(self, channels, reduction): | ||||||
|  | 		super(SEModule, self).__init__() | ||||||
|  | 		self.avg_pool = AdaptiveAvgPool2d(1) | ||||||
|  | 		self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) | ||||||
|  | 		self.relu = ReLU(inplace=True) | ||||||
|  | 		self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) | ||||||
|  | 		self.sigmoid = Sigmoid() | ||||||
|  | 
 | ||||||
|  | 	def forward(self, x): | ||||||
|  | 		module_input = x | ||||||
|  | 		x = self.avg_pool(x) | ||||||
|  | 		x = self.fc1(x) | ||||||
|  | 		x = self.relu(x) | ||||||
|  | 		x = self.fc2(x) | ||||||
|  | 		x = self.sigmoid(x) | ||||||
|  | 		return module_input * x | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class bottleneck_IR(Module): | ||||||
|  | 	def __init__(self, in_channel, depth, stride): | ||||||
|  | 		super(bottleneck_IR, self).__init__() | ||||||
|  | 		if in_channel == depth: | ||||||
|  | 			self.shortcut_layer = MaxPool2d(1, stride) | ||||||
|  | 		else: | ||||||
|  | 			self.shortcut_layer = Sequential( | ||||||
|  | 				Conv2d(in_channel, depth, (1, 1), stride, bias=False), | ||||||
|  | 				BatchNorm2d(depth) | ||||||
|  | 			) | ||||||
|  | 		self.res_layer = Sequential( | ||||||
|  | 			BatchNorm2d(in_channel), | ||||||
|  | 			Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), | ||||||
|  | 			Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
|  | 	def forward(self, x): | ||||||
|  | 		shortcut = self.shortcut_layer(x) | ||||||
|  | 		res = self.res_layer(x) | ||||||
|  | 		return res + shortcut | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class bottleneck_IR_SE(Module): | ||||||
|  | 	def __init__(self, in_channel, depth, stride): | ||||||
|  | 		super(bottleneck_IR_SE, self).__init__() | ||||||
|  | 		if in_channel == depth: | ||||||
|  | 			self.shortcut_layer = MaxPool2d(1, stride) | ||||||
|  | 		else: | ||||||
|  | 			self.shortcut_layer = Sequential( | ||||||
|  | 				Conv2d(in_channel, depth, (1, 1), stride, bias=False), | ||||||
|  | 				BatchNorm2d(depth) | ||||||
|  | 			) | ||||||
|  | 		self.res_layer = Sequential( | ||||||
|  | 			BatchNorm2d(in_channel), | ||||||
|  | 			Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | ||||||
|  | 			PReLU(depth), | ||||||
|  | 			Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | ||||||
|  | 			BatchNorm2d(depth), | ||||||
|  | 			SEModule(depth, 16) | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
|  | 	def forward(self, x): | ||||||
|  | 		shortcut = self.shortcut_layer(x) | ||||||
|  | 		res = self.res_layer(x) | ||||||
|  | 		return res + shortcut | ||||||
							
								
								
									
										86
									
								
								models/facial_recognition/model_irse.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								models/facial_recognition/model_irse.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,86 @@ | |||||||
|  | import sys | ||||||
|  | sys.path.append('/home/ly/StyleCLIP-main/models/facial_recognition') | ||||||
|  | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module | ||||||
|  | from helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm | ||||||
|  | 
 | ||||||
|  | """ | ||||||
|  | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Backbone(Module): | ||||||
|  | 	def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): | ||||||
|  | 		super(Backbone, self).__init__() | ||||||
|  | 		assert input_size in [112, 224], "input_size should be 112 or 224" | ||||||
|  | 		assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" | ||||||
|  | 		assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" | ||||||
|  | 		blocks = get_blocks(num_layers) | ||||||
|  | 		if mode == 'ir': | ||||||
|  | 			unit_module = bottleneck_IR | ||||||
|  | 		elif mode == 'ir_se': | ||||||
|  | 			unit_module = bottleneck_IR_SE | ||||||
|  | 		self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), | ||||||
|  | 									  BatchNorm2d(64), | ||||||
|  | 									  PReLU(64)) | ||||||
|  | 		if input_size == 112: | ||||||
|  | 			self.output_layer = Sequential(BatchNorm2d(512), | ||||||
|  | 			                               Dropout(drop_ratio), | ||||||
|  | 			                               Flatten(), | ||||||
|  | 			                               Linear(512 * 7 * 7, 512), | ||||||
|  | 			                               BatchNorm1d(512, affine=affine)) | ||||||
|  | 		else: | ||||||
|  | 			self.output_layer = Sequential(BatchNorm2d(512), | ||||||
|  | 			                               Dropout(drop_ratio), | ||||||
|  | 			                               Flatten(), | ||||||
|  | 			                               Linear(512 * 14 * 14, 512), | ||||||
|  | 			                               BatchNorm1d(512, affine=affine)) | ||||||
|  | 
 | ||||||
|  | 		modules = [] | ||||||
|  | 		for block in blocks: | ||||||
|  | 			for bottleneck in block: | ||||||
|  | 				modules.append(unit_module(bottleneck.in_channel, | ||||||
|  | 										   bottleneck.depth, | ||||||
|  | 										   bottleneck.stride)) | ||||||
|  | 		self.body = Sequential(*modules) | ||||||
|  | 
 | ||||||
|  | 	def forward(self, x): | ||||||
|  | 		x = self.input_layer(x) | ||||||
|  | 		x = self.body(x) | ||||||
|  | 		x = self.output_layer(x) | ||||||
|  | 		return l2_norm(x) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def IR_50(input_size): | ||||||
|  | 	"""Constructs a ir-50 model.""" | ||||||
|  | 	model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) | ||||||
|  | 	return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def IR_101(input_size): | ||||||
|  | 	"""Constructs a ir-101 model.""" | ||||||
|  | 	model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) | ||||||
|  | 	return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def IR_152(input_size): | ||||||
|  | 	"""Constructs a ir-152 model.""" | ||||||
|  | 	model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) | ||||||
|  | 	return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def IR_SE_50(input_size): | ||||||
|  | 	"""Constructs a ir_se-50 model.""" | ||||||
|  | 	model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) | ||||||
|  | 	return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def IR_SE_101(input_size): | ||||||
|  | 	"""Constructs a ir_se-101 model.""" | ||||||
|  | 	model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) | ||||||
|  | 	return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def IR_SE_152(input_size): | ||||||
|  | 	"""Constructs a ir_se-152 model.""" | ||||||
|  | 	model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) | ||||||
|  | 	return model | ||||||
							
								
								
									
										0
									
								
								models/stylegan2/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								models/stylegan2/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										715
									
								
								models/stylegan2/model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										715
									
								
								models/stylegan2/model.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,715 @@ | |||||||
|  | import math | ||||||
|  | import random | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | from torch import nn | ||||||
|  | from torch.nn import functional as F | ||||||
|  | 
 | ||||||
|  | from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class PixelNorm(nn.Module): | ||||||
|  |     def __init__(self): | ||||||
|  |         super().__init__() | ||||||
|  |     #normalizes了特征向量的每个元素到单位长度附近,阻止了信号幅度signal magnitudes导致的在训练过程中逐步失控的风险。 | ||||||
|  |     def forward(self, input): | ||||||
|  |         return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def make_kernel(k): | ||||||
|  |     k = torch.tensor(k, dtype=torch.float32) | ||||||
|  | 
 | ||||||
|  |     if k.ndim == 1: | ||||||
|  |         k = k[None, :] * k[:, None] | ||||||
|  | 
 | ||||||
|  |     k /= k.sum() | ||||||
|  | 
 | ||||||
|  |     return k | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Upsample(nn.Module): | ||||||
|  |     def __init__(self, kernel, factor=2): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.factor = factor | ||||||
|  |         kernel = make_kernel(kernel) * (factor ** 2) | ||||||
|  |         self.register_buffer('kernel', kernel) | ||||||
|  | 
 | ||||||
|  |         p = kernel.shape[0] - factor | ||||||
|  | 
 | ||||||
|  |         pad0 = (p + 1) // 2 + factor - 1 | ||||||
|  |         pad1 = p // 2 | ||||||
|  | 
 | ||||||
|  |         self.pad = (pad0, pad1) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) | ||||||
|  | 
 | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Downsample(nn.Module): | ||||||
|  |     def __init__(self, kernel, factor=2): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.factor = factor | ||||||
|  |         kernel = make_kernel(kernel) | ||||||
|  |         self.register_buffer('kernel', kernel) | ||||||
|  | 
 | ||||||
|  |         p = kernel.shape[0] - factor | ||||||
|  | 
 | ||||||
|  |         pad0 = (p + 1) // 2 | ||||||
|  |         pad1 = p // 2 | ||||||
|  | 
 | ||||||
|  |         self.pad = (pad0, pad1) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) | ||||||
|  | 
 | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Blur(nn.Module): | ||||||
|  |     def __init__(self, kernel, pad, upsample_factor=1): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         kernel = make_kernel(kernel) | ||||||
|  | 
 | ||||||
|  |         if upsample_factor > 1: | ||||||
|  |             kernel = kernel * (upsample_factor ** 2) | ||||||
|  | 
 | ||||||
|  |         self.register_buffer('kernel', kernel) | ||||||
|  | 
 | ||||||
|  |         self.pad = pad | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         out = upfirdn2d(input, self.kernel, pad=self.pad) | ||||||
|  | 
 | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class EqualConv2d(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.weight = nn.Parameter( | ||||||
|  |             torch.randn(out_channel, in_channel, kernel_size, kernel_size) | ||||||
|  |         ) | ||||||
|  |         self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) | ||||||
|  | 
 | ||||||
|  |         self.stride = stride | ||||||
|  |         self.padding = padding | ||||||
|  | 
 | ||||||
|  |         if bias: | ||||||
|  |             self.bias = nn.Parameter(torch.zeros(out_channel)) | ||||||
|  | 
 | ||||||
|  |         else: | ||||||
|  |             self.bias = None | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         out = F.conv2d( | ||||||
|  |             input, | ||||||
|  |             self.weight * self.scale, | ||||||
|  |             bias=self.bias, | ||||||
|  |             stride=self.stride, | ||||||
|  |             padding=self.padding, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  |     def __repr__(self): | ||||||
|  |         return ( | ||||||
|  |             f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' | ||||||
|  |             f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | #定义了一个线性激活层 | ||||||
|  | class EqualLinear(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) | ||||||
|  | 
 | ||||||
|  |         if bias: | ||||||
|  |             self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) | ||||||
|  | 
 | ||||||
|  |         else: | ||||||
|  |             self.bias = None | ||||||
|  | 
 | ||||||
|  |         self.activation = activation | ||||||
|  | 
 | ||||||
|  |         self.scale = (1 / math.sqrt(in_dim)) * lr_mul | ||||||
|  |         self.lr_mul = lr_mul | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  | 
 | ||||||
|  |         if self.activation: | ||||||
|  |             out = F.linear(input, self.weight * self.scale) | ||||||
|  |             out = fused_leaky_relu(out, self.bias * self.lr_mul) | ||||||
|  | 
 | ||||||
|  |         else: | ||||||
|  |             out = F.linear( | ||||||
|  |                 input, self.weight * self.scale, bias=self.bias * self.lr_mul | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  |     def __repr__(self): | ||||||
|  |         return ( | ||||||
|  |             f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ScaledLeakyReLU(nn.Module): | ||||||
|  |     def __init__(self, negative_slope=0.2): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.negative_slope = negative_slope | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         out = F.leaky_relu(input, negative_slope=self.negative_slope) | ||||||
|  | 
 | ||||||
|  |         return out * math.sqrt(2) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ModulatedConv2d(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         in_channel, | ||||||
|  |         out_channel, | ||||||
|  |         kernel_size, | ||||||
|  |         style_dim, | ||||||
|  |         demodulate=True, | ||||||
|  |         upsample=False, | ||||||
|  |         #给卷积核乘以放缩参数 | ||||||
|  |         downsample=False, | ||||||
|  |         blur_kernel=[1, 3, 3, 1], | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.eps = 1e-8 | ||||||
|  |         self.kernel_size = kernel_size | ||||||
|  |         self.in_channel = in_channel | ||||||
|  |         self.out_channel = out_channel | ||||||
|  |         self.upsample = upsample | ||||||
|  |         self.downsample = downsample | ||||||
|  | 
 | ||||||
|  |         if upsample: | ||||||
|  |             factor = 2 | ||||||
|  |             p = (len(blur_kernel) - factor) - (kernel_size - 1) | ||||||
|  |             pad0 = (p + 1) // 2 + factor - 1 | ||||||
|  |             pad1 = p // 2 + 1 | ||||||
|  | 
 | ||||||
|  |             self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) | ||||||
|  | 
 | ||||||
|  |         if downsample: | ||||||
|  |             factor = 2 | ||||||
|  |             p = (len(blur_kernel) - factor) + (kernel_size - 1) | ||||||
|  |             pad0 = (p + 1) // 2 | ||||||
|  |             pad1 = p // 2 | ||||||
|  | 
 | ||||||
|  |             self.blur = Blur(blur_kernel, pad=(pad0, pad1)) | ||||||
|  | 
 | ||||||
|  |         fan_in = in_channel * kernel_size ** 2 | ||||||
|  |         self.scale = 1 / math.sqrt(fan_in) | ||||||
|  |         self.padding = kernel_size // 2 | ||||||
|  | 
 | ||||||
|  |         self.weight = nn.Parameter( | ||||||
|  |             torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) | ||||||
|  | 
 | ||||||
|  |         self.demodulate = demodulate | ||||||
|  | 
 | ||||||
|  |     def __repr__(self): | ||||||
|  |     #返回了模型的各个参数的字符串 | ||||||
|  |         return ( | ||||||
|  |             f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' | ||||||
|  |             f'upsample={self.upsample}, downsample={self.downsample})' | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input, style, input_is_stylespace=False): | ||||||
|  |         batch, in_channel, height, width = input.shape | ||||||
|  | 
 | ||||||
|  |         if not input_is_stylespace: | ||||||
|  |             style = self.modulation(style).view(batch, 1, in_channel, 1, 1) | ||||||
|  |         weight = self.scale * self.weight * style | ||||||
|  | 
 | ||||||
|  |         #对权重进行解调 | ||||||
|  |         if self.demodulate: | ||||||
|  |             #类似标准差计算,平方求和再反平方,目的是计算每个权重向量的解调因子 | ||||||
|  |             demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) | ||||||
|  |             #demod是一个解调因子矩阵,通过demod.view()将其形状调整为与权重矩阵相同的形状,以便进行逐元素的相乘操作。 | ||||||
|  |             weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) | ||||||
|  | 
 | ||||||
|  |         weight = weight.view( | ||||||
|  |             batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         if self.upsample: | ||||||
|  |             input = input.view(1, batch * in_channel, height, width) | ||||||
|  |             weight = weight.view( | ||||||
|  |                 batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size | ||||||
|  |             ) | ||||||
|  |             weight = weight.transpose(1, 2).reshape( | ||||||
|  |                 batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size | ||||||
|  |             ) | ||||||
|  |             out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) | ||||||
|  |             _, _, height, width = out.shape | ||||||
|  |             out = out.view(batch, self.out_channel, height, width) | ||||||
|  |             out = self.blur(out) | ||||||
|  | 
 | ||||||
|  |         elif self.downsample: | ||||||
|  |             input = self.blur(input) | ||||||
|  |             _, _, height, width = input.shape | ||||||
|  |             input = input.view(1, batch * in_channel, height, width) | ||||||
|  |             out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) | ||||||
|  |             _, _, height, width = out.shape | ||||||
|  |             out = out.view(batch, self.out_channel, height, width) | ||||||
|  | 
 | ||||||
|  |         else: | ||||||
|  |             input = input.view(1, batch * in_channel, height, width) | ||||||
|  |             out = F.conv2d(input, weight, padding=self.padding, groups=batch) | ||||||
|  |             _, _, height, width = out.shape | ||||||
|  |             out = out.view(batch, self.out_channel, height, width) | ||||||
|  | 
 | ||||||
|  |         return out, style | ||||||
|  | 
 | ||||||
|  | # 用噪声 ( noise ) 来影响头发丝、皱纹、肤色等细节部分。 | ||||||
|  | class NoiseInjection(nn.Module): | ||||||
|  |     def __init__(self): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.weight = nn.Parameter(torch.zeros(1)) | ||||||
|  | 
 | ||||||
|  |     def forward(self, image, noise=None): | ||||||
|  |         if noise is None: | ||||||
|  |             batch, _, height, width = image.shape | ||||||
|  |             noise = image.new_empty(batch, 1, height, width).normal_() | ||||||
|  | 
 | ||||||
|  |         return image + self.weight * noise | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ConstantInput(nn.Module): | ||||||
|  |     def __init__(self, channel, size=4): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.input = nn.Parameter(torch.randn(1, channel, size, size)) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         batch = input.shape[0] | ||||||
|  |         out = self.input.repeat(batch, 1, 1, 1) | ||||||
|  | 
 | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class StyledConv(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         in_channel, | ||||||
|  |         out_channel, | ||||||
|  |         kernel_size, | ||||||
|  |         style_dim, | ||||||
|  |         upsample=False, | ||||||
|  |         blur_kernel=[1, 3, 3, 1], | ||||||
|  |         demodulate=True, | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.conv = ModulatedConv2d( | ||||||
|  |             in_channel, | ||||||
|  |             out_channel, | ||||||
|  |             kernel_size, | ||||||
|  |             style_dim, | ||||||
|  |             upsample=upsample, | ||||||
|  |             blur_kernel=blur_kernel, | ||||||
|  |             demodulate=demodulate, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self.noise = NoiseInjection() | ||||||
|  |         # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) | ||||||
|  |         # self.activate = ScaledLeakyReLU(0.2) | ||||||
|  |         self.activate = FusedLeakyReLU(out_channel) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input, style, noise=None, input_is_stylespace=False): | ||||||
|  |         out, style = self.conv(input, style, input_is_stylespace=input_is_stylespace) | ||||||
|  |         out = self.noise(out, noise=noise) | ||||||
|  |         # out = out + self.bias | ||||||
|  |         out = self.activate(out) | ||||||
|  | 
 | ||||||
|  |         return out, style | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ToRGB(nn.Module): | ||||||
|  |     def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         if upsample: | ||||||
|  |             self.upsample = Upsample(blur_kernel) | ||||||
|  | 
 | ||||||
|  |         #ToRGB层不进行demodulate处理 | ||||||
|  |         self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) | ||||||
|  |         self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input, style, skip=None, input_is_stylespace=False): | ||||||
|  |         out, style = self.conv(input, style, input_is_stylespace=input_is_stylespace) | ||||||
|  |         out = out + self.bias | ||||||
|  | 
 | ||||||
|  |         if skip is not None: | ||||||
|  |             skip = self.upsample(skip) | ||||||
|  | 
 | ||||||
|  |             out = out + skip | ||||||
|  | 
 | ||||||
|  |         return out, style | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Generator(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         size, | ||||||
|  |         style_dim, | ||||||
|  |         n_mlp, | ||||||
|  |         channel_multiplier=2, | ||||||
|  |         blur_kernel=[1, 3, 3, 1], | ||||||
|  |         lr_mlp=0.01, | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.size = size | ||||||
|  | 
 | ||||||
|  |         self.style_dim = style_dim | ||||||
|  | 
 | ||||||
|  |         layers = [PixelNorm()] | ||||||
|  | 
 | ||||||
|  |         for i in range(n_mlp): | ||||||
|  |             layers.append( | ||||||
|  |                 EqualLinear( | ||||||
|  |                     style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         self.style = nn.Sequential(*layers) | ||||||
|  | 
 | ||||||
|  |         self.channels = { | ||||||
|  |             4: 512, | ||||||
|  |             8: 512, | ||||||
|  |             16: 512, | ||||||
|  |             32: 512, | ||||||
|  |             64: 256 * channel_multiplier, | ||||||
|  |             128: 128 * channel_multiplier, | ||||||
|  |             256: 64 * channel_multiplier, | ||||||
|  |             512: 32 * channel_multiplier, | ||||||
|  |             1024: 16 * channel_multiplier, | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         self.input = ConstantInput(self.channels[4]) | ||||||
|  |         self.conv1 = StyledConv( | ||||||
|  |             self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel | ||||||
|  |         ) | ||||||
|  |         self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) | ||||||
|  | 
 | ||||||
|  |         self.log_size = int(math.log(size, 2))  #log(1024,2) = 10 | ||||||
|  |         self.num_layers = (self.log_size - 2) * 2 + 1 | ||||||
|  | 
 | ||||||
|  |         self.convs = nn.ModuleList() | ||||||
|  |         self.upsamples = nn.ModuleList() | ||||||
|  |         self.to_rgbs = nn.ModuleList() | ||||||
|  |         self.noises = nn.Module() | ||||||
|  | 
 | ||||||
|  |         in_channel = self.channels[4] | ||||||
|  | 
 | ||||||
|  |         for layer_idx in range(self.num_layers): | ||||||
|  |             res = (layer_idx + 5) // 2 | ||||||
|  |             shape = [1, 1, 2 ** res, 2 ** res] | ||||||
|  |             self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) | ||||||
|  | 
 | ||||||
|  |         for i in range(3, self.log_size + 1): | ||||||
|  |             out_channel = self.channels[2 ** i] | ||||||
|  | 
 | ||||||
|  |             self.convs.append( | ||||||
|  |                 StyledConv( | ||||||
|  |                     in_channel, | ||||||
|  |                     out_channel, | ||||||
|  |                     3, | ||||||
|  |                     style_dim, | ||||||
|  |                     upsample=True, | ||||||
|  |                     blur_kernel=blur_kernel, | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             self.convs.append( | ||||||
|  |                 StyledConv( | ||||||
|  |                     out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             self.to_rgbs.append(ToRGB(out_channel, style_dim)) | ||||||
|  | 
 | ||||||
|  |             in_channel = out_channel | ||||||
|  |         # w+ repeat的倍数,例如1024计算为18,实际上就是上采样层1+8*2+1,因为第一层只需要一个style最后又多了一层to_rgb用了style,其中8个block每个上采样层之前均要加入两次style | ||||||
|  |         self.n_latent = self.log_size * 2 - 2 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     def make_noise(self): | ||||||
|  |         device = self.input.input.device | ||||||
|  | 
 | ||||||
|  |         noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] | ||||||
|  | 
 | ||||||
|  |         for i in range(3, self.log_size + 1): | ||||||
|  |             for _ in range(2): | ||||||
|  |                 noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) | ||||||
|  | 
 | ||||||
|  |         return noises | ||||||
|  | 
 | ||||||
|  |     def mean_latent(self, n_latent): | ||||||
|  |         latent_in = torch.randn( | ||||||
|  |             n_latent, self.style_dim, device=self.input.input.device | ||||||
|  |         ) | ||||||
|  |         latent = self.style(latent_in).mean(0, keepdim=True) | ||||||
|  | 
 | ||||||
|  |         return latent | ||||||
|  | 
 | ||||||
|  |     def get_latent(self, input): | ||||||
|  |         return self.style(input) | ||||||
|  | 
 | ||||||
|  |     def forward( | ||||||
|  |         self, | ||||||
|  |         styles, | ||||||
|  |         return_latents=False, | ||||||
|  |         inject_index=None, | ||||||
|  |         truncation=1, | ||||||
|  |         truncation_latent=None, | ||||||
|  |         input_is_latent=False, | ||||||
|  |         input_is_stylespace=False, | ||||||
|  |         noise=None, | ||||||
|  |         randomize_noise=True, | ||||||
|  |     ): | ||||||
|  |         if not input_is_latent and not input_is_stylespace: | ||||||
|  |             styles = [self.style(s) for s in styles] | ||||||
|  | 
 | ||||||
|  |         if noise is None: | ||||||
|  |             if randomize_noise: | ||||||
|  |                 noise = [None] * self.num_layers | ||||||
|  |             else: | ||||||
|  |                 noise = [ | ||||||
|  |                     getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) | ||||||
|  |                 ] | ||||||
|  | 
 | ||||||
|  |         if truncation < 1 and not input_is_stylespace: | ||||||
|  |             style_t = [] | ||||||
|  | 
 | ||||||
|  |             for style in styles: | ||||||
|  |                 style_t.append( | ||||||
|  |                     truncation_latent + truncation * (style - truncation_latent) | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |             styles = style_t | ||||||
|  | 
 | ||||||
|  |         if input_is_stylespace: | ||||||
|  |             latent = styles[0] | ||||||
|  |         elif len(styles) < 2: | ||||||
|  |             inject_index = self.n_latent | ||||||
|  | 
 | ||||||
|  |             if styles[0].ndim < 3: | ||||||
|  |                 latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) | ||||||
|  | 
 | ||||||
|  |             else: | ||||||
|  |                 latent = styles[0] | ||||||
|  | 
 | ||||||
|  |         else: | ||||||
|  |             if inject_index is None: | ||||||
|  |                 inject_index = random.randint(1, self.n_latent - 1) | ||||||
|  | 
 | ||||||
|  |             latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) | ||||||
|  |             latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) | ||||||
|  | 
 | ||||||
|  |             latent = torch.cat([latent, latent2], 1) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         style_vector = [] | ||||||
|  | 
 | ||||||
|  |         if not input_is_stylespace: | ||||||
|  |             out = self.input(latent) | ||||||
|  |             # print('laten:',latent.shape)   # torch.Size([1, 18, 512]) | ||||||
|  |             out, out_style = self.conv1(out, latent[:, 0], noise=noise[0]) | ||||||
|  |             style_vector.append(out_style) | ||||||
|  | 
 | ||||||
|  |             skip, out_style = self.to_rgb1(out, latent[:, 1]) | ||||||
|  |             style_vector.append(out_style) | ||||||
|  | 
 | ||||||
|  |             i = 1 | ||||||
|  |         else: | ||||||
|  |             out = self.input(latent[0]) | ||||||
|  |             out, out_style = self.conv1(out, latent[0], noise=noise[0], input_is_stylespace=input_is_stylespace) | ||||||
|  |             style_vector.append(out_style) | ||||||
|  | 
 | ||||||
|  |             skip, out_style = self.to_rgb1(out, latent[1], input_is_stylespace=input_is_stylespace) | ||||||
|  |             style_vector.append(out_style) | ||||||
|  | 
 | ||||||
|  |             i = 2 | ||||||
|  | 
 | ||||||
|  |         for conv1, conv2, noise1, noise2, to_rgb in zip( | ||||||
|  |             self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs | ||||||
|  |         ): | ||||||
|  |             if not input_is_stylespace: | ||||||
|  |                 out, out_style1 = conv1(out, latent[:, i], noise=noise1) | ||||||
|  |                 out, out_style2 = conv2(out, latent[:, i + 1], noise=noise2) | ||||||
|  |                 skip, rgb_style = to_rgb(out, latent[:, i + 2], skip) | ||||||
|  | 
 | ||||||
|  |                 style_vector.extend([out_style1, out_style2, rgb_style]) | ||||||
|  | 
 | ||||||
|  |                 i += 2 | ||||||
|  |             else: | ||||||
|  |                 out, out_style1 = conv1(out, latent[i], noise=noise1, input_is_stylespace=input_is_stylespace) | ||||||
|  |                 out, out_style2 = conv2(out, latent[i + 1], noise=noise2, input_is_stylespace=input_is_stylespace) | ||||||
|  |                 skip, rgb_style = to_rgb(out, latent[i + 2], skip, input_is_stylespace=input_is_stylespace) | ||||||
|  | 
 | ||||||
|  |                 style_vector.extend([out_style1, out_style2, rgb_style]) | ||||||
|  | 
 | ||||||
|  |                 i += 3 | ||||||
|  | 
 | ||||||
|  |         image = skip | ||||||
|  | 
 | ||||||
|  |         if return_latents: | ||||||
|  |             return image, latent, style_vector | ||||||
|  | 
 | ||||||
|  |         else: | ||||||
|  |             return image, None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ConvLayer(nn.Sequential): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         in_channel, | ||||||
|  |         out_channel, | ||||||
|  |         kernel_size, | ||||||
|  |         downsample=False, | ||||||
|  |         blur_kernel=[1, 3, 3, 1], | ||||||
|  |         bias=True, | ||||||
|  |         activate=True, | ||||||
|  |     ): | ||||||
|  |         layers = [] | ||||||
|  | 
 | ||||||
|  |         if downsample: | ||||||
|  |             factor = 2 | ||||||
|  |             p = (len(blur_kernel) - factor) + (kernel_size - 1) | ||||||
|  |             pad0 = (p + 1) // 2 | ||||||
|  |             pad1 = p // 2 | ||||||
|  | 
 | ||||||
|  |             layers.append(Blur(blur_kernel, pad=(pad0, pad1))) | ||||||
|  | 
 | ||||||
|  |             stride = 2 | ||||||
|  |             self.padding = 0 | ||||||
|  | 
 | ||||||
|  |         else: | ||||||
|  |             stride = 1 | ||||||
|  |             self.padding = kernel_size // 2 | ||||||
|  | 
 | ||||||
|  |         layers.append( | ||||||
|  |             EqualConv2d( | ||||||
|  |                 in_channel, | ||||||
|  |                 out_channel, | ||||||
|  |                 kernel_size, | ||||||
|  |                 padding=self.padding, | ||||||
|  |                 stride=stride, | ||||||
|  |                 bias=bias and not activate, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         if activate: | ||||||
|  |             if bias: | ||||||
|  |                 layers.append(FusedLeakyReLU(out_channel)) | ||||||
|  | 
 | ||||||
|  |             else: | ||||||
|  |                 layers.append(ScaledLeakyReLU(0.2)) | ||||||
|  | 
 | ||||||
|  |         super().__init__(*layers) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ResBlock(nn.Module): | ||||||
|  |     def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.conv1 = ConvLayer(in_channel, in_channel, 3) | ||||||
|  |         self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) | ||||||
|  | 
 | ||||||
|  |         self.skip = ConvLayer( | ||||||
|  |             in_channel, out_channel, 1, downsample=True, activate=False, bias=False | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         out = self.conv1(input) | ||||||
|  |         out = self.conv2(out) | ||||||
|  | 
 | ||||||
|  |         skip = self.skip(input) | ||||||
|  |         out = (out + skip) / math.sqrt(2) | ||||||
|  | 
 | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Discriminator(nn.Module): | ||||||
|  |     def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         channels = { | ||||||
|  |             4: 512, | ||||||
|  |             8: 512, | ||||||
|  |             16: 512, | ||||||
|  |             32: 512, | ||||||
|  |             64: 256 * channel_multiplier, | ||||||
|  |             128: 128 * channel_multiplier, | ||||||
|  |             256: 64 * channel_multiplier, | ||||||
|  |             512: 32 * channel_multiplier, | ||||||
|  |             1024: 16 * channel_multiplier, | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         convs = [ConvLayer(3, channels[size], 1)] | ||||||
|  | 
 | ||||||
|  |         log_size = int(math.log(size, 2)) | ||||||
|  | 
 | ||||||
|  |         in_channel = channels[size] | ||||||
|  | 
 | ||||||
|  |         #这里代码是8个大残差block,让feature map大小从1024到4 | ||||||
|  |         for i in range(log_size, 2, -1): | ||||||
|  |             out_channel = channels[2 ** (i - 1)] | ||||||
|  | 
 | ||||||
|  |             convs.append(ResBlock(in_channel, out_channel, blur_kernel)) | ||||||
|  | 
 | ||||||
|  |             in_channel = out_channel | ||||||
|  | 
 | ||||||
|  |         self.convs = nn.Sequential(*convs) | ||||||
|  | 
 | ||||||
|  |         self.stddev_group = 4 | ||||||
|  |         self.stddev_feat = 1 | ||||||
|  | 
 | ||||||
|  |         self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) | ||||||
|  |         self.final_linear = nn.Sequential( | ||||||
|  |             EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), | ||||||
|  |             EqualLinear(channels[4], 1), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         out = self.convs(input) | ||||||
|  | 
 | ||||||
|  |         batch, channel, height, width = out.shape | ||||||
|  |         group = min(batch, self.stddev_group) | ||||||
|  |         stddev = out.view( | ||||||
|  |             group, -1, self.stddev_feat, channel // self.stddev_feat, height, width | ||||||
|  |         ) | ||||||
|  |         stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) | ||||||
|  |         stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) | ||||||
|  |         stddev = stddev.repeat(group, 1, height, width) | ||||||
|  |         out = torch.cat([out, stddev], 1) | ||||||
|  | 
 | ||||||
|  |         out = self.final_conv(out) | ||||||
|  | 
 | ||||||
|  |         out = out.view(batch, -1) | ||||||
|  |         out = self.final_linear(out) | ||||||
|  | 
 | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
							
								
								
									
										2
									
								
								models/stylegan2/op/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								models/stylegan2/op/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | |||||||
|  | from .fused_act import FusedLeakyReLU, fused_leaky_relu | ||||||
|  | from .upfirdn2d import upfirdn2d | ||||||
							
								
								
									
										
											BIN
										
									
								
								models/stylegan2/op/__pycache__/__init__.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								models/stylegan2/op/__pycache__/__init__.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								models/stylegan2/op/__pycache__/fused_act.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								models/stylegan2/op/__pycache__/fused_act.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								models/stylegan2/op/__pycache__/upfirdn2d.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								models/stylegan2/op/__pycache__/upfirdn2d.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										40
									
								
								models/stylegan2/op/fused_act.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								models/stylegan2/op/fused_act.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,40 @@ | |||||||
|  | import os | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | from torch import nn | ||||||
|  | from torch.nn import functional as F | ||||||
|  | 
 | ||||||
|  | module_path = os.path.dirname(__file__) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class FusedLeakyReLU(nn.Module): | ||||||
|  |     def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.bias = nn.Parameter(torch.zeros(channel)) | ||||||
|  |         self.negative_slope = negative_slope | ||||||
|  |         self.scale = scale | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): | ||||||
|  |     rest_dim = [1] * (input.ndim - bias.ndim - 1) | ||||||
|  |     input = input.cuda() | ||||||
|  |     if input.ndim == 3: | ||||||
|  |         return ( | ||||||
|  |             F.leaky_relu( | ||||||
|  |                 input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope | ||||||
|  |             ) | ||||||
|  |             * scale  #增益值,激活函数里的 gain(torch中scale) 是一个增益值,增益值是指的非线性函数稳态时输入幅度与输出幅度的比值,通常被用来乘在激活函数之后使激活函数更加稳定。 | ||||||
|  |         ) | ||||||
|  |     else: | ||||||
|  |         return ( | ||||||
|  |             F.leaky_relu( | ||||||
|  |                 input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope | ||||||
|  |             ) | ||||||
|  |             * scale | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
							
								
								
									
										60
									
								
								models/stylegan2/op/upfirdn2d.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								models/stylegan2/op/upfirdn2d.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,60 @@ | |||||||
|  | import os | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | from torch.nn import functional as F | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | module_path = os.path.dirname(__file__) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): | ||||||
|  |     out = upfirdn2d_native( | ||||||
|  |         input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     return out | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def upfirdn2d_native( | ||||||
|  |     input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 | ||||||
|  | ): | ||||||
|  |     _, channel, in_h, in_w = input.shape | ||||||
|  |     input = input.reshape(-1, in_h, in_w, 1) | ||||||
|  | 
 | ||||||
|  |     _, in_h, in_w, minor = input.shape | ||||||
|  |     kernel_h, kernel_w = kernel.shape | ||||||
|  | 
 | ||||||
|  |     out = input.view(-1, in_h, 1, in_w, 1, minor) | ||||||
|  |     out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) | ||||||
|  |     out = out.view(-1, in_h * up_y, in_w * up_x, minor) | ||||||
|  | 
 | ||||||
|  |     out = F.pad( | ||||||
|  |         out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] | ||||||
|  |     ) | ||||||
|  |     out = out[ | ||||||
|  |         :, | ||||||
|  |         max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), | ||||||
|  |         max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), | ||||||
|  |         :, | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  |     out = out.permute(0, 3, 1, 2) | ||||||
|  |     out = out.reshape( | ||||||
|  |         [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] | ||||||
|  |     ) | ||||||
|  |     w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) | ||||||
|  |     out = F.conv2d(out, w) | ||||||
|  |     out = out.reshape( | ||||||
|  |         -1, | ||||||
|  |         minor, | ||||||
|  |         in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, | ||||||
|  |         in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, | ||||||
|  |     ) | ||||||
|  |     out = out.permute(0, 2, 3, 1) | ||||||
|  |     out = out[:, ::down_y, ::down_x, :] | ||||||
|  | 
 | ||||||
|  |     out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 | ||||||
|  |     out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 | ||||||
|  | 
 | ||||||
|  |     return out.view(-1, channel, out_h, out_w) | ||||||
							
								
								
									
										9
									
								
								models/stylegan3/dnnlib/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								models/stylegan3/dnnlib/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | from .util import EasyDict, make_cache_dir_path | ||||||
							
								
								
									
										
											BIN
										
									
								
								models/stylegan3/dnnlib/__pycache__/__init__.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								models/stylegan3/dnnlib/__pycache__/__init__.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								models/stylegan3/dnnlib/__pycache__/util.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								models/stylegan3/dnnlib/__pycache__/util.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										491
									
								
								models/stylegan3/dnnlib/util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										491
									
								
								models/stylegan3/dnnlib/util.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,491 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Miscellaneous utility classes and functions.""" | ||||||
|  | 
 | ||||||
|  | import ctypes | ||||||
|  | import fnmatch | ||||||
|  | import importlib | ||||||
|  | import inspect | ||||||
|  | import numpy as np | ||||||
|  | import os | ||||||
|  | import shutil | ||||||
|  | import sys | ||||||
|  | import types | ||||||
|  | import io | ||||||
|  | import pickle | ||||||
|  | import re | ||||||
|  | import requests | ||||||
|  | import html | ||||||
|  | import hashlib | ||||||
|  | import glob | ||||||
|  | import tempfile | ||||||
|  | import urllib | ||||||
|  | import urllib.request | ||||||
|  | import uuid | ||||||
|  | 
 | ||||||
|  | from distutils.util import strtobool | ||||||
|  | from typing import Any, List, Tuple, Union | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Util classes | ||||||
|  | # ------------------------------------------------------------------------------------------ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class EasyDict(dict): | ||||||
|  |     """Convenience class that behaves like a dict but allows access with the attribute syntax.""" | ||||||
|  | 
 | ||||||
|  |     def __getattr__(self, name: str) -> Any: | ||||||
|  |         try: | ||||||
|  |             return self[name] | ||||||
|  |         except KeyError: | ||||||
|  |             raise AttributeError(name) | ||||||
|  | 
 | ||||||
|  |     def __setattr__(self, name: str, value: Any) -> None: | ||||||
|  |         self[name] = value | ||||||
|  | 
 | ||||||
|  |     def __delattr__(self, name: str) -> None: | ||||||
|  |         del self[name] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Logger(object): | ||||||
|  |     """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" | ||||||
|  | 
 | ||||||
|  |     def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): | ||||||
|  |         self.file = None | ||||||
|  | 
 | ||||||
|  |         if file_name is not None: | ||||||
|  |             self.file = open(file_name, file_mode) | ||||||
|  | 
 | ||||||
|  |         self.should_flush = should_flush | ||||||
|  |         self.stdout = sys.stdout | ||||||
|  |         self.stderr = sys.stderr | ||||||
|  | 
 | ||||||
|  |         sys.stdout = self | ||||||
|  |         sys.stderr = self | ||||||
|  | 
 | ||||||
|  |     def __enter__(self) -> "Logger": | ||||||
|  |         return self | ||||||
|  | 
 | ||||||
|  |     def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: | ||||||
|  |         self.close() | ||||||
|  | 
 | ||||||
|  |     def write(self, text: Union[str, bytes]) -> None: | ||||||
|  |         """Write text to stdout (and a file) and optionally flush.""" | ||||||
|  |         if isinstance(text, bytes): | ||||||
|  |             text = text.decode() | ||||||
|  |         if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         if self.file is not None: | ||||||
|  |             self.file.write(text) | ||||||
|  | 
 | ||||||
|  |         self.stdout.write(text) | ||||||
|  | 
 | ||||||
|  |         if self.should_flush: | ||||||
|  |             self.flush() | ||||||
|  | 
 | ||||||
|  |     def flush(self) -> None: | ||||||
|  |         """Flush written text to both stdout and a file, if open.""" | ||||||
|  |         if self.file is not None: | ||||||
|  |             self.file.flush() | ||||||
|  | 
 | ||||||
|  |         self.stdout.flush() | ||||||
|  | 
 | ||||||
|  |     def close(self) -> None: | ||||||
|  |         """Flush, close possible files, and remove stdout/stderr mirroring.""" | ||||||
|  |         self.flush() | ||||||
|  | 
 | ||||||
|  |         # if using multiple loggers, prevent closing in wrong order | ||||||
|  |         if sys.stdout is self: | ||||||
|  |             sys.stdout = self.stdout | ||||||
|  |         if sys.stderr is self: | ||||||
|  |             sys.stderr = self.stderr | ||||||
|  | 
 | ||||||
|  |         if self.file is not None: | ||||||
|  |             self.file.close() | ||||||
|  |             self.file = None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Cache directories | ||||||
|  | # ------------------------------------------------------------------------------------------ | ||||||
|  | 
 | ||||||
|  | _dnnlib_cache_dir = None | ||||||
|  | 
 | ||||||
|  | def set_cache_dir(path: str) -> None: | ||||||
|  |     global _dnnlib_cache_dir | ||||||
|  |     _dnnlib_cache_dir = path | ||||||
|  | 
 | ||||||
|  | def make_cache_dir_path(*paths: str) -> str: | ||||||
|  |     if _dnnlib_cache_dir is not None: | ||||||
|  |         return os.path.join(_dnnlib_cache_dir, *paths) | ||||||
|  |     if 'DNNLIB_CACHE_DIR' in os.environ: | ||||||
|  |         return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) | ||||||
|  |     if 'HOME' in os.environ: | ||||||
|  |         return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) | ||||||
|  |     if 'USERPROFILE' in os.environ: | ||||||
|  |         return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) | ||||||
|  |     return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) | ||||||
|  | 
 | ||||||
|  | # Small util functions | ||||||
|  | # ------------------------------------------------------------------------------------------ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def format_time(seconds: Union[int, float]) -> str: | ||||||
|  |     """Convert the seconds to human readable string with days, hours, minutes and seconds.""" | ||||||
|  |     s = int(np.rint(seconds)) | ||||||
|  | 
 | ||||||
|  |     if s < 60: | ||||||
|  |         return "{0}s".format(s) | ||||||
|  |     elif s < 60 * 60: | ||||||
|  |         return "{0}m {1:02}s".format(s // 60, s % 60) | ||||||
|  |     elif s < 24 * 60 * 60: | ||||||
|  |         return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) | ||||||
|  |     else: | ||||||
|  |         return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def format_time_brief(seconds: Union[int, float]) -> str: | ||||||
|  |     """Convert the seconds to human readable string with days, hours, minutes and seconds.""" | ||||||
|  |     s = int(np.rint(seconds)) | ||||||
|  | 
 | ||||||
|  |     if s < 60: | ||||||
|  |         return "{0}s".format(s) | ||||||
|  |     elif s < 60 * 60: | ||||||
|  |         return "{0}m {1:02}s".format(s // 60, s % 60) | ||||||
|  |     elif s < 24 * 60 * 60: | ||||||
|  |         return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) | ||||||
|  |     else: | ||||||
|  |         return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def ask_yes_no(question: str) -> bool: | ||||||
|  |     """Ask the user the question until the user inputs a valid answer.""" | ||||||
|  |     while True: | ||||||
|  |         try: | ||||||
|  |             print("{0} [y/n]".format(question)) | ||||||
|  |             return strtobool(input().lower()) | ||||||
|  |         except ValueError: | ||||||
|  |             pass | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def tuple_product(t: Tuple) -> Any: | ||||||
|  |     """Calculate the product of the tuple elements.""" | ||||||
|  |     result = 1 | ||||||
|  | 
 | ||||||
|  |     for v in t: | ||||||
|  |         result *= v | ||||||
|  | 
 | ||||||
|  |     return result | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | _str_to_ctype = { | ||||||
|  |     "uint8": ctypes.c_ubyte, | ||||||
|  |     "uint16": ctypes.c_uint16, | ||||||
|  |     "uint32": ctypes.c_uint32, | ||||||
|  |     "uint64": ctypes.c_uint64, | ||||||
|  |     "int8": ctypes.c_byte, | ||||||
|  |     "int16": ctypes.c_int16, | ||||||
|  |     "int32": ctypes.c_int32, | ||||||
|  |     "int64": ctypes.c_int64, | ||||||
|  |     "float32": ctypes.c_float, | ||||||
|  |     "float64": ctypes.c_double | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: | ||||||
|  |     """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" | ||||||
|  |     type_str = None | ||||||
|  | 
 | ||||||
|  |     if isinstance(type_obj, str): | ||||||
|  |         type_str = type_obj | ||||||
|  |     elif hasattr(type_obj, "__name__"): | ||||||
|  |         type_str = type_obj.__name__ | ||||||
|  |     elif hasattr(type_obj, "name"): | ||||||
|  |         type_str = type_obj.name | ||||||
|  |     else: | ||||||
|  |         raise RuntimeError("Cannot infer type name from input") | ||||||
|  | 
 | ||||||
|  |     assert type_str in _str_to_ctype.keys() | ||||||
|  | 
 | ||||||
|  |     my_dtype = np.dtype(type_str) | ||||||
|  |     my_ctype = _str_to_ctype[type_str] | ||||||
|  | 
 | ||||||
|  |     assert my_dtype.itemsize == ctypes.sizeof(my_ctype) | ||||||
|  | 
 | ||||||
|  |     return my_dtype, my_ctype | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def is_pickleable(obj: Any) -> bool: | ||||||
|  |     try: | ||||||
|  |         with io.BytesIO() as stream: | ||||||
|  |             pickle.dump(obj, stream) | ||||||
|  |         return True | ||||||
|  |     except: | ||||||
|  |         return False | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Functionality to import modules/objects by name, and call functions by name | ||||||
|  | # ------------------------------------------------------------------------------------------ | ||||||
|  | 
 | ||||||
|  | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: | ||||||
|  |     """Searches for the underlying module behind the name to some python object. | ||||||
|  |     Returns the module and the object name (original name with module part removed).""" | ||||||
|  | 
 | ||||||
|  |     # allow convenience shorthands, substitute them by full names | ||||||
|  |     obj_name = re.sub("^np.", "numpy.", obj_name) | ||||||
|  |     obj_name = re.sub("^tf.", "tensorflow.", obj_name) | ||||||
|  | 
 | ||||||
|  |     # list alternatives for (module_name, local_obj_name) | ||||||
|  |     parts = obj_name.split(".") | ||||||
|  |     name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] | ||||||
|  | 
 | ||||||
|  |     # try each alternative in turn | ||||||
|  |     for module_name, local_obj_name in name_pairs: | ||||||
|  |         try: | ||||||
|  |             module = importlib.import_module(module_name) # may raise ImportError | ||||||
|  |             get_obj_from_module(module, local_obj_name) # may raise AttributeError | ||||||
|  |             return module, local_obj_name | ||||||
|  |         except: | ||||||
|  |             pass | ||||||
|  | 
 | ||||||
|  |     # maybe some of the modules themselves contain errors? | ||||||
|  |     for module_name, _local_obj_name in name_pairs: | ||||||
|  |         try: | ||||||
|  |             importlib.import_module(module_name) # may raise ImportError | ||||||
|  |         except ImportError: | ||||||
|  |             if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): | ||||||
|  |                 raise | ||||||
|  | 
 | ||||||
|  |     # maybe the requested attribute is missing? | ||||||
|  |     for module_name, local_obj_name in name_pairs: | ||||||
|  |         try: | ||||||
|  |             module = importlib.import_module(module_name) # may raise ImportError | ||||||
|  |             get_obj_from_module(module, local_obj_name) # may raise AttributeError | ||||||
|  |         except ImportError: | ||||||
|  |             pass | ||||||
|  | 
 | ||||||
|  |     # we are out of luck, but we have no idea why | ||||||
|  |     raise ImportError(obj_name) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: | ||||||
|  |     """Traverses the object name and returns the last (rightmost) python object.""" | ||||||
|  |     if obj_name == '': | ||||||
|  |         return module | ||||||
|  |     obj = module | ||||||
|  |     for part in obj_name.split("."): | ||||||
|  |         obj = getattr(obj, part) | ||||||
|  |     return obj | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_obj_by_name(name: str) -> Any: | ||||||
|  |     """Finds the python object with the given name.""" | ||||||
|  |     module, obj_name = get_module_from_obj_name(name) | ||||||
|  |     return get_obj_from_module(module, obj_name) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: | ||||||
|  |     """Finds the python object with the given name and calls it as a function.""" | ||||||
|  |     assert func_name is not None | ||||||
|  |     func_obj = get_obj_by_name(func_name) | ||||||
|  |     assert callable(func_obj) | ||||||
|  |     return func_obj(*args, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: | ||||||
|  |     """Finds the python class with the given name and constructs it with the given arguments.""" | ||||||
|  |     return call_func_by_name(*args, func_name=class_name, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_module_dir_by_obj_name(obj_name: str) -> str: | ||||||
|  |     """Get the directory path of the module containing the given object name.""" | ||||||
|  |     module, _ = get_module_from_obj_name(obj_name) | ||||||
|  |     return os.path.dirname(inspect.getfile(module)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def is_top_level_function(obj: Any) -> bool: | ||||||
|  |     """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" | ||||||
|  |     return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_top_level_function_name(obj: Any) -> str: | ||||||
|  |     """Return the fully-qualified name of a top-level function.""" | ||||||
|  |     assert is_top_level_function(obj) | ||||||
|  |     module = obj.__module__ | ||||||
|  |     if module == '__main__': | ||||||
|  |         module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] | ||||||
|  |     return module + "." + obj.__name__ | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # File system helpers | ||||||
|  | # ------------------------------------------------------------------------------------------ | ||||||
|  | 
 | ||||||
|  | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: | ||||||
|  |     """List all files recursively in a given directory while ignoring given file and directory names. | ||||||
|  |     Returns list of tuples containing both absolute and relative paths.""" | ||||||
|  |     assert os.path.isdir(dir_path) | ||||||
|  |     base_name = os.path.basename(os.path.normpath(dir_path)) | ||||||
|  | 
 | ||||||
|  |     if ignores is None: | ||||||
|  |         ignores = [] | ||||||
|  | 
 | ||||||
|  |     result = [] | ||||||
|  | 
 | ||||||
|  |     for root, dirs, files in os.walk(dir_path, topdown=True): | ||||||
|  |         for ignore_ in ignores: | ||||||
|  |             dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] | ||||||
|  | 
 | ||||||
|  |             # dirs need to be edited in-place | ||||||
|  |             for d in dirs_to_remove: | ||||||
|  |                 dirs.remove(d) | ||||||
|  | 
 | ||||||
|  |             files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] | ||||||
|  | 
 | ||||||
|  |         absolute_paths = [os.path.join(root, f) for f in files] | ||||||
|  |         relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] | ||||||
|  | 
 | ||||||
|  |         if add_base_to_relative: | ||||||
|  |             relative_paths = [os.path.join(base_name, p) for p in relative_paths] | ||||||
|  | 
 | ||||||
|  |         assert len(absolute_paths) == len(relative_paths) | ||||||
|  |         result += zip(absolute_paths, relative_paths) | ||||||
|  | 
 | ||||||
|  |     return result | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: | ||||||
|  |     """Takes in a list of tuples of (src, dst) paths and copies files. | ||||||
|  |     Will create all necessary directories.""" | ||||||
|  |     for file in files: | ||||||
|  |         target_dir_name = os.path.dirname(file[1]) | ||||||
|  | 
 | ||||||
|  |         # will create all intermediate-level directories | ||||||
|  |         if not os.path.exists(target_dir_name): | ||||||
|  |             os.makedirs(target_dir_name) | ||||||
|  | 
 | ||||||
|  |         shutil.copyfile(file[0], file[1]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # URL helpers | ||||||
|  | # ------------------------------------------------------------------------------------------ | ||||||
|  | 
 | ||||||
|  | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: | ||||||
|  |     """Determine whether the given object is a valid URL string.""" | ||||||
|  |     if not isinstance(obj, str) or not "://" in obj: | ||||||
|  |         return False | ||||||
|  |     if allow_file_urls and obj.startswith('file://'): | ||||||
|  |         return True | ||||||
|  |     try: | ||||||
|  |         res = requests.compat.urlparse(obj) | ||||||
|  |         if not res.scheme or not res.netloc or not "." in res.netloc: | ||||||
|  |             return False | ||||||
|  |         res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) | ||||||
|  |         if not res.scheme or not res.netloc or not "." in res.netloc: | ||||||
|  |             return False | ||||||
|  |     except: | ||||||
|  |         return False | ||||||
|  |     return True | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: | ||||||
|  |     """Download the given URL and return a binary-mode file object to access the data.""" | ||||||
|  |     assert num_attempts >= 1 | ||||||
|  |     assert not (return_filename and (not cache)) | ||||||
|  | 
 | ||||||
|  |     # Doesn't look like an URL scheme so interpret it as a local filename. | ||||||
|  |     if not re.match('^[a-z]+://', url): | ||||||
|  |         return url if return_filename else open(url, "rb") | ||||||
|  | 
 | ||||||
|  |     # Handle file URLs.  This code handles unusual file:// patterns that | ||||||
|  |     # arise on Windows: | ||||||
|  |     # | ||||||
|  |     # file:///c:/foo.txt | ||||||
|  |     # | ||||||
|  |     # which would translate to a local '/c:/foo.txt' filename that's | ||||||
|  |     # invalid.  Drop the forward slash for such pathnames. | ||||||
|  |     # | ||||||
|  |     # If you touch this code path, you should test it on both Linux and | ||||||
|  |     # Windows. | ||||||
|  |     # | ||||||
|  |     # Some internet resources suggest using urllib.request.url2pathname() but | ||||||
|  |     # but that converts forward slashes to backslashes and this causes | ||||||
|  |     # its own set of problems. | ||||||
|  |     if url.startswith('file://'): | ||||||
|  |         filename = urllib.parse.urlparse(url).path | ||||||
|  |         if re.match(r'^/[a-zA-Z]:', filename): | ||||||
|  |             filename = filename[1:] | ||||||
|  |         return filename if return_filename else open(filename, "rb") | ||||||
|  | 
 | ||||||
|  |     assert is_url(url) | ||||||
|  | 
 | ||||||
|  |     # Lookup from cache. | ||||||
|  |     if cache_dir is None: | ||||||
|  |         cache_dir = make_cache_dir_path('downloads') | ||||||
|  | 
 | ||||||
|  |     url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() | ||||||
|  |     if cache: | ||||||
|  |         cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) | ||||||
|  |         if len(cache_files) == 1: | ||||||
|  |             filename = cache_files[0] | ||||||
|  |             return filename if return_filename else open(filename, "rb") | ||||||
|  | 
 | ||||||
|  |     # Download. | ||||||
|  |     url_name = None | ||||||
|  |     url_data = None | ||||||
|  |     with requests.Session() as session: | ||||||
|  |         if verbose: | ||||||
|  |             print("Downloading %s ..." % url, end="", flush=True) | ||||||
|  |         for attempts_left in reversed(range(num_attempts)): | ||||||
|  |             try: | ||||||
|  |                 with session.get(url) as res: | ||||||
|  |                     res.raise_for_status() | ||||||
|  |                     if len(res.content) == 0: | ||||||
|  |                         raise IOError("No data received") | ||||||
|  | 
 | ||||||
|  |                     if len(res.content) < 8192: | ||||||
|  |                         content_str = res.content.decode("utf-8") | ||||||
|  |                         if "download_warning" in res.headers.get("Set-Cookie", ""): | ||||||
|  |                             links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] | ||||||
|  |                             if len(links) == 1: | ||||||
|  |                                 url = requests.compat.urljoin(url, links[0]) | ||||||
|  |                                 raise IOError("Google Drive virus checker nag") | ||||||
|  |                         if "Google Drive - Quota exceeded" in content_str: | ||||||
|  |                             raise IOError("Google Drive download quota exceeded -- please try again later") | ||||||
|  | 
 | ||||||
|  |                     match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) | ||||||
|  |                     url_name = match[1] if match else url | ||||||
|  |                     url_data = res.content | ||||||
|  |                     if verbose: | ||||||
|  |                         print(" done") | ||||||
|  |                     break | ||||||
|  |             except KeyboardInterrupt: | ||||||
|  |                 raise | ||||||
|  |             except: | ||||||
|  |                 if not attempts_left: | ||||||
|  |                     if verbose: | ||||||
|  |                         print(" failed") | ||||||
|  |                     raise | ||||||
|  |                 if verbose: | ||||||
|  |                     print(".", end="", flush=True) | ||||||
|  | 
 | ||||||
|  |     # Save to cache. | ||||||
|  |     if cache: | ||||||
|  |         safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) | ||||||
|  |         cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) | ||||||
|  |         temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) | ||||||
|  |         os.makedirs(cache_dir, exist_ok=True) | ||||||
|  |         with open(temp_file, "wb") as f: | ||||||
|  |             f.write(url_data) | ||||||
|  |         os.replace(temp_file, cache_file) # atomic | ||||||
|  |         if return_filename: | ||||||
|  |             return cache_file | ||||||
|  | 
 | ||||||
|  |     # Return data as file object. | ||||||
|  |     assert not return_filename | ||||||
|  |     return io.BytesIO(url_data) | ||||||
							
								
								
									
										529
									
								
								models/stylegan3/model_3.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										529
									
								
								models/stylegan3/model_3.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,529 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Generator architecture from the paper | ||||||
|  | "Alias-Free Generative Adversarial Networks".""" | ||||||
|  | 
 | ||||||
|  | import numpy as np | ||||||
|  | import scipy.signal | ||||||
|  | import scipy.optimize | ||||||
|  | import torch | ||||||
|  | from torch_utils import misc | ||||||
|  | from torch_utils import persistence | ||||||
|  | from torch_utils.ops import conv2d_gradfix | ||||||
|  | from torch_utils.ops import filtered_lrelu | ||||||
|  | from torch_utils.ops import bias_act | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @misc.profiled_function | ||||||
|  | def modulated_conv2d( | ||||||
|  |     x,                  # Input tensor: [batch_size, in_channels, in_height, in_width] | ||||||
|  |     w,                  # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width] | ||||||
|  |     s,                  # Style tensor: [batch_size, in_channels] | ||||||
|  |     demodulate  = True, # Apply weight demodulation? | ||||||
|  |     padding     = 0,    # Padding: int or [padH, padW] | ||||||
|  |     input_gain  = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels] | ||||||
|  | ): | ||||||
|  |     with misc.suppress_tracer_warnings(): # this value will be treated as a constant | ||||||
|  |         batch_size = int(x.shape[0]) | ||||||
|  |     out_channels, in_channels, kh, kw = w.shape | ||||||
|  |     misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk] | ||||||
|  |     misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] | ||||||
|  |     misc.assert_shape(s, [batch_size, in_channels]) # [NI] | ||||||
|  | 
 | ||||||
|  |     # Pre-normalize inputs. | ||||||
|  |     if demodulate: | ||||||
|  |         w = w * w.square().mean([1,2,3], keepdim=True).rsqrt() | ||||||
|  |         s = s * s.square().mean().rsqrt() | ||||||
|  | 
 | ||||||
|  |     # Modulate weights. | ||||||
|  |     w = w.unsqueeze(0) # [NOIkk] | ||||||
|  |     w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] | ||||||
|  | 
 | ||||||
|  |     # Demodulate weights. | ||||||
|  |     if demodulate: | ||||||
|  |         dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] | ||||||
|  |         w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk] | ||||||
|  | 
 | ||||||
|  |     # Apply input scaling. | ||||||
|  |     if input_gain is not None: | ||||||
|  |         input_gain = input_gain.expand(batch_size, in_channels) # [NI] | ||||||
|  |         w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk] | ||||||
|  | 
 | ||||||
|  |     # Execute as one fused op using grouped convolution. | ||||||
|  |     x = x.reshape(1, -1, *x.shape[2:]) | ||||||
|  |     w = w.reshape(-1, in_channels, kh, kw) | ||||||
|  |     x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size) | ||||||
|  |     x = x.reshape(batch_size, -1, *x.shape[2:]) | ||||||
|  |     return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class FullyConnectedLayer(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         in_features,                # Number of input features. | ||||||
|  |         out_features,               # Number of output features. | ||||||
|  |         activation      = 'linear', # Activation function: 'relu', 'lrelu', etc. | ||||||
|  |         bias            = True,     # Apply additive bias before the activation function? | ||||||
|  |         lr_multiplier   = 1,        # Learning rate multiplier. | ||||||
|  |         weight_init     = 1,        # Initial standard deviation of the weight tensor. | ||||||
|  |         bias_init       = 0,        # Initial value of the additive bias. | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.in_features = in_features | ||||||
|  |         self.out_features = out_features | ||||||
|  |         self.activation = activation | ||||||
|  |         self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier)) | ||||||
|  |         bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features]) | ||||||
|  |         self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None | ||||||
|  |         self.weight_gain = lr_multiplier / np.sqrt(in_features) | ||||||
|  |         self.bias_gain = lr_multiplier | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         w = self.weight.to(x.dtype) * self.weight_gain | ||||||
|  |         b = self.bias | ||||||
|  |         if b is not None: | ||||||
|  |             b = b.to(x.dtype) | ||||||
|  |             if self.bias_gain != 1: | ||||||
|  |                 b = b * self.bias_gain | ||||||
|  |         if self.activation == 'linear' and b is not None: | ||||||
|  |             x = torch.addmm(b.unsqueeze(0), x, w.t()) | ||||||
|  |         else: | ||||||
|  |             x = x.matmul(w.t()) | ||||||
|  |             x = bias_act.bias_act(x, b, act=self.activation) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     def extra_repr(self): | ||||||
|  |         return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class MappingNetwork(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         z_dim,                      # Input latent (Z) dimensionality. | ||||||
|  |         c_dim,                      # Conditioning label (C) dimensionality, 0 = no labels. | ||||||
|  |         w_dim,                      # Intermediate latent (W) dimensionality. | ||||||
|  |         num_ws,                     # Number of intermediate latents to output. | ||||||
|  |         num_layers      = 2,        # Number of mapping layers. | ||||||
|  |         lr_multiplier   = 0.01,     # Learning rate multiplier for the mapping layers. | ||||||
|  |         w_avg_beta      = 0.998,    # Decay for tracking the moving average of W during training. | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.z_dim = z_dim | ||||||
|  |         self.c_dim = c_dim | ||||||
|  |         self.w_dim = w_dim | ||||||
|  |         self.num_ws = num_ws | ||||||
|  |         self.num_layers = num_layers | ||||||
|  |         self.w_avg_beta = w_avg_beta | ||||||
|  | 
 | ||||||
|  |         # Construct layers. | ||||||
|  |         self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None | ||||||
|  |         features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers | ||||||
|  |         for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]): | ||||||
|  |             layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier) | ||||||
|  |             setattr(self, f'fc{idx}', layer) | ||||||
|  |         self.register_buffer('w_avg', torch.zeros([w_dim])) | ||||||
|  | 
 | ||||||
|  |     def forward(self, z, c=0, truncation_psi=1, truncation_cutoff=None, update_emas=False): | ||||||
|  |         #将传入的z由list改为tensor 好像改得不对,还是别改把 | ||||||
|  |         # z = torch.tensor( [item.cpu().detach().numpy() for item in z] ) | ||||||
|  |         misc.assert_shape(z, [None, self.z_dim]) | ||||||
|  |         if truncation_cutoff is None: | ||||||
|  |             truncation_cutoff = self.num_ws | ||||||
|  | 
 | ||||||
|  |         # Embed, normalize, and concatenate inputs. | ||||||
|  |         x = z.to(torch.float32) | ||||||
|  |         x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt() | ||||||
|  |         if self.c_dim > 0: | ||||||
|  |             misc.assert_shape(c, [None, self.c_dim]) | ||||||
|  |             y = self.embed(c.to(torch.float32)) | ||||||
|  |             y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt() | ||||||
|  |             x = torch.cat([x, y], dim=1) if x is not None else y | ||||||
|  | 
 | ||||||
|  |         # Execute layers. | ||||||
|  |         for idx in range(self.num_layers): | ||||||
|  |             x = getattr(self, f'fc{idx}')(x) | ||||||
|  | 
 | ||||||
|  |         # Update moving average of W. | ||||||
|  |         if update_emas: | ||||||
|  |             self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) | ||||||
|  | 
 | ||||||
|  |         # Broadcast and apply truncation. | ||||||
|  |         x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) | ||||||
|  |         if truncation_psi != 1: | ||||||
|  |             x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     def extra_repr(self): | ||||||
|  |         return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class SynthesisInput(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         w_dim,          # Intermediate latent (W) dimensionality. | ||||||
|  |         channels,       # Number of output channels. | ||||||
|  |         size,           # Output spatial size: int or [width, height]. | ||||||
|  |         sampling_rate,  # Output sampling rate. | ||||||
|  |         bandwidth,      # Output bandwidth. | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.w_dim = w_dim | ||||||
|  |         self.channels = channels | ||||||
|  |         self.size = np.broadcast_to(np.asarray(size), [2]) | ||||||
|  |         self.sampling_rate = sampling_rate | ||||||
|  |         self.bandwidth = bandwidth | ||||||
|  | 
 | ||||||
|  |         # Draw random frequencies from uniform 2D disc. | ||||||
|  |         freqs = torch.randn([self.channels, 2]) | ||||||
|  |         radii = freqs.square().sum(dim=1, keepdim=True).sqrt() | ||||||
|  |         freqs /= radii * radii.square().exp().pow(0.25) | ||||||
|  |         freqs *= bandwidth | ||||||
|  |         phases = torch.rand([self.channels]) - 0.5 | ||||||
|  | 
 | ||||||
|  |         # Setup parameters and buffers. | ||||||
|  |         self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels])) | ||||||
|  |         self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0]) | ||||||
|  |         self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image. | ||||||
|  |         self.register_buffer('freqs', freqs) | ||||||
|  |         self.register_buffer('phases', phases) | ||||||
|  | 
 | ||||||
|  |     def forward(self, w): | ||||||
|  |         # Introduce batch dimension. | ||||||
|  |         transforms = self.transform.unsqueeze(0) # [batch, row, col] | ||||||
|  |         freqs = self.freqs.unsqueeze(0) # [batch, channel, xy] | ||||||
|  |         phases = self.phases.unsqueeze(0) # [batch, channel] | ||||||
|  | 
 | ||||||
|  |         # Apply learned transformation. | ||||||
|  |         t = self.affine(w) # t = (r_c, r_s, t_x, t_y) | ||||||
|  |         t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y) | ||||||
|  |         m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image. | ||||||
|  |         m_r[:, 0, 0] = t[:, 0]  # r'_c | ||||||
|  |         m_r[:, 0, 1] = -t[:, 1] # r'_s | ||||||
|  |         m_r[:, 1, 0] = t[:, 1]  # r'_s | ||||||
|  |         m_r[:, 1, 1] = t[:, 0]  # r'_c | ||||||
|  |         m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image. | ||||||
|  |         m_t[:, 0, 2] = -t[:, 2] # t'_x | ||||||
|  |         m_t[:, 1, 2] = -t[:, 3] # t'_y | ||||||
|  |         transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform. | ||||||
|  | 
 | ||||||
|  |         # Transform frequencies. | ||||||
|  |         phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2) | ||||||
|  |         freqs = freqs @ transforms[:, :2, :2] | ||||||
|  | 
 | ||||||
|  |         # Dampen out-of-band frequencies that may occur due to the user-specified transform. | ||||||
|  |         amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1) | ||||||
|  | 
 | ||||||
|  |         # Construct sampling grid. | ||||||
|  |         theta = torch.eye(2, 3, device=w.device) | ||||||
|  |         theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate | ||||||
|  |         theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate | ||||||
|  |         grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False) | ||||||
|  | 
 | ||||||
|  |         # Compute Fourier features. | ||||||
|  |         x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel] | ||||||
|  |         x = x + phases.unsqueeze(1).unsqueeze(2) | ||||||
|  |         x = torch.sin(x * (np.pi * 2)) | ||||||
|  |         x = x * amplitudes.unsqueeze(1).unsqueeze(2) | ||||||
|  | 
 | ||||||
|  |         # Apply trainable mapping. | ||||||
|  |         weight = self.weight / np.sqrt(self.channels) | ||||||
|  |         x = x @ weight.t() | ||||||
|  | 
 | ||||||
|  |         # Ensure correct shape. | ||||||
|  |         x = x.permute(0, 3, 1, 2) # [batch, channel, height, width] | ||||||
|  |         misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])]) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     def extra_repr(self): | ||||||
|  |         return '\n'.join([ | ||||||
|  |             f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},', | ||||||
|  |             f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}']) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class SynthesisLayer(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         w_dim,                          # Intermediate latent (W) dimensionality. | ||||||
|  |         is_torgb,                       # Is this the final ToRGB layer? | ||||||
|  |         is_critically_sampled,          # Does this layer use critical sampling? | ||||||
|  |         use_fp16,                       # Does this layer use FP16? | ||||||
|  | 
 | ||||||
|  |         # Input & output specifications. | ||||||
|  |         in_channels,                    # Number of input channels. | ||||||
|  |         out_channels,                   # Number of output channels. | ||||||
|  |         in_size,                        # Input spatial size: int or [width, height]. | ||||||
|  |         out_size,                       # Output spatial size: int or [width, height]. | ||||||
|  |         in_sampling_rate,               # Input sampling rate (s). | ||||||
|  |         out_sampling_rate,              # Output sampling rate (s). | ||||||
|  |         in_cutoff,                      # Input cutoff frequency (f_c). | ||||||
|  |         out_cutoff,                     # Output cutoff frequency (f_c). | ||||||
|  |         in_half_width,                  # Input transition band half-width (f_h). | ||||||
|  |         out_half_width,                 # Output Transition band half-width (f_h). | ||||||
|  | 
 | ||||||
|  |         # Hyperparameters. | ||||||
|  |         conv_kernel         = 3,        # Convolution kernel size. Ignored for final the ToRGB layer. | ||||||
|  |         filter_size         = 6,        # Low-pass filter size relative to the lower resolution when up/downsampling. | ||||||
|  |         lrelu_upsampling    = 2,        # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer. | ||||||
|  |         use_radial_filters  = False,    # Use radially symmetric downsampling filter? Ignored for critically sampled layers. | ||||||
|  |         conv_clamp          = 256,      # Clamp the output to [-X, +X], None = disable clamping. | ||||||
|  |         magnitude_ema_beta  = 0.999,    # Decay rate for the moving average of input magnitudes. | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.w_dim = w_dim | ||||||
|  |         self.is_torgb = is_torgb | ||||||
|  |         self.is_critically_sampled = is_critically_sampled | ||||||
|  |         self.use_fp16 = use_fp16 | ||||||
|  |         self.in_channels = in_channels | ||||||
|  |         self.out_channels = out_channels | ||||||
|  |         self.in_size = np.broadcast_to(np.asarray(in_size), [2]) | ||||||
|  |         self.out_size = np.broadcast_to(np.asarray(out_size), [2]) | ||||||
|  |         self.in_sampling_rate = in_sampling_rate | ||||||
|  |         self.out_sampling_rate = out_sampling_rate | ||||||
|  |         self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) | ||||||
|  |         self.in_cutoff = in_cutoff | ||||||
|  |         self.out_cutoff = out_cutoff | ||||||
|  |         self.in_half_width = in_half_width | ||||||
|  |         self.out_half_width = out_half_width | ||||||
|  |         self.conv_kernel = 1 if is_torgb else conv_kernel | ||||||
|  |         self.conv_clamp = conv_clamp | ||||||
|  |         self.magnitude_ema_beta = magnitude_ema_beta | ||||||
|  | 
 | ||||||
|  |         # Setup parameters and buffers. | ||||||
|  |         self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1) | ||||||
|  |         self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel])) | ||||||
|  |         self.bias = torch.nn.Parameter(torch.zeros([self.out_channels])) | ||||||
|  |         self.register_buffer('magnitude_ema', torch.ones([])) | ||||||
|  | 
 | ||||||
|  |         # Design upsampling filter. | ||||||
|  |         self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) | ||||||
|  |         assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate | ||||||
|  |         self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1 | ||||||
|  |         self.register_buffer('up_filter', self.design_lowpass_filter( | ||||||
|  |             numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate)) | ||||||
|  | 
 | ||||||
|  |         # Design downsampling filter. | ||||||
|  |         self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) | ||||||
|  |         assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate | ||||||
|  |         self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1 | ||||||
|  |         self.down_radial = use_radial_filters and not self.is_critically_sampled | ||||||
|  |         self.register_buffer('down_filter', self.design_lowpass_filter( | ||||||
|  |             numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial)) | ||||||
|  | 
 | ||||||
|  |         # Compute padding. | ||||||
|  |         pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling. | ||||||
|  |         pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling. | ||||||
|  |         pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters. | ||||||
|  |         pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3). | ||||||
|  |         pad_hi = pad_total - pad_lo | ||||||
|  |         self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] | ||||||
|  | 
 | ||||||
|  |     def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False): | ||||||
|  |         assert noise_mode in ['random', 'const', 'none'] # unused | ||||||
|  |         misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])]) | ||||||
|  |         misc.assert_shape(w, [x.shape[0], self.w_dim]) | ||||||
|  | 
 | ||||||
|  |         # Track input magnitude. | ||||||
|  |         if update_emas: | ||||||
|  |             with torch.autograd.profiler.record_function('update_magnitude_ema'): | ||||||
|  |                 magnitude_cur = x.detach().to(torch.float32).square().mean() | ||||||
|  |                 self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta)) | ||||||
|  |         input_gain = self.magnitude_ema.rsqrt() | ||||||
|  | 
 | ||||||
|  |         # Execute affine layer. | ||||||
|  |         styles = self.affine(w) | ||||||
|  |         if self.is_torgb: | ||||||
|  |             weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2)) | ||||||
|  |             styles = styles * weight_gain | ||||||
|  | 
 | ||||||
|  |         # Execute modulated conv2d. | ||||||
|  |         dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 | ||||||
|  |         x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles, | ||||||
|  |             padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain) | ||||||
|  | 
 | ||||||
|  |         # Execute bias, filtered leaky ReLU, and clamping. | ||||||
|  |         gain = 1 if self.is_torgb else np.sqrt(2) | ||||||
|  |         slope = 1 if self.is_torgb else 0.2 | ||||||
|  |         x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype), | ||||||
|  |             up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp) | ||||||
|  | 
 | ||||||
|  |         # Ensure correct shape and dtype. | ||||||
|  |         misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])]) | ||||||
|  |         assert x.dtype == dtype | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False): | ||||||
|  |         assert numtaps >= 1 | ||||||
|  | 
 | ||||||
|  |         # Identity filter. | ||||||
|  |         if numtaps == 1: | ||||||
|  |             return None | ||||||
|  | 
 | ||||||
|  |         # Separable Kaiser low-pass filter. | ||||||
|  |         if not radial: | ||||||
|  |             f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) | ||||||
|  |             return torch.as_tensor(f, dtype=torch.float32) | ||||||
|  | 
 | ||||||
|  |         # Radially symmetric jinc-based filter. | ||||||
|  |         x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs | ||||||
|  |         r = np.hypot(*np.meshgrid(x, x)) | ||||||
|  |         f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) | ||||||
|  |         beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2))) | ||||||
|  |         w = np.kaiser(numtaps, beta) | ||||||
|  |         f *= np.outer(w, w) | ||||||
|  |         f /= np.sum(f) | ||||||
|  |         return torch.as_tensor(f, dtype=torch.float32) | ||||||
|  | 
 | ||||||
|  |     def extra_repr(self): | ||||||
|  |         return '\n'.join([ | ||||||
|  |             f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},', | ||||||
|  |             f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},', | ||||||
|  |             f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},', | ||||||
|  |             f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},', | ||||||
|  |             f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},', | ||||||
|  |             f'in_size={list(self.in_size)}, out_size={list(self.out_size)},', | ||||||
|  |             f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}']) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class SynthesisNetwork(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         w_dim,                          # Intermediate latent (W) dimensionality. 512 | ||||||
|  |         img_resolution,                 # Output image resolution. 1024 | ||||||
|  |         img_channels,                   # Number of color channels. 3 | ||||||
|  |         channel_base        = 32768,    # Overall multiplier for the number of channels.通道总体倍增因子 | ||||||
|  |         channel_max         = 512,      # Maximum number of channels in any layer. | ||||||
|  |         num_layers          = 14,       # Total number of layers, excluding Fourier features and ToRGB. | ||||||
|  |         num_critical        = 2,        # Number of critically sampled layers at the end. | ||||||
|  |         first_cutoff        = 2,        # Cutoff frequency of the first layer (f_{c,0}). | ||||||
|  |         first_stopband      = 2**2.1,   # Minimum stopband of the first layer (f_{t,0}). | ||||||
|  |         last_stopband_rel   = 2**0.3,   # Minimum stopband of the last layer, expressed relative to the cutoff. | ||||||
|  |         margin_size         = 10,       # Number of additional pixels outside the image. | ||||||
|  |         output_scale        = 0.25,     # Scale factor for the output image. | ||||||
|  |         num_fp16_res        = 4,        # Use FP16 for the N highest resolutions. | ||||||
|  |         **layer_kwargs,                 # Arguments for SynthesisLayer. | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.w_dim = w_dim | ||||||
|  |         self.num_ws = num_layers + 2 | ||||||
|  |         self.img_resolution = img_resolution | ||||||
|  |         self.img_channels = img_channels | ||||||
|  |         self.num_layers = num_layers | ||||||
|  |         self.num_critical = num_critical | ||||||
|  |         self.margin_size = margin_size | ||||||
|  |         self.output_scale = output_scale | ||||||
|  |         self.num_fp16_res = num_fp16_res | ||||||
|  | 
 | ||||||
|  |         # Geometric progression of layer cutoffs and min. stopbands. | ||||||
|  |         last_cutoff = self.img_resolution / 2 # f_{c,N} | ||||||
|  |         last_stopband = last_cutoff * last_stopband_rel # f_{t,N} | ||||||
|  |         exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1) | ||||||
|  |         cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i] [  2.           3.1748021    5.0396842    8.          12.69920842,  20.1587368   32.          50.79683366  80.63494719 128., 203.18733465 322.53978877 512.         512.         512.        ] | ||||||
|  |         stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i] | ||||||
|  | 
 | ||||||
|  |         # Compute remaining layer parameters. | ||||||
|  |         sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i] | ||||||
|  |         half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i] | ||||||
|  |         sizes = sampling_rates + self.margin_size * 2 | ||||||
|  |         sizes[-2:] = self.img_resolution | ||||||
|  |         channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max)) | ||||||
|  |         channels[-1] = self.img_channels | ||||||
|  | 
 | ||||||
|  |         # Construct layers. | ||||||
|  |         self.input = SynthesisInput( | ||||||
|  |             w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]), #sizes:[  36.   36.   52.   52.   84.  148.  148.  276.  276.  532. 1044. 1044., 1044. 1024. 1024.] | ||||||
|  |             sampling_rate=sampling_rates[0], bandwidth=cutoffs[0])  #sampling_rates :[  16.   16.   32.   32.   64.  128.  128.  256.  256.  512. 1024. 1024., 1024. 1024. 1024.] | ||||||
|  |         self.layer_names = [] | ||||||
|  |         for idx in range(self.num_layers + 1): | ||||||
|  |             prev = max(idx - 1, 0) | ||||||
|  |             is_torgb = (idx == self.num_layers) | ||||||
|  |             is_critically_sampled = (idx >= self.num_layers - self.num_critical) | ||||||
|  |             use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution) | ||||||
|  |             layer = SynthesisLayer( | ||||||
|  |                 w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16, | ||||||
|  |                 in_channels=int(channels[prev]), out_channels= int(channels[idx]), | ||||||
|  |                 in_size=int(sizes[prev]), out_size=int(sizes[idx]), | ||||||
|  |                 in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]), | ||||||
|  |                 in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx], | ||||||
|  |                 in_half_width=half_widths[prev], out_half_width=half_widths[idx], | ||||||
|  |                 **layer_kwargs) | ||||||
|  |             name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}' | ||||||
|  |             setattr(self, name, layer) | ||||||
|  |             self.layer_names.append(name) | ||||||
|  | 
 | ||||||
|  |     def forward(self, ws, **layer_kwargs): | ||||||
|  |         misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) | ||||||
|  |         ws = ws.to(torch.float32).unbind(dim=1) | ||||||
|  | 
 | ||||||
|  |         # Execute layers. | ||||||
|  |         x = self.input(ws[0]) | ||||||
|  |         for name, w in zip(self.layer_names, ws[1:]): | ||||||
|  |             x = getattr(self, name)(x, w, **layer_kwargs) | ||||||
|  |         if self.output_scale != 1: | ||||||
|  |             x = x * self.output_scale | ||||||
|  | 
 | ||||||
|  |         # Ensure correct shape and dtype. | ||||||
|  |         misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution]) | ||||||
|  |         x = x.to(torch.float32) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     def extra_repr(self): | ||||||
|  |         return '\n'.join([ | ||||||
|  |             f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', | ||||||
|  |             f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', | ||||||
|  |             f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},', | ||||||
|  |             f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}']) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @persistence.persistent_class | ||||||
|  | class Generator(torch.nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |         z_dim,                      # Input latent (Z) dimensionality. | ||||||
|  |         c_dim,                      # Conditioning label (C) dimensionality. | ||||||
|  |         w_dim,                      # Intermediate latent (W) dimensionality. | ||||||
|  |         img_resolution,             # Output resolution. | ||||||
|  |         img_channels,               # Number of output color channels. | ||||||
|  |         mapping_kwargs      = {},   # Arguments for MappingNetwork. | ||||||
|  |         **synthesis_kwargs,         # Arguments for SynthesisNetwork. | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.z_dim = z_dim    #512 | ||||||
|  |         self.c_dim = c_dim    #0 | ||||||
|  |         self.w_dim = w_dim    #512 | ||||||
|  |         self.img_resolution = img_resolution | ||||||
|  |         self.img_channels = img_channels | ||||||
|  |         self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) | ||||||
|  |         self.num_ws = self.synthesis.num_ws   #16 | ||||||
|  |         self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) | ||||||
|  | 
 | ||||||
|  |     # def mean_latent(self, n_latent): | ||||||
|  |     #     latent_in = torch.randn( | ||||||
|  |     #         #此处的style_dim应与w_dim对应 | ||||||
|  |     #         n_latent, self.w_dim, device=self.synthesis.input.weight.device | ||||||
|  |     #     ) | ||||||
|  |     #     latent = self.synthesis.styles(latent_in).mean(0, keepdim=True) | ||||||
|  |     # | ||||||
|  |     #     return latent | ||||||
|  | 
 | ||||||
|  |     def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): | ||||||
|  |         # print("-----------------------------------") | ||||||
|  |         # print(z) | ||||||
|  |         # print("-----------------------------------") | ||||||
|  |         ws = self.mapping(z, c = None, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) | ||||||
|  |         img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) | ||||||
|  |         return img | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										267
									
								
								models/stylegan3/run_optimization3.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										267
									
								
								models/stylegan3/run_optimization3.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,267 @@ | |||||||
|  | import argparse | ||||||
|  | import math | ||||||
|  | import os | ||||||
|  | import pickle | ||||||
|  | 
 | ||||||
|  | import torchvision | ||||||
|  | from torch import optim | ||||||
|  | from tqdm import tqdm | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | import clip | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class CLIPLoss(torch.nn.Module): | ||||||
|  | 
 | ||||||
|  |     def __init__(self, opts): | ||||||
|  |         super(CLIPLoss, self).__init__() | ||||||
|  |         self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") | ||||||
|  |         self.upsample = torch.nn.Upsample(scale_factor=7) | ||||||
|  |         self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32) | ||||||
|  | 
 | ||||||
|  |     def forward(self, image, text): | ||||||
|  |         image = self.avg_pool(self.upsample(image)) | ||||||
|  |         similarity = 1 - self.model(image, text)[0] / 100 | ||||||
|  |         return similarity | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | from torch import nn | ||||||
|  | import sys | ||||||
|  | sys.path.append('/home/ly/StyleCLIP-main/models/facial_recognition') | ||||||
|  | from model_irse import Backbone | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class IDLoss(nn.Module): | ||||||
|  |     def __init__(self, opts): | ||||||
|  |         super(IDLoss, self).__init__() | ||||||
|  |         print('Loading ResNet ArcFace') | ||||||
|  |         self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') | ||||||
|  |         self.facenet.load_state_dict(torch.load(opts.ir_se50_weights)) | ||||||
|  |         self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | ||||||
|  |         self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) | ||||||
|  |         self.facenet.eval() | ||||||
|  |         self.facenet.cuda() | ||||||
|  |         self.opts = opts | ||||||
|  | 
 | ||||||
|  |     def extract_feats(self, x): | ||||||
|  |         if x.shape[2] != 256: | ||||||
|  |             x = self.pool(x) | ||||||
|  |         x = x[:, :, 35:223, 32:220]  # Crop interesting region | ||||||
|  |         x = self.face_pool(x) | ||||||
|  |         x_feats = self.facenet(x) | ||||||
|  |         return x_feats | ||||||
|  | 
 | ||||||
|  |     def forward(self, y_hat, y): | ||||||
|  |         n_samples = y.shape[0] | ||||||
|  |         y_feats = self.extract_feats(y)  # Otherwise use the feature from there | ||||||
|  |         y_hat_feats = self.extract_feats(y_hat) | ||||||
|  |         y_feats = y_feats.detach() | ||||||
|  |         loss = 0 | ||||||
|  |         sim_improvement = 0 | ||||||
|  |         count = 0 | ||||||
|  |         for i in range(n_samples): | ||||||
|  |             diff_target = y_hat_feats[i].dot(y_feats[i]) | ||||||
|  |             loss += 1 - diff_target | ||||||
|  |             count += 1 | ||||||
|  | 
 | ||||||
|  |         return loss / count, sim_improvement / count | ||||||
|  | sys.path.append('/home/ly/StyleCLIP-main/mapper/training') | ||||||
|  | from train_utils import STYLESPACE_DIMENSIONS | ||||||
|  | from model_3 import Generator | ||||||
|  | from model_3 import SynthesisNetwork | ||||||
|  | from model_3 import SynthesisLayer | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | sys.path.append('/home/ly/StyleCLIP-main') | ||||||
|  | from utils import ensure_checkpoint_exists | ||||||
|  | 
 | ||||||
|  | STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in list(range(1, len(STYLESPACE_DIMENSIONS), 3))] | ||||||
|  | 
 | ||||||
|  | def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): | ||||||
|  |     lr_ramp = min(1, (1 - t) / rampdown) | ||||||
|  |     lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) | ||||||
|  |     lr_ramp = lr_ramp * min(1, t / rampup) | ||||||
|  | 
 | ||||||
|  |     return initial_lr * lr_ramp | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def main(args): | ||||||
|  |     ensure_checkpoint_exists(args.ckpt) | ||||||
|  |     # 把描述加载进clip预训练模型里面去 | ||||||
|  |     text_inputs = torch.cat([clip.tokenize(args.description)]).cuda() | ||||||
|  |     # print('text_input是: ', text_inputs) | ||||||
|  |     #tokenizer clip分词的机制 依据规则 | ||||||
|  |     #以及词汇表的总量 | ||||||
|  |     ''' | ||||||
|  |      --description "a person with purple hair"  | ||||||
|  |      tensor([[49406,   320,  2533,   593,  5496,  2225, 49407,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0]], device='cuda:0', | ||||||
|  |        dtype=torch.int32) | ||||||
|  |        --description "a person with red hair"  | ||||||
|  |         tensor([[49406,   320,  2533,   593,   736,  2225, 49407,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0]], device='cuda:0', | ||||||
|  |        dtype=torch.int32) | ||||||
|  |     ''' | ||||||
|  | 
 | ||||||
|  |     os.makedirs(args.results_dir, exist_ok=True) | ||||||
|  |     #改成stylegan3的输入 | ||||||
|  | 
 | ||||||
|  |     # with open('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl', 'rb') as f: | ||||||
|  |     #     G = pickle.load(f)['G_ema'].cuda()  # torch.nn.Module | ||||||
|  |     # z = torch.randn([1, G.z_dim]).cuda()  # latent codes | ||||||
|  |     # c = None  # class labels (not used in this example) | ||||||
|  |     # img = G(z, c)  # NCHW, float32, dynamic range [-1, +1], no truncation | ||||||
|  |      | ||||||
|  |     # g_ema = Generator(512, 0, 512,args.stylegan_size, 3)   #512,0,512,1024,3 | ||||||
|  |     # with open('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl', 'rb') as f: | ||||||
|  |     #stylegan3-r-ffhqu-1024x1024.pkl 生成图片的效果欠佳 别用 | ||||||
|  |     #stylegan3-t-ffhq-1024x1024.pkl 生成效果一般 loss值较好 | ||||||
|  |     #stylegan3-r-ffhq-1024x1024.pkl 折中 | ||||||
|  |     #stylegan3-t-ffhqu-1024x1024.pkl 生成图片可以 loss较差 | ||||||
|  |     with open('/home/ly/StyleCLIP-main/pretrained_models/stylegan3-t-ffhq-1024x1024.pkl', 'rb') as f:   #stylespace_dimensions [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 256, 256, 128, 128, 128, 64, 64, 64, 32, 32] | ||||||
|  |         # new_p = pickle.load(f) | ||||||
|  |         # print(new_p) | ||||||
|  |         # print("new_p") | ||||||
|  |         # print(new_p.keys()) | ||||||
|  |         # G_ema.load_state_dict(pickle.load(f)['G_ema'].cuda(), strict=False) 这种方式模型加载不进来 | ||||||
|  |         g_ema = pickle.load(f)['G_ema'].cuda()  # torch.nn.Module 这种方式推演三百步的图片平均要4分钟 | ||||||
|  |     z = torch.randn([1, g_ema.z_dim]).cuda()  # latent codes | ||||||
|  |     c = None  # class labels (not used in this example) | ||||||
|  |     #g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) | ||||||
|  |     # 将模型对象设置为评估模式 | ||||||
|  |     g_ema.eval() | ||||||
|  |     #更改cuda卡号 | ||||||
|  |     g_ema = g_ema.cuda() | ||||||
|  |     # device = torch.cuda.current_device() | ||||||
|  |     # print('cuda:',device) | ||||||
|  |     mean_latent = torch.randn([1, g_ema.z_dim]).cuda() | ||||||
|  |     torch.save(mean_latent,'/home/ly/StyleCLIP-main/pretrained_models/latent_code/style3.pt') | ||||||
|  |     # print('mean_latent: ', mean_latent) | ||||||
|  | 
 | ||||||
|  |     if args.latent_path: | ||||||
|  |         latent_code_init = torch.load(args.latent_path).cuda() | ||||||
|  |     # elif args.mode == "edit": | ||||||
|  |     #     latent_code_init_not_trunc = torch.randn(1, 512).cuda() | ||||||
|  |     #     with torch.no_grad(): | ||||||
|  |     #         _, latent_code_init, _ = g_ema([latent_code_init_not_trunc], return_latents=True, | ||||||
|  |     #                                     truncation=args.truncation, truncation_latent=mean_latent) | ||||||
|  |     else: | ||||||
|  |         # latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)  #在维度1上重复18次 | ||||||
|  |         latent_code_init = mean_latent.detach().clone() | ||||||
|  |     # def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         print("mean_latent ", mean_latent.shape) | ||||||
|  |         # img_orig, _ = g_ema([latent_code_init], c, input_is_latent=True, randomize_noise=False) | ||||||
|  |         img_orig = g_ema(latent_code_init, c) | ||||||
|  | 
 | ||||||
|  |     if args.work_in_stylespace: | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             _, _, latent_code_init = g_ema([latent_code_init], input_is_latent=True, return_latents=True) | ||||||
|  |         latent = [s.detach().clone() for s in latent_code_init] | ||||||
|  |         for c, s in enumerate(latent): | ||||||
|  |             if c in STYLESPACE_INDICES_WITHOUT_TORGB: | ||||||
|  |                 s.requires_grad = True | ||||||
|  |     else: | ||||||
|  |         latent = latent_code_init.detach().clone() | ||||||
|  |         latent.requires_grad = True | ||||||
|  | 
 | ||||||
|  |     clip_loss = CLIPLoss(args) | ||||||
|  |     id_loss = IDLoss(args) | ||||||
|  | 
 | ||||||
|  |     if args.work_in_stylespace: | ||||||
|  |         optimizer = optim.Adam(latent, lr=args.lr) | ||||||
|  |     else: | ||||||
|  |         optimizer = optim.Adam([latent], lr=args.lr) | ||||||
|  | 
 | ||||||
|  |     pbar = tqdm(range(args.step)) | ||||||
|  | 
 | ||||||
|  |     for i in pbar: | ||||||
|  |         t = i / args.step | ||||||
|  |         lr = get_lr(t, args.lr) | ||||||
|  |         optimizer.param_groups[0]["lr"] = lr | ||||||
|  | 
 | ||||||
|  |         img_gen = g_ema(latent,c) | ||||||
|  | 
 | ||||||
|  |         c_loss = clip_loss(img_gen, text_inputs) | ||||||
|  | 
 | ||||||
|  |         if args.id_lambda > 0: | ||||||
|  |             #身份损失 | ||||||
|  |             i_loss = id_loss(img_gen, img_orig)[0] | ||||||
|  |         else: | ||||||
|  |             i_loss = 0 | ||||||
|  | 
 | ||||||
|  |         if args.mode == "edit": | ||||||
|  |             if args.work_in_stylespace: | ||||||
|  |                 l2_loss = sum([((latent_code_init[c] - latent[c]) ** 2).sum() for c in range(len(latent_code_init))]) | ||||||
|  |             else: | ||||||
|  |                 #与潜在空间的L2距离 | ||||||
|  |                 l2_loss = ((latent_code_init - latent) ** 2).sum() | ||||||
|  |             loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss | ||||||
|  |         else: | ||||||
|  |             loss = c_loss | ||||||
|  | 
 | ||||||
|  |         optimizer.zero_grad() | ||||||
|  |         loss.backward() | ||||||
|  |         optimizer.step() | ||||||
|  | 
 | ||||||
|  |         pbar.set_description( | ||||||
|  |             ( | ||||||
|  |                 f"loss: {loss.item():.4f};" | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0: | ||||||
|  |             with torch.no_grad(): | ||||||
|  |                 img_gen = g_ema(latent, c) | ||||||
|  | 
 | ||||||
|  |             torchvision.utils.save_image(img_gen, f"results/stygan3Clip/{str(i).zfill(5)}.jpg", normalize=True, range=(-1, 1)) | ||||||
|  | 
 | ||||||
|  |     if args.mode == "edit": | ||||||
|  |         final_result = torch.cat([img_orig, img_gen]) | ||||||
|  |     else: | ||||||
|  |         final_result = img_gen | ||||||
|  | 
 | ||||||
|  |     return final_result | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument("--description", type=str, default="a person with purple hair", help="the text that guides the editing/generation") | ||||||
|  |     parser.add_argument("--ckpt", type=str, default="../pretrained_models/stylegan2-ffhq-config-f.pt", help="pretrained StyleGAN2 weights") | ||||||
|  |     parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution") | ||||||
|  |     parser.add_argument("--lr_rampup", type=float, default=0.05) | ||||||
|  |     parser.add_argument("--lr", type=float, default=0.1) | ||||||
|  |     parser.add_argument("--step", type=int, default=300, help="number of optimization steps") | ||||||
|  |     parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"], help="choose between edit an image an generate a free one") | ||||||
|  |     parser.add_argument("--l2_lambda", type=float, default=0.008, help="weight of the latent distance (used for editing only)") | ||||||
|  |     parser.add_argument("--id_lambda", type=float, default=0.000, help="weight of id loss (used for editing only)") | ||||||
|  |     parser.add_argument("--latent_path", type=str, default=None, help="starts the optimization from the given latent code if provided. Otherwose, starts from" | ||||||
|  |                                                                       "the mean latent in a free generation, and from a random one in editing. " | ||||||
|  |                                                                       "Expects a .pt format") | ||||||
|  |     parser.add_argument("--truncation", type=float, default=1, help="used only for the initial latent vector, and only when a latent code path is" | ||||||
|  |                                                                       "not provided") | ||||||
|  |     parser.add_argument('--work_in_stylespace', default=False, action='store_true') | ||||||
|  |     parser.add_argument("--save_intermediate_image_every", type=int, default=20, help="if > 0 then saves intermidate results during the optimization") | ||||||
|  |     parser.add_argument("--results_dir", type=str, default="results") | ||||||
|  |     parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str, | ||||||
|  |                              help="Path to facial recognition network used in ID loss") | ||||||
|  | 
 | ||||||
|  |     args = parser.parse_args() | ||||||
|  | 
 | ||||||
|  |     result_image = main(args) | ||||||
|  | 
 | ||||||
|  |     torchvision.utils.save_image(result_image.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"), normalize=True, scale_each=True, range=(-1, 1)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
							
								
								
									
										194
									
								
								models/stylegan3/show_pkl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										194
									
								
								models/stylegan3/show_pkl.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,194 @@ | |||||||
|  | # show_pkl.py | ||||||
|  | 
 | ||||||
|  | import pickle | ||||||
|  | import sys | ||||||
|  | import torch | ||||||
|  | sys.path.append('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils') | ||||||
|  | 
 | ||||||
|  | # | ||||||
|  | # path = '/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl'  # path='/root/……/aus_openface.pkl'   pkl文件所在路径 | ||||||
|  | # | ||||||
|  | # f = open(path, 'rb') | ||||||
|  | # data = pickle.load(f) | ||||||
|  | # | ||||||
|  | # print(data) | ||||||
|  | # print(len(data)) | ||||||
|  | # print(data.shape) | ||||||
|  | 
 | ||||||
|  | with open('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl', 'rb') as f: | ||||||
|  |     G = pickle.load(f)['G_ema'].cuda()  # torch.nn.Module | ||||||
|  | z = torch.randn([1, G.z_dim]).cuda()    # latent codes | ||||||
|  | c = None                                # class labels (not used in this example) | ||||||
|  | img = G(z, c)                           # NCHW, float32, dynamic range [-1, +1], no truncation | ||||||
|  | print(G) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | #输出 | ||||||
|  | # Generator( | ||||||
|  | #   (synthesis): SynthesisNetwork( | ||||||
|  | #     w_dim=512, num_ws=16, | ||||||
|  | #     img_resolution=512, img_channels=3, | ||||||
|  | #     num_layers=14, num_critical=2, | ||||||
|  | #     margin_size=10, num_fp16_res=4 | ||||||
|  | #     (input): SynthesisInput( | ||||||
|  | #       w_dim=512, channels=1024, size=[36, 36], | ||||||
|  | #       sampling_rate=16, bandwidth=2 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=4, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L0_36_1024): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=False, use_fp16=False, | ||||||
|  | #       in_sampling_rate=16, out_sampling_rate=16, | ||||||
|  | #       in_cutoff=2, out_cutoff=2, | ||||||
|  | #       in_half_width=6, out_half_width=6, | ||||||
|  | #       in_size=[36, 36], out_size=[36, 36], | ||||||
|  | #       in_channels=1024, out_channels=1024 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L1_36_1024): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=False, use_fp16=False, | ||||||
|  | #       in_sampling_rate=16, out_sampling_rate=16, | ||||||
|  | #       in_cutoff=2, out_cutoff=2.99661, | ||||||
|  | #       in_half_width=6, out_half_width=5.00339, | ||||||
|  | #       in_size=[36, 36], out_size=[36, 36], | ||||||
|  | #       in_channels=1024, out_channels=1024 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L2_52_1024): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=False, use_fp16=False, | ||||||
|  | #       in_sampling_rate=16, out_sampling_rate=32, | ||||||
|  | #       in_cutoff=2.99661, out_cutoff=4.48985, | ||||||
|  | #       in_half_width=5.00339, out_half_width=11.5102, | ||||||
|  | #       in_size=[36, 36], out_size=[52, 52], | ||||||
|  | #       in_channels=1024, out_channels=1024 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L3_52_1024): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=False, use_fp16=False, | ||||||
|  | #       in_sampling_rate=32, out_sampling_rate=32, | ||||||
|  | #       in_cutoff=4.48985, out_cutoff=6.72717, | ||||||
|  | #       in_half_width=11.5102, out_half_width=9.27283, | ||||||
|  | #       in_size=[52, 52], out_size=[52, 52], | ||||||
|  | #       in_channels=1024, out_channels=1024 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L4_84_1024): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=False, use_fp16=True, | ||||||
|  | #       in_sampling_rate=32, out_sampling_rate=64, | ||||||
|  | #       in_cutoff=6.72717, out_cutoff=10.0794, | ||||||
|  | #       in_half_width=9.27283, out_half_width=21.9206, | ||||||
|  | #       in_size=[52, 52], out_size=[84, 84], | ||||||
|  | #       in_channels=1024, out_channels=1024 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L5_84_1024): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=False, use_fp16=True, | ||||||
|  | #       in_sampling_rate=64, out_sampling_rate=64, | ||||||
|  | #       in_cutoff=10.0794, out_cutoff=15.102, | ||||||
|  | #       in_half_width=21.9206, out_half_width=16.898, | ||||||
|  | #       in_size=[84, 84], out_size=[84, 84], | ||||||
|  | #       in_channels=1024, out_channels=1024 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L6_148_1024): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=False, use_fp16=True, | ||||||
|  | #       in_sampling_rate=64, out_sampling_rate=128, | ||||||
|  | #       in_cutoff=15.102, out_cutoff=22.6274, | ||||||
|  | #       in_half_width=16.898, out_half_width=41.3726, | ||||||
|  | #       in_size=[84, 84], out_size=[148, 148], | ||||||
|  | #       in_channels=1024, out_channels=1024 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L7_148_967): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=False, use_fp16=True, | ||||||
|  | #       in_sampling_rate=128, out_sampling_rate=128, | ||||||
|  | #       in_cutoff=22.6274, out_cutoff=33.9028, | ||||||
|  | #       in_half_width=41.3726, out_half_width=30.0972, | ||||||
|  | #       in_size=[148, 148], out_size=[148, 148], | ||||||
|  | #       in_channels=1024, out_channels=967 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L8_276_645): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=False, use_fp16=True, | ||||||
|  | #       in_sampling_rate=128, out_sampling_rate=256, | ||||||
|  | #       in_cutoff=33.9028, out_cutoff=50.7968, | ||||||
|  | #       in_half_width=30.0972, out_half_width=77.2032, | ||||||
|  | #       in_size=[148, 148], out_size=[276, 276], | ||||||
|  | #       in_channels=967, out_channels=645 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=967, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L9_276_431): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=False, use_fp16=True, | ||||||
|  | #       in_sampling_rate=256, out_sampling_rate=256, | ||||||
|  | #       in_cutoff=50.7968, out_cutoff=76.1093, | ||||||
|  | #       in_half_width=77.2032, out_half_width=51.8907, | ||||||
|  | #       in_size=[276, 276], out_size=[276, 276], | ||||||
|  | #       in_channels=645, out_channels=431 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=645, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L10_532_287): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=False, use_fp16=True, | ||||||
|  | #       in_sampling_rate=256, out_sampling_rate=512, | ||||||
|  | #       in_cutoff=76.1093, out_cutoff=114.035, | ||||||
|  | #       in_half_width=51.8907, out_half_width=141.965, | ||||||
|  | #       in_size=[276, 276], out_size=[532, 532], | ||||||
|  | #       in_channels=431, out_channels=287 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=431, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L11_532_192): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=False, use_fp16=True, | ||||||
|  | #       in_sampling_rate=512, out_sampling_rate=512, | ||||||
|  | #       in_cutoff=114.035, out_cutoff=170.86, | ||||||
|  | #       in_half_width=141.965, out_half_width=85.1405, | ||||||
|  | #       in_size=[532, 532], out_size=[532, 532], | ||||||
|  | #       in_channels=287, out_channels=192 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=287, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L12_532_128): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=True, use_fp16=True, | ||||||
|  | #       in_sampling_rate=512, out_sampling_rate=512, | ||||||
|  | #       in_cutoff=170.86, out_cutoff=256, | ||||||
|  | #       in_half_width=85.1405, out_half_width=59.173, | ||||||
|  | #       in_size=[532, 532], out_size=[532, 532], | ||||||
|  | #       in_channels=192, out_channels=128 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=192, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L13_512_128): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=False, | ||||||
|  | #       is_critically_sampled=True, use_fp16=True, | ||||||
|  | #       in_sampling_rate=512, out_sampling_rate=512, | ||||||
|  | #       in_cutoff=256, out_cutoff=256, | ||||||
|  | #       in_half_width=59.173, out_half_width=59.173, | ||||||
|  | #       in_size=[532, 532], out_size=[512, 512], | ||||||
|  | #       in_channels=128, out_channels=128 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=128, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #     (L14_512_3): SynthesisLayer( | ||||||
|  | #       w_dim=512, is_torgb=True, | ||||||
|  | #       is_critically_sampled=True, use_fp16=True, | ||||||
|  | #       in_sampling_rate=512, out_sampling_rate=512, | ||||||
|  | #       in_cutoff=256, out_cutoff=256, | ||||||
|  | #       in_half_width=59.173, out_half_width=59.173, | ||||||
|  | #       in_size=[512, 512], out_size=[512, 512], | ||||||
|  | #       in_channels=128, out_channels=3 | ||||||
|  | #       (affine): FullyConnectedLayer(in_features=512, out_features=128, activation=linear) | ||||||
|  | #     ) | ||||||
|  | #   ) | ||||||
|  | #   (mapping): MappingNetwork( | ||||||
|  | #     z_dim=512, c_dim=0, w_dim=512, num_ws=16 | ||||||
|  | #     (fc0): FullyConnectedLayer(in_features=512, out_features=512, activation=lrelu) | ||||||
|  | #     (fc1): FullyConnectedLayer(in_features=512, out_features=512, activation=lrelu) | ||||||
|  | #   ) | ||||||
|  | # ) | ||||||
							
								
								
									
										37
									
								
								models/stylegan3/test001_s3.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								models/stylegan3/test001_s3.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,37 @@ | |||||||
|  | import torchvision | ||||||
|  | import argparse | ||||||
|  | from argparse import Namespace | ||||||
|  | from run_optimization3 import main | ||||||
|  | 
 | ||||||
|  | parser = argparse.ArgumentParser() | ||||||
|  | # parser.add_argument("--description", type=str, default="a person with purple hair", | ||||||
|  | parser.add_argument("--description", type=str, default="a person with purple hair", | ||||||
|  |                     help="the text that guides the editing/generation") | ||||||
|  | parser.add_argument("--ckpt", type=str, default="/home/ly/StyleCLIP-main/pretrained_models/stylegan3-r-ffhqu-1024x1024.pkl", | ||||||
|  |                     help="pretrained StyleGAN3 weights") | ||||||
|  | parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution") | ||||||
|  | parser.add_argument("--lr_rampup", type=float, default=0.05) | ||||||
|  | parser.add_argument("--lr", type=float, default=0.1) | ||||||
|  | parser.add_argument("--step", type=int, default=300, help="number of optimization steps") | ||||||
|  | parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"], | ||||||
|  |                     help="choose between edit an image an generate a free one") | ||||||
|  | parser.add_argument("--l2_lambda", type=float, default=0.008, | ||||||
|  |                     help="weight of the latent distance (used for editing only)") | ||||||
|  | parser.add_argument("--latent_path", type=str, default=None,            #"/home/ly/StyleCLIP-main/latents_test/example_celebs.pt" | ||||||
|  |                     help="starts the optimization from the given latent code if provided. Otherwise, starts from" | ||||||
|  |                          "the mean latent in a free generation, and from a random one in editing. " | ||||||
|  |                          "Expects a .pt format") | ||||||
|  | parser.add_argument("--truncation", type=float, default=0.5, | ||||||
|  |                     help="used only for the initial latent vector, and only when a latent code path is" | ||||||
|  |                          "not provided") | ||||||
|  | parser.add_argument("--save_intermediate_image_every", type=int, default=20, | ||||||
|  |                     help="if > 0 then saves intermidate results during the optimization") | ||||||
|  | parser.add_argument("--results_dir", type=str, default="/home/ly/StyleCLIP-main/results/stygan3Clip") | ||||||
|  | parser.add_argument('--work_in_stylespace', default=False, action='store_true', help="trains a mapper in S instead of W+") | ||||||
|  | parser.add_argument('--ir_se50_weights', default='/home/ly/StyleCLIP-main/pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss") | ||||||
|  | parser.add_argument('--id_lambda', default=0.10, type=float, help='ID loss multiplier factor') | ||||||
|  | 
 | ||||||
|  | args = vars(parser.parse_args()) | ||||||
|  | result_image = main(Namespace(**args)) | ||||||
|  | torchvision.utils.save_image(result_image.detach().cpu(), f"/home/ly/StyleCLIP-main/results/stygan3Clip/final_result.png", normalize=True, scale_each=True, | ||||||
|  |                              range=(-1, 1)) | ||||||
							
								
								
									
										9
									
								
								models/stylegan3/torch_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								models/stylegan3/torch_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | # empty | ||||||
							
								
								
									
										157
									
								
								models/stylegan3/torch_utils/custom_ops.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										157
									
								
								models/stylegan3/torch_utils/custom_ops.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,157 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | import glob | ||||||
|  | import hashlib | ||||||
|  | import importlib | ||||||
|  | import os | ||||||
|  | import re | ||||||
|  | import shutil | ||||||
|  | import uuid | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | import torch.utils.cpp_extension | ||||||
|  | from torch.utils.file_baton import FileBaton | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Global options. | ||||||
|  | 
 | ||||||
|  | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Internal helper funcs. | ||||||
|  | 
 | ||||||
|  | def _find_compiler_bindir(): | ||||||
|  |     patterns = [ | ||||||
|  |         'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', | ||||||
|  |         'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', | ||||||
|  |         'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', | ||||||
|  |         'C:/Program Files*/Microsoft Visual Studio */vc/bin', | ||||||
|  |     ] | ||||||
|  |     for pattern in patterns: | ||||||
|  |         matches = sorted(glob.glob(pattern)) | ||||||
|  |         if len(matches): | ||||||
|  |             return matches[-1] | ||||||
|  |     return None | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _get_mangled_gpu_name(): | ||||||
|  |     name = torch.cuda.get_device_name().lower() | ||||||
|  |     out = [] | ||||||
|  |     for c in name: | ||||||
|  |         if re.match('[a-z0-9_-]+', c): | ||||||
|  |             out.append(c) | ||||||
|  |         else: | ||||||
|  |             out.append('-') | ||||||
|  |     return ''.join(out) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Main entry point for compiling and loading C++/CUDA plugins. | ||||||
|  | 
 | ||||||
|  | _cached_plugins = dict() | ||||||
|  | 
 | ||||||
|  | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): | ||||||
|  |     assert verbosity in ['none', 'brief', 'full'] | ||||||
|  |     if headers is None: | ||||||
|  |         headers = [] | ||||||
|  |     if source_dir is not None: | ||||||
|  |         sources = [os.path.join(source_dir, fname) for fname in sources] | ||||||
|  |         headers = [os.path.join(source_dir, fname) for fname in headers] | ||||||
|  | 
 | ||||||
|  |     # Already cached? | ||||||
|  |     if module_name in _cached_plugins: | ||||||
|  |         return _cached_plugins[module_name] | ||||||
|  | 
 | ||||||
|  |     # Print status. | ||||||
|  |     if verbosity == 'full': | ||||||
|  |         print(f'Setting up PyTorch plugin "{module_name}"...') | ||||||
|  |     elif verbosity == 'brief': | ||||||
|  |         print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) | ||||||
|  |     verbose_build = (verbosity == 'full') | ||||||
|  | 
 | ||||||
|  |     # Compile and load. | ||||||
|  |     try: # pylint: disable=too-many-nested-blocks | ||||||
|  |         # Make sure we can find the necessary compiler binaries. | ||||||
|  |         if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: | ||||||
|  |             compiler_bindir = _find_compiler_bindir() | ||||||
|  |             if compiler_bindir is None: | ||||||
|  |                 raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') | ||||||
|  |             os.environ['PATH'] += ';' + compiler_bindir | ||||||
|  | 
 | ||||||
|  |         # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either | ||||||
|  |         # break the build or unnecessarily restrict what's available to nvcc. | ||||||
|  |         # Unset it to let nvcc decide based on what's available on the | ||||||
|  |         # machine. | ||||||
|  |         os.environ['TORCH_CUDA_ARCH_LIST'] = '' | ||||||
|  | 
 | ||||||
|  |         # Incremental build md5sum trickery.  Copies all the input source files | ||||||
|  |         # into a cached build directory under a combined md5 digest of the input | ||||||
|  |         # source files.  Copying is done only if the combined digest has changed. | ||||||
|  |         # This keeps input file timestamps and filenames the same as in previous | ||||||
|  |         # extension builds, allowing for fast incremental rebuilds. | ||||||
|  |         # | ||||||
|  |         # This optimization is done only in case all the source files reside in | ||||||
|  |         # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR | ||||||
|  |         # environment variable is set (we take this as a signal that the user | ||||||
|  |         # actually cares about this.) | ||||||
|  |         # | ||||||
|  |         # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work | ||||||
|  |         # around the *.cu dependency bug in ninja config. | ||||||
|  |         # | ||||||
|  |         all_source_files = sorted(sources + headers) | ||||||
|  |         all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) | ||||||
|  |         if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): | ||||||
|  | 
 | ||||||
|  |             # Compute combined hash digest for all source files. | ||||||
|  |             hash_md5 = hashlib.md5() | ||||||
|  |             for src in all_source_files: | ||||||
|  |                 with open(src, 'rb') as f: | ||||||
|  |                     hash_md5.update(f.read()) | ||||||
|  | 
 | ||||||
|  |             # Select cached build directory name. | ||||||
|  |             source_digest = hash_md5.hexdigest() | ||||||
|  |             build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access | ||||||
|  |             cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') | ||||||
|  | 
 | ||||||
|  |             if not os.path.isdir(cached_build_dir): | ||||||
|  |                 tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' | ||||||
|  |                 os.makedirs(tmpdir) | ||||||
|  |                 for src in all_source_files: | ||||||
|  |                     shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) | ||||||
|  |                 try: | ||||||
|  |                     os.replace(tmpdir, cached_build_dir) # atomic | ||||||
|  |                 except OSError: | ||||||
|  |                     # source directory already exists, delete tmpdir and its contents. | ||||||
|  |                     shutil.rmtree(tmpdir) | ||||||
|  |                     if not os.path.isdir(cached_build_dir): raise | ||||||
|  | 
 | ||||||
|  |             # Compile. | ||||||
|  |             cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] | ||||||
|  |             torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, | ||||||
|  |                 verbose=verbose_build, sources=cached_sources, **build_kwargs) | ||||||
|  |         else: | ||||||
|  |             torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) | ||||||
|  | 
 | ||||||
|  |         # Load. | ||||||
|  |         module = importlib.import_module(module_name) | ||||||
|  | 
 | ||||||
|  |     except: | ||||||
|  |         if verbosity == 'brief': | ||||||
|  |             print('Failed!') | ||||||
|  |         raise | ||||||
|  | 
 | ||||||
|  |     # Print status and add to cache dict. | ||||||
|  |     if verbosity == 'full': | ||||||
|  |         print(f'Done setting up PyTorch plugin "{module_name}".') | ||||||
|  |     elif verbosity == 'brief': | ||||||
|  |         print('Done.') | ||||||
|  |     _cached_plugins[module_name] = module | ||||||
|  |     return module | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										267
									
								
								models/stylegan3/torch_utils/misc.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										267
									
								
								models/stylegan3/torch_utils/misc.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,267 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | import re | ||||||
|  | import contextlib | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import warnings | ||||||
|  | import dnnlib | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the | ||||||
|  | # same constant is used multiple times. | ||||||
|  | 
 | ||||||
|  | _constant_cache = dict() | ||||||
|  | 
 | ||||||
|  | def constant(value, shape=None, dtype=None, device=None, memory_format=None): | ||||||
|  |     value = np.asarray(value) | ||||||
|  |     if shape is not None: | ||||||
|  |         shape = tuple(shape) | ||||||
|  |     if dtype is None: | ||||||
|  |         dtype = torch.get_default_dtype() | ||||||
|  |     if device is None: | ||||||
|  |         device = torch.device('cpu') | ||||||
|  |     if memory_format is None: | ||||||
|  |         memory_format = torch.contiguous_format | ||||||
|  | 
 | ||||||
|  |     key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) | ||||||
|  |     tensor = _constant_cache.get(key, None) | ||||||
|  |     if tensor is None: | ||||||
|  |         tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) | ||||||
|  |         if shape is not None: | ||||||
|  |             tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) | ||||||
|  |         tensor = tensor.contiguous(memory_format=memory_format) | ||||||
|  |         _constant_cache[key] = tensor | ||||||
|  |     return tensor | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Replace NaN/Inf with specified numerical values. | ||||||
|  | 
 | ||||||
|  | try: | ||||||
|  |     nan_to_num = torch.nan_to_num # 1.8.0a0 | ||||||
|  | except AttributeError: | ||||||
|  |     def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin | ||||||
|  |         assert isinstance(input, torch.Tensor) | ||||||
|  |         if posinf is None: | ||||||
|  |             posinf = torch.finfo(input.dtype).max | ||||||
|  |         if neginf is None: | ||||||
|  |             neginf = torch.finfo(input.dtype).min | ||||||
|  |         assert nan == 0 | ||||||
|  |         return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Symbolic assert. | ||||||
|  | 
 | ||||||
|  | try: | ||||||
|  |     symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access | ||||||
|  | except AttributeError: | ||||||
|  |     symbolic_assert = torch.Assert # 1.7.0 | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Context manager to temporarily suppress known warnings in torch.jit.trace(). | ||||||
|  | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 | ||||||
|  | 
 | ||||||
|  | @contextlib.contextmanager | ||||||
|  | def suppress_tracer_warnings(): | ||||||
|  |     flt = ('ignore', None, torch.jit.TracerWarning, None, 0) | ||||||
|  |     warnings.filters.insert(0, flt) | ||||||
|  |     yield | ||||||
|  |     warnings.filters.remove(flt) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Assert that the shape of a tensor matches the given list of integers. | ||||||
|  | # None indicates that the size of a dimension is allowed to vary. | ||||||
|  | # Performs symbolic assertion when used in torch.jit.trace(). | ||||||
|  | 
 | ||||||
|  | def assert_shape(tensor, ref_shape): | ||||||
|  |     #使用ndim报错:AttributeError: 'list' object has no attribute 'ndim' | ||||||
|  |     if tensor.ndim != len(ref_shape): | ||||||
|  |         raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') | ||||||
|  |     for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): | ||||||
|  |         if ref_size is None: | ||||||
|  |             pass | ||||||
|  |         elif isinstance(ref_size, torch.Tensor): | ||||||
|  |             with suppress_tracer_warnings(): # as_tensor results are registered as constants | ||||||
|  |                 symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') | ||||||
|  |         elif isinstance(size, torch.Tensor): | ||||||
|  |             with suppress_tracer_warnings(): # as_tensor results are registered as constants | ||||||
|  |                 symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') | ||||||
|  |         elif size != ref_size: | ||||||
|  |             raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Function decorator that calls torch.autograd.profiler.record_function(). | ||||||
|  | 
 | ||||||
|  | def profiled_function(fn): | ||||||
|  |     def decorator(*args, **kwargs): | ||||||
|  |         with torch.autograd.profiler.record_function(fn.__name__): | ||||||
|  |             return fn(*args, **kwargs) | ||||||
|  |     decorator.__name__ = fn.__name__ | ||||||
|  |     return decorator | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Sampler for torch.utils.data.DataLoader that loops over the dataset | ||||||
|  | # indefinitely, shuffling items as it goes. | ||||||
|  | 
 | ||||||
|  | class InfiniteSampler(torch.utils.data.Sampler): | ||||||
|  |     def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): | ||||||
|  |         assert len(dataset) > 0 | ||||||
|  |         assert num_replicas > 0 | ||||||
|  |         assert 0 <= rank < num_replicas | ||||||
|  |         assert 0 <= window_size <= 1 | ||||||
|  |         super().__init__(dataset) | ||||||
|  |         self.dataset = dataset | ||||||
|  |         self.rank = rank | ||||||
|  |         self.num_replicas = num_replicas | ||||||
|  |         self.shuffle = shuffle | ||||||
|  |         self.seed = seed | ||||||
|  |         self.window_size = window_size | ||||||
|  | 
 | ||||||
|  |     def __iter__(self): | ||||||
|  |         order = np.arange(len(self.dataset)) | ||||||
|  |         rnd = None | ||||||
|  |         window = 0 | ||||||
|  |         if self.shuffle: | ||||||
|  |             rnd = np.random.RandomState(self.seed) | ||||||
|  |             rnd.shuffle(order) | ||||||
|  |             window = int(np.rint(order.size * self.window_size)) | ||||||
|  | 
 | ||||||
|  |         idx = 0 | ||||||
|  |         while True: | ||||||
|  |             i = idx % order.size | ||||||
|  |             if idx % self.num_replicas == self.rank: | ||||||
|  |                 yield order[i] | ||||||
|  |             if window >= 2: | ||||||
|  |                 j = (i - rnd.randint(window)) % order.size | ||||||
|  |                 order[i], order[j] = order[j], order[i] | ||||||
|  |             idx += 1 | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Utilities for operating with torch.nn.Module parameters and buffers. | ||||||
|  | 
 | ||||||
|  | def params_and_buffers(module): | ||||||
|  |     assert isinstance(module, torch.nn.Module) | ||||||
|  |     return list(module.parameters()) + list(module.buffers()) | ||||||
|  | 
 | ||||||
|  | def named_params_and_buffers(module): | ||||||
|  |     assert isinstance(module, torch.nn.Module) | ||||||
|  |     return list(module.named_parameters()) + list(module.named_buffers()) | ||||||
|  | 
 | ||||||
|  | def copy_params_and_buffers(src_module, dst_module, require_all=False): | ||||||
|  |     assert isinstance(src_module, torch.nn.Module) | ||||||
|  |     assert isinstance(dst_module, torch.nn.Module) | ||||||
|  |     src_tensors = dict(named_params_and_buffers(src_module)) | ||||||
|  |     for name, tensor in named_params_and_buffers(dst_module): | ||||||
|  |         assert (name in src_tensors) or (not require_all) | ||||||
|  |         if name in src_tensors: | ||||||
|  |             tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Context manager for easily enabling/disabling DistributedDataParallel | ||||||
|  | # synchronization. | ||||||
|  | 
 | ||||||
|  | @contextlib.contextmanager | ||||||
|  | def ddp_sync(module, sync): | ||||||
|  |     assert isinstance(module, torch.nn.Module) | ||||||
|  |     if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): | ||||||
|  |         yield | ||||||
|  |     else: | ||||||
|  |         with module.no_sync(): | ||||||
|  |             yield | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Check DistributedDataParallel consistency across processes. | ||||||
|  | 
 | ||||||
|  | def check_ddp_consistency(module, ignore_regex=None): | ||||||
|  |     assert isinstance(module, torch.nn.Module) | ||||||
|  |     for name, tensor in named_params_and_buffers(module): | ||||||
|  |         fullname = type(module).__name__ + '.' + name | ||||||
|  |         if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): | ||||||
|  |             continue | ||||||
|  |         tensor = tensor.detach() | ||||||
|  |         if tensor.is_floating_point(): | ||||||
|  |             tensor = nan_to_num(tensor) | ||||||
|  |         other = tensor.clone() | ||||||
|  |         torch.distributed.broadcast(tensor=other, src=0) | ||||||
|  |         assert (tensor == other).all(), fullname | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | # Print summary table of module hierarchy. | ||||||
|  | 
 | ||||||
|  | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): | ||||||
|  |     assert isinstance(module, torch.nn.Module) | ||||||
|  |     assert not isinstance(module, torch.jit.ScriptModule) | ||||||
|  |     assert isinstance(inputs, (tuple, list)) | ||||||
|  | 
 | ||||||
|  |     # Register hooks. | ||||||
|  |     entries = [] | ||||||
|  |     nesting = [0] | ||||||
|  |     def pre_hook(_mod, _inputs): | ||||||
|  |         nesting[0] += 1 | ||||||
|  |     def post_hook(mod, _inputs, outputs): | ||||||
|  |         nesting[0] -= 1 | ||||||
|  |         if nesting[0] <= max_nesting: | ||||||
|  |             outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] | ||||||
|  |             outputs = [t for t in outputs if isinstance(t, torch.Tensor)] | ||||||
|  |             entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) | ||||||
|  |     hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] | ||||||
|  |     hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] | ||||||
|  | 
 | ||||||
|  |     # Run module. | ||||||
|  |     outputs = module(*inputs) | ||||||
|  |     for hook in hooks: | ||||||
|  |         hook.remove() | ||||||
|  | 
 | ||||||
|  |     # Identify unique outputs, parameters, and buffers. | ||||||
|  |     tensors_seen = set() | ||||||
|  |     for e in entries: | ||||||
|  |         e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] | ||||||
|  |         e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] | ||||||
|  |         e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] | ||||||
|  |         tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} | ||||||
|  | 
 | ||||||
|  |     # Filter out redundant entries. | ||||||
|  |     if skip_redundant: | ||||||
|  |         entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] | ||||||
|  | 
 | ||||||
|  |     # Construct table. | ||||||
|  |     rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] | ||||||
|  |     rows += [['---'] * len(rows[0])] | ||||||
|  |     param_total = 0 | ||||||
|  |     buffer_total = 0 | ||||||
|  |     submodule_names = {mod: name for name, mod in module.named_modules()} | ||||||
|  |     for e in entries: | ||||||
|  |         name = '<top-level>' if e.mod is module else submodule_names[e.mod] | ||||||
|  |         param_size = sum(t.numel() for t in e.unique_params) | ||||||
|  |         buffer_size = sum(t.numel() for t in e.unique_buffers) | ||||||
|  |         output_shapes = [str(list(t.shape)) for t in e.outputs] | ||||||
|  |         output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] | ||||||
|  |         rows += [[ | ||||||
|  |             name + (':0' if len(e.outputs) >= 2 else ''), | ||||||
|  |             str(param_size) if param_size else '-', | ||||||
|  |             str(buffer_size) if buffer_size else '-', | ||||||
|  |             (output_shapes + ['-'])[0], | ||||||
|  |             (output_dtypes + ['-'])[0], | ||||||
|  |         ]] | ||||||
|  |         for idx in range(1, len(e.outputs)): | ||||||
|  |             rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] | ||||||
|  |         param_total += param_size | ||||||
|  |         buffer_total += buffer_size | ||||||
|  |     rows += [['---'] * len(rows[0])] | ||||||
|  |     rows += [['Total', str(param_total), str(buffer_total), '-', '-']] | ||||||
|  | 
 | ||||||
|  |     # Print table. | ||||||
|  |     widths = [max(len(cell) for cell in column) for column in zip(*rows)] | ||||||
|  |     print() | ||||||
|  |     for row in rows: | ||||||
|  |         print('  '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) | ||||||
|  |     print() | ||||||
|  |     return outputs | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										9
									
								
								models/stylegan3/torch_utils/ops/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								models/stylegan3/torch_utils/ops/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | # empty | ||||||
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								models/stylegan3/torch_utils/ops/__pycache__/fma.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								models/stylegan3/torch_utils/ops/__pycache__/fma.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										99
									
								
								models/stylegan3/torch_utils/ops/bias_act.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								models/stylegan3/torch_utils/ops/bias_act.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,99 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property
 | ||||||
|  | // and proprietary rights in and to this software, related documentation
 | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or
 | ||||||
|  | // distribution of this software and related documentation without an express
 | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited.
 | ||||||
|  | 
 | ||||||
|  | #include <torch/extension.h> | ||||||
|  | #include <ATen/cuda/CUDAContext.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  | #include "bias_act.h" | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | static bool has_same_layout(torch::Tensor x, torch::Tensor y) | ||||||
|  | { | ||||||
|  |     if (x.dim() != y.dim()) | ||||||
|  |         return false; | ||||||
|  |     for (int64_t i = 0; i < x.dim(); i++) | ||||||
|  |     { | ||||||
|  |         if (x.size(i) != y.size(i)) | ||||||
|  |             return false; | ||||||
|  |         if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) | ||||||
|  |             return false; | ||||||
|  |     } | ||||||
|  |     return true; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) | ||||||
|  | { | ||||||
|  |     // Validate arguments.
 | ||||||
|  |     TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); | ||||||
|  |     TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); | ||||||
|  |     TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); | ||||||
|  |     TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); | ||||||
|  |     TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); | ||||||
|  |     TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); | ||||||
|  |     TORCH_CHECK(b.dim() == 1, "b must have rank 1"); | ||||||
|  |     TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); | ||||||
|  |     TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); | ||||||
|  |     TORCH_CHECK(grad >= 0, "grad must be non-negative"); | ||||||
|  | 
 | ||||||
|  |     // Validate layout.
 | ||||||
|  |     TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); | ||||||
|  |     TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); | ||||||
|  |     TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); | ||||||
|  |     TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); | ||||||
|  |     TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); | ||||||
|  | 
 | ||||||
|  |     // Create output tensor.
 | ||||||
|  |     const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | ||||||
|  |     torch::Tensor y = torch::empty_like(x); | ||||||
|  |     TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); | ||||||
|  | 
 | ||||||
|  |     // Initialize CUDA kernel parameters.
 | ||||||
|  |     bias_act_kernel_params p; | ||||||
|  |     p.x     = x.data_ptr(); | ||||||
|  |     p.b     = (b.numel()) ? b.data_ptr() : NULL; | ||||||
|  |     p.xref  = (xref.numel()) ? xref.data_ptr() : NULL; | ||||||
|  |     p.yref  = (yref.numel()) ? yref.data_ptr() : NULL; | ||||||
|  |     p.dy    = (dy.numel()) ? dy.data_ptr() : NULL; | ||||||
|  |     p.y     = y.data_ptr(); | ||||||
|  |     p.grad  = grad; | ||||||
|  |     p.act   = act; | ||||||
|  |     p.alpha = alpha; | ||||||
|  |     p.gain  = gain; | ||||||
|  |     p.clamp = clamp; | ||||||
|  |     p.sizeX = (int)x.numel(); | ||||||
|  |     p.sizeB = (int)b.numel(); | ||||||
|  |     p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; | ||||||
|  | 
 | ||||||
|  |     // Choose CUDA kernel.
 | ||||||
|  |     void* kernel; | ||||||
|  |     AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] | ||||||
|  |     { | ||||||
|  |         kernel = choose_bias_act_kernel<scalar_t>(p); | ||||||
|  |     }); | ||||||
|  |     TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); | ||||||
|  | 
 | ||||||
|  |     // Launch CUDA kernel.
 | ||||||
|  |     p.loopX = 4; | ||||||
|  |     int blockSize = 4 * 32; | ||||||
|  |     int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; | ||||||
|  |     void* args[] = {&p}; | ||||||
|  |     AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); | ||||||
|  |     return y; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||||||
|  | { | ||||||
|  |     m.def("bias_act", &bias_act); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
							
								
								
									
										173
									
								
								models/stylegan3/torch_utils/ops/bias_act.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								models/stylegan3/torch_utils/ops/bias_act.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,173 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | // | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | // and proprietary rights in and to this software, related documentation | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | // distribution of this software and related documentation without an express | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | #include <c10/util/Half.h> | ||||||
|  | #include "bias_act.h" | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // Helpers. | ||||||
|  | 
 | ||||||
|  | template <class T> struct InternalType; | ||||||
|  | template <> struct InternalType<double>     { typedef double scalar_t; }; | ||||||
|  | template <> struct InternalType<float>      { typedef float  scalar_t; }; | ||||||
|  | template <> struct InternalType<c10::Half>  { typedef float  scalar_t; }; | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // CUDA kernel. | ||||||
|  | 
 | ||||||
|  | template <class T, int A> | ||||||
|  | __global__ void bias_act_kernel(bias_act_kernel_params p) | ||||||
|  | { | ||||||
|  |     typedef typename InternalType<T>::scalar_t scalar_t; | ||||||
|  |     int G                 = p.grad; | ||||||
|  |     scalar_t alpha        = (scalar_t)p.alpha; | ||||||
|  |     scalar_t gain         = (scalar_t)p.gain; | ||||||
|  |     scalar_t clamp        = (scalar_t)p.clamp; | ||||||
|  |     scalar_t one          = (scalar_t)1; | ||||||
|  |     scalar_t two          = (scalar_t)2; | ||||||
|  |     scalar_t expRange     = (scalar_t)80; | ||||||
|  |     scalar_t halfExpRange = (scalar_t)40; | ||||||
|  |     scalar_t seluScale    = (scalar_t)1.0507009873554804934193349852946; | ||||||
|  |     scalar_t seluAlpha    = (scalar_t)1.6732632423543772848170429916717; | ||||||
|  | 
 | ||||||
|  |     // Loop over elements. | ||||||
|  |     int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; | ||||||
|  |     for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) | ||||||
|  |     { | ||||||
|  |         // Load. | ||||||
|  |         scalar_t x = (scalar_t)((const T*)p.x)[xi]; | ||||||
|  |         scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; | ||||||
|  |         scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; | ||||||
|  |         scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; | ||||||
|  |         scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; | ||||||
|  |         scalar_t yy = (gain != 0) ? yref / gain : 0; | ||||||
|  |         scalar_t y = 0; | ||||||
|  | 
 | ||||||
|  |         // Apply bias. | ||||||
|  |         ((G == 0) ? x : xref) += b; | ||||||
|  | 
 | ||||||
|  |         // linear | ||||||
|  |         if (A == 1) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = x; | ||||||
|  |             if (G == 1) y = x; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // relu | ||||||
|  |         if (A == 2) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = (x > 0) ? x : 0; | ||||||
|  |             if (G == 1) y = (yy > 0) ? x : 0; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // lrelu | ||||||
|  |         if (A == 3) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = (x > 0) ? x : x * alpha; | ||||||
|  |             if (G == 1) y = (yy > 0) ? x : x * alpha; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // tanh | ||||||
|  |         if (A == 4) | ||||||
|  |         { | ||||||
|  |             if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } | ||||||
|  |             if (G == 1) y = x * (one - yy * yy); | ||||||
|  |             if (G == 2) y = x * (one - yy * yy) * (-two * yy); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // sigmoid | ||||||
|  |         if (A == 5) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); | ||||||
|  |             if (G == 1) y = x * yy * (one - yy); | ||||||
|  |             if (G == 2) y = x * yy * (one - yy) * (one - two * yy); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // elu | ||||||
|  |         if (A == 6) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = (x >= 0) ? x : exp(x) - one; | ||||||
|  |             if (G == 1) y = (yy >= 0) ? x : x * (yy + one); | ||||||
|  |             if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // selu | ||||||
|  |         if (A == 7) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); | ||||||
|  |             if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); | ||||||
|  |             if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // softplus | ||||||
|  |         if (A == 8) | ||||||
|  |         { | ||||||
|  |             if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); | ||||||
|  |             if (G == 1) y = x * (one - exp(-yy)); | ||||||
|  |             if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // swish | ||||||
|  |         if (A == 9) | ||||||
|  |         { | ||||||
|  |             if (G == 0) | ||||||
|  |                 y = (x < -expRange) ? 0 : x / (exp(-x) + one); | ||||||
|  |             else | ||||||
|  |             { | ||||||
|  |                 scalar_t c = exp(xref); | ||||||
|  |                 scalar_t d = c + one; | ||||||
|  |                 if (G == 1) | ||||||
|  |                     y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); | ||||||
|  |                 else | ||||||
|  |                     y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); | ||||||
|  |                 yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Apply gain. | ||||||
|  |         y *= gain * dy; | ||||||
|  | 
 | ||||||
|  |         // Clamp. | ||||||
|  |         if (clamp >= 0) | ||||||
|  |         { | ||||||
|  |             if (G == 0) | ||||||
|  |                 y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; | ||||||
|  |             else | ||||||
|  |                 y = (yref > -clamp & yref < clamp) ? y : 0; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Store. | ||||||
|  |         ((T*)p.y)[xi] = (T)y; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // CUDA kernel selection. | ||||||
|  | 
 | ||||||
|  | template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p) | ||||||
|  | { | ||||||
|  |     if (p.act == 1) return (void*)bias_act_kernel<T, 1>; | ||||||
|  |     if (p.act == 2) return (void*)bias_act_kernel<T, 2>; | ||||||
|  |     if (p.act == 3) return (void*)bias_act_kernel<T, 3>; | ||||||
|  |     if (p.act == 4) return (void*)bias_act_kernel<T, 4>; | ||||||
|  |     if (p.act == 5) return (void*)bias_act_kernel<T, 5>; | ||||||
|  |     if (p.act == 6) return (void*)bias_act_kernel<T, 6>; | ||||||
|  |     if (p.act == 7) return (void*)bias_act_kernel<T, 7>; | ||||||
|  |     if (p.act == 8) return (void*)bias_act_kernel<T, 8>; | ||||||
|  |     if (p.act == 9) return (void*)bias_act_kernel<T, 9>; | ||||||
|  |     return NULL; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // Template specializations. | ||||||
|  | 
 | ||||||
|  | template void* choose_bias_act_kernel<double>       (const bias_act_kernel_params& p); | ||||||
|  | template void* choose_bias_act_kernel<float>        (const bias_act_kernel_params& p); | ||||||
|  | template void* choose_bias_act_kernel<c10::Half>    (const bias_act_kernel_params& p); | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
							
								
								
									
										38
									
								
								models/stylegan3/torch_utils/ops/bias_act.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								models/stylegan3/torch_utils/ops/bias_act.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,38 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property
 | ||||||
|  | // and proprietary rights in and to this software, related documentation
 | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or
 | ||||||
|  | // distribution of this software and related documentation without an express
 | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited.
 | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | // CUDA kernel parameters.
 | ||||||
|  | 
 | ||||||
|  | struct bias_act_kernel_params | ||||||
|  | { | ||||||
|  |     const void* x;      // [sizeX]
 | ||||||
|  |     const void* b;      // [sizeB] or NULL
 | ||||||
|  |     const void* xref;   // [sizeX] or NULL
 | ||||||
|  |     const void* yref;   // [sizeX] or NULL
 | ||||||
|  |     const void* dy;     // [sizeX] or NULL
 | ||||||
|  |     void*       y;      // [sizeX]
 | ||||||
|  | 
 | ||||||
|  |     int         grad; | ||||||
|  |     int         act; | ||||||
|  |     float       alpha; | ||||||
|  |     float       gain; | ||||||
|  |     float       clamp; | ||||||
|  | 
 | ||||||
|  |     int         sizeX; | ||||||
|  |     int         sizeB; | ||||||
|  |     int         stepB; | ||||||
|  |     int         loopX; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | // CUDA kernel selection.
 | ||||||
|  | 
 | ||||||
|  | template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p); | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
							
								
								
									
										209
									
								
								models/stylegan3/torch_utils/ops/bias_act.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										209
									
								
								models/stylegan3/torch_utils/ops/bias_act.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,209 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Custom PyTorch ops for efficient bias and activation.""" | ||||||
|  | 
 | ||||||
|  | import os | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import dnnlib | ||||||
|  | 
 | ||||||
|  | from .. import custom_ops | ||||||
|  | from .. import misc | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | activation_funcs = { | ||||||
|  |     'linear':   dnnlib.EasyDict(func=lambda x, **_:         x,                                          def_alpha=0,    def_gain=1,             cuda_idx=1, ref='',  has_2nd_grad=False), | ||||||
|  |     'relu':     dnnlib.EasyDict(func=lambda x, **_:         torch.nn.functional.relu(x),                def_alpha=0,    def_gain=np.sqrt(2),    cuda_idx=2, ref='y', has_2nd_grad=False), | ||||||
|  |     'lrelu':    dnnlib.EasyDict(func=lambda x, alpha, **_:  torch.nn.functional.leaky_relu(x, alpha),   def_alpha=0.2,  def_gain=np.sqrt(2),    cuda_idx=3, ref='y', has_2nd_grad=False), | ||||||
|  |     'tanh':     dnnlib.EasyDict(func=lambda x, **_:         torch.tanh(x),                              def_alpha=0,    def_gain=1,             cuda_idx=4, ref='y', has_2nd_grad=True), | ||||||
|  |     'sigmoid':  dnnlib.EasyDict(func=lambda x, **_:         torch.sigmoid(x),                           def_alpha=0,    def_gain=1,             cuda_idx=5, ref='y', has_2nd_grad=True), | ||||||
|  |     'elu':      dnnlib.EasyDict(func=lambda x, **_:         torch.nn.functional.elu(x),                 def_alpha=0,    def_gain=1,             cuda_idx=6, ref='y', has_2nd_grad=True), | ||||||
|  |     'selu':     dnnlib.EasyDict(func=lambda x, **_:         torch.nn.functional.selu(x),                def_alpha=0,    def_gain=1,             cuda_idx=7, ref='y', has_2nd_grad=True), | ||||||
|  |     'softplus': dnnlib.EasyDict(func=lambda x, **_:         torch.nn.functional.softplus(x),            def_alpha=0,    def_gain=1,             cuda_idx=8, ref='y', has_2nd_grad=True), | ||||||
|  |     'swish':    dnnlib.EasyDict(func=lambda x, **_:         torch.sigmoid(x) * x,                       def_alpha=0,    def_gain=np.sqrt(2),    cuda_idx=9, ref='x', has_2nd_grad=True), | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _plugin = None | ||||||
|  | _null_tensor = torch.empty([0]) | ||||||
|  | 
 | ||||||
|  | def _init(): | ||||||
|  |     global _plugin | ||||||
|  |     if _plugin is None: | ||||||
|  |         _plugin = custom_ops.get_plugin( | ||||||
|  |             module_name='bias_act_plugin', | ||||||
|  |             sources=['bias_act.cpp', 'bias_act.cu'], | ||||||
|  |             headers=['bias_act.h'], | ||||||
|  |             source_dir=os.path.dirname(__file__), | ||||||
|  |             extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], | ||||||
|  |         ) | ||||||
|  |     return True | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): | ||||||
|  |     r"""Fused bias and activation function. | ||||||
|  | 
 | ||||||
|  |     Adds bias `b` to activation tensor `x`, evaluates activation function `act`, | ||||||
|  |     and scales the result by `gain`. Each of the steps is optional. In most cases, | ||||||
|  |     the fused op is considerably more efficient than performing the same calculation | ||||||
|  |     using standard PyTorch ops. It supports first and second order gradients, | ||||||
|  |     but not third order gradients. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x:      Input activation tensor. Can be of any shape. | ||||||
|  |         b:      Bias vector, or `None` to disable. Must be a 1D tensor of the same type | ||||||
|  |                 as `x`. The shape must be known, and it must match the dimension of `x` | ||||||
|  |                 corresponding to `dim`. | ||||||
|  |         dim:    The dimension in `x` corresponding to the elements of `b`. | ||||||
|  |                 The value of `dim` is ignored if `b` is not specified. | ||||||
|  |         act:    Name of the activation function to evaluate, or `"linear"` to disable. | ||||||
|  |                 Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. | ||||||
|  |                 See `activation_funcs` for a full list. `None` is not allowed. | ||||||
|  |         alpha:  Shape parameter for the activation function, or `None` to use the default. | ||||||
|  |         gain:   Scaling factor for the output tensor, or `None` to use default. | ||||||
|  |                 See `activation_funcs` for the default scaling of each activation function. | ||||||
|  |                 If unsure, consider specifying 1. | ||||||
|  |         clamp:  Clamp the output values to `[-clamp, +clamp]`, or `None` to disable | ||||||
|  |                 the clamping (default). | ||||||
|  |         impl:   Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tensor of the same shape and datatype as `x`. | ||||||
|  |     """ | ||||||
|  |     assert isinstance(x, torch.Tensor) | ||||||
|  |     assert impl in ['ref', 'cuda'] | ||||||
|  |     if impl == 'cuda' and x.device.type == 'cuda' and _init(): | ||||||
|  |         return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) | ||||||
|  |     return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @misc.profiled_function | ||||||
|  | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): | ||||||
|  |     """Slow reference implementation of `bias_act()` using standard TensorFlow ops. | ||||||
|  |     """ | ||||||
|  |     assert isinstance(x, torch.Tensor) | ||||||
|  |     assert clamp is None or clamp >= 0 | ||||||
|  |     spec = activation_funcs[act] | ||||||
|  |     alpha = float(alpha if alpha is not None else spec.def_alpha) | ||||||
|  |     gain = float(gain if gain is not None else spec.def_gain) | ||||||
|  |     clamp = float(clamp if clamp is not None else -1) | ||||||
|  | 
 | ||||||
|  |     # Add bias. | ||||||
|  |     if b is not None: | ||||||
|  |         assert isinstance(b, torch.Tensor) and b.ndim == 1 | ||||||
|  |         assert 0 <= dim < x.ndim | ||||||
|  |         assert b.shape[0] == x.shape[dim] | ||||||
|  |         x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) | ||||||
|  | 
 | ||||||
|  |     # Evaluate activation function. | ||||||
|  |     alpha = float(alpha) | ||||||
|  |     x = spec.func(x, alpha=alpha) | ||||||
|  | 
 | ||||||
|  |     # Scale by gain. | ||||||
|  |     gain = float(gain) | ||||||
|  |     if gain != 1: | ||||||
|  |         x = x * gain | ||||||
|  | 
 | ||||||
|  |     # Clamp. | ||||||
|  |     if clamp >= 0: | ||||||
|  |         x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type | ||||||
|  |     return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _bias_act_cuda_cache = dict() | ||||||
|  | 
 | ||||||
|  | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): | ||||||
|  |     """Fast CUDA implementation of `bias_act()` using custom ops. | ||||||
|  |     """ | ||||||
|  |     # Parse arguments. | ||||||
|  |     assert clamp is None or clamp >= 0 | ||||||
|  |     spec = activation_funcs[act] | ||||||
|  |     alpha = float(alpha if alpha is not None else spec.def_alpha) | ||||||
|  |     gain = float(gain if gain is not None else spec.def_gain) | ||||||
|  |     clamp = float(clamp if clamp is not None else -1) | ||||||
|  | 
 | ||||||
|  |     # Lookup from cache. | ||||||
|  |     key = (dim, act, alpha, gain, clamp) | ||||||
|  |     if key in _bias_act_cuda_cache: | ||||||
|  |         return _bias_act_cuda_cache[key] | ||||||
|  | 
 | ||||||
|  |     # Forward op. | ||||||
|  |     class BiasActCuda(torch.autograd.Function): | ||||||
|  |         @staticmethod | ||||||
|  |         def forward(ctx, x, b): # pylint: disable=arguments-differ | ||||||
|  |             ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format | ||||||
|  |             x = x.contiguous(memory_format=ctx.memory_format) | ||||||
|  |             b = b.contiguous() if b is not None else _null_tensor | ||||||
|  |             y = x | ||||||
|  |             if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: | ||||||
|  |                 y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) | ||||||
|  |             ctx.save_for_backward( | ||||||
|  |                 x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, | ||||||
|  |                 b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, | ||||||
|  |                 y if 'y' in spec.ref else _null_tensor) | ||||||
|  |             return y | ||||||
|  | 
 | ||||||
|  |         @staticmethod | ||||||
|  |         def backward(ctx, dy): # pylint: disable=arguments-differ | ||||||
|  |             dy = dy.contiguous(memory_format=ctx.memory_format) | ||||||
|  |             x, b, y = ctx.saved_tensors | ||||||
|  |             dx = None | ||||||
|  |             db = None | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: | ||||||
|  |                 dx = dy | ||||||
|  |                 if act != 'linear' or gain != 1 or clamp >= 0: | ||||||
|  |                     dx = BiasActCudaGrad.apply(dy, x, b, y) | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[1]: | ||||||
|  |                 db = dx.sum([i for i in range(dx.ndim) if i != dim]) | ||||||
|  | 
 | ||||||
|  |             return dx, db | ||||||
|  | 
 | ||||||
|  |     # Backward op. | ||||||
|  |     class BiasActCudaGrad(torch.autograd.Function): | ||||||
|  |         @staticmethod | ||||||
|  |         def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ | ||||||
|  |             ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format | ||||||
|  |             dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) | ||||||
|  |             ctx.save_for_backward( | ||||||
|  |                 dy if spec.has_2nd_grad else _null_tensor, | ||||||
|  |                 x, b, y) | ||||||
|  |             return dx | ||||||
|  | 
 | ||||||
|  |         @staticmethod | ||||||
|  |         def backward(ctx, d_dx): # pylint: disable=arguments-differ | ||||||
|  |             d_dx = d_dx.contiguous(memory_format=ctx.memory_format) | ||||||
|  |             dy, x, b, y = ctx.saved_tensors | ||||||
|  |             d_dy = None | ||||||
|  |             d_x = None | ||||||
|  |             d_b = None | ||||||
|  |             d_y = None | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[0]: | ||||||
|  |                 d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) | ||||||
|  | 
 | ||||||
|  |             if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): | ||||||
|  |                 d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) | ||||||
|  | 
 | ||||||
|  |             if spec.has_2nd_grad and ctx.needs_input_grad[2]: | ||||||
|  |                 d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) | ||||||
|  | 
 | ||||||
|  |             return d_dy, d_x, d_b, d_y | ||||||
|  | 
 | ||||||
|  |     # Add to cache. | ||||||
|  |     _bias_act_cuda_cache[key] = BiasActCuda | ||||||
|  |     return BiasActCuda | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										203
									
								
								models/stylegan3/torch_utils/ops/conv2d_gradfix.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										203
									
								
								models/stylegan3/torch_utils/ops/conv2d_gradfix.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,203 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Custom replacement for `torch.nn.functional.conv2d` that supports | ||||||
|  | arbitrarily high order gradients with zero performance penalty.""" | ||||||
|  | 
 | ||||||
|  | import contextlib | ||||||
|  | import torch | ||||||
|  | from pkg_resources import parse_version | ||||||
|  | 
 | ||||||
|  | # pylint: disable=redefined-builtin | ||||||
|  | # pylint: disable=arguments-differ | ||||||
|  | # pylint: disable=protected-access | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | enabled = False                     # Enable the custom op by setting this to true. | ||||||
|  | weight_gradients_disabled = False   # Forcefully disable computation of gradients with respect to the weights. | ||||||
|  | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 | ||||||
|  | 
 | ||||||
|  | @contextlib.contextmanager | ||||||
|  | def no_weight_gradients(disable=True): | ||||||
|  |     global weight_gradients_disabled | ||||||
|  |     old = weight_gradients_disabled | ||||||
|  |     if disable: | ||||||
|  |         weight_gradients_disabled = True | ||||||
|  |     yield | ||||||
|  |     weight_gradients_disabled = old | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): | ||||||
|  |     if _should_use_custom_op(input): | ||||||
|  |         return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) | ||||||
|  |     return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) | ||||||
|  | 
 | ||||||
|  | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): | ||||||
|  |     if _should_use_custom_op(input): | ||||||
|  |         return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) | ||||||
|  |     return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _should_use_custom_op(input): | ||||||
|  |     assert isinstance(input, torch.Tensor) | ||||||
|  |     if (not enabled) or (not torch.backends.cudnn.enabled): | ||||||
|  |         return False | ||||||
|  |     if _use_pytorch_1_11_api: | ||||||
|  |         # The work-around code doesn't work on PyTorch 1.11.0 onwards | ||||||
|  |         return False | ||||||
|  |     if input.device.type != 'cuda': | ||||||
|  |         return False | ||||||
|  |     return True | ||||||
|  | 
 | ||||||
|  | def _tuple_of_ints(xs, ndim): | ||||||
|  |     xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim | ||||||
|  |     assert len(xs) == ndim | ||||||
|  |     assert all(isinstance(x, int) for x in xs) | ||||||
|  |     return xs | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _conv2d_gradfix_cache = dict() | ||||||
|  | _null_tensor = torch.empty([0]) | ||||||
|  | 
 | ||||||
|  | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): | ||||||
|  |     # Parse arguments. | ||||||
|  |     ndim = 2 | ||||||
|  |     weight_shape = tuple(weight_shape) | ||||||
|  |     stride = _tuple_of_ints(stride, ndim) | ||||||
|  |     padding = _tuple_of_ints(padding, ndim) | ||||||
|  |     output_padding = _tuple_of_ints(output_padding, ndim) | ||||||
|  |     dilation = _tuple_of_ints(dilation, ndim) | ||||||
|  | 
 | ||||||
|  |     # Lookup from cache. | ||||||
|  |     key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) | ||||||
|  |     if key in _conv2d_gradfix_cache: | ||||||
|  |         return _conv2d_gradfix_cache[key] | ||||||
|  | 
 | ||||||
|  |     # Validate arguments. | ||||||
|  |     assert groups >= 1 | ||||||
|  |     assert len(weight_shape) == ndim + 2 | ||||||
|  |     assert all(stride[i] >= 1 for i in range(ndim)) | ||||||
|  |     assert all(padding[i] >= 0 for i in range(ndim)) | ||||||
|  |     assert all(dilation[i] >= 0 for i in range(ndim)) | ||||||
|  |     if not transpose: | ||||||
|  |         assert all(output_padding[i] == 0 for i in range(ndim)) | ||||||
|  |     else: # transpose | ||||||
|  |         assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) | ||||||
|  | 
 | ||||||
|  |     # Helpers. | ||||||
|  |     common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) | ||||||
|  |     def calc_output_padding(input_shape, output_shape): | ||||||
|  |         if transpose: | ||||||
|  |             return [0, 0] | ||||||
|  |         return [ | ||||||
|  |             input_shape[i + 2] | ||||||
|  |             - (output_shape[i + 2] - 1) * stride[i] | ||||||
|  |             - (1 - 2 * padding[i]) | ||||||
|  |             - dilation[i] * (weight_shape[i + 2] - 1) | ||||||
|  |             for i in range(ndim) | ||||||
|  |         ] | ||||||
|  | 
 | ||||||
|  |     # Forward & backward. | ||||||
|  |     class Conv2d(torch.autograd.Function): | ||||||
|  |         @staticmethod | ||||||
|  |         def forward(ctx, input, weight, bias): | ||||||
|  |             assert weight.shape == weight_shape | ||||||
|  |             ctx.save_for_backward( | ||||||
|  |                 input if weight.requires_grad else _null_tensor, | ||||||
|  |                 weight if input.requires_grad else _null_tensor, | ||||||
|  |             ) | ||||||
|  |             ctx.input_shape = input.shape | ||||||
|  | 
 | ||||||
|  |             # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). | ||||||
|  |             if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): | ||||||
|  |                 a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) | ||||||
|  |                 b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) | ||||||
|  |                 c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) | ||||||
|  |                 c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) | ||||||
|  |                 c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) | ||||||
|  |                 return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) | ||||||
|  | 
 | ||||||
|  |             # General case => cuDNN. | ||||||
|  |             if transpose: | ||||||
|  |                 return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) | ||||||
|  |             return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) | ||||||
|  | 
 | ||||||
|  |         @staticmethod | ||||||
|  |         def backward(ctx, grad_output): | ||||||
|  |             input, weight = ctx.saved_tensors | ||||||
|  |             input_shape = ctx.input_shape | ||||||
|  |             grad_input = None | ||||||
|  |             grad_weight = None | ||||||
|  |             grad_bias = None | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[0]: | ||||||
|  |                 p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) | ||||||
|  |                 op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) | ||||||
|  |                 grad_input = op.apply(grad_output, weight, None) | ||||||
|  |                 assert grad_input.shape == input_shape | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[1] and not weight_gradients_disabled: | ||||||
|  |                 grad_weight = Conv2dGradWeight.apply(grad_output, input) | ||||||
|  |                 assert grad_weight.shape == weight_shape | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[2]: | ||||||
|  |                 grad_bias = grad_output.sum([0, 2, 3]) | ||||||
|  | 
 | ||||||
|  |             return grad_input, grad_weight, grad_bias | ||||||
|  | 
 | ||||||
|  |     # Gradient with respect to the weights. | ||||||
|  |     class Conv2dGradWeight(torch.autograd.Function): | ||||||
|  |         @staticmethod | ||||||
|  |         def forward(ctx, grad_output, input): | ||||||
|  |             ctx.save_for_backward( | ||||||
|  |                 grad_output if input.requires_grad else _null_tensor, | ||||||
|  |                 input if grad_output.requires_grad else _null_tensor, | ||||||
|  |             ) | ||||||
|  |             ctx.grad_output_shape = grad_output.shape | ||||||
|  |             ctx.input_shape = input.shape | ||||||
|  | 
 | ||||||
|  |             # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). | ||||||
|  |             if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): | ||||||
|  |                 a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) | ||||||
|  |                 b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) | ||||||
|  |                 c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) | ||||||
|  |                 return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) | ||||||
|  | 
 | ||||||
|  |             # General case => cuDNN. | ||||||
|  |             name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' | ||||||
|  |             flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] | ||||||
|  |             return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) | ||||||
|  | 
 | ||||||
|  |         @staticmethod | ||||||
|  |         def backward(ctx, grad2_grad_weight): | ||||||
|  |             grad_output, input = ctx.saved_tensors | ||||||
|  |             grad_output_shape = ctx.grad_output_shape | ||||||
|  |             input_shape = ctx.input_shape | ||||||
|  |             grad2_grad_output = None | ||||||
|  |             grad2_input = None | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[0]: | ||||||
|  |                 grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) | ||||||
|  |                 assert grad2_grad_output.shape == grad_output_shape | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[1]: | ||||||
|  |                 p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) | ||||||
|  |                 op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) | ||||||
|  |                 grad2_input = op.apply(grad_output, grad2_grad_weight, None) | ||||||
|  |                 assert grad2_input.shape == input_shape | ||||||
|  | 
 | ||||||
|  |             return grad2_grad_output, grad2_input | ||||||
|  | 
 | ||||||
|  |     _conv2d_gradfix_cache[key] = Conv2d | ||||||
|  |     return Conv2d | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										143
									
								
								models/stylegan3/torch_utils/ops/conv2d_resample.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								models/stylegan3/torch_utils/ops/conv2d_resample.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,143 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """2D convolution with optional up/downsampling.""" | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | from .. import misc | ||||||
|  | from . import conv2d_gradfix | ||||||
|  | from . import upfirdn2d | ||||||
|  | from .upfirdn2d import _parse_padding | ||||||
|  | from .upfirdn2d import _get_filter_size | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _get_weight_shape(w): | ||||||
|  |     with misc.suppress_tracer_warnings(): # this value will be treated as a constant | ||||||
|  |         shape = [int(sz) for sz in w.shape] | ||||||
|  |     misc.assert_shape(w, shape) | ||||||
|  |     return shape | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): | ||||||
|  |     """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. | ||||||
|  |     """ | ||||||
|  |     _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) | ||||||
|  | 
 | ||||||
|  |     # Flip weight if requested. | ||||||
|  |     # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). | ||||||
|  |     if not flip_weight and (kw > 1 or kh > 1): | ||||||
|  |         w = w.flip([2, 3]) | ||||||
|  | 
 | ||||||
|  |     # Execute using conv2d_gradfix. | ||||||
|  |     op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d | ||||||
|  |     return op(x, w, stride=stride, padding=padding, groups=groups) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @misc.profiled_function | ||||||
|  | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): | ||||||
|  |     r"""2D convolution with optional up/downsampling. | ||||||
|  | 
 | ||||||
|  |     Padding is performed only once at the beginning, not between the operations. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x:              Input tensor of shape | ||||||
|  |                         `[batch_size, in_channels, in_height, in_width]`. | ||||||
|  |         w:              Weight tensor of shape | ||||||
|  |                         `[out_channels, in_channels//groups, kernel_height, kernel_width]`. | ||||||
|  |         f:              Low-pass filter for up/downsampling. Must be prepared beforehand by | ||||||
|  |                         calling upfirdn2d.setup_filter(). None = identity (default). | ||||||
|  |         up:             Integer upsampling factor (default: 1). | ||||||
|  |         down:           Integer downsampling factor (default: 1). | ||||||
|  |         padding:        Padding with respect to the upsampled image. Can be a single number | ||||||
|  |                         or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | ||||||
|  |                         (default: 0). | ||||||
|  |         groups:         Split input channels into N groups (default: 1). | ||||||
|  |         flip_weight:    False = convolution, True = correlation (default: True). | ||||||
|  |         flip_filter:    False = convolution, True = correlation (default: False). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | ||||||
|  |     """ | ||||||
|  |     # Validate arguments. | ||||||
|  |     assert isinstance(x, torch.Tensor) and (x.ndim == 4) | ||||||
|  |     assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) | ||||||
|  |     assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) | ||||||
|  |     assert isinstance(up, int) and (up >= 1) | ||||||
|  |     assert isinstance(down, int) and (down >= 1) | ||||||
|  |     assert isinstance(groups, int) and (groups >= 1) | ||||||
|  |     out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) | ||||||
|  |     fw, fh = _get_filter_size(f) | ||||||
|  |     px0, px1, py0, py1 = _parse_padding(padding) | ||||||
|  | 
 | ||||||
|  |     # Adjust padding to account for up/downsampling. | ||||||
|  |     if up > 1: | ||||||
|  |         px0 += (fw + up - 1) // 2 | ||||||
|  |         px1 += (fw - up) // 2 | ||||||
|  |         py0 += (fh + up - 1) // 2 | ||||||
|  |         py1 += (fh - up) // 2 | ||||||
|  |     if down > 1: | ||||||
|  |         px0 += (fw - down + 1) // 2 | ||||||
|  |         px1 += (fw - down) // 2 | ||||||
|  |         py0 += (fh - down + 1) // 2 | ||||||
|  |         py1 += (fh - down) // 2 | ||||||
|  | 
 | ||||||
|  |     # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. | ||||||
|  |     if kw == 1 and kh == 1 and (down > 1 and up == 1): | ||||||
|  |         x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) | ||||||
|  |         x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. | ||||||
|  |     if kw == 1 and kh == 1 and (up > 1 and down == 1): | ||||||
|  |         x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) | ||||||
|  |         x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     # Fast path: downsampling only => use strided convolution. | ||||||
|  |     if down > 1 and up == 1: | ||||||
|  |         x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) | ||||||
|  |         x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     # Fast path: upsampling with optional downsampling => use transpose strided convolution. | ||||||
|  |     if up > 1: | ||||||
|  |         if groups == 1: | ||||||
|  |             w = w.transpose(0, 1) | ||||||
|  |         else: | ||||||
|  |             w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) | ||||||
|  |             w = w.transpose(1, 2) | ||||||
|  |             w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) | ||||||
|  |         px0 -= kw - 1 | ||||||
|  |         px1 -= kw - up | ||||||
|  |         py0 -= kh - 1 | ||||||
|  |         py1 -= kh - up | ||||||
|  |         pxt = max(min(-px0, -px1), 0) | ||||||
|  |         pyt = max(min(-py0, -py1), 0) | ||||||
|  |         x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) | ||||||
|  |         x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) | ||||||
|  |         if down > 1: | ||||||
|  |             x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. | ||||||
|  |     if up == 1 and down == 1: | ||||||
|  |         if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: | ||||||
|  |             return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) | ||||||
|  | 
 | ||||||
|  |     # Fallback: Generic reference implementation. | ||||||
|  |     x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) | ||||||
|  |     x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) | ||||||
|  |     if down > 1: | ||||||
|  |         x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) | ||||||
|  |     return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										300
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										300
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,300 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property
 | ||||||
|  | // and proprietary rights in and to this software, related documentation
 | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or
 | ||||||
|  | // distribution of this software and related documentation without an express
 | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited.
 | ||||||
|  | 
 | ||||||
|  | #include <torch/extension.h> | ||||||
|  | #include <ATen/cuda/CUDAContext.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  | #include "filtered_lrelu.h" | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu( | ||||||
|  |     torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si, | ||||||
|  |     int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns) | ||||||
|  | { | ||||||
|  |     // Set CUDA device.
 | ||||||
|  |     TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); | ||||||
|  |     const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | ||||||
|  | 
 | ||||||
|  |     // Validate arguments.
 | ||||||
|  |     TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device"); | ||||||
|  |     TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32"); | ||||||
|  |     TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); | ||||||
|  |     TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32"); | ||||||
|  |     TORCH_CHECK(x.dim() == 4, "x must be rank 4"); | ||||||
|  |     TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); | ||||||
|  |     TORCH_CHECK(x.numel() > 0, "x is empty"); | ||||||
|  |     TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2"); | ||||||
|  |     TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large"); | ||||||
|  |     TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large"); | ||||||
|  |     TORCH_CHECK(fu.numel() > 0, "fu is empty"); | ||||||
|  |     TORCH_CHECK(fd.numel() > 0, "fd is empty"); | ||||||
|  |     TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x"); | ||||||
|  |     TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); | ||||||
|  | 
 | ||||||
|  |     // Figure out how much shared memory is available on the device.
 | ||||||
|  |     int maxSharedBytes = 0; | ||||||
|  |     AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index())); | ||||||
|  |     int sharedKB = maxSharedBytes >> 10; | ||||||
|  | 
 | ||||||
|  |     // Populate enough launch parameters to check if a CUDA kernel exists.
 | ||||||
|  |     filtered_lrelu_kernel_params p; | ||||||
|  |     p.up      = up; | ||||||
|  |     p.down    = down; | ||||||
|  |     p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
 | ||||||
|  |     p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); | ||||||
|  |     filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB); | ||||||
|  |     if (!test_spec.exec) | ||||||
|  |     { | ||||||
|  |         // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
 | ||||||
|  |         return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Input/output element size.
 | ||||||
|  |     int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; | ||||||
|  | 
 | ||||||
|  |     // Input sizes.
 | ||||||
|  |     int64_t xw = (int)x.size(3); | ||||||
|  |     int64_t xh = (int)x.size(2); | ||||||
|  |     int64_t fut_w = (int)fu.size(-1) - 1; | ||||||
|  |     int64_t fut_h = (int)fu.size(0)  - 1; | ||||||
|  |     int64_t fdt_w = (int)fd.size(-1) - 1; | ||||||
|  |     int64_t fdt_h = (int)fd.size(0)  - 1; | ||||||
|  | 
 | ||||||
|  |     // Logical size of upsampled buffer.
 | ||||||
|  |     int64_t cw = xw * up + (px0 + px1) - fut_w; | ||||||
|  |     int64_t ch = xh * up + (py0 + py1) - fut_h; | ||||||
|  |     TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter"); | ||||||
|  |     TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); | ||||||
|  | 
 | ||||||
|  |     // Compute output size and allocate.
 | ||||||
|  |     int64_t yw = (cw - fdt_w + (down - 1)) / down; | ||||||
|  |     int64_t yh = (ch - fdt_h + (down - 1)) / down; | ||||||
|  |     TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); | ||||||
|  |     TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); | ||||||
|  |     torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format()); | ||||||
|  | 
 | ||||||
|  |     // Allocate sign tensor.
 | ||||||
|  |     torch::Tensor so; | ||||||
|  |     torch::Tensor s = si; | ||||||
|  |     bool readSigns = !!s.numel(); | ||||||
|  |     int64_t sw_active = 0; // Active width of sign tensor.
 | ||||||
|  |     if (writeSigns) | ||||||
|  |     { | ||||||
|  |         sw_active = yw * down - (down - 1) + fdt_w;     // Active width in elements.
 | ||||||
|  |         int64_t sh = yh * down - (down - 1) + fdt_h;    // Height = active height.
 | ||||||
|  |         int64_t sw = (sw_active + 15) & ~15;            // Width  = active width in elements, rounded up to multiple of 16.
 | ||||||
|  |         TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); | ||||||
|  |         s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); | ||||||
|  |     } | ||||||
|  |     else if (readSigns) | ||||||
|  |         sw_active = s.size(3) << 2; | ||||||
|  | 
 | ||||||
|  |     // Validate sign tensor if in use.
 | ||||||
|  |     if (readSigns || writeSigns) | ||||||
|  |     { | ||||||
|  |         TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); | ||||||
|  |         TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); | ||||||
|  |         TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); | ||||||
|  |         TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); | ||||||
|  |         TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); | ||||||
|  |         TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large"); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Populate rest of CUDA kernel parameters.
 | ||||||
|  |     p.x         = x.data_ptr(); | ||||||
|  |     p.y         = y.data_ptr(); | ||||||
|  |     p.b         = b.data_ptr(); | ||||||
|  |     p.s         = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0; | ||||||
|  |     p.fu        = fu.data_ptr<float>(); | ||||||
|  |     p.fd        = fd.data_ptr<float>(); | ||||||
|  |     p.pad0      = make_int2(px0, py0); | ||||||
|  |     p.gain      = gain; | ||||||
|  |     p.slope     = slope; | ||||||
|  |     p.clamp     = clamp; | ||||||
|  |     p.flip      = (flip_filters) ? 1 : 0; | ||||||
|  |     p.xShape    = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); | ||||||
|  |     p.yShape    = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); | ||||||
|  |     p.sShape    = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
 | ||||||
|  |     p.sOfs      = make_int2(sx, sy); | ||||||
|  |     p.swLimit   = (sw_active + 3) >> 2; // Rounded up to bytes.
 | ||||||
|  | 
 | ||||||
|  |     // x, y, b strides are in bytes.
 | ||||||
|  |     p.xStride   = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0)); | ||||||
|  |     p.yStride   = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0)); | ||||||
|  |     p.bStride   = sz * b.stride(0); | ||||||
|  | 
 | ||||||
|  |     // fu, fd strides are in elements.
 | ||||||
|  |     p.fuStride  = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); | ||||||
|  |     p.fdStride  = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); | ||||||
|  | 
 | ||||||
|  |     // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
 | ||||||
|  |     bool index64b = false; | ||||||
|  |     if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; | ||||||
|  |     if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true; | ||||||
|  |     if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) >  INT_MAX) index64b = true; | ||||||
|  |     if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true; | ||||||
|  |     if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) >  INT_MAX) index64b = true; | ||||||
|  |     if (s.numel() > INT_MAX) index64b = true; | ||||||
|  | 
 | ||||||
|  |     // Choose CUDA kernel.
 | ||||||
|  |     filtered_lrelu_kernel_spec spec = { 0 }; | ||||||
|  |     AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&] | ||||||
|  |     { | ||||||
|  |         if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
 | ||||||
|  |         { | ||||||
|  |             // Choose kernel based on index type, datatype and sign read/write modes.
 | ||||||
|  |             if      (!index64b &&  writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true,  false>(p, sharedKB); | ||||||
|  |             else if (!index64b && !writeSigns &&  readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB); | ||||||
|  |             else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB); | ||||||
|  |             else if ( index64b &&  writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true,  false>(p, sharedKB); | ||||||
|  |             else if ( index64b && !writeSigns &&  readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB); | ||||||
|  |             else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB); | ||||||
|  |         } | ||||||
|  |     }); | ||||||
|  |     TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
 | ||||||
|  | 
 | ||||||
|  |     // Launch CUDA kernel.
 | ||||||
|  |     void* args[] = {&p}; | ||||||
|  |     int bx = spec.numWarps * 32; | ||||||
|  |     int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; | ||||||
|  |     int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; | ||||||
|  |     int gz = p.yShape.z * p.yShape.w; | ||||||
|  | 
 | ||||||
|  |     // Repeat multiple horizontal tiles in a CTA?
 | ||||||
|  |     if (spec.xrep) | ||||||
|  |     { | ||||||
|  |         p.tilesXrep = spec.xrep; | ||||||
|  |         p.tilesXdim = gx; | ||||||
|  | 
 | ||||||
|  |         gx = (gx + p.tilesXrep - 1) / p.tilesXrep; | ||||||
|  |         std::swap(gx, gy); | ||||||
|  |     } | ||||||
|  |     else | ||||||
|  |     { | ||||||
|  |         p.tilesXrep = 0; | ||||||
|  |         p.tilesXdim = 0; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Launch filter setup kernel.
 | ||||||
|  |     AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream())); | ||||||
|  | 
 | ||||||
|  |     // Copy kernels to constant memory.
 | ||||||
|  |     if      ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true,  false>(at::cuda::getCurrentCUDAStream()))); | ||||||
|  |     else if (!writeSigns &&  readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream()))); | ||||||
|  |     else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream()))); | ||||||
|  | 
 | ||||||
|  |     // Set cache and shared memory configurations for main kernel.
 | ||||||
|  |     AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); | ||||||
|  |     if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
 | ||||||
|  |         AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10)); | ||||||
|  |     AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); | ||||||
|  | 
 | ||||||
|  |     // Launch main kernel.
 | ||||||
|  |     const int maxSubGz = 65535; // CUDA maximum for block z dimension.
 | ||||||
|  |     for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
 | ||||||
|  |     { | ||||||
|  |         p.blockZofs = zofs; | ||||||
|  |         int subGz = std::min(maxSubGz, gz - zofs); | ||||||
|  |         AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream())); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Done.
 | ||||||
|  |     return std::make_tuple(y, so, 0); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns) | ||||||
|  | { | ||||||
|  |     // Set CUDA device.
 | ||||||
|  |     TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); | ||||||
|  |     const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | ||||||
|  | 
 | ||||||
|  |     // Validate arguments.
 | ||||||
|  |     TORCH_CHECK(x.dim() == 4, "x must be rank 4"); | ||||||
|  |     TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); | ||||||
|  |     TORCH_CHECK(x.numel() > 0, "x is empty"); | ||||||
|  |     TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64"); | ||||||
|  | 
 | ||||||
|  |     // Output signs if we don't have sign input.
 | ||||||
|  |     torch::Tensor so; | ||||||
|  |     torch::Tensor s = si; | ||||||
|  |     bool readSigns = !!s.numel(); | ||||||
|  |     if (writeSigns) | ||||||
|  |     { | ||||||
|  |         int64_t sw = x.size(3); | ||||||
|  |         sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
 | ||||||
|  |         s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Validate sign tensor if in use.
 | ||||||
|  |     if (readSigns || writeSigns) | ||||||
|  |     { | ||||||
|  |         TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); | ||||||
|  |         TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); | ||||||
|  |         TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); | ||||||
|  |         TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); | ||||||
|  |         TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); | ||||||
|  |         TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large"); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Initialize CUDA kernel parameters.
 | ||||||
|  |     filtered_lrelu_act_kernel_params p; | ||||||
|  |     p.x         = x.data_ptr(); | ||||||
|  |     p.s         = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0; | ||||||
|  |     p.gain      = gain; | ||||||
|  |     p.slope     = slope; | ||||||
|  |     p.clamp     = clamp; | ||||||
|  |     p.xShape    = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); | ||||||
|  |     p.xStride   = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); | ||||||
|  |     p.sShape    = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
 | ||||||
|  |     p.sOfs      = make_int2(sx, sy); | ||||||
|  | 
 | ||||||
|  |     // Choose CUDA kernel.
 | ||||||
|  |     void* func = 0; | ||||||
|  |     AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&] | ||||||
|  |     { | ||||||
|  |         if (writeSigns) | ||||||
|  |             func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>(); | ||||||
|  |         else if (readSigns) | ||||||
|  |             func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>(); | ||||||
|  |         else | ||||||
|  |             func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>(); | ||||||
|  |     }); | ||||||
|  |     TORCH_CHECK(func, "internal error - CUDA kernel not found"); | ||||||
|  | 
 | ||||||
|  |     // Launch CUDA kernel.
 | ||||||
|  |     void* args[] = {&p}; | ||||||
|  |     int bx = 128; // 4 warps per block.
 | ||||||
|  | 
 | ||||||
|  |     // Logical size of launch = writeSigns ? p.s : p.x
 | ||||||
|  |     uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; | ||||||
|  |     uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; | ||||||
|  |     uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
 | ||||||
|  |     gx = (gx - 1) / bx + 1; | ||||||
|  | 
 | ||||||
|  |     // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
 | ||||||
|  |     const uint32_t gmax = 65535; | ||||||
|  |     gy = std::min(gy, gmax); | ||||||
|  |     gz = std::min(gz, gmax); | ||||||
|  | 
 | ||||||
|  |     // Launch.
 | ||||||
|  |     AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream())); | ||||||
|  |     return so; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||||||
|  | { | ||||||
|  |     m.def("filtered_lrelu",      &filtered_lrelu);      // The whole thing.
 | ||||||
|  |     m.def("filtered_lrelu_act_", &filtered_lrelu_act);  // Activation and sign tensor handling only. Modifies data tensor in-place.
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
							
								
								
									
										1284
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1284
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu.cu
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										90
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,90 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property
 | ||||||
|  | // and proprietary rights in and to this software, related documentation
 | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or
 | ||||||
|  | // distribution of this software and related documentation without an express
 | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited.
 | ||||||
|  | 
 | ||||||
|  | #include <cuda_runtime.h> | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | // CUDA kernel parameters.
 | ||||||
|  | 
 | ||||||
|  | struct filtered_lrelu_kernel_params | ||||||
|  | { | ||||||
|  |     // These parameters decide which kernel to use.
 | ||||||
|  |     int             up;         // upsampling ratio (1, 2, 4)
 | ||||||
|  |     int             down;       // downsampling ratio (1, 2, 4)
 | ||||||
|  |     int2            fuShape;    // [size, 1] | [size, size]
 | ||||||
|  |     int2            fdShape;    // [size, 1] | [size, size]
 | ||||||
|  | 
 | ||||||
|  |     int             _dummy;     // Alignment.
 | ||||||
|  | 
 | ||||||
|  |     // Rest of the parameters.
 | ||||||
|  |     const void*     x;          // Input tensor.
 | ||||||
|  |     void*           y;          // Output tensor.
 | ||||||
|  |     const void*     b;          // Bias tensor.
 | ||||||
|  |     unsigned char*  s;          // Sign tensor in/out. NULL if unused.
 | ||||||
|  |     const float*    fu;         // Upsampling filter.
 | ||||||
|  |     const float*    fd;         // Downsampling filter.
 | ||||||
|  | 
 | ||||||
|  |     int2            pad0;       // Left/top padding.
 | ||||||
|  |     float           gain;       // Additional gain factor.
 | ||||||
|  |     float           slope;      // Leaky ReLU slope on negative side.
 | ||||||
|  |     float           clamp;      // Clamp after nonlinearity.
 | ||||||
|  |     int             flip;       // Filter kernel flip for gradient computation.
 | ||||||
|  | 
 | ||||||
|  |     int             tilesXdim;  // Original number of horizontal output tiles.
 | ||||||
|  |     int             tilesXrep;  // Number of horizontal tiles per CTA.
 | ||||||
|  |     int             blockZofs;  // Block z offset to support large minibatch, channel dimensions.
 | ||||||
|  | 
 | ||||||
|  |     int4            xShape;     // [width, height, channel, batch]
 | ||||||
|  |     int4            yShape;     // [width, height, channel, batch]
 | ||||||
|  |     int2            sShape;     // [width, height] - width is in bytes. Contiguous. Zeros if unused.
 | ||||||
|  |     int2            sOfs;       // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
 | ||||||
|  |     int             swLimit;    // Active width of sign tensor in bytes.
 | ||||||
|  | 
 | ||||||
|  |     longlong4       xStride;    // Strides of all tensors except signs, same component order as shapes.
 | ||||||
|  |     longlong4       yStride;    //
 | ||||||
|  |     int64_t         bStride;    //
 | ||||||
|  |     longlong3       fuStride;   //
 | ||||||
|  |     longlong3       fdStride;   //
 | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | struct filtered_lrelu_act_kernel_params | ||||||
|  | { | ||||||
|  |     void*           x;          // Input/output, modified in-place.
 | ||||||
|  |     unsigned char*  s;          // Sign tensor in/out. NULL if unused.
 | ||||||
|  | 
 | ||||||
|  |     float           gain;       // Additional gain factor.
 | ||||||
|  |     float           slope;      // Leaky ReLU slope on negative side.
 | ||||||
|  |     float           clamp;      // Clamp after nonlinearity.
 | ||||||
|  | 
 | ||||||
|  |     int4            xShape;     // [width, height, channel, batch]
 | ||||||
|  |     longlong4       xStride;    // Input/output tensor strides, same order as in shape.
 | ||||||
|  |     int2            sShape;     // [width, height] - width is in elements. Contiguous. Zeros if unused.
 | ||||||
|  |     int2            sOfs;       // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
 | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | // CUDA kernel specialization.
 | ||||||
|  | 
 | ||||||
|  | struct filtered_lrelu_kernel_spec | ||||||
|  | { | ||||||
|  |     void*   setup;              // Function for filter kernel setup.
 | ||||||
|  |     void*   exec;               // Function for main operation.
 | ||||||
|  |     int2    tileOut;            // Width/height of launch tile.
 | ||||||
|  |     int     numWarps;           // Number of warps per thread block, determines launch block size.
 | ||||||
|  |     int     xrep;               // For processing multiple horizontal tiles per thread block.
 | ||||||
|  |     int     dynamicSharedKB;    // How much dynamic shared memory the exec kernel wants.
 | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | // CUDA kernel selection.
 | ||||||
|  | 
 | ||||||
|  | template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); | ||||||
|  | template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void); | ||||||
|  | template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream); | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
							
								
								
									
										274
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										274
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,274 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | import os | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import warnings | ||||||
|  | 
 | ||||||
|  | from .. import custom_ops | ||||||
|  | from .. import misc | ||||||
|  | from . import upfirdn2d | ||||||
|  | from . import bias_act | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _plugin = None | ||||||
|  | 
 | ||||||
|  | def _init(): | ||||||
|  |     global _plugin | ||||||
|  |     if _plugin is None: | ||||||
|  |         _plugin = custom_ops.get_plugin( | ||||||
|  |             module_name='filtered_lrelu_plugin', | ||||||
|  |             sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'], | ||||||
|  |             headers=['filtered_lrelu.h', 'filtered_lrelu.cu'], | ||||||
|  |             source_dir=os.path.dirname(__file__), | ||||||
|  |             extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], | ||||||
|  |         ) | ||||||
|  |     return True | ||||||
|  | 
 | ||||||
|  | def _get_filter_size(f): | ||||||
|  |     if f is None: | ||||||
|  |         return 1, 1 | ||||||
|  |     assert isinstance(f, torch.Tensor) | ||||||
|  |     assert 1 <= f.ndim <= 2 | ||||||
|  |     return f.shape[-1], f.shape[0] # width, height | ||||||
|  | 
 | ||||||
|  | def _parse_padding(padding): | ||||||
|  |     if isinstance(padding, int): | ||||||
|  |         padding = [padding, padding] | ||||||
|  |     assert isinstance(padding, (list, tuple)) | ||||||
|  |     assert all(isinstance(x, (int, np.integer)) for x in padding) | ||||||
|  |     padding = [int(x) for x in padding] | ||||||
|  |     if len(padding) == 2: | ||||||
|  |         px, py = padding | ||||||
|  |         padding = [px, px, py, py] | ||||||
|  |     px0, px1, py0, py1 = padding | ||||||
|  |     return px0, px1, py0, py1 | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'): | ||||||
|  |     r"""Filtered leaky ReLU for a batch of 2D images. | ||||||
|  | 
 | ||||||
|  |     Performs the following sequence of operations for each channel: | ||||||
|  | 
 | ||||||
|  |     1. Add channel-specific bias if provided (`b`). | ||||||
|  | 
 | ||||||
|  |     2. Upsample the image by inserting N-1 zeros after each pixel (`up`). | ||||||
|  | 
 | ||||||
|  |     3. Pad the image with the specified number of zeros on each side (`padding`). | ||||||
|  |        Negative padding corresponds to cropping the image. | ||||||
|  | 
 | ||||||
|  |     4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it | ||||||
|  |        so that the footprint of all output pixels lies within the input image. | ||||||
|  | 
 | ||||||
|  |     5. Multiply each value by the provided gain factor (`gain`). | ||||||
|  | 
 | ||||||
|  |     6. Apply leaky ReLU activation function to each value. | ||||||
|  | 
 | ||||||
|  |     7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided. | ||||||
|  | 
 | ||||||
|  |     8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking | ||||||
|  |        it so that the footprint of all output pixels lies within the input image. | ||||||
|  | 
 | ||||||
|  |     9. Downsample the image by keeping every Nth pixel (`down`). | ||||||
|  | 
 | ||||||
|  |     The fused op is considerably more efficient than performing the same calculation | ||||||
|  |     using standard PyTorch ops. It supports gradients of arbitrary order. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x:           Float32/float16/float64 input tensor of the shape | ||||||
|  |                      `[batch_size, num_channels, in_height, in_width]`. | ||||||
|  |         fu:          Float32 upsampling FIR filter of the shape | ||||||
|  |                      `[filter_height, filter_width]` (non-separable), | ||||||
|  |                      `[filter_taps]` (separable), or | ||||||
|  |                      `None` (identity). | ||||||
|  |         fd:          Float32 downsampling FIR filter of the shape | ||||||
|  |                      `[filter_height, filter_width]` (non-separable), | ||||||
|  |                      `[filter_taps]` (separable), or | ||||||
|  |                      `None` (identity). | ||||||
|  |         b:           Bias vector, or `None` to disable. Must be a 1D tensor of the same type | ||||||
|  |                      as `x`. The length of vector must must match the channel dimension of `x`. | ||||||
|  |         up:          Integer upsampling factor (default: 1). | ||||||
|  |         down:        Integer downsampling factor. (default: 1). | ||||||
|  |         padding:     Padding with respect to the upsampled image. Can be a single number | ||||||
|  |                      or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | ||||||
|  |                      (default: 0). | ||||||
|  |         gain:        Overall scaling factor for signal magnitude (default: sqrt(2)). | ||||||
|  |         slope:       Slope on the negative side of leaky ReLU (default: 0.2). | ||||||
|  |         clamp:       Maximum magnitude for leaky ReLU output (default: None). | ||||||
|  |         flip_filter: False = convolution, True = correlation (default: False). | ||||||
|  |         impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | ||||||
|  |     """ | ||||||
|  |     assert isinstance(x, torch.Tensor) | ||||||
|  |     assert impl in ['ref', 'cuda'] | ||||||
|  |     if impl == 'cuda' and x.device.type == 'cuda' and _init(): | ||||||
|  |         return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0) | ||||||
|  |     return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @misc.profiled_function | ||||||
|  | def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): | ||||||
|  |     """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using | ||||||
|  |     existing `upfirdn2n()` and `bias_act()` ops. | ||||||
|  |     """ | ||||||
|  |     assert isinstance(x, torch.Tensor) and x.ndim == 4 | ||||||
|  |     fu_w, fu_h = _get_filter_size(fu) | ||||||
|  |     fd_w, fd_h = _get_filter_size(fd) | ||||||
|  |     if b is not None: | ||||||
|  |         assert isinstance(b, torch.Tensor) and b.dtype == x.dtype | ||||||
|  |         misc.assert_shape(b, [x.shape[1]]) | ||||||
|  |     assert isinstance(up, int) and up >= 1 | ||||||
|  |     assert isinstance(down, int) and down >= 1 | ||||||
|  |     px0, px1, py0, py1 = _parse_padding(padding) | ||||||
|  |     assert gain == float(gain) and gain > 0 | ||||||
|  |     assert slope == float(slope) and slope >= 0 | ||||||
|  |     assert clamp is None or (clamp == float(clamp) and clamp >= 0) | ||||||
|  | 
 | ||||||
|  |     # Calculate output size. | ||||||
|  |     batch_size, channels, in_h, in_w = x.shape | ||||||
|  |     in_dtype = x.dtype | ||||||
|  |     out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down | ||||||
|  |     out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down | ||||||
|  | 
 | ||||||
|  |     # Compute using existing ops. | ||||||
|  |     x = bias_act.bias_act(x=x, b=b) # Apply bias. | ||||||
|  |     x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. | ||||||
|  |     x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp. | ||||||
|  |     x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample. | ||||||
|  | 
 | ||||||
|  |     # Check output shape & dtype. | ||||||
|  |     misc.assert_shape(x, [batch_size, channels, out_h, out_w]) | ||||||
|  |     assert x.dtype == in_dtype | ||||||
|  |     return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _filtered_lrelu_cuda_cache = dict() | ||||||
|  | 
 | ||||||
|  | def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): | ||||||
|  |     """Fast CUDA implementation of `filtered_lrelu()` using custom ops. | ||||||
|  |     """ | ||||||
|  |     assert isinstance(up, int) and up >= 1 | ||||||
|  |     assert isinstance(down, int) and down >= 1 | ||||||
|  |     px0, px1, py0, py1 = _parse_padding(padding) | ||||||
|  |     assert gain == float(gain) and gain > 0 | ||||||
|  |     gain = float(gain) | ||||||
|  |     assert slope == float(slope) and slope >= 0 | ||||||
|  |     slope = float(slope) | ||||||
|  |     assert clamp is None or (clamp == float(clamp) and clamp >= 0) | ||||||
|  |     clamp = float(clamp if clamp is not None else 'inf') | ||||||
|  | 
 | ||||||
|  |     # Lookup from cache. | ||||||
|  |     key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) | ||||||
|  |     if key in _filtered_lrelu_cuda_cache: | ||||||
|  |         return _filtered_lrelu_cuda_cache[key] | ||||||
|  | 
 | ||||||
|  |     # Forward op. | ||||||
|  |     class FilteredLReluCuda(torch.autograd.Function): | ||||||
|  |         @staticmethod | ||||||
|  |         def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ | ||||||
|  |             assert isinstance(x, torch.Tensor) and x.ndim == 4 | ||||||
|  | 
 | ||||||
|  |             # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable). | ||||||
|  |             if fu is None: | ||||||
|  |                 fu = torch.ones([1, 1], dtype=torch.float32, device=x.device) | ||||||
|  |             if fd is None: | ||||||
|  |                 fd = torch.ones([1, 1], dtype=torch.float32, device=x.device) | ||||||
|  |             assert 1 <= fu.ndim <= 2 | ||||||
|  |             assert 1 <= fd.ndim <= 2 | ||||||
|  | 
 | ||||||
|  |             # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1. | ||||||
|  |             if up == 1 and fu.ndim == 1 and fu.shape[0] == 1: | ||||||
|  |                 fu = fu.square()[None] | ||||||
|  |             if down == 1 and fd.ndim == 1 and fd.shape[0] == 1: | ||||||
|  |                 fd = fd.square()[None] | ||||||
|  | 
 | ||||||
|  |             # Missing sign input tensor. | ||||||
|  |             if si is None: | ||||||
|  |                 si = torch.empty([0]) | ||||||
|  | 
 | ||||||
|  |             # Missing bias tensor. | ||||||
|  |             if b is None: | ||||||
|  |                 b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) | ||||||
|  | 
 | ||||||
|  |             # Construct internal sign tensor only if gradients are needed. | ||||||
|  |             write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad) | ||||||
|  | 
 | ||||||
|  |             # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout. | ||||||
|  |             strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1] | ||||||
|  |             if any(a < b for a, b in zip(strides[:-1], strides[1:])): | ||||||
|  |                 warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning) | ||||||
|  | 
 | ||||||
|  |             # Call C++/Cuda plugin if datatype is supported. | ||||||
|  |             if x.dtype in [torch.float16, torch.float32]: | ||||||
|  |                 if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): | ||||||
|  |                     warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning) | ||||||
|  |                 y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs) | ||||||
|  |             else: | ||||||
|  |                 return_code = -1 | ||||||
|  | 
 | ||||||
|  |             # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because | ||||||
|  |             # only the bit-packed sign tensor is retained for gradient computation. | ||||||
|  |             if return_code < 0: | ||||||
|  |                 warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning) | ||||||
|  | 
 | ||||||
|  |                 y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias. | ||||||
|  |                 y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. | ||||||
|  |                 so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place. | ||||||
|  |                 y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample. | ||||||
|  | 
 | ||||||
|  |             # Prepare for gradient computation. | ||||||
|  |             ctx.save_for_backward(fu, fd, (si if si.numel() else so)) | ||||||
|  |             ctx.x_shape = x.shape | ||||||
|  |             ctx.y_shape = y.shape | ||||||
|  |             ctx.s_ofs = sx, sy | ||||||
|  |             return y | ||||||
|  | 
 | ||||||
|  |         @staticmethod | ||||||
|  |         def backward(ctx, dy): # pylint: disable=arguments-differ | ||||||
|  |             fu, fd, si = ctx.saved_tensors | ||||||
|  |             _, _, xh, xw = ctx.x_shape | ||||||
|  |             _, _, yh, yw = ctx.y_shape | ||||||
|  |             sx, sy = ctx.s_ofs | ||||||
|  |             dx  = None # 0 | ||||||
|  |             dfu = None; assert not ctx.needs_input_grad[1] | ||||||
|  |             dfd = None; assert not ctx.needs_input_grad[2] | ||||||
|  |             db  = None # 3 | ||||||
|  |             dsi = None; assert not ctx.needs_input_grad[4] | ||||||
|  |             dsx = None; assert not ctx.needs_input_grad[5] | ||||||
|  |             dsy = None; assert not ctx.needs_input_grad[6] | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: | ||||||
|  |                 pp = [ | ||||||
|  |                     (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0, | ||||||
|  |                     xw * up - yw * down + px0 - (up - 1), | ||||||
|  |                     (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0, | ||||||
|  |                     xh * up - yh * down + py0 - (up - 1), | ||||||
|  |                 ] | ||||||
|  |                 gg = gain * (up ** 2) / (down ** 2) | ||||||
|  |                 ff = (not flip_filter) | ||||||
|  |                 sx = sx - (fu.shape[-1] - 1) + px0 | ||||||
|  |                 sy = sy - (fu.shape[0]  - 1) + py0 | ||||||
|  |                 dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy) | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[3]: | ||||||
|  |                 db = dx.sum([0, 2, 3]) | ||||||
|  | 
 | ||||||
|  |             return dx, dfu, dfd, db, dsi, dsx, dsy | ||||||
|  | 
 | ||||||
|  |     # Add to cache. | ||||||
|  |     _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda | ||||||
|  |     return FilteredLReluCuda | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										27
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu_ns.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu_ns.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | // | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | // and proprietary rights in and to this software, related documentation | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | // distribution of this software and related documentation without an express | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | #include "filtered_lrelu.cu" | ||||||
|  | 
 | ||||||
|  | // Template/kernel specializations for no signs mode (no gradients required). | ||||||
|  | 
 | ||||||
|  | // Full op, 32-bit indexing. | ||||||
|  | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB); | ||||||
|  | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float,     int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB); | ||||||
|  | 
 | ||||||
|  | // Full op, 64-bit indexing. | ||||||
|  | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB); | ||||||
|  | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float,     int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB); | ||||||
|  | 
 | ||||||
|  | // Activation/signs only for generic variant. 64-bit indexing. | ||||||
|  | template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void); | ||||||
|  | template void* choose_filtered_lrelu_act_kernel<float,     false, false>(void); | ||||||
|  | template void* choose_filtered_lrelu_act_kernel<double,    false, false>(void); | ||||||
|  | 
 | ||||||
|  | // Copy filters to constant memory. | ||||||
|  | template cudaError_t copy_filters<false, false>(cudaStream_t stream); | ||||||
							
								
								
									
										27
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu_rd.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu_rd.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | // | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | // and proprietary rights in and to this software, related documentation | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | // distribution of this software and related documentation without an express | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | #include "filtered_lrelu.cu" | ||||||
|  | 
 | ||||||
|  | // Template/kernel specializations for sign read mode. | ||||||
|  | 
 | ||||||
|  | // Full op, 32-bit indexing. | ||||||
|  | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB); | ||||||
|  | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float,     int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB); | ||||||
|  | 
 | ||||||
|  | // Full op, 64-bit indexing. | ||||||
|  | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB); | ||||||
|  | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float,     int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB); | ||||||
|  | 
 | ||||||
|  | // Activation/signs only for generic variant. 64-bit indexing. | ||||||
|  | template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void); | ||||||
|  | template void* choose_filtered_lrelu_act_kernel<float,     false, true>(void); | ||||||
|  | template void* choose_filtered_lrelu_act_kernel<double,    false, true>(void); | ||||||
|  | 
 | ||||||
|  | // Copy filters to constant memory. | ||||||
|  | template cudaError_t copy_filters<false, true>(cudaStream_t stream); | ||||||
							
								
								
									
										27
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu_wr.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								models/stylegan3/torch_utils/ops/filtered_lrelu_wr.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | // | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | // and proprietary rights in and to this software, related documentation | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | // distribution of this software and related documentation without an express | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | #include "filtered_lrelu.cu" | ||||||
|  | 
 | ||||||
|  | // Template/kernel specializations for sign write mode. | ||||||
|  | 
 | ||||||
|  | // Full op, 32-bit indexing. | ||||||
|  | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB); | ||||||
|  | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float,     int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB); | ||||||
|  | 
 | ||||||
|  | // Full op, 64-bit indexing. | ||||||
|  | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB); | ||||||
|  | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float,     int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB); | ||||||
|  | 
 | ||||||
|  | // Activation/signs only for generic variant. 64-bit indexing. | ||||||
|  | template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void); | ||||||
|  | template void* choose_filtered_lrelu_act_kernel<float,     true, false>(void); | ||||||
|  | template void* choose_filtered_lrelu_act_kernel<double,    true, false>(void); | ||||||
|  | 
 | ||||||
|  | // Copy filters to constant memory. | ||||||
|  | template cudaError_t copy_filters<true, false>(cudaStream_t stream); | ||||||
							
								
								
									
										60
									
								
								models/stylegan3/torch_utils/ops/fma.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								models/stylegan3/torch_utils/ops/fma.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,60 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def fma(a, b, c): # => a * b + c | ||||||
|  |     return _FusedMultiplyAdd.apply(a, b, c) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c | ||||||
|  |     @staticmethod | ||||||
|  |     def forward(ctx, a, b, c): # pylint: disable=arguments-differ | ||||||
|  |         out = torch.addcmul(c, a, b) | ||||||
|  |         ctx.save_for_backward(a, b) | ||||||
|  |         ctx.c_shape = c.shape | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def backward(ctx, dout): # pylint: disable=arguments-differ | ||||||
|  |         a, b = ctx.saved_tensors | ||||||
|  |         c_shape = ctx.c_shape | ||||||
|  |         da = None | ||||||
|  |         db = None | ||||||
|  |         dc = None | ||||||
|  | 
 | ||||||
|  |         if ctx.needs_input_grad[0]: | ||||||
|  |             da = _unbroadcast(dout * b, a.shape) | ||||||
|  | 
 | ||||||
|  |         if ctx.needs_input_grad[1]: | ||||||
|  |             db = _unbroadcast(dout * a, b.shape) | ||||||
|  | 
 | ||||||
|  |         if ctx.needs_input_grad[2]: | ||||||
|  |             dc = _unbroadcast(dout, c_shape) | ||||||
|  | 
 | ||||||
|  |         return da, db, dc | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _unbroadcast(x, shape): | ||||||
|  |     extra_dims = x.ndim - len(shape) | ||||||
|  |     assert extra_dims >= 0 | ||||||
|  |     dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] | ||||||
|  |     if len(dim): | ||||||
|  |         x = x.sum(dim=dim, keepdim=True) | ||||||
|  |     if extra_dims: | ||||||
|  |         x = x.reshape(-1, *x.shape[extra_dims+1:]) | ||||||
|  |     assert x.shape == shape | ||||||
|  |     return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										86
									
								
								models/stylegan3/torch_utils/ops/grid_sample_gradfix.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								models/stylegan3/torch_utils/ops/grid_sample_gradfix.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,86 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Custom replacement for `torch.nn.functional.grid_sample` that | ||||||
|  | supports arbitrarily high order gradients between the input and output. | ||||||
|  | Only works on 2D images and assumes | ||||||
|  | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | from pkg_resources import parse_version | ||||||
|  | 
 | ||||||
|  | # pylint: disable=redefined-builtin | ||||||
|  | # pylint: disable=arguments-differ | ||||||
|  | # pylint: disable=protected-access | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | enabled = False  # Enable the custom op by setting this to true. | ||||||
|  | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 | ||||||
|  | _use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12 | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def grid_sample(input, grid): | ||||||
|  |     if _should_use_custom_op(): | ||||||
|  |         return _GridSample2dForward.apply(input, grid) | ||||||
|  |     return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _should_use_custom_op(): | ||||||
|  |     return enabled | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | class _GridSample2dForward(torch.autograd.Function): | ||||||
|  |     @staticmethod | ||||||
|  |     def forward(ctx, input, grid): | ||||||
|  |         assert input.ndim == 4 | ||||||
|  |         assert grid.ndim == 4 | ||||||
|  |         output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) | ||||||
|  |         ctx.save_for_backward(input, grid) | ||||||
|  |         return output | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def backward(ctx, grad_output): | ||||||
|  |         input, grid = ctx.saved_tensors | ||||||
|  |         grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) | ||||||
|  |         return grad_input, grad_grid | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | class _GridSample2dBackward(torch.autograd.Function): | ||||||
|  |     @staticmethod | ||||||
|  |     def forward(ctx, grad_output, input, grid): | ||||||
|  |         op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') | ||||||
|  |         if _use_pytorch_1_12_api: | ||||||
|  |             op = op[0] | ||||||
|  |         if _use_pytorch_1_11_api: | ||||||
|  |             output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) | ||||||
|  |             grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) | ||||||
|  |         else: | ||||||
|  |             grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) | ||||||
|  |         ctx.save_for_backward(grid) | ||||||
|  |         return grad_input, grad_grid | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def backward(ctx, grad2_grad_input, grad2_grad_grid): | ||||||
|  |         _ = grad2_grad_grid # unused | ||||||
|  |         grid, = ctx.saved_tensors | ||||||
|  |         grad2_grad_output = None | ||||||
|  |         grad2_input = None | ||||||
|  |         grad2_grid = None | ||||||
|  | 
 | ||||||
|  |         if ctx.needs_input_grad[0]: | ||||||
|  |             grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) | ||||||
|  | 
 | ||||||
|  |         assert not ctx.needs_input_grad[2] | ||||||
|  |         return grad2_grad_output, grad2_input, grad2_grid | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										107
									
								
								models/stylegan3/torch_utils/ops/upfirdn2d.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								models/stylegan3/torch_utils/ops/upfirdn2d.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,107 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property
 | ||||||
|  | // and proprietary rights in and to this software, related documentation
 | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or
 | ||||||
|  | // distribution of this software and related documentation without an express
 | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited.
 | ||||||
|  | 
 | ||||||
|  | #include <torch/extension.h> | ||||||
|  | #include <ATen/cuda/CUDAContext.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  | #include "upfirdn2d.h" | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) | ||||||
|  | { | ||||||
|  |     // Validate arguments.
 | ||||||
|  |     TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); | ||||||
|  |     TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); | ||||||
|  |     TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); | ||||||
|  |     TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); | ||||||
|  |     TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); | ||||||
|  |     TORCH_CHECK(x.numel() > 0, "x has zero size"); | ||||||
|  |     TORCH_CHECK(f.numel() > 0, "f has zero size"); | ||||||
|  |     TORCH_CHECK(x.dim() == 4, "x must be rank 4"); | ||||||
|  |     TORCH_CHECK(f.dim() == 2, "f must be rank 2"); | ||||||
|  |     TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); | ||||||
|  |     TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); | ||||||
|  |     TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); | ||||||
|  |     TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); | ||||||
|  | 
 | ||||||
|  |     // Create output tensor.
 | ||||||
|  |     const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | ||||||
|  |     int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; | ||||||
|  |     int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; | ||||||
|  |     TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); | ||||||
|  |     torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); | ||||||
|  |     TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); | ||||||
|  |     TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); | ||||||
|  | 
 | ||||||
|  |     // Initialize CUDA kernel parameters.
 | ||||||
|  |     upfirdn2d_kernel_params p; | ||||||
|  |     p.x             = x.data_ptr(); | ||||||
|  |     p.f             = f.data_ptr<float>(); | ||||||
|  |     p.y             = y.data_ptr(); | ||||||
|  |     p.up            = make_int2(upx, upy); | ||||||
|  |     p.down          = make_int2(downx, downy); | ||||||
|  |     p.pad0          = make_int2(padx0, pady0); | ||||||
|  |     p.flip          = (flip) ? 1 : 0; | ||||||
|  |     p.gain          = gain; | ||||||
|  |     p.inSize        = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); | ||||||
|  |     p.inStride      = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); | ||||||
|  |     p.filterSize    = make_int2((int)f.size(1), (int)f.size(0)); | ||||||
|  |     p.filterStride  = make_int2((int)f.stride(1), (int)f.stride(0)); | ||||||
|  |     p.outSize       = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); | ||||||
|  |     p.outStride     = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); | ||||||
|  |     p.sizeMajor     = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; | ||||||
|  |     p.sizeMinor     = (p.inStride.z == 1) ? p.inSize.z : 1; | ||||||
|  | 
 | ||||||
|  |     // Choose CUDA kernel.
 | ||||||
|  |     upfirdn2d_kernel_spec spec; | ||||||
|  |     AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] | ||||||
|  |     { | ||||||
|  |         spec = choose_upfirdn2d_kernel<scalar_t>(p); | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     // Set looping options.
 | ||||||
|  |     p.loopMajor     = (p.sizeMajor - 1) / 16384 + 1; | ||||||
|  |     p.loopMinor     = spec.loopMinor; | ||||||
|  |     p.loopX         = spec.loopX; | ||||||
|  |     p.launchMinor   = (p.sizeMinor - 1) / p.loopMinor + 1; | ||||||
|  |     p.launchMajor   = (p.sizeMajor - 1) / p.loopMajor + 1; | ||||||
|  | 
 | ||||||
|  |     // Compute grid size.
 | ||||||
|  |     dim3 blockSize, gridSize; | ||||||
|  |     if (spec.tileOutW < 0) // large
 | ||||||
|  |     { | ||||||
|  |         blockSize = dim3(4, 32, 1); | ||||||
|  |         gridSize = dim3( | ||||||
|  |             ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, | ||||||
|  |             (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, | ||||||
|  |             p.launchMajor); | ||||||
|  |     } | ||||||
|  |     else // small
 | ||||||
|  |     { | ||||||
|  |         blockSize = dim3(256, 1, 1); | ||||||
|  |         gridSize = dim3( | ||||||
|  |             ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, | ||||||
|  |             (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, | ||||||
|  |             p.launchMajor); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Launch CUDA kernel.
 | ||||||
|  |     void* args[] = {&p}; | ||||||
|  |     AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); | ||||||
|  |     return y; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||||||
|  | { | ||||||
|  |     m.def("upfirdn2d", &upfirdn2d); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
							
								
								
									
										384
									
								
								models/stylegan3/torch_utils/ops/upfirdn2d.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										384
									
								
								models/stylegan3/torch_utils/ops/upfirdn2d.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,384 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | // | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | // and proprietary rights in and to this software, related documentation | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | // distribution of this software and related documentation without an express | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | #include <c10/util/Half.h> | ||||||
|  | #include "upfirdn2d.h" | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // Helpers. | ||||||
|  | 
 | ||||||
|  | template <class T> struct InternalType; | ||||||
|  | template <> struct InternalType<double>     { typedef double scalar_t; }; | ||||||
|  | template <> struct InternalType<float>      { typedef float  scalar_t; }; | ||||||
|  | template <> struct InternalType<c10::Half>  { typedef float  scalar_t; }; | ||||||
|  | 
 | ||||||
|  | static __device__ __forceinline__ int floor_div(int a, int b) | ||||||
|  | { | ||||||
|  |     int t = 1 - a / b; | ||||||
|  |     return (a + t * b) / b - t; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // Generic CUDA implementation for large filters. | ||||||
|  | 
 | ||||||
|  | template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) | ||||||
|  | { | ||||||
|  |     typedef typename InternalType<T>::scalar_t scalar_t; | ||||||
|  | 
 | ||||||
|  |     // Calculate thread index. | ||||||
|  |     int minorBase = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |     int outY = minorBase / p.launchMinor; | ||||||
|  |     minorBase -= outY * p.launchMinor; | ||||||
|  |     int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; | ||||||
|  |     int majorBase = blockIdx.z * p.loopMajor; | ||||||
|  |     if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) | ||||||
|  |         return; | ||||||
|  | 
 | ||||||
|  |     // Setup Y receptive field. | ||||||
|  |     int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; | ||||||
|  |     int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); | ||||||
|  |     int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; | ||||||
|  |     int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; | ||||||
|  |     if (p.flip) | ||||||
|  |         filterY = p.filterSize.y - 1 - filterY; | ||||||
|  | 
 | ||||||
|  |     // Loop over major, minor, and X. | ||||||
|  |     for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) | ||||||
|  |     for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) | ||||||
|  |     { | ||||||
|  |         int nc = major * p.sizeMinor + minor; | ||||||
|  |         int n = nc / p.inSize.z; | ||||||
|  |         int c = nc - n * p.inSize.z; | ||||||
|  |         for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) | ||||||
|  |         { | ||||||
|  |             // Setup X receptive field. | ||||||
|  |             int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; | ||||||
|  |             int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); | ||||||
|  |             int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; | ||||||
|  |             int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; | ||||||
|  |             if (p.flip) | ||||||
|  |                 filterX = p.filterSize.x - 1 - filterX; | ||||||
|  | 
 | ||||||
|  |             // Initialize pointers. | ||||||
|  |             const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; | ||||||
|  |             const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; | ||||||
|  |             int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; | ||||||
|  |             int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; | ||||||
|  | 
 | ||||||
|  |             // Inner loop. | ||||||
|  |             scalar_t v = 0; | ||||||
|  |             for (int y = 0; y < h; y++) | ||||||
|  |             { | ||||||
|  |                 for (int x = 0; x < w; x++) | ||||||
|  |                 { | ||||||
|  |                     v += (scalar_t)(*xp) * (scalar_t)(*fp); | ||||||
|  |                     xp += p.inStride.x; | ||||||
|  |                     fp += filterStepX; | ||||||
|  |                 } | ||||||
|  |                 xp += p.inStride.y - w * p.inStride.x; | ||||||
|  |                 fp += filterStepY - w * filterStepX; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Store result. | ||||||
|  |             v *= p.gain; | ||||||
|  |             ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // Specialized CUDA implementation for small filters. | ||||||
|  | 
 | ||||||
|  | template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor> | ||||||
|  | static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) | ||||||
|  | { | ||||||
|  |     typedef typename InternalType<T>::scalar_t scalar_t; | ||||||
|  |     const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; | ||||||
|  |     const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; | ||||||
|  |     __shared__ volatile scalar_t sf[filterH][filterW]; | ||||||
|  |     __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; | ||||||
|  | 
 | ||||||
|  |     // Calculate tile index. | ||||||
|  |     int minorBase = blockIdx.x; | ||||||
|  |     int tileOutY = minorBase / p.launchMinor; | ||||||
|  |     minorBase -= tileOutY * p.launchMinor; | ||||||
|  |     minorBase *= loopMinor; | ||||||
|  |     tileOutY *= tileOutH; | ||||||
|  |     int tileOutXBase = blockIdx.y * p.loopX * tileOutW; | ||||||
|  |     int majorBase = blockIdx.z * p.loopMajor; | ||||||
|  |     if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) | ||||||
|  |         return; | ||||||
|  | 
 | ||||||
|  |     // Load filter (flipped). | ||||||
|  |     for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) | ||||||
|  |     { | ||||||
|  |         int fy = tapIdx / filterW; | ||||||
|  |         int fx = tapIdx - fy * filterW; | ||||||
|  |         scalar_t v = 0; | ||||||
|  |         if (fx < p.filterSize.x & fy < p.filterSize.y) | ||||||
|  |         { | ||||||
|  |             int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; | ||||||
|  |             int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; | ||||||
|  |             v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; | ||||||
|  |         } | ||||||
|  |         sf[fy][fx] = v; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Loop over major and X. | ||||||
|  |     for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) | ||||||
|  |     { | ||||||
|  |         int baseNC = major * p.sizeMinor + minorBase; | ||||||
|  |         int n = baseNC / p.inSize.z; | ||||||
|  |         int baseC = baseNC - n * p.inSize.z; | ||||||
|  |         for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) | ||||||
|  |         { | ||||||
|  |             // Load input pixels. | ||||||
|  |             int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; | ||||||
|  |             int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; | ||||||
|  |             int tileInX = floor_div(tileMidX, upx); | ||||||
|  |             int tileInY = floor_div(tileMidY, upy); | ||||||
|  |             __syncthreads(); | ||||||
|  |             for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) | ||||||
|  |             { | ||||||
|  |                 int relC = inIdx; | ||||||
|  |                 int relInX = relC / loopMinor; | ||||||
|  |                 int relInY = relInX / tileInW; | ||||||
|  |                 relC -= relInX * loopMinor; | ||||||
|  |                 relInX -= relInY * tileInW; | ||||||
|  |                 int c = baseC + relC; | ||||||
|  |                 int inX = tileInX + relInX; | ||||||
|  |                 int inY = tileInY + relInY; | ||||||
|  |                 scalar_t v = 0; | ||||||
|  |                 if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) | ||||||
|  |                     v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; | ||||||
|  |                 sx[relInY][relInX][relC] = v; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Loop over output pixels. | ||||||
|  |             __syncthreads(); | ||||||
|  |             for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) | ||||||
|  |             { | ||||||
|  |                 int relC = outIdx; | ||||||
|  |                 int relOutX = relC / loopMinor; | ||||||
|  |                 int relOutY = relOutX / tileOutW; | ||||||
|  |                 relC -= relOutX * loopMinor; | ||||||
|  |                 relOutX -= relOutY * tileOutW; | ||||||
|  |                 int c = baseC + relC; | ||||||
|  |                 int outX = tileOutX + relOutX; | ||||||
|  |                 int outY = tileOutY + relOutY; | ||||||
|  | 
 | ||||||
|  |                 // Setup receptive field. | ||||||
|  |                 int midX = tileMidX + relOutX * downx; | ||||||
|  |                 int midY = tileMidY + relOutY * downy; | ||||||
|  |                 int inX = floor_div(midX, upx); | ||||||
|  |                 int inY = floor_div(midY, upy); | ||||||
|  |                 int relInX = inX - tileInX; | ||||||
|  |                 int relInY = inY - tileInY; | ||||||
|  |                 int filterX = (inX + 1) * upx - midX - 1; // flipped | ||||||
|  |                 int filterY = (inY + 1) * upy - midY - 1; // flipped | ||||||
|  | 
 | ||||||
|  |                 // Inner loop. | ||||||
|  |                 if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) | ||||||
|  |                 { | ||||||
|  |                     scalar_t v = 0; | ||||||
|  |                     #pragma unroll | ||||||
|  |                     for (int y = 0; y < filterH / upy; y++) | ||||||
|  |                         #pragma unroll | ||||||
|  |                         for (int x = 0; x < filterW / upx; x++) | ||||||
|  |                             v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; | ||||||
|  |                     v *= p.gain; | ||||||
|  |                     ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // CUDA kernel selection. | ||||||
|  | 
 | ||||||
|  | template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) | ||||||
|  | { | ||||||
|  |     int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; | ||||||
|  |     upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous | ||||||
|  |     if (s == 1)           spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last | ||||||
|  | 
 | ||||||
|  |     // No up/downsampling. | ||||||
|  |     if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) | ||||||
|  |     { | ||||||
|  |         // contiguous | ||||||
|  |         if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 7  && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7,   64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6,   64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 5  && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5,   64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,   64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 3  && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3,   64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1,  128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1,  128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1,   128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24,  32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16,  32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8,   32,32,1>, 32,32,1, 1}; | ||||||
|  |         // channels_last | ||||||
|  |         if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>,  32,32,1,  1}; | ||||||
|  |         if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>,  32,32,1,  1}; | ||||||
|  |         if (s == 1 && fx <= 7  && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7,   16,16,8>,  16,16,8,  1}; | ||||||
|  |         if (s == 1 && fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6,   16,16,8>,  16,16,8,  1}; | ||||||
|  |         if (s == 1 && fx <= 5  && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5,   16,16,8>,  16,16,8,  1}; | ||||||
|  |         if (s == 1 && fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,   16,16,8>,  16,16,8,  1}; | ||||||
|  |         if (s == 1 && fx <= 3  && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3,   16,16,8>,  16,16,8,  1}; | ||||||
|  |         if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1,  128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1,  128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (s == 1 && fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1,   128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (s == 1 && fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24,  1,128,16>, 1,128,16, 1}; | ||||||
|  |         if (s == 1 && fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16,  1,128,16>, 1,128,16, 1}; | ||||||
|  |         if (s == 1 && fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8,   1,128,16>, 1,128,16, 1}; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // 2x upsampling. | ||||||
|  |     if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) | ||||||
|  |     { | ||||||
|  |         // contiguous | ||||||
|  |         if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8,   64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6,   64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4,   64,16,1>, 64,16,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2,   64,16,1>, 64,16,1, 1}; | ||||||
|  |         // channels_last | ||||||
|  |         if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (s == 1 && fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8,   16,16,8>, 16,16,8, 1}; | ||||||
|  |         if (s == 1 && fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6,   16,16,8>, 16,16,8, 1}; | ||||||
|  |         if (s == 1 && fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4,   16,16,8>, 16,16,8, 1}; | ||||||
|  |         if (s == 1 && fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2,   16,16,8>, 16,16,8, 1}; | ||||||
|  |     } | ||||||
|  |     if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) | ||||||
|  |     { | ||||||
|  |         // contiguous | ||||||
|  |         if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 8  && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1,  128,8,1>, 128,8,1, 1}; | ||||||
|  |         // channels_last | ||||||
|  |         if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (s == 1 && fx <= 8  && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1,  128,1,16>, 128,1,16, 1}; | ||||||
|  |     } | ||||||
|  |     if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) | ||||||
|  |     { | ||||||
|  |         // contiguous | ||||||
|  |         if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8,  32,32,1>, 32,32,1, 1}; | ||||||
|  |         // channels_last | ||||||
|  |         if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1}; | ||||||
|  |         if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1}; | ||||||
|  |         if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8,  1,128,16>, 1,128,16, 1}; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // 2x downsampling. | ||||||
|  |     if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) | ||||||
|  |     { | ||||||
|  |         // contiguous | ||||||
|  |         if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8,   32,8,1>,  32,8,1,  1}; | ||||||
|  |         if (s != 1 && fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6,   32,8,1>,  32,8,1,  1}; | ||||||
|  |         if (s != 1 && fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4,   32,8,1>,  32,8,1,  1}; | ||||||
|  |         if (s != 1 && fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2,   32,8,1>,  32,8,1,  1}; | ||||||
|  |         // channels_last | ||||||
|  |         if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1}; | ||||||
|  |         if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1}; | ||||||
|  |         if (s == 1 && fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8,   8,8,8>,   8,8,8,   1}; | ||||||
|  |         if (s == 1 && fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6,   8,8,8>,   8,8,8,   1}; | ||||||
|  |         if (s == 1 && fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4,   8,8,8>,   8,8,8,   1}; | ||||||
|  |         if (s == 1 && fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2,   8,8,8>,   8,8,8,   1}; | ||||||
|  |     } | ||||||
|  |     if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) | ||||||
|  |     { | ||||||
|  |         // contiguous | ||||||
|  |         if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 8  && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1,  64,8,1>, 64,8,1, 1}; | ||||||
|  |         // channels_last | ||||||
|  |         if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1}; | ||||||
|  |         if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1}; | ||||||
|  |         if (s == 1 && fx <= 8  && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1,  64,1,8>, 64,1,8, 1}; | ||||||
|  |     } | ||||||
|  |     if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) | ||||||
|  |     { | ||||||
|  |         // contiguous | ||||||
|  |         if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8,  32,16,1>, 32,16,1, 1}; | ||||||
|  |         // channels_last | ||||||
|  |         if (s == 1 && fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1}; | ||||||
|  |         if (s == 1 && fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1}; | ||||||
|  |         if (s == 1 && fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8,  1,64,8>, 1,64,8, 1}; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // 4x upsampling. | ||||||
|  |     if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) | ||||||
|  |     { | ||||||
|  |         // contiguous | ||||||
|  |         if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1}; | ||||||
|  |         // channels_last | ||||||
|  |         if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1}; | ||||||
|  |     } | ||||||
|  |     if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) | ||||||
|  |     { | ||||||
|  |         // contiguous | ||||||
|  |         if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1}; | ||||||
|  |         // channels_last | ||||||
|  |         if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1}; | ||||||
|  |         if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1}; | ||||||
|  |     } | ||||||
|  |     if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) | ||||||
|  |     { | ||||||
|  |         // contiguous | ||||||
|  |         if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1}; | ||||||
|  |         // channels_last | ||||||
|  |         if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1}; | ||||||
|  |         if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1}; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // 4x downsampling (inefficient). | ||||||
|  |     if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) | ||||||
|  |     { | ||||||
|  |         // contiguous | ||||||
|  |         if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1}; | ||||||
|  |         // channels_last | ||||||
|  |         if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1}; | ||||||
|  |         if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1}; | ||||||
|  |     } | ||||||
|  |     if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) | ||||||
|  |     { | ||||||
|  |         // contiguous | ||||||
|  |         if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1}; | ||||||
|  |         if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1}; | ||||||
|  |         // channels_last | ||||||
|  |         if (s == 1 && fx <= 1  && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1}; | ||||||
|  |         if (s == 1 && fx <= 1  && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1}; | ||||||
|  |     } | ||||||
|  |     return spec; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
|  | // Template specializations. | ||||||
|  | 
 | ||||||
|  | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double>   (const upfirdn2d_kernel_params& p); | ||||||
|  | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float>    (const upfirdn2d_kernel_params& p); | ||||||
|  | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p); | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------ | ||||||
							
								
								
									
										59
									
								
								models/stylegan3/torch_utils/ops/upfirdn2d.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								models/stylegan3/torch_utils/ops/upfirdn2d.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,59 @@ | |||||||
|  | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 | ||||||
|  | //
 | ||||||
|  | // NVIDIA CORPORATION and its licensors retain all intellectual property
 | ||||||
|  | // and proprietary rights in and to this software, related documentation
 | ||||||
|  | // and any modifications thereto.  Any use, reproduction, disclosure or
 | ||||||
|  | // distribution of this software and related documentation without an express
 | ||||||
|  | // license agreement from NVIDIA CORPORATION is strictly prohibited.
 | ||||||
|  | 
 | ||||||
|  | #include <cuda_runtime.h> | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | // CUDA kernel parameters.
 | ||||||
|  | 
 | ||||||
|  | struct upfirdn2d_kernel_params | ||||||
|  | { | ||||||
|  |     const void*     x; | ||||||
|  |     const float*    f; | ||||||
|  |     void*           y; | ||||||
|  | 
 | ||||||
|  |     int2            up; | ||||||
|  |     int2            down; | ||||||
|  |     int2            pad0; | ||||||
|  |     int             flip; | ||||||
|  |     float           gain; | ||||||
|  | 
 | ||||||
|  |     int4            inSize;         // [width, height, channel, batch]
 | ||||||
|  |     int4            inStride; | ||||||
|  |     int2            filterSize;     // [width, height]
 | ||||||
|  |     int2            filterStride; | ||||||
|  |     int4            outSize;        // [width, height, channel, batch]
 | ||||||
|  |     int4            outStride; | ||||||
|  |     int             sizeMinor; | ||||||
|  |     int             sizeMajor; | ||||||
|  | 
 | ||||||
|  |     int             loopMinor; | ||||||
|  |     int             loopMajor; | ||||||
|  |     int             loopX; | ||||||
|  |     int             launchMinor; | ||||||
|  |     int             launchMajor; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | // CUDA kernel specialization.
 | ||||||
|  | 
 | ||||||
|  | struct upfirdn2d_kernel_spec | ||||||
|  | { | ||||||
|  |     void*   kernel; | ||||||
|  |     int     tileOutW; | ||||||
|  |     int     tileOutH; | ||||||
|  |     int     loopMinor; | ||||||
|  |     int     loopX; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
|  | // CUDA kernel selection.
 | ||||||
|  | 
 | ||||||
|  | template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); | ||||||
|  | 
 | ||||||
|  | //------------------------------------------------------------------------
 | ||||||
							
								
								
									
										389
									
								
								models/stylegan3/torch_utils/ops/upfirdn2d.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										389
									
								
								models/stylegan3/torch_utils/ops/upfirdn2d.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,389 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Custom PyTorch ops for efficient resampling of 2D images.""" | ||||||
|  | 
 | ||||||
|  | import os | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | from .. import custom_ops | ||||||
|  | from .. import misc | ||||||
|  | from . import conv2d_gradfix | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _plugin = None | ||||||
|  | 
 | ||||||
|  | def _init(): | ||||||
|  |     global _plugin | ||||||
|  |     if _plugin is None: | ||||||
|  |         _plugin = custom_ops.get_plugin( | ||||||
|  |             module_name='upfirdn2d_plugin', | ||||||
|  |             sources=['upfirdn2d.cpp', 'upfirdn2d.cu'], | ||||||
|  |             headers=['upfirdn2d.h'], | ||||||
|  |             source_dir=os.path.dirname(__file__), | ||||||
|  |             extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], | ||||||
|  |         ) | ||||||
|  |     return True | ||||||
|  | 
 | ||||||
|  | def _parse_scaling(scaling): | ||||||
|  |     if isinstance(scaling, int): | ||||||
|  |         scaling = [scaling, scaling] | ||||||
|  |     assert isinstance(scaling, (list, tuple)) | ||||||
|  |     assert all(isinstance(x, int) for x in scaling) | ||||||
|  |     sx, sy = scaling | ||||||
|  |     assert sx >= 1 and sy >= 1 | ||||||
|  |     return sx, sy | ||||||
|  | 
 | ||||||
|  | def _parse_padding(padding): | ||||||
|  |     if isinstance(padding, int): | ||||||
|  |         padding = [padding, padding] | ||||||
|  |     assert isinstance(padding, (list, tuple)) | ||||||
|  |     assert all(isinstance(x, int) for x in padding) | ||||||
|  |     if len(padding) == 2: | ||||||
|  |         padx, pady = padding | ||||||
|  |         padding = [padx, padx, pady, pady] | ||||||
|  |     padx0, padx1, pady0, pady1 = padding | ||||||
|  |     return padx0, padx1, pady0, pady1 | ||||||
|  | 
 | ||||||
|  | def _get_filter_size(f): | ||||||
|  |     if f is None: | ||||||
|  |         return 1, 1 | ||||||
|  |     assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] | ||||||
|  |     fw = f.shape[-1] | ||||||
|  |     fh = f.shape[0] | ||||||
|  |     with misc.suppress_tracer_warnings(): | ||||||
|  |         fw = int(fw) | ||||||
|  |         fh = int(fh) | ||||||
|  |     misc.assert_shape(f, [fh, fw][:f.ndim]) | ||||||
|  |     assert fw >= 1 and fh >= 1 | ||||||
|  |     return fw, fh | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): | ||||||
|  |     r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         f:           Torch tensor, numpy array, or python list of the shape | ||||||
|  |                      `[filter_height, filter_width]` (non-separable), | ||||||
|  |                      `[filter_taps]` (separable), | ||||||
|  |                      `[]` (impulse), or | ||||||
|  |                      `None` (identity). | ||||||
|  |         device:      Result device (default: cpu). | ||||||
|  |         normalize:   Normalize the filter so that it retains the magnitude | ||||||
|  |                      for constant input signal (DC)? (default: True). | ||||||
|  |         flip_filter: Flip the filter? (default: False). | ||||||
|  |         gain:        Overall scaling factor for signal magnitude (default: 1). | ||||||
|  |         separable:   Return a separable filter? (default: select automatically). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Float32 tensor of the shape | ||||||
|  |         `[filter_height, filter_width]` (non-separable) or | ||||||
|  |         `[filter_taps]` (separable). | ||||||
|  |     """ | ||||||
|  |     # Validate. | ||||||
|  |     if f is None: | ||||||
|  |         f = 1 | ||||||
|  |     f = torch.as_tensor(f, dtype=torch.float32) | ||||||
|  |     assert f.ndim in [0, 1, 2] | ||||||
|  |     assert f.numel() > 0 | ||||||
|  |     if f.ndim == 0: | ||||||
|  |         f = f[np.newaxis] | ||||||
|  | 
 | ||||||
|  |     # Separable? | ||||||
|  |     if separable is None: | ||||||
|  |         separable = (f.ndim == 1 and f.numel() >= 8) | ||||||
|  |     if f.ndim == 1 and not separable: | ||||||
|  |         f = f.ger(f) | ||||||
|  |     assert f.ndim == (1 if separable else 2) | ||||||
|  | 
 | ||||||
|  |     # Apply normalize, flip, gain, and device. | ||||||
|  |     if normalize: | ||||||
|  |         f /= f.sum() | ||||||
|  |     if flip_filter: | ||||||
|  |         f = f.flip(list(range(f.ndim))) | ||||||
|  |     f = f * (gain ** (f.ndim / 2)) | ||||||
|  |     f = f.to(device=device) | ||||||
|  |     return f | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): | ||||||
|  |     r"""Pad, upsample, filter, and downsample a batch of 2D images. | ||||||
|  | 
 | ||||||
|  |     Performs the following sequence of operations for each channel: | ||||||
|  | 
 | ||||||
|  |     1. Upsample the image by inserting N-1 zeros after each pixel (`up`). | ||||||
|  | 
 | ||||||
|  |     2. Pad the image with the specified number of zeros on each side (`padding`). | ||||||
|  |        Negative padding corresponds to cropping the image. | ||||||
|  | 
 | ||||||
|  |     3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it | ||||||
|  |        so that the footprint of all output pixels lies within the input image. | ||||||
|  | 
 | ||||||
|  |     4. Downsample the image by keeping every Nth pixel (`down`). | ||||||
|  | 
 | ||||||
|  |     This sequence of operations bears close resemblance to scipy.signal.upfirdn(). | ||||||
|  |     The fused op is considerably more efficient than performing the same calculation | ||||||
|  |     using standard PyTorch ops. It supports gradients of arbitrary order. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x:           Float32/float64/float16 input tensor of the shape | ||||||
|  |                      `[batch_size, num_channels, in_height, in_width]`. | ||||||
|  |         f:           Float32 FIR filter of the shape | ||||||
|  |                      `[filter_height, filter_width]` (non-separable), | ||||||
|  |                      `[filter_taps]` (separable), or | ||||||
|  |                      `None` (identity). | ||||||
|  |         up:          Integer upsampling factor. Can be a single int or a list/tuple | ||||||
|  |                      `[x, y]` (default: 1). | ||||||
|  |         down:        Integer downsampling factor. Can be a single int or a list/tuple | ||||||
|  |                      `[x, y]` (default: 1). | ||||||
|  |         padding:     Padding with respect to the upsampled image. Can be a single number | ||||||
|  |                      or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | ||||||
|  |                      (default: 0). | ||||||
|  |         flip_filter: False = convolution, True = correlation (default: False). | ||||||
|  |         gain:        Overall scaling factor for signal magnitude (default: 1). | ||||||
|  |         impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | ||||||
|  |     """ | ||||||
|  |     assert isinstance(x, torch.Tensor) | ||||||
|  |     assert impl in ['ref', 'cuda'] | ||||||
|  |     if impl == 'cuda' and x.device.type == 'cuda' and _init(): | ||||||
|  |         return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) | ||||||
|  |     return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @misc.profiled_function | ||||||
|  | def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): | ||||||
|  |     """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. | ||||||
|  |     """ | ||||||
|  |     # Validate arguments. | ||||||
|  |     assert isinstance(x, torch.Tensor) and x.ndim == 4 | ||||||
|  |     if f is None: | ||||||
|  |         f = torch.ones([1, 1], dtype=torch.float32, device=x.device) | ||||||
|  |     assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] | ||||||
|  |     assert f.dtype == torch.float32 and not f.requires_grad | ||||||
|  |     batch_size, num_channels, in_height, in_width = x.shape | ||||||
|  |     upx, upy = _parse_scaling(up) | ||||||
|  |     downx, downy = _parse_scaling(down) | ||||||
|  |     padx0, padx1, pady0, pady1 = _parse_padding(padding) | ||||||
|  | 
 | ||||||
|  |     # Check that upsampled buffer is not smaller than the filter. | ||||||
|  |     upW = in_width * upx + padx0 + padx1 | ||||||
|  |     upH = in_height * upy + pady0 + pady1 | ||||||
|  |     assert upW >= f.shape[-1] and upH >= f.shape[0] | ||||||
|  | 
 | ||||||
|  |     # Upsample by inserting zeros. | ||||||
|  |     x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) | ||||||
|  |     x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) | ||||||
|  |     x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) | ||||||
|  | 
 | ||||||
|  |     # Pad or crop. | ||||||
|  |     x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) | ||||||
|  |     x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] | ||||||
|  | 
 | ||||||
|  |     # Setup filter. | ||||||
|  |     f = f * (gain ** (f.ndim / 2)) | ||||||
|  |     f = f.to(x.dtype) | ||||||
|  |     if not flip_filter: | ||||||
|  |         f = f.flip(list(range(f.ndim))) | ||||||
|  | 
 | ||||||
|  |     # Convolve with the filter. | ||||||
|  |     f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) | ||||||
|  |     if f.ndim == 4: | ||||||
|  |         x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) | ||||||
|  |     else: | ||||||
|  |         x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) | ||||||
|  |         x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) | ||||||
|  | 
 | ||||||
|  |     # Downsample by throwing away pixels. | ||||||
|  |     x = x[:, :, ::downy, ::downx] | ||||||
|  |     return x | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _upfirdn2d_cuda_cache = dict() | ||||||
|  | 
 | ||||||
|  | def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): | ||||||
|  |     """Fast CUDA implementation of `upfirdn2d()` using custom ops. | ||||||
|  |     """ | ||||||
|  |     # Parse arguments. | ||||||
|  |     upx, upy = _parse_scaling(up) | ||||||
|  |     downx, downy = _parse_scaling(down) | ||||||
|  |     padx0, padx1, pady0, pady1 = _parse_padding(padding) | ||||||
|  | 
 | ||||||
|  |     # Lookup from cache. | ||||||
|  |     key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) | ||||||
|  |     if key in _upfirdn2d_cuda_cache: | ||||||
|  |         return _upfirdn2d_cuda_cache[key] | ||||||
|  | 
 | ||||||
|  |     # Forward op. | ||||||
|  |     class Upfirdn2dCuda(torch.autograd.Function): | ||||||
|  |         @staticmethod | ||||||
|  |         def forward(ctx, x, f): # pylint: disable=arguments-differ | ||||||
|  |             assert isinstance(x, torch.Tensor) and x.ndim == 4 | ||||||
|  |             if f is None: | ||||||
|  |                 f = torch.ones([1, 1], dtype=torch.float32, device=x.device) | ||||||
|  |             if f.ndim == 1 and f.shape[0] == 1: | ||||||
|  |                 f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1. | ||||||
|  |             assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] | ||||||
|  |             y = x | ||||||
|  |             if f.ndim == 2: | ||||||
|  |                 y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) | ||||||
|  |             else: | ||||||
|  |                 y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0) | ||||||
|  |                 y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain) | ||||||
|  |             ctx.save_for_backward(f) | ||||||
|  |             ctx.x_shape = x.shape | ||||||
|  |             return y | ||||||
|  | 
 | ||||||
|  |         @staticmethod | ||||||
|  |         def backward(ctx, dy): # pylint: disable=arguments-differ | ||||||
|  |             f, = ctx.saved_tensors | ||||||
|  |             _, _, ih, iw = ctx.x_shape | ||||||
|  |             _, _, oh, ow = dy.shape | ||||||
|  |             fw, fh = _get_filter_size(f) | ||||||
|  |             p = [ | ||||||
|  |                 fw - padx0 - 1, | ||||||
|  |                 iw * upx - ow * downx + padx0 - upx + 1, | ||||||
|  |                 fh - pady0 - 1, | ||||||
|  |                 ih * upy - oh * downy + pady0 - upy + 1, | ||||||
|  |             ] | ||||||
|  |             dx = None | ||||||
|  |             df = None | ||||||
|  | 
 | ||||||
|  |             if ctx.needs_input_grad[0]: | ||||||
|  |                 dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) | ||||||
|  | 
 | ||||||
|  |             assert not ctx.needs_input_grad[1] | ||||||
|  |             return dx, df | ||||||
|  | 
 | ||||||
|  |     # Add to cache. | ||||||
|  |     _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda | ||||||
|  |     return Upfirdn2dCuda | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): | ||||||
|  |     r"""Filter a batch of 2D images using the given 2D FIR filter. | ||||||
|  | 
 | ||||||
|  |     By default, the result is padded so that its shape matches the input. | ||||||
|  |     User-specified padding is applied on top of that, with negative values | ||||||
|  |     indicating cropping. Pixels outside the image are assumed to be zero. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x:           Float32/float64/float16 input tensor of the shape | ||||||
|  |                      `[batch_size, num_channels, in_height, in_width]`. | ||||||
|  |         f:           Float32 FIR filter of the shape | ||||||
|  |                      `[filter_height, filter_width]` (non-separable), | ||||||
|  |                      `[filter_taps]` (separable), or | ||||||
|  |                      `None` (identity). | ||||||
|  |         padding:     Padding with respect to the output. Can be a single number or a | ||||||
|  |                      list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | ||||||
|  |                      (default: 0). | ||||||
|  |         flip_filter: False = convolution, True = correlation (default: False). | ||||||
|  |         gain:        Overall scaling factor for signal magnitude (default: 1). | ||||||
|  |         impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | ||||||
|  |     """ | ||||||
|  |     padx0, padx1, pady0, pady1 = _parse_padding(padding) | ||||||
|  |     fw, fh = _get_filter_size(f) | ||||||
|  |     p = [ | ||||||
|  |         padx0 + fw // 2, | ||||||
|  |         padx1 + (fw - 1) // 2, | ||||||
|  |         pady0 + fh // 2, | ||||||
|  |         pady1 + (fh - 1) // 2, | ||||||
|  |     ] | ||||||
|  |     return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): | ||||||
|  |     r"""Upsample a batch of 2D images using the given 2D FIR filter. | ||||||
|  | 
 | ||||||
|  |     By default, the result is padded so that its shape is a multiple of the input. | ||||||
|  |     User-specified padding is applied on top of that, with negative values | ||||||
|  |     indicating cropping. Pixels outside the image are assumed to be zero. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x:           Float32/float64/float16 input tensor of the shape | ||||||
|  |                      `[batch_size, num_channels, in_height, in_width]`. | ||||||
|  |         f:           Float32 FIR filter of the shape | ||||||
|  |                      `[filter_height, filter_width]` (non-separable), | ||||||
|  |                      `[filter_taps]` (separable), or | ||||||
|  |                      `None` (identity). | ||||||
|  |         up:          Integer upsampling factor. Can be a single int or a list/tuple | ||||||
|  |                      `[x, y]` (default: 1). | ||||||
|  |         padding:     Padding with respect to the output. Can be a single number or a | ||||||
|  |                      list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | ||||||
|  |                      (default: 0). | ||||||
|  |         flip_filter: False = convolution, True = correlation (default: False). | ||||||
|  |         gain:        Overall scaling factor for signal magnitude (default: 1). | ||||||
|  |         impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | ||||||
|  |     """ | ||||||
|  |     upx, upy = _parse_scaling(up) | ||||||
|  |     padx0, padx1, pady0, pady1 = _parse_padding(padding) | ||||||
|  |     fw, fh = _get_filter_size(f) | ||||||
|  |     p = [ | ||||||
|  |         padx0 + (fw + upx - 1) // 2, | ||||||
|  |         padx1 + (fw - upx) // 2, | ||||||
|  |         pady0 + (fh + upy - 1) // 2, | ||||||
|  |         pady1 + (fh - upy) // 2, | ||||||
|  |     ] | ||||||
|  |     return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): | ||||||
|  |     r"""Downsample a batch of 2D images using the given 2D FIR filter. | ||||||
|  | 
 | ||||||
|  |     By default, the result is padded so that its shape is a fraction of the input. | ||||||
|  |     User-specified padding is applied on top of that, with negative values | ||||||
|  |     indicating cropping. Pixels outside the image are assumed to be zero. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x:           Float32/float64/float16 input tensor of the shape | ||||||
|  |                      `[batch_size, num_channels, in_height, in_width]`. | ||||||
|  |         f:           Float32 FIR filter of the shape | ||||||
|  |                      `[filter_height, filter_width]` (non-separable), | ||||||
|  |                      `[filter_taps]` (separable), or | ||||||
|  |                      `None` (identity). | ||||||
|  |         down:        Integer downsampling factor. Can be a single int or a list/tuple | ||||||
|  |                      `[x, y]` (default: 1). | ||||||
|  |         padding:     Padding with respect to the input. Can be a single number or a | ||||||
|  |                      list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` | ||||||
|  |                      (default: 0). | ||||||
|  |         flip_filter: False = convolution, True = correlation (default: False). | ||||||
|  |         gain:        Overall scaling factor for signal magnitude (default: 1). | ||||||
|  |         impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. | ||||||
|  |     """ | ||||||
|  |     downx, downy = _parse_scaling(down) | ||||||
|  |     padx0, padx1, pady0, pady1 = _parse_padding(padding) | ||||||
|  |     fw, fh = _get_filter_size(f) | ||||||
|  |     p = [ | ||||||
|  |         padx0 + (fw - downx + 1) // 2, | ||||||
|  |         padx1 + (fw - downx) // 2, | ||||||
|  |         pady0 + (fh - downy + 1) // 2, | ||||||
|  |         pady1 + (fh - downy) // 2, | ||||||
|  |     ] | ||||||
|  |     return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										251
									
								
								models/stylegan3/torch_utils/persistence.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										251
									
								
								models/stylegan3/torch_utils/persistence.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,251 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Facilities for pickling Python code alongside other data. | ||||||
|  | 
 | ||||||
|  | The pickled code is automatically imported into a separate Python module | ||||||
|  | during unpickling. This way, any previously exported pickles will remain | ||||||
|  | usable even if the original code is no longer available, or if the current | ||||||
|  | version of the code is not consistent with what was originally pickled.""" | ||||||
|  | 
 | ||||||
|  | import sys | ||||||
|  | import pickle | ||||||
|  | import io | ||||||
|  | import inspect | ||||||
|  | import copy | ||||||
|  | import uuid | ||||||
|  | import types | ||||||
|  | import dnnlib | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _version            = 6         # internal version number | ||||||
|  | _decorators         = set()     # {decorator_class, ...} | ||||||
|  | _import_hooks       = []        # [hook_function, ...] | ||||||
|  | _module_to_src_dict = dict()    # {module: src, ...} | ||||||
|  | _src_to_module_dict = dict()    # {src: module, ...} | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def persistent_class(orig_class): | ||||||
|  |     r"""Class decorator that extends a given class to save its source code | ||||||
|  |     when pickled. | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  | 
 | ||||||
|  |         from torch_utils import persistence | ||||||
|  | 
 | ||||||
|  |         @persistence.persistent_class | ||||||
|  |         class MyNetwork(torch.nn.Module): | ||||||
|  |             def __init__(self, num_inputs, num_outputs): | ||||||
|  |                 super().__init__() | ||||||
|  |                 self.fc = MyLayer(num_inputs, num_outputs) | ||||||
|  |                 ... | ||||||
|  | 
 | ||||||
|  |         @persistence.persistent_class | ||||||
|  |         class MyLayer(torch.nn.Module): | ||||||
|  |             ... | ||||||
|  | 
 | ||||||
|  |     When pickled, any instance of `MyNetwork` and `MyLayer` will save its | ||||||
|  |     source code alongside other internal state (e.g., parameters, buffers, | ||||||
|  |     and submodules). This way, any previously exported pickle will remain | ||||||
|  |     usable even if the class definitions have been modified or are no | ||||||
|  |     longer available. | ||||||
|  | 
 | ||||||
|  |     The decorator saves the source code of the entire Python module | ||||||
|  |     containing the decorated class. It does *not* save the source code of | ||||||
|  |     any imported modules. Thus, the imported modules must be available | ||||||
|  |     during unpickling, also including `torch_utils.persistence` itself. | ||||||
|  | 
 | ||||||
|  |     It is ok to call functions defined in the same module from the | ||||||
|  |     decorated class. However, if the decorated class depends on other | ||||||
|  |     classes defined in the same module, they must be decorated as well. | ||||||
|  |     This is illustrated in the above example in the case of `MyLayer`. | ||||||
|  | 
 | ||||||
|  |     It is also possible to employ the decorator just-in-time before | ||||||
|  |     calling the constructor. For example: | ||||||
|  | 
 | ||||||
|  |         cls = MyLayer | ||||||
|  |         if want_to_make_it_persistent: | ||||||
|  |             cls = persistence.persistent_class(cls) | ||||||
|  |         layer = cls(num_inputs, num_outputs) | ||||||
|  | 
 | ||||||
|  |     As an additional feature, the decorator also keeps track of the | ||||||
|  |     arguments that were used to construct each instance of the decorated | ||||||
|  |     class. The arguments can be queried via `obj.init_args` and | ||||||
|  |     `obj.init_kwargs`, and they are automatically pickled alongside other | ||||||
|  |     object state. A typical use case is to first unpickle a previous | ||||||
|  |     instance of a persistent class, and then upgrade it to use the latest | ||||||
|  |     version of the source code: | ||||||
|  | 
 | ||||||
|  |         with open('old_pickle.pkl', 'rb') as f: | ||||||
|  |             old_net = pickle.load(f) | ||||||
|  |         new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) | ||||||
|  |         misc.copy_params_and_buffers(old_net, new_net, require_all=True) | ||||||
|  |     """ | ||||||
|  |     assert isinstance(orig_class, type) | ||||||
|  |     if is_persistent(orig_class): | ||||||
|  |         return orig_class | ||||||
|  | 
 | ||||||
|  |     assert orig_class.__module__ in sys.modules | ||||||
|  |     orig_module = sys.modules[orig_class.__module__] | ||||||
|  |     orig_module_src = _module_to_src(orig_module) | ||||||
|  | 
 | ||||||
|  |     class Decorator(orig_class): | ||||||
|  |         _orig_module_src = orig_module_src | ||||||
|  |         _orig_class_name = orig_class.__name__ | ||||||
|  | 
 | ||||||
|  |         def __init__(self, *args, **kwargs): | ||||||
|  |             super().__init__(*args, **kwargs) | ||||||
|  |             self._init_args = copy.deepcopy(args) | ||||||
|  |             self._init_kwargs = copy.deepcopy(kwargs) | ||||||
|  |             assert orig_class.__name__ in orig_module.__dict__ | ||||||
|  |             _check_pickleable(self.__reduce__()) | ||||||
|  | 
 | ||||||
|  |         @property | ||||||
|  |         def init_args(self): | ||||||
|  |             return copy.deepcopy(self._init_args) | ||||||
|  | 
 | ||||||
|  |         @property | ||||||
|  |         def init_kwargs(self): | ||||||
|  |             return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) | ||||||
|  | 
 | ||||||
|  |         def __reduce__(self): | ||||||
|  |             fields = list(super().__reduce__()) | ||||||
|  |             fields += [None] * max(3 - len(fields), 0) | ||||||
|  |             if fields[0] is not _reconstruct_persistent_obj: | ||||||
|  |                 meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) | ||||||
|  |                 fields[0] = _reconstruct_persistent_obj # reconstruct func | ||||||
|  |                 fields[1] = (meta,) # reconstruct args | ||||||
|  |                 fields[2] = None # state dict | ||||||
|  |             return tuple(fields) | ||||||
|  | 
 | ||||||
|  |     Decorator.__name__ = orig_class.__name__ | ||||||
|  |     _decorators.add(Decorator) | ||||||
|  |     return Decorator | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def is_persistent(obj): | ||||||
|  |     r"""Test whether the given object or class is persistent, i.e., | ||||||
|  |     whether it will save its source code when pickled. | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         if obj in _decorators: | ||||||
|  |             return True | ||||||
|  |     except TypeError: | ||||||
|  |         pass | ||||||
|  |     return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def import_hook(hook): | ||||||
|  |     r"""Register an import hook that is called whenever a persistent object | ||||||
|  |     is being unpickled. A typical use case is to patch the pickled source | ||||||
|  |     code to avoid errors and inconsistencies when the API of some imported | ||||||
|  |     module has changed. | ||||||
|  | 
 | ||||||
|  |     The hook should have the following signature: | ||||||
|  | 
 | ||||||
|  |         hook(meta) -> modified meta | ||||||
|  | 
 | ||||||
|  |     `meta` is an instance of `dnnlib.EasyDict` with the following fields: | ||||||
|  | 
 | ||||||
|  |         type:       Type of the persistent object, e.g. `'class'`. | ||||||
|  |         version:    Internal version number of `torch_utils.persistence`. | ||||||
|  |         module_src  Original source code of the Python module. | ||||||
|  |         class_name: Class name in the original Python module. | ||||||
|  |         state:      Internal state of the object. | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  | 
 | ||||||
|  |         @persistence.import_hook | ||||||
|  |         def wreck_my_network(meta): | ||||||
|  |             if meta.class_name == 'MyNetwork': | ||||||
|  |                 print('MyNetwork is being imported. I will wreck it!') | ||||||
|  |                 meta.module_src = meta.module_src.replace("True", "False") | ||||||
|  |             return meta | ||||||
|  |     """ | ||||||
|  |     assert callable(hook) | ||||||
|  |     _import_hooks.append(hook) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _reconstruct_persistent_obj(meta): | ||||||
|  |     r"""Hook that is called internally by the `pickle` module to unpickle | ||||||
|  |     a persistent object. | ||||||
|  |     """ | ||||||
|  |     meta = dnnlib.EasyDict(meta) | ||||||
|  |     meta.state = dnnlib.EasyDict(meta.state) | ||||||
|  |     for hook in _import_hooks: | ||||||
|  |         meta = hook(meta) | ||||||
|  |         assert meta is not None | ||||||
|  | 
 | ||||||
|  |     assert meta.version == _version | ||||||
|  |     module = _src_to_module(meta.module_src) | ||||||
|  | 
 | ||||||
|  |     assert meta.type == 'class' | ||||||
|  |     orig_class = module.__dict__[meta.class_name] | ||||||
|  |     decorator_class = persistent_class(orig_class) | ||||||
|  |     obj = decorator_class.__new__(decorator_class) | ||||||
|  | 
 | ||||||
|  |     setstate = getattr(obj, '__setstate__', None) | ||||||
|  |     if callable(setstate): | ||||||
|  |         setstate(meta.state) # pylint: disable=not-callable | ||||||
|  |     else: | ||||||
|  |         obj.__dict__.update(meta.state) | ||||||
|  |     return obj | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _module_to_src(module): | ||||||
|  |     r"""Query the source code of a given Python module. | ||||||
|  |     """ | ||||||
|  |     src = _module_to_src_dict.get(module, None) | ||||||
|  |     if src is None: | ||||||
|  |         src = inspect.getsource(module) | ||||||
|  |         _module_to_src_dict[module] = src | ||||||
|  |         _src_to_module_dict[src] = module | ||||||
|  |     return src | ||||||
|  | 
 | ||||||
|  | def _src_to_module(src): | ||||||
|  |     r"""Get or create a Python module for the given source code. | ||||||
|  |     """ | ||||||
|  |     module = _src_to_module_dict.get(src, None) | ||||||
|  |     if module is None: | ||||||
|  |         module_name = "_imported_module_" + uuid.uuid4().hex | ||||||
|  |         module = types.ModuleType(module_name) | ||||||
|  |         sys.modules[module_name] = module | ||||||
|  |         _module_to_src_dict[module] = src | ||||||
|  |         _src_to_module_dict[src] = module | ||||||
|  |         exec(src, module.__dict__) # pylint: disable=exec-used | ||||||
|  |     return module | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _check_pickleable(obj): | ||||||
|  |     r"""Check that the given object is pickleable, raising an exception if | ||||||
|  |     it is not. This function is expected to be considerably more efficient | ||||||
|  |     than actually pickling the object. | ||||||
|  |     """ | ||||||
|  |     def recurse(obj): | ||||||
|  |         if isinstance(obj, (list, tuple, set)): | ||||||
|  |             return [recurse(x) for x in obj] | ||||||
|  |         if isinstance(obj, dict): | ||||||
|  |             return [[recurse(x), recurse(y)] for x, y in obj.items()] | ||||||
|  |         if isinstance(obj, (str, int, float, bool, bytes, bytearray)): | ||||||
|  |             return None # Python primitive types are pickleable. | ||||||
|  |         if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: | ||||||
|  |             return None # NumPy arrays and PyTorch tensors are pickleable. | ||||||
|  |         if is_persistent(obj): | ||||||
|  |             return None # Persistent objects are pickleable, by virtue of the constructor check. | ||||||
|  |         return obj | ||||||
|  |     with io.BytesIO() as f: | ||||||
|  |         pickle.dump(recurse(obj), f) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										268
									
								
								models/stylegan3/torch_utils/training_stats.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										268
									
								
								models/stylegan3/torch_utils/training_stats.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,268 @@ | |||||||
|  | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved. | ||||||
|  | # | ||||||
|  | # NVIDIA CORPORATION and its licensors retain all intellectual property | ||||||
|  | # and proprietary rights in and to this software, related documentation | ||||||
|  | # and any modifications thereto.  Any use, reproduction, disclosure or | ||||||
|  | # distribution of this software and related documentation without an express | ||||||
|  | # license agreement from NVIDIA CORPORATION is strictly prohibited. | ||||||
|  | 
 | ||||||
|  | """Facilities for reporting and collecting training statistics across | ||||||
|  | multiple processes and devices. The interface is designed to minimize | ||||||
|  | synchronization overhead as well as the amount of boilerplate in user | ||||||
|  | code.""" | ||||||
|  | 
 | ||||||
|  | import re | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | import dnnlib | ||||||
|  | 
 | ||||||
|  | from . import misc | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | _num_moments    = 3             # [num_scalars, sum_of_scalars, sum_of_squares] | ||||||
|  | _reduce_dtype   = torch.float32 # Data type to use for initial per-tensor reduction. | ||||||
|  | _counter_dtype  = torch.float64 # Data type to use for the internal counters. | ||||||
|  | _rank           = 0             # Rank of the current process. | ||||||
|  | _sync_device    = None          # Device to use for multiprocess communication. None = single-process. | ||||||
|  | _sync_called    = False         # Has _sync() been called yet? | ||||||
|  | _counters       = dict()        # Running counters on each device, updated by report(): name => device => torch.Tensor | ||||||
|  | _cumulative     = dict()        # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def init_multiprocessing(rank, sync_device): | ||||||
|  |     r"""Initializes `torch_utils.training_stats` for collecting statistics | ||||||
|  |     across multiple processes. | ||||||
|  | 
 | ||||||
|  |     This function must be called after | ||||||
|  |     `torch.distributed.init_process_group()` and before `Collector.update()`. | ||||||
|  |     The call is not necessary if multi-process collection is not needed. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         rank:           Rank of the current process. | ||||||
|  |         sync_device:    PyTorch device to use for inter-process | ||||||
|  |                         communication, or None to disable multi-process | ||||||
|  |                         collection. Typically `torch.device('cuda', rank)`. | ||||||
|  |     """ | ||||||
|  |     global _rank, _sync_device | ||||||
|  |     assert not _sync_called | ||||||
|  |     _rank = rank | ||||||
|  |     _sync_device = sync_device | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | @misc.profiled_function | ||||||
|  | def report(name, value): | ||||||
|  |     r"""Broadcasts the given set of scalars to all interested instances of | ||||||
|  |     `Collector`, across device and process boundaries. | ||||||
|  | 
 | ||||||
|  |     This function is expected to be extremely cheap and can be safely | ||||||
|  |     called from anywhere in the training loop, loss function, or inside a | ||||||
|  |     `torch.nn.Module`. | ||||||
|  | 
 | ||||||
|  |     Warning: The current implementation expects the set of unique names to | ||||||
|  |     be consistent across processes. Please make sure that `report()` is | ||||||
|  |     called at least once for each unique name by each process, and in the | ||||||
|  |     same order. If a given process has no scalars to broadcast, it can do | ||||||
|  |     `report(name, [])` (empty list). | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         name:   Arbitrary string specifying the name of the statistic. | ||||||
|  |                 Averages are accumulated separately for each unique name. | ||||||
|  |         value:  Arbitrary set of scalars. Can be a list, tuple, | ||||||
|  |                 NumPy array, PyTorch tensor, or Python scalar. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         The same `value` that was passed in. | ||||||
|  |     """ | ||||||
|  |     if name not in _counters: | ||||||
|  |         _counters[name] = dict() | ||||||
|  | 
 | ||||||
|  |     elems = torch.as_tensor(value) | ||||||
|  |     if elems.numel() == 0: | ||||||
|  |         return value | ||||||
|  | 
 | ||||||
|  |     elems = elems.detach().flatten().to(_reduce_dtype) | ||||||
|  |     moments = torch.stack([ | ||||||
|  |         torch.ones_like(elems).sum(), | ||||||
|  |         elems.sum(), | ||||||
|  |         elems.square().sum(), | ||||||
|  |     ]) | ||||||
|  |     assert moments.ndim == 1 and moments.shape[0] == _num_moments | ||||||
|  |     moments = moments.to(_counter_dtype) | ||||||
|  | 
 | ||||||
|  |     device = moments.device | ||||||
|  |     if device not in _counters[name]: | ||||||
|  |         _counters[name][device] = torch.zeros_like(moments) | ||||||
|  |     _counters[name][device].add_(moments) | ||||||
|  |     return value | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def report0(name, value): | ||||||
|  |     r"""Broadcasts the given set of scalars by the first process (`rank = 0`), | ||||||
|  |     but ignores any scalars provided by the other processes. | ||||||
|  |     See `report()` for further details. | ||||||
|  |     """ | ||||||
|  |     report(name, value if _rank == 0 else []) | ||||||
|  |     return value | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | class Collector: | ||||||
|  |     r"""Collects the scalars broadcasted by `report()` and `report0()` and | ||||||
|  |     computes their long-term averages (mean and standard deviation) over | ||||||
|  |     user-defined periods of time. | ||||||
|  | 
 | ||||||
|  |     The averages are first collected into internal counters that are not | ||||||
|  |     directly visible to the user. They are then copied to the user-visible | ||||||
|  |     state as a result of calling `update()` and can then be queried using | ||||||
|  |     `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the | ||||||
|  |     internal counters for the next round, so that the user-visible state | ||||||
|  |     effectively reflects averages collected between the last two calls to | ||||||
|  |     `update()`. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         regex:          Regular expression defining which statistics to | ||||||
|  |                         collect. The default is to collect everything. | ||||||
|  |         keep_previous:  Whether to retain the previous averages if no | ||||||
|  |                         scalars were collected on a given round | ||||||
|  |                         (default: True). | ||||||
|  |     """ | ||||||
|  |     def __init__(self, regex='.*', keep_previous=True): | ||||||
|  |         self._regex = re.compile(regex) | ||||||
|  |         self._keep_previous = keep_previous | ||||||
|  |         self._cumulative = dict() | ||||||
|  |         self._moments = dict() | ||||||
|  |         self.update() | ||||||
|  |         self._moments.clear() | ||||||
|  | 
 | ||||||
|  |     def names(self): | ||||||
|  |         r"""Returns the names of all statistics broadcasted so far that | ||||||
|  |         match the regular expression specified at construction time. | ||||||
|  |         """ | ||||||
|  |         return [name for name in _counters if self._regex.fullmatch(name)] | ||||||
|  | 
 | ||||||
|  |     def update(self): | ||||||
|  |         r"""Copies current values of the internal counters to the | ||||||
|  |         user-visible state and resets them for the next round. | ||||||
|  | 
 | ||||||
|  |         If `keep_previous=True` was specified at construction time, the | ||||||
|  |         operation is skipped for statistics that have received no scalars | ||||||
|  |         since the last update, retaining their previous averages. | ||||||
|  | 
 | ||||||
|  |         This method performs a number of GPU-to-CPU transfers and one | ||||||
|  |         `torch.distributed.all_reduce()`. It is intended to be called | ||||||
|  |         periodically in the main training loop, typically once every | ||||||
|  |         N training steps. | ||||||
|  |         """ | ||||||
|  |         if not self._keep_previous: | ||||||
|  |             self._moments.clear() | ||||||
|  |         for name, cumulative in _sync(self.names()): | ||||||
|  |             if name not in self._cumulative: | ||||||
|  |                 self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) | ||||||
|  |             delta = cumulative - self._cumulative[name] | ||||||
|  |             self._cumulative[name].copy_(cumulative) | ||||||
|  |             if float(delta[0]) != 0: | ||||||
|  |                 self._moments[name] = delta | ||||||
|  | 
 | ||||||
|  |     def _get_delta(self, name): | ||||||
|  |         r"""Returns the raw moments that were accumulated for the given | ||||||
|  |         statistic between the last two calls to `update()`, or zero if | ||||||
|  |         no scalars were collected. | ||||||
|  |         """ | ||||||
|  |         assert self._regex.fullmatch(name) | ||||||
|  |         if name not in self._moments: | ||||||
|  |             self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) | ||||||
|  |         return self._moments[name] | ||||||
|  | 
 | ||||||
|  |     def num(self, name): | ||||||
|  |         r"""Returns the number of scalars that were accumulated for the given | ||||||
|  |         statistic between the last two calls to `update()`, or zero if | ||||||
|  |         no scalars were collected. | ||||||
|  |         """ | ||||||
|  |         delta = self._get_delta(name) | ||||||
|  |         return int(delta[0]) | ||||||
|  | 
 | ||||||
|  |     def mean(self, name): | ||||||
|  |         r"""Returns the mean of the scalars that were accumulated for the | ||||||
|  |         given statistic between the last two calls to `update()`, or NaN if | ||||||
|  |         no scalars were collected. | ||||||
|  |         """ | ||||||
|  |         delta = self._get_delta(name) | ||||||
|  |         if int(delta[0]) == 0: | ||||||
|  |             return float('nan') | ||||||
|  |         return float(delta[1] / delta[0]) | ||||||
|  | 
 | ||||||
|  |     def std(self, name): | ||||||
|  |         r"""Returns the standard deviation of the scalars that were | ||||||
|  |         accumulated for the given statistic between the last two calls to | ||||||
|  |         `update()`, or NaN if no scalars were collected. | ||||||
|  |         """ | ||||||
|  |         delta = self._get_delta(name) | ||||||
|  |         if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): | ||||||
|  |             return float('nan') | ||||||
|  |         if int(delta[0]) == 1: | ||||||
|  |             return float(0) | ||||||
|  |         mean = float(delta[1] / delta[0]) | ||||||
|  |         raw_var = float(delta[2] / delta[0]) | ||||||
|  |         return np.sqrt(max(raw_var - np.square(mean), 0)) | ||||||
|  | 
 | ||||||
|  |     def as_dict(self): | ||||||
|  |         r"""Returns the averages accumulated between the last two calls to | ||||||
|  |         `update()` as an `dnnlib.EasyDict`. The contents are as follows: | ||||||
|  | 
 | ||||||
|  |             dnnlib.EasyDict( | ||||||
|  |                 NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), | ||||||
|  |                 ... | ||||||
|  |             ) | ||||||
|  |         """ | ||||||
|  |         stats = dnnlib.EasyDict() | ||||||
|  |         for name in self.names(): | ||||||
|  |             stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) | ||||||
|  |         return stats | ||||||
|  | 
 | ||||||
|  |     def __getitem__(self, name): | ||||||
|  |         r"""Convenience getter. | ||||||
|  |         `collector[name]` is a synonym for `collector.mean(name)`. | ||||||
|  |         """ | ||||||
|  |         return self.mean(name) | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  | def _sync(names): | ||||||
|  |     r"""Synchronize the global cumulative counters across devices and | ||||||
|  |     processes. Called internally by `Collector.update()`. | ||||||
|  |     """ | ||||||
|  |     if len(names) == 0: | ||||||
|  |         return [] | ||||||
|  |     global _sync_called | ||||||
|  |     _sync_called = True | ||||||
|  | 
 | ||||||
|  |     # Collect deltas within current rank. | ||||||
|  |     deltas = [] | ||||||
|  |     device = _sync_device if _sync_device is not None else torch.device('cpu') | ||||||
|  |     for name in names: | ||||||
|  |         delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) | ||||||
|  |         for counter in _counters[name].values(): | ||||||
|  |             delta.add_(counter.to(device)) | ||||||
|  |             counter.copy_(torch.zeros_like(counter)) | ||||||
|  |         deltas.append(delta) | ||||||
|  |     deltas = torch.stack(deltas) | ||||||
|  | 
 | ||||||
|  |     # Sum deltas across ranks. | ||||||
|  |     if _sync_device is not None: | ||||||
|  |         torch.distributed.all_reduce(deltas) | ||||||
|  | 
 | ||||||
|  |     # Update cumulative values. | ||||||
|  |     deltas = deltas.cpu() | ||||||
|  |     for idx, name in enumerate(names): | ||||||
|  |         if name not in _cumulative: | ||||||
|  |             _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) | ||||||
|  |         _cumulative[name].add_(deltas[idx]) | ||||||
|  | 
 | ||||||
|  |     # Return name-value pairs. | ||||||
|  |     return [(name, _cumulative[name]) for name in names] | ||||||
|  | 
 | ||||||
|  | #---------------------------------------------------------------------------- | ||||||
							
								
								
									
										181
									
								
								optimization/run_optimization.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										181
									
								
								optimization/run_optimization.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,181 @@ | |||||||
|  | import argparse | ||||||
|  | import math | ||||||
|  | import os | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | import torchvision | ||||||
|  | from torch import optim | ||||||
|  | from tqdm import tqdm | ||||||
|  | 
 | ||||||
|  | from criteria.clip_loss import CLIPLoss | ||||||
|  | from criteria.id_loss import IDLoss | ||||||
|  | from mapper.training.train_utils import STYLESPACE_DIMENSIONS | ||||||
|  | from models.stylegan2.model import Generator | ||||||
|  | import clip | ||||||
|  | from utils import ensure_checkpoint_exists | ||||||
|  | 
 | ||||||
|  | STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in list(range(1, len(STYLESPACE_DIMENSIONS), 3))] | ||||||
|  | 
 | ||||||
|  | def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): | ||||||
|  |     lr_ramp = min(1, (1 - t) / rampdown) | ||||||
|  |     lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) | ||||||
|  |     lr_ramp = lr_ramp * min(1, t / rampup) | ||||||
|  | 
 | ||||||
|  |     return initial_lr * lr_ramp | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def main(args): | ||||||
|  |     ensure_checkpoint_exists(args.ckpt) | ||||||
|  |     # 把描述加载进clip预训练模型里面去 | ||||||
|  |     text_inputs = torch.cat([clip.tokenize(args.description)]).cuda() | ||||||
|  |     # print('text_input是: ', text_inputs) | ||||||
|  |     ''' | ||||||
|  |      --description "a person with purple hair"  | ||||||
|  |      tensor([[49406,   320,  2533,   593,  5496,  2225, 49407,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0]], device='cuda:0', | ||||||
|  |        dtype=torch.int32) | ||||||
|  |        --description "a person with red hair"  | ||||||
|  |         tensor([[49406,   320,  2533,   593,   736,  2225, 49407,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0,     0,     0,     0, | ||||||
|  |              0,     0,     0,     0,     0,     0,     0]], device='cuda:0', | ||||||
|  |        dtype=torch.int32) | ||||||
|  |     ''' | ||||||
|  | 
 | ||||||
|  |     os.makedirs(args.results_dir, exist_ok=True) | ||||||
|  | 
 | ||||||
|  |     g_ema = Generator(args.stylegan_size, 512, 8) | ||||||
|  |     g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) | ||||||
|  |     # 将模型对象设置为评估模式 | ||||||
|  |     g_ema.eval() | ||||||
|  |     #更改cuda卡号 | ||||||
|  |     g_ema = g_ema.cuda() | ||||||
|  |     # device = torch.cuda.current_device() | ||||||
|  |     # print('cuda:',device) | ||||||
|  |     mean_latent = g_ema.mean_latent(4096) | ||||||
|  |     # print('mean_latent: ', mean_latent.shape )    #[1,512] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     if args.latent_path: | ||||||
|  |         latent_code_init = torch.load(args.latent_path).cuda() | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             _, latent_code_init, _ = g_ema([latent_code_init], return_latents=True, | ||||||
|  |                                         truncation=args.truncation, truncation_latent=mean_latent) | ||||||
|  |     elif args.mode == "edit": | ||||||
|  |         latent_code_init_not_trunc = torch.randn(1, 512).cuda() | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             _, latent_code_init, _ = g_ema([latent_code_init_not_trunc], return_latents=True, | ||||||
|  |                                         truncation=args.truncation, truncation_latent=mean_latent) | ||||||
|  |     else: | ||||||
|  |         latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1) | ||||||
|  |         print(latent_code_init)  #在维度1上重复18次 torch.Size([1, 18, 512]) | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         img_orig, _ = g_ema([latent_code_init], input_is_latent=True, randomize_noise=False) | ||||||
|  | 
 | ||||||
|  |     if args.work_in_stylespace: | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             _, _, latent_code_init = g_ema([latent_code_init], input_is_latent=True, return_latents=True) | ||||||
|  |         latent = [s.detach().clone() for s in latent_code_init] | ||||||
|  |         for c, s in enumerate(latent): | ||||||
|  |             if c in STYLESPACE_INDICES_WITHOUT_TORGB: | ||||||
|  |                 s.requires_grad = True | ||||||
|  |     else: | ||||||
|  |         latent = latent_code_init.detach().clone() | ||||||
|  |         latent.requires_grad = True | ||||||
|  | 
 | ||||||
|  |     clip_loss = CLIPLoss(args) | ||||||
|  |     id_loss = IDLoss(args) | ||||||
|  | 
 | ||||||
|  |     if args.work_in_stylespace: | ||||||
|  |         optimizer = optim.Adam(latent, lr=args.lr) | ||||||
|  |     else: | ||||||
|  |         optimizer = optim.Adam([latent], lr=args.lr) | ||||||
|  | 
 | ||||||
|  |     pbar = tqdm(range(args.step)) | ||||||
|  | 
 | ||||||
|  |     for i in pbar: | ||||||
|  |         t = i / args.step | ||||||
|  |         lr = get_lr(t, args.lr) | ||||||
|  |         optimizer.param_groups[0]["lr"] = lr | ||||||
|  | 
 | ||||||
|  |         img_gen, _ = g_ema([latent], input_is_latent=True, randomize_noise=False, input_is_stylespace=args.work_in_stylespace) | ||||||
|  | 
 | ||||||
|  |         c_loss = clip_loss(img_gen, text_inputs) | ||||||
|  | 
 | ||||||
|  |         if args.id_lambda > 0: | ||||||
|  |             i_loss = id_loss(img_gen, img_orig)[0] | ||||||
|  |         else: | ||||||
|  |             i_loss = 0 | ||||||
|  | 
 | ||||||
|  |         if args.mode == "edit": | ||||||
|  |             if args.work_in_stylespace: | ||||||
|  |                 l2_loss = sum([((latent_code_init[c] - latent[c]) ** 2).sum() for c in range(len(latent_code_init))]) | ||||||
|  |             else: | ||||||
|  |                 l2_loss = ((latent_code_init - latent) ** 2).sum() | ||||||
|  |             loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss | ||||||
|  |         else: | ||||||
|  |             loss = c_loss | ||||||
|  | 
 | ||||||
|  |         optimizer.zero_grad() | ||||||
|  |         loss.backward() | ||||||
|  |         optimizer.step() | ||||||
|  | 
 | ||||||
|  |         pbar.set_description( | ||||||
|  |             ( | ||||||
|  |                 f"loss: {loss.item():.4f};" | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0: | ||||||
|  |             with torch.no_grad(): | ||||||
|  |                 img_gen, _ = g_ema([latent], input_is_latent=True, randomize_noise=False, input_is_stylespace=args.work_in_stylespace) | ||||||
|  | 
 | ||||||
|  |             torchvision.utils.save_image(img_gen, f"results/{str(i).zfill(5)}.jpg", normalize=True, range=(-1, 1)) | ||||||
|  | 
 | ||||||
|  |     if args.mode == "edit": | ||||||
|  |         final_result = torch.cat([img_orig, img_gen]) | ||||||
|  |     else: | ||||||
|  |         final_result = img_gen | ||||||
|  | 
 | ||||||
|  |     return final_result | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument("--description", type=str, default="a person with purple hair", help="the text that guides the editing/generation") | ||||||
|  |     parser.add_argument("--ckpt", type=str, default="../pretrained_models/stylegan2-ffhq-config-f.pt", help="pretrained StyleGAN2 weights") | ||||||
|  |     parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution") | ||||||
|  |     parser.add_argument("--lr_rampup", type=float, default=0.05) | ||||||
|  |     parser.add_argument("--lr", type=float, default=0.1) | ||||||
|  |     parser.add_argument("--step", type=int, default=300, help="number of optimization steps") | ||||||
|  |     parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"], help="choose between edit an image an generate a free one") | ||||||
|  |     parser.add_argument("--l2_lambda", type=float, default=0.008, help="weight of the latent distance (used for editing only)") | ||||||
|  |     parser.add_argument("--id_lambda", type=float, default=0.000, help="weight of id loss (used for editing only)") | ||||||
|  |     parser.add_argument("--latent_path", type=str, default=None, help="starts the optimization from the given latent code if provided. Otherwose, starts from" | ||||||
|  |                                                                       "the mean latent in a free generation, and from a random one in editing. " | ||||||
|  |                                                                       "Expects a .pt format") | ||||||
|  |     parser.add_argument("--truncation", type=float, default=0.7, help="used only for the initial latent vector, and only when a latent code path is" | ||||||
|  |                                                                       "not provided") | ||||||
|  |     parser.add_argument('--work_in_stylespace', default=False, action='store_true') | ||||||
|  |     parser.add_argument("--save_intermediate_image_every", type=int, default=20, help="if > 0 then saves intermidate results during the optimization") | ||||||
|  |     parser.add_argument("--results_dir", type=str, default="results") | ||||||
|  |     parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str, | ||||||
|  |                              help="Path to facial recognition network used in ID loss") | ||||||
|  | 
 | ||||||
|  |     args = parser.parse_args() | ||||||
|  | 
 | ||||||
|  |     result_image = main(args) | ||||||
|  | 
 | ||||||
|  |     torchvision.utils.save_image(result_image.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"), normalize=True, scale_each=True, range=(-1, 1)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
							
								
								
									
										36
									
								
								test001.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								test001.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,36 @@ | |||||||
|  | import torchvision | ||||||
|  | import argparse | ||||||
|  | from argparse import Namespace | ||||||
|  | from optimization.run_optimization import main | ||||||
|  | 
 | ||||||
|  | parser = argparse.ArgumentParser() | ||||||
|  | parser.add_argument("--description", type=str, default="a person with purple hair", | ||||||
|  |                     help="the text that guides the editing/generation") | ||||||
|  | parser.add_argument("--ckpt", type=str, default="./pretrained_models/stylegan2-ffhq-config-f.pt", | ||||||
|  |                     help="pretrained StyleGAN2 weights") | ||||||
|  | parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution") | ||||||
|  | parser.add_argument("--lr_rampup", type=float, default=0.05) | ||||||
|  | parser.add_argument("--lr", type=float, default=0.1) | ||||||
|  | parser.add_argument("--step", type=int, default=300, help="number of optimization steps") | ||||||
|  | parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"], | ||||||
|  |                     help="choose between edit an image an generate a free one") | ||||||
|  | parser.add_argument("--l2_lambda", type=float, default=0.008, | ||||||
|  |                     help="weight of the latent distance (used for editing only)") | ||||||
|  | parser.add_argument("--latent_path", type=str, default="/home/ly/StyleCLIP-main/pretrained_models/latent_code/style3.pt", | ||||||
|  |                     help="starts the optimization from the given latent code if provided. Otherwise, starts from" | ||||||
|  |                          "the mean latent in a free generation, and from a random one in editing. " | ||||||
|  |                          "Expects a .pt format") | ||||||
|  | parser.add_argument("--truncation", type=float, default=0.7, | ||||||
|  |                     help="used only for the initial latent vector, and only when a latent code path is" | ||||||
|  |                          "not provided") | ||||||
|  | parser.add_argument("--save_intermediate_image_every", type=int, default=20, | ||||||
|  |                     help="if > 0 then saves intermidate results during the optimization") | ||||||
|  | parser.add_argument("--results_dir", type=str, default="results") | ||||||
|  | parser.add_argument('--work_in_stylespace', default=False, action='store_true', help="trains a mapper in S instead of W+") | ||||||
|  | parser.add_argument('--ir_se50_weights', default='pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss") | ||||||
|  | parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor') | ||||||
|  | 
 | ||||||
|  | args = vars(parser.parse_args()) | ||||||
|  | result_image = main(Namespace(**args)) | ||||||
|  | torchvision.utils.save_image(result_image.detach().cpu(), f"results/final_result.png", normalize=True, scale_each=True, | ||||||
|  |                              range=(-1, 1)) | ||||||
							
								
								
									
										27
									
								
								test002.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								test002.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | |||||||
|  | import torchvision | ||||||
|  | import argparse | ||||||
|  | from argparse import Namespace | ||||||
|  | from PIL import Image | ||||||
|  | 
 | ||||||
|  | from utils import ensure_checkpoint_exists | ||||||
|  | from mapper.scripts.inference import run | ||||||
|  | 
 | ||||||
|  | parser = argparse.ArgumentParser() | ||||||
|  | parser.add_argument('--exp_dir', default="./results", type=str, help='Path to experiment output directory') | ||||||
|  | parser.add_argument('--checkpoint_path', default="./pretrained_models/mapper/purple_hair.pt", type=str, | ||||||
|  |                     help='Path to model checkpoint') | ||||||
|  | parser.add_argument('--couple_outputs', default=True, action='store_true', | ||||||
|  |                     help='Whether to also save inputs + outputs side-by-side') | ||||||
|  | parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use') | ||||||
|  | parser.add_argument('--no_coarse_mapper', default=False, action="store_true") | ||||||
|  | parser.add_argument('--no_medium_mapper', default=False, action="store_true") | ||||||
|  | parser.add_argument('--no_fine_mapper', default=False, action="store_true") | ||||||
|  | parser.add_argument('--stylegan_size', default=1024, type=int) | ||||||
|  | parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference') | ||||||
|  | parser.add_argument('--latents_test_path', default="./latents_test/example_celebs.pt", type=str, | ||||||
|  |                     help="The latents for the validation") | ||||||
|  | parser.add_argument('--test_workers', default=0, type=int, help='Number of test/inference dataloader workers') | ||||||
|  | parser.add_argument('--n_images', type=int, default=None, help='Number of images to output. If None, run on all data') | ||||||
|  | 
 | ||||||
|  | args = vars(parser.parse_args()) | ||||||
|  | run(Namespace(**args)) | ||||||
							
								
								
									
										49
									
								
								utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,49 @@ | |||||||
|  | import os | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | google_drive_paths = { | ||||||
|  |     "stylegan2-ffhq-config-f.pt": "https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT", | ||||||
|  | 
 | ||||||
|  |     "mapper/pretrained/afro.pt": "https://drive.google.com/uc?id=1i5vAqo4z0I-Yon3FNft_YZOq7ClWayQJ", | ||||||
|  |     "mapper/pretrained/angry.pt": "https://drive.google.com/uc?id=1g82HEH0jFDrcbCtn3M22gesWKfzWV_ma", | ||||||
|  |     "mapper/pretrained/beyonce.pt": "https://drive.google.com/uc?id=1KJTc-h02LXs4zqCyo7pzCp0iWeO6T9fz", | ||||||
|  |     "mapper/pretrained/bobcut.pt": "https://drive.google.com/uc?id=1IvyqjZzKS-vNdq_OhwapAcwrxgLAY8UF", | ||||||
|  |     "mapper/pretrained/bowlcut.pt": "https://drive.google.com/uc?id=1xwdxI2YCewSt05dEHgkpmmzoauPjEnnZ", | ||||||
|  |     "mapper/pretrained/curly_hair.pt": "https://drive.google.com/uc?id=1xZ7fFB12Ci6rUbUfaHPpo44xUFzpWQ6M", | ||||||
|  |     "mapper/pretrained/depp.pt": "https://drive.google.com/uc?id=1FPiJkvFPG_y-bFanxLLP91wUKuy-l3IV", | ||||||
|  |     "mapper/pretrained/hilary_clinton.pt": "https://drive.google.com/uc?id=1X7U2zj2lt0KFifIsTfOOzVZXqYyCWVll", | ||||||
|  |     "mapper/pretrained/mohawk.pt": "https://drive.google.com/uc?id=1oMMPc8iQZ7dhyWavZ7VNWLwzf9aX4C09", | ||||||
|  |     "mapper/pretrained/purple_hair.pt": "https://drive.google.com/uc?id=14H0CGXWxePrrKIYmZnDD2Ccs65EEww75", | ||||||
|  |     "mapper/pretrained/surprised.pt": "https://drive.google.com/uc?id=1F-mPrhO-UeWrV1QYMZck63R43aLtPChI", | ||||||
|  |     "mapper/pretrained/taylor_swift.pt": "https://drive.google.com/uc?id=10jHuHsKKJxuf3N0vgQbX_SMEQgFHDrZa", | ||||||
|  |     "mapper/pretrained/trump.pt": "https://drive.google.com/uc?id=14v8D0uzy4tOyfBU3ca9T0AzTt3v-dNyh", | ||||||
|  |     "mapper/pretrained/zuckerberg.pt": "https://drive.google.com/uc?id=1NjDcMUL8G-pO3i_9N6EPpQNXeMc3Ar1r", | ||||||
|  | 
 | ||||||
|  |     "example_celebs.pt": "https://drive.google.com/uc?id=1VL3lP4avRhz75LxSza6jgDe-pHd2veQG" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def ensure_checkpoint_exists(model_weights_filename): | ||||||
|  |     if not os.path.isfile(model_weights_filename) and ( | ||||||
|  |         model_weights_filename in google_drive_paths | ||||||
|  |     ): | ||||||
|  |         gdrive_url = google_drive_paths[model_weights_filename] | ||||||
|  |         try: | ||||||
|  |             from gdown import download as drive_download | ||||||
|  | 
 | ||||||
|  |             drive_download(gdrive_url, model_weights_filename, quiet=False) | ||||||
|  |         except ModuleNotFoundError: | ||||||
|  |             print( | ||||||
|  |                 "gdown module not found.", | ||||||
|  |                 "pip3 install gdown or, manually download the checkpoint file:", | ||||||
|  |                 gdrive_url | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |     if not os.path.isfile(model_weights_filename) and ( | ||||||
|  |         model_weights_filename not in google_drive_paths | ||||||
|  |     ): | ||||||
|  |         print( | ||||||
|  |             model_weights_filename, | ||||||
|  |             " not found, you may need to manually download the model weights." | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user