first commit
This commit is contained in:
commit
4f9296236a
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 Or Patashnik, Zongze Wu
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
287
README.md
Normal file
287
README.md
Normal file
@ -0,0 +1,287 @@
|
||||
# StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery (ICCV 2021 Oral)
|
||||
|
||||
[Run this model on Replicate](https://replicate.ai/orpatashnik/styleclip)
|
||||
|
||||
Optimization: [](http://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/optimization_playground.ipynb)
|
||||
Mapper: [](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/mapper_playground.ipynb)
|
||||
|
||||
Global directions Torch: [](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/StyleCLIP_global_torch.ipynb)
|
||||
Global directions TF1: [](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/StyleCLIP_global.ipynb)
|
||||
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.youtube.com/watch?v=5icI0NgALnQ"><img src='https://github.com/orpatashnik/StyleCLIP/blob/main/img/StyleCLIP_gif.gif' width=600 ></a>
|
||||
|
||||
Full Demo Video: <a href="https://www.youtube.com/watch?v=5icI0NgALnQ"><img src="https://img.shields.io/badge/-YouTube-red?&style=for-the-badge&logo=youtube&logoColor=white" height=20></a> ICCV Video <a href="https://www.youtube.com/watch?v=PhR1gpXDu0w"><img src="https://img.shields.io/badge/-YouTube-red?&style=for-the-badge&logo=youtube&logoColor=white" height=20></a>
|
||||
|
||||
</p>
|
||||
|
||||
|
||||
|
||||

|
||||
|
||||
> **StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery**<br>
|
||||
> Or Patashnik*, Zongze Wu*, Eli Shechtman, Daniel Cohen-Or, Dani Lischinski <br>
|
||||
> *Equal contribution, ordered alphabetically <br>
|
||||
> https://arxiv.org/abs/2103.17249 <br>
|
||||
>
|
||||
>**Abstract:** Inspired by the ability of StyleGAN to generate highly realistic
|
||||
images in a variety of domains, much recent work has
|
||||
focused on understanding how to use the latent spaces of
|
||||
StyleGAN to manipulate generated and real images. However,
|
||||
discovering semantically meaningful latent manipulations
|
||||
typically involves painstaking human examination of
|
||||
the many degrees of freedom, or an annotated collection
|
||||
of images for each desired manipulation. In this work, we
|
||||
explore leveraging the power of recently introduced Contrastive
|
||||
Language-Image Pre-training (CLIP) models in order
|
||||
to develop a text-based interface for StyleGAN image
|
||||
manipulation that does not require such manual effort. We
|
||||
first introduce an optimization scheme that utilizes a CLIP-based
|
||||
loss to modify an input latent vector in response to a
|
||||
user-provided text prompt. Next, we describe a latent mapper
|
||||
that infers a text-guided latent manipulation step for
|
||||
a given input image, allowing faster and more stable textbased
|
||||
manipulation. Finally, we present a method for mapping
|
||||
a text prompts to input-agnostic directions in StyleGAN’s
|
||||
style space, enabling interactive text-driven image
|
||||
manipulation. Extensive results and comparisons demonstrate
|
||||
the effectiveness of our approaches.
|
||||
|
||||
|
||||
## Description
|
||||
Official Implementation of StyleCLIP, a method to manipulate images using a driving text.
|
||||
Our method uses the generative power of a pretrained StyleGAN generator, and the visual-language power of CLIP.
|
||||
In the paper we present three methods:
|
||||
- Latent vector optimization.
|
||||
- Latent mapper, trained to manipulate latent vectors according to a specific text description.
|
||||
- Global directions in the StyleSpace.
|
||||
|
||||
|
||||
## Updates
|
||||
**31/10/2022** Add support for global direction with torch implementation
|
||||
|
||||
**15/8/2021** Add support for StyleSpace in optimization and latent mapper methods
|
||||
|
||||
**6/4/2021** Add mapper training and inference (including a jupyter notebook) code
|
||||
|
||||
**6/4/2021** Add support for custom StyleGAN2 and StyleGAN2-ada models, and also custom images
|
||||
|
||||
**2/4/2021** Add the global directions code (a local GUI and a colab notebook)
|
||||
|
||||
**31/3/2021** Upload paper to arxiv, and video to YouTube
|
||||
|
||||
**14/2/2021** Initial version
|
||||
|
||||
## Setup (for all three methods)
|
||||
For all the methods described in the paper, is it required to have:
|
||||
- Anaconda
|
||||
- [CLIP](https://github.com/openai/CLIP)
|
||||
|
||||
Specific requirements for each method are described in its section.
|
||||
To install CLIP please run the following commands:
|
||||
```shell script
|
||||
conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=<CUDA_VERSION>
|
||||
pip install ftfy regex tqdm gdown
|
||||
pip install git+https://github.com/openai/CLIP.git
|
||||
```
|
||||
|
||||
|
||||
## Editing via Latent Vector Optimization
|
||||
|
||||
### Setup
|
||||
|
||||
Here, the code relies on the [Rosinality](https://github.com/rosinality/stylegan2-pytorch/) pytorch implementation of StyleGAN2.
|
||||
Some parts of the StyleGAN implementation were modified, so that the whole implementation is native pytorch.
|
||||
|
||||
In addition to the requirements mentioned before, a pretrained StyleGAN2 generator will attempt to be downloaded, (or manually download from [here](https://drive.google.com/file/d/1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT/view?usp=sharing)).
|
||||
|
||||
### Usage
|
||||
|
||||
Given a textual description, one can both edit a given image, or generate a random image that best fits to the description.
|
||||
Both operations can be done through the `main.py` script, or the `optimization_playground.ipynb` notebook ([](http://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/optimization_playground.ipynb)).
|
||||
|
||||
#### Editing
|
||||
To edit an image set `--mode=edit`. Editing can be done on both provided latent vector, and on a random latent vector from StyleGAN's latent space.
|
||||
It is recommended to adjust the `--l2_lambda` according to the desired edit.
|
||||
|
||||
#### Generating Free-style Images
|
||||
To generate a free-style image set `--mode=free_generation`.
|
||||
|
||||
## Editing via Latent Mapper
|
||||
Here, we provide the code for the latent mapper. The mapper is trained to learn *residuals* from a given latent vector, according to the driving text.
|
||||
The code for the mapper is in `mapper/`.
|
||||
|
||||
### Setup
|
||||
As in the optimization, the code relies on [Rosinality](https://github.com/rosinality/stylegan2-pytorch/) pytorch implementation of StyleGAN2.
|
||||
In addition the the StyleGAN weights, it is neccessary to have weights for the facial recognition network used in the ID loss.
|
||||
The weights can be downloaded from [here](https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing).
|
||||
|
||||
The mapper is trained on latent vectors. It is recommended to train on *inverted real images*.
|
||||
To this end, we provide the CelebA-HQ that was inverted by e4e:
|
||||
[train set](https://drive.google.com/file/d/1gof8kYc_gDLUT4wQlmUdAtPnQIlCO26q/view?usp=sharing), [test set](https://drive.google.com/file/d/1j7RIfmrCoisxx3t-r-KC02Qc8barBecr/view?usp=sharing).
|
||||
|
||||
### Usage
|
||||
|
||||
#### Training
|
||||
- The main training script is placed in `mapper/scripts/train.py`.
|
||||
- Training arguments can be found at `mapper/options/train_options.py`.
|
||||
- Intermediate training results are saved to opts.exp_dir. This includes checkpoints, train outputs, and test outputs.
|
||||
Additionally, if you have tensorboard installed, you can visualize tensorboard logs in opts.exp_dir/logs.
|
||||
Note that
|
||||
- To resume a training, please provide `--checkpoint_path`.
|
||||
- `--description` is where you provide the driving text.
|
||||
- If you perform an edit that is not supposed to change "colors" in the image, it is recommended to use the flag `--no_fine_mapper`.
|
||||
|
||||
Example for training a mapper for the moahwk hairstyle:
|
||||
```bash
|
||||
cd mapper
|
||||
python train.py --exp_dir ../results/mohawk_hairstyle --no_fine_mapper --description "mohawk hairstyle"
|
||||
```
|
||||
All configurations for the examples shown in the paper are provided there.
|
||||
|
||||
#### Inference
|
||||
- The main inferece script is placed in `mapper/scripts/inference.py`.
|
||||
- Inference arguments can be found at `mapper/options/test_options.py`.
|
||||
- Adding the flag `--couple_outputs` will save image containing the input and output images side-by-side.
|
||||
|
||||
Pretrained models for variuos edits are provided. Please refer to `utils.py` for the complete links list.
|
||||
|
||||
We also provide a notebook for performing inference with the mapper Mapper notebook: [](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/mapper_playground.ipynb)
|
||||
|
||||
## Editing via Global Direction
|
||||
|
||||
Here we provide GUI for editing images with the global directions.
|
||||
We provide both a jupyter notebook [](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/StyleCLIP_global.ipynb),
|
||||
and the GUI used in the [video](https://www.youtube.com/watch?v=5icI0NgALnQ).
|
||||
For both, the linear direction are computed in **real time**.
|
||||
The code is located at `global_directions/`.
|
||||
|
||||
|
||||
### Setup
|
||||
Here, we rely on the [official](https://github.com/NVlabs/stylegan2) TensorFlow implementation of StyleGAN2.
|
||||
|
||||
It is required to have TensorFlow, version 1.14 or 1.15 (`conda install -c anaconda tensorflow-gpu==1.14`).
|
||||
|
||||
### Usage
|
||||
|
||||
|
||||
#### Local GUI
|
||||
|
||||
To start the local GUI please run the following commands:
|
||||
|
||||
```shell script
|
||||
cd global_directions
|
||||
|
||||
# input dataset name
|
||||
dataset_name='ffhq'
|
||||
|
||||
# pretrained StyleGAN2 model from standard [NVlabs implementation](https://github.com/NVlabs/stylegan2) will be download automatically.
|
||||
# pretrained StyleGAN2-ada model could be download from https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ .
|
||||
# for custom StyleGAN2 or StyleGAN2-ada model, please place the model under ./StyleCLIP/global_directions/model/ folder.
|
||||
|
||||
|
||||
# input prepare data
|
||||
python GetCode.py --dataset_name $dataset_name --code_type 'w'
|
||||
python GetCode.py --dataset_name $dataset_name --code_type 's'
|
||||
python GetCode.py --dataset_name $dataset_name --code_type 's_mean_std'
|
||||
|
||||
# preprocess (this may take a few hours).
|
||||
# we precompute the results for StyleGAN2 on ffhq, StyleGAN2-ada on afhqdog, afhqcat. For these model, we can skip the preprocess step.
|
||||
python SingleChannel.py --dataset_name $dataset_name
|
||||
|
||||
# generated image to be manipulated
|
||||
# this operation will generate and replace the w_plu.npy and .jpg images in './data/dataset_name/' folder.
|
||||
# if you you want to keep the original data, please rename the original folder.
|
||||
# to use custom images, please use e4e encoder to generate latents.pt, and place it in './data/dataset_name/' folder, and add --real flag while running this function.
|
||||
# you may skip this step if you want to manipulate the real human faces we prepare in ./data/ffhq/ folder.
|
||||
python GetGUIData.py --dataset_name $dataset_name
|
||||
|
||||
# interactively manipulation
|
||||
python PlayInteractively.py --dataset_name $dataset_name
|
||||
```
|
||||
|
||||
As shown in the video, to edit an image it is requires to write a _neutral text_ and a _target text_.
|
||||
To operate the GUI, please do the following:
|
||||
- Maximize the window size
|
||||
- Double click on the left square to choose an image. The images are taken from `global_directions/data/ffhq`, and the corresponding latent vectors are in `global_directions/data/ffhq/w_plus.npy`.
|
||||
- Type a neutral text, then press enter
|
||||
- Modify the target text so that it will contain the target edit, then press enter.
|
||||
|
||||
You can now play with:
|
||||
- *Manipulation strength* - positive values correspond to moving along the target direction.
|
||||
- *Disentanglement threshold* - large value means more disentangled edit, just a few channels will be manipulated so only the target attribute will change (for example, grey hair). Small value means less disentangled edit, a large number of channels will be manipulated, related attributes will also change (such as wrinkle, skin color, glasses).
|
||||
|
||||
##### Examples:
|
||||
|
||||
| Edit | Neutral Text | Target Text |
|
||||
| --- | --- | --- |
|
||||
| Smile | face | smiling face |
|
||||
| Gender | female face | male face |
|
||||
| Blonde hair | face with hair | face with blonde hair |
|
||||
| Hi-top fade | face with hair | face with Hi-top fade hair |
|
||||
| Blue eyes | face with eyes | face with blue eyes |
|
||||
|
||||
More examples could be found in the [video](https://www.youtube.com/watch?v=5icI0NgALnQ) and in the paper.
|
||||
|
||||
|
||||
##### Pratice Tips:
|
||||
In the terminal, for every manipulation, the number of channels being manipulated is printed (the number is controlled by the attribute (neutral, target) and the disentanglement threshold).
|
||||
|
||||
1. For color transformation, usually 10-20 channels is enough. For large structure change (for example, Hi-top fade), usually 100-200 channels are required.
|
||||
2. For an attribute (neutral, target), if you give a low disentanglement threshold, there are just few channels (<20) being manipulated, and usually it is not enough for performing the desired edit.
|
||||
|
||||
|
||||
#### Notebook
|
||||
Open the notebook in colab and run all the cells. In the last cell you can play with the image.
|
||||
|
||||
`beta` corresponds to the *disentanglement threshold*, and `alpha` to the *manipulation strength*.
|
||||
|
||||
After you set the desired set of parameters, please run again the last cell to generate the image.
|
||||
|
||||
## Editing Examples
|
||||
|
||||
In the following, we show some results obtained with our methods.
|
||||
All images are real, and were inverted into the StyleGAN's latent space using [e4e](https://github.com/omertov/encoder4editing).
|
||||
The driving text that was used for each edit appears below or above each image.
|
||||
|
||||
#### Latent Optimization
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
#### Latent Mapper
|
||||
|
||||

|
||||
|
||||
#### Global Directions
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
## Related Works
|
||||
|
||||
The global directions we find for editing are direction in the _S Space_, which was introduced and analyzed in [StyleSpace](https://arxiv.org/abs/2011.12799) (Wu et al).
|
||||
|
||||
To edit real images, we inverted them to the StyleGAN's latent space using [e4e](https://arxiv.org/abs/2102.02766) (Tov et al.).
|
||||
|
||||
The code strcuture of the mapper is heavily based on [pSp](https://github.com/eladrich/pixel2style2pixel).
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this code for your research, please cite our paper:
|
||||
|
||||
```
|
||||
@InProceedings{Patashnik_2021_ICCV,
|
||||
author = {Patashnik, Or and Wu, Zongze and Shechtman, Eli and Cohen-Or, Daniel and Lischinski, Dani},
|
||||
title = {StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery},
|
||||
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
|
||||
month = {October},
|
||||
year = {2021},
|
||||
pages = {2085-2094}
|
||||
}
|
||||
```
|
||||
34
cog.yaml
Normal file
34
cog.yaml
Normal file
@ -0,0 +1,34 @@
|
||||
build:
|
||||
gpu: true
|
||||
system_packages:
|
||||
- libgl1-mesa-glx
|
||||
- libglib2.0-0
|
||||
- cmake
|
||||
- zip
|
||||
python_version: 3.7
|
||||
python_packages:
|
||||
- torch==1.7.1
|
||||
- tensorflow==1.15.0
|
||||
- torchvision==0.8.2
|
||||
- torchaudio==0.7.2
|
||||
- ftfy==5.9
|
||||
- regex==2021.4.4
|
||||
- tqdm==4.59.0
|
||||
- requests==2.25.1
|
||||
- matplotlib==3.4.1
|
||||
- opencv-python==4.3.0.38
|
||||
- dlib==19.18.0
|
||||
- scipy==1.6.3
|
||||
- "git+git://github.com/openai/CLIP.git@8a665a683d791ed3491fedadcb3c91878f9eb78d"
|
||||
pre_install:
|
||||
- "mkdir /content"
|
||||
- "git clone https://github.com/omertov/encoder4editing.git /content/encoder4editing"
|
||||
- "wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip"
|
||||
- "unzip ninja-linux.zip -d /usr/local/bin/"
|
||||
- "update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force"
|
||||
- "wget -O /content/shape_predictor_68_face_landmarks.dat.bz2 http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2"
|
||||
- "cd /content && bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2"
|
||||
- "echo > /content/encoder4editing/__init__.py"
|
||||
- |
|
||||
sed -i 's/img = PIL.Image.open(filepath)/img = PIL.Image.open(filepath).convert(\"RGB\")/' /content/encoder4editing/utils/alignment.py
|
||||
predict: cog_predict.py:Predictor
|
||||
196
cog_predict.py
Normal file
196
cog_predict.py
Normal file
@ -0,0 +1,196 @@
|
||||
import copy
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import clip
|
||||
import cog
|
||||
import dlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
|
||||
sys.path.insert(0, "/content")
|
||||
sys.path.insert(0, "/content/encoder4editing")
|
||||
|
||||
from encoder4editing.models.psp import pSp
|
||||
from encoder4editing.utils.alignment import align_face
|
||||
from encoder4editing.utils.common import tensor2im
|
||||
|
||||
os.chdir("global_directions")
|
||||
sys.path.insert(0, ".")
|
||||
|
||||
from dnnlib import tflib
|
||||
from manipulate import Manipulator
|
||||
from MapTS import GetBoundary, GetDt, GetFs
|
||||
|
||||
class Predictor(cog.Predictor):
|
||||
def setup(self):
|
||||
|
||||
print("starting setup")
|
||||
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.model, self.preprocess = clip.load(
|
||||
"ViT-B/32", device=self.device, jit=False
|
||||
)
|
||||
|
||||
self.graph = tf.get_default_graph()
|
||||
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
|
||||
self.sess = tf.Session(
|
||||
graph=self.graph, config=tf.ConfigProto(gpu_options=gpu_options)
|
||||
)
|
||||
|
||||
self.experiment_args = {"model_path": "e4e_ffhq_encode.pt"}
|
||||
self.experiment_args["transform"] = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
]
|
||||
)
|
||||
self.resize_dims = (256, 256)
|
||||
|
||||
model_path = self.experiment_args["model_path"]
|
||||
|
||||
ckpt = torch.load(model_path, map_location="cpu")
|
||||
opts = ckpt["opts"]
|
||||
# pprint.pprint(opts) # Display full options used
|
||||
# update the training options
|
||||
opts["checkpoint_path"] = model_path
|
||||
opts = Namespace(**opts)
|
||||
|
||||
self.net = pSp(opts)
|
||||
self.net.eval()
|
||||
self.net.cuda()
|
||||
|
||||
self.shape_predictor = dlib.shape_predictor(
|
||||
"/content/shape_predictor_68_face_landmarks.dat"
|
||||
)
|
||||
|
||||
with self.graph.as_default(), self.sess.as_default():
|
||||
#tflib.init_tf()
|
||||
|
||||
self.M = Manipulator(dataset_name="ffhq", sess=self.sess)
|
||||
self.fs3 = np.load("npy/ffhq/fs3.npy")
|
||||
np.set_printoptions(suppress=True)
|
||||
|
||||
print("setup complete")
|
||||
|
||||
@cog.input("input", type=Path, help="Input image")
|
||||
@cog.input("neutral", type=str, help="Neutral image description")
|
||||
@cog.input("target", type=str, help="Target image description")
|
||||
@cog.input(
|
||||
"manipulation_strength",
|
||||
type=float,
|
||||
min=-10,
|
||||
max=10,
|
||||
default=4.1,
|
||||
help="The higher the manipulation strength, the closer the generated image becomes to the target description. Negative values moves the generated image further from the target description",
|
||||
)
|
||||
@cog.input(
|
||||
"disentanglement_threshold",
|
||||
type=float,
|
||||
min=0.08,
|
||||
max=0.3,
|
||||
default=0.15,
|
||||
help="The higher the disentanglement threshold, the more specific the changes are to the target attribute. Lower values mean that broader changes are made to the input image",
|
||||
)
|
||||
def predict(
|
||||
self,
|
||||
input,
|
||||
neutral,
|
||||
target,
|
||||
manipulation_strength,
|
||||
disentanglement_threshold,
|
||||
):
|
||||
|
||||
# @title Align image
|
||||
#original_image = Image.open(str(input))
|
||||
#original_image = original_image.convert("RGB")
|
||||
input_image = self.run_alignment(str(input))
|
||||
#input_image = original_image
|
||||
input_image = input_image.resize(self.resize_dims)
|
||||
|
||||
img_transforms = self.experiment_args["transform"]
|
||||
transformed_image = img_transforms(input_image)
|
||||
|
||||
with torch.no_grad():
|
||||
images, latents = self.run_on_batch(transformed_image.unsqueeze(0))
|
||||
result_image, latent = images[0], latents[0]
|
||||
|
||||
print("latents", latents)
|
||||
|
||||
print(transformed_image.shape, result_image.shape)
|
||||
|
||||
w_plus = latents.cpu().detach().numpy()
|
||||
with self.graph.as_default(), self.sess.as_default():
|
||||
dlatents_loaded = self.M.W2S(w_plus)
|
||||
|
||||
#print("w_plus, dlatents_loaded", w_plus, dlatents_loaded)
|
||||
|
||||
img_index = 0
|
||||
w_plus=latents.cpu().detach().numpy()
|
||||
with self.graph.as_default(), self.sess.as_default():
|
||||
dlatents_loaded=self.M.W2S(w_plus)
|
||||
|
||||
img_indexs=[img_index]
|
||||
dlatent_tmp=[tmp[img_indexs] for tmp in dlatents_loaded]
|
||||
with self.graph.as_default(), self.sess.as_default():
|
||||
self.M.num_images = len(img_indexs)
|
||||
self.M.alpha = [0]
|
||||
self.M.manipulate_layers = [0]
|
||||
|
||||
with self.graph.as_default(), self.sess.as_default():
|
||||
codes, out = self.M.EditOneC(0, dlatent_tmp)
|
||||
|
||||
original = Image.fromarray(out[0, 0]).resize((512, 512))
|
||||
with self.graph.as_default(), self.sess.as_default():
|
||||
self.M.manipulate_layers = None
|
||||
|
||||
classnames = [target, neutral]
|
||||
dt = GetDt(classnames, self.model)
|
||||
|
||||
with self.graph.as_default(), self.sess.as_default():
|
||||
self.M.alpha = [manipulation_strength]
|
||||
boundary_tmp2, c = GetBoundary(
|
||||
self.fs3, dt, self.M, threshold=disentanglement_threshold
|
||||
)
|
||||
codes = self.M.MSCode(dlatent_tmp, boundary_tmp2)
|
||||
out = self.M.GenerateImg(codes)
|
||||
generated = Image.fromarray(out[0, 0]) # .resize((512,512))
|
||||
|
||||
out_path = Path(tempfile.mkdtemp()) / "out.jpg"
|
||||
generated.save(str(out_path))
|
||||
|
||||
return out_path
|
||||
|
||||
def run_alignment(self, image_path):
|
||||
aligned_image = align_face(filepath=image_path, predictor=self.shape_predictor)
|
||||
print("Aligned image has shape: {}".format(aligned_image.size))
|
||||
return aligned_image
|
||||
|
||||
def run_on_batch(self, inputs):
|
||||
images, latents = self.net(
|
||||
inputs.to("cuda").float(), randomize_noise=False, return_latents=True
|
||||
)
|
||||
return images, latents
|
||||
|
||||
|
||||
def concat_images(*images):
|
||||
width = 0
|
||||
for im in images:
|
||||
width += im.width
|
||||
height = max([im.height for im in images])
|
||||
concat = Image.new("RGB", (width, height))
|
||||
offset = 0
|
||||
for im in images:
|
||||
concat.paste(im, (offset, 0))
|
||||
offset += im.width
|
||||
return concat
|
||||
0
criteria/__init__.py
Normal file
0
criteria/__init__.py
Normal file
17
criteria/clip_loss.py
Normal file
17
criteria/clip_loss.py
Normal file
@ -0,0 +1,17 @@
|
||||
|
||||
import torch
|
||||
import clip
|
||||
|
||||
|
||||
class CLIPLoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, opts):
|
||||
super(CLIPLoss, self).__init__()
|
||||
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
|
||||
self.upsample = torch.nn.Upsample(scale_factor=7)
|
||||
self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
|
||||
|
||||
def forward(self, image, text):
|
||||
image = self.avg_pool(self.upsample(image))
|
||||
similarity = 1 - self.model(image, text)[0] / 100
|
||||
return similarity
|
||||
40
criteria/id_loss.py
Normal file
40
criteria/id_loss.py
Normal file
@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from models.facial_recognition.model_irse import Backbone
|
||||
|
||||
|
||||
class IDLoss(nn.Module):
|
||||
def __init__(self, opts):
|
||||
super(IDLoss, self).__init__()
|
||||
print('Loading ResNet ArcFace')
|
||||
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
||||
self.facenet.load_state_dict(torch.load(opts.ir_se50_weights))
|
||||
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
||||
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
||||
self.facenet.eval()
|
||||
self.facenet.cuda()
|
||||
self.opts = opts
|
||||
|
||||
def extract_feats(self, x):
|
||||
if x.shape[2] != 256:
|
||||
x = self.pool(x)
|
||||
x = x[:, :, 35:223, 32:220] # Crop interesting region
|
||||
x = self.face_pool(x)
|
||||
x_feats = self.facenet(x)
|
||||
return x_feats
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
n_samples = y.shape[0]
|
||||
y_feats = self.extract_feats(y) # Otherwise use the feature from there
|
||||
y_hat_feats = self.extract_feats(y_hat)
|
||||
y_feats = y_feats.detach()
|
||||
loss = 0
|
||||
sim_improvement = 0
|
||||
count = 0
|
||||
for i in range(n_samples):
|
||||
diff_target = y_hat_feats[i].dot(y_feats[i])
|
||||
loss += 1 - diff_target
|
||||
count += 1
|
||||
|
||||
return loss / count, sim_improvement / count
|
||||
127
global_torch/SingleChannel.py
Normal file
127
global_torch/SingleChannel.py
Normal file
@ -0,0 +1,127 @@
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
import copy
|
||||
from manipulate import Manipulator
|
||||
import argparse
|
||||
|
||||
import sys
|
||||
sys.path.append('/cs/labs/danix/wuzongze/Tansformer_Manipulation/CLIP/')
|
||||
import clip
|
||||
|
||||
def GetImgF(out,model,preprocess):
|
||||
imgs=out
|
||||
imgs1=imgs.reshape([-1]+list(imgs.shape[2:]))
|
||||
|
||||
tmp=[]
|
||||
for i in range(len(imgs1)):
|
||||
|
||||
img=Image.fromarray(imgs1[i])
|
||||
image = preprocess(img).unsqueeze(0).to(device)
|
||||
tmp.append(image)
|
||||
|
||||
image=torch.cat(tmp)
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(image)
|
||||
|
||||
image_features1=image_features.cpu().numpy()
|
||||
image_features1=image_features1.reshape(list(imgs.shape[:2])+[512])
|
||||
|
||||
return image_features1
|
||||
|
||||
def GetFs(fs):
|
||||
tmp=np.linalg.norm(fs,axis=-1)
|
||||
fs1=fs/tmp[:,:,:,None]
|
||||
fs2=fs1[:,:,1,:]-fs1[:,:,0,:] # 5*sigma - (-5)* sigma
|
||||
fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None]
|
||||
fs3=fs3.mean(axis=1)
|
||||
fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None]
|
||||
return fs3
|
||||
|
||||
#%%
|
||||
if __name__ == "__main__":
|
||||
'''
|
||||
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||
|
||||
parser.add_argument('--dataset_name',type=str,default='cat',
|
||||
help='name of dataset, for example, ffhq')
|
||||
args = parser.parse_args()
|
||||
dataset_name=args.dataset_name
|
||||
'''
|
||||
#%%
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model, preprocess = clip.load("ViT-B/32", device=device,jit=False)
|
||||
#%%
|
||||
|
||||
network_pkl='/cs/labs/danix/wuzongze/Gan_Manipulation/stylegan2/model/stylegan2-human-config-f.pkl'
|
||||
device = torch.device('cuda')
|
||||
M=Manipulator()
|
||||
M.device=device
|
||||
G=M.LoadModel(network_pkl,device)
|
||||
M.G=G
|
||||
M.SetGParameters()
|
||||
num_img=100_000
|
||||
M.GenerateS(num_img=num_img)
|
||||
M.GetCodeMS()
|
||||
|
||||
# M=Manipulator(dataset_name=dataset_name)
|
||||
np.set_printoptions(suppress=True)
|
||||
# print(M.dataset_name)
|
||||
#%%
|
||||
img_sindex=0
|
||||
num_images=100
|
||||
dlatents_o=[]
|
||||
tmp=img_sindex*num_images
|
||||
for i in range(len(M.dlatents)):
|
||||
tmp1=M.dlatents[i][tmp:(tmp+num_images)]
|
||||
dlatents_o.append(tmp1)
|
||||
#%%
|
||||
|
||||
all_f=[]
|
||||
M.alpha=[-5,5] #ffhq 5
|
||||
M.step=2
|
||||
M.num_images=num_images
|
||||
select=np.array(M.mindexs)<=16 #below or equal to 128 resolution
|
||||
mindexs2=np.array(M.mindexs)[select]
|
||||
for lindex in mindexs2: #ignore ToRGB layers
|
||||
print(lindex)
|
||||
num_c=M.dlatents[lindex].shape[1]
|
||||
for cindex in range(num_c):
|
||||
|
||||
M.dlatents=copy.copy(dlatents_o)
|
||||
M.dlatents[lindex][:,cindex]=M.code_mean[lindex][cindex]
|
||||
|
||||
M.manipulate_layers=[lindex]
|
||||
codes,out=M.EditOneC(cindex)
|
||||
image_features1=GetImgF(out,model,preprocess)
|
||||
all_f.append(image_features1)
|
||||
|
||||
all_f=np.array(all_f)
|
||||
|
||||
fs3=GetFs(all_f)
|
||||
|
||||
#%%
|
||||
# file_path='./npy/'+M.dataset_name+'/'
|
||||
file_path='/cs/labs/danix/wuzongze/Gan_Manipulation/stylegan2/results/npy/human/'
|
||||
np.save(file_path+'fs3',fs3)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
246
global_torch/StyleCLIP.py
Normal file
246
global_torch/StyleCLIP.py
Normal file
@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Tue Jun 14 09:40:28 2022
|
||||
|
||||
@author: wuzongze
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
import pickle
|
||||
import copy
|
||||
import matplotlib.pyplot as plt
|
||||
from manipulate import Manipulator
|
||||
|
||||
import clip
|
||||
|
||||
|
||||
def SplitS(ds_p,M,if_std):
|
||||
all_ds=[]
|
||||
start=0
|
||||
for i in M.mindexs:
|
||||
tmp=M.dlatents[i].shape[1]
|
||||
end=start+tmp
|
||||
tmp=ds_p[start:end]
|
||||
# tmp=tmp*M.code_std[i]
|
||||
|
||||
all_ds.append(tmp)
|
||||
start=end
|
||||
|
||||
all_ds2=[]
|
||||
tmp_index=0
|
||||
for i in range(len(M.s_names)):
|
||||
if (not 'RGB' in M.s_names[i]) and (not len(all_ds[tmp_index])==0):
|
||||
|
||||
if if_std:
|
||||
tmp=all_ds[tmp_index]*M.code_std[i]
|
||||
else:
|
||||
tmp=all_ds[tmp_index]
|
||||
|
||||
all_ds2.append(tmp)
|
||||
tmp_index+=1
|
||||
else:
|
||||
tmp=np.zeros(len(M.dlatents[i][0]))
|
||||
all_ds2.append(tmp)
|
||||
return all_ds2
|
||||
|
||||
|
||||
imagenet_templates = [
|
||||
'a bad photo of a {}.',
|
||||
# 'a photo of many {}.',
|
||||
'a sculpture of a {}.',
|
||||
'a photo of the hard to see {}.',
|
||||
'a low resolution photo of the {}.',
|
||||
'a rendering of a {}.',
|
||||
'graffiti of a {}.',
|
||||
'a bad photo of the {}.',
|
||||
'a cropped photo of the {}.',
|
||||
'a tattoo of a {}.',
|
||||
'the embroidered {}.',
|
||||
'a photo of a hard to see {}.',
|
||||
'a bright photo of a {}.',
|
||||
'a photo of a clean {}.',
|
||||
'a photo of a dirty {}.',
|
||||
'a dark photo of the {}.',
|
||||
'a drawing of a {}.',
|
||||
'a photo of my {}.',
|
||||
'the plastic {}.',
|
||||
'a photo of the cool {}.',
|
||||
'a close-up photo of a {}.',
|
||||
'a black and white photo of the {}.',
|
||||
'a painting of the {}.',
|
||||
'a painting of a {}.',
|
||||
'a pixelated photo of the {}.',
|
||||
'a sculpture of the {}.',
|
||||
'a bright photo of the {}.',
|
||||
'a cropped photo of a {}.',
|
||||
'a plastic {}.',
|
||||
'a photo of the dirty {}.',
|
||||
'a jpeg corrupted photo of a {}.',
|
||||
'a blurry photo of the {}.',
|
||||
'a photo of the {}.',
|
||||
'a good photo of the {}.',
|
||||
'a rendering of the {}.',
|
||||
'a {} in a video game.',
|
||||
'a photo of one {}.',
|
||||
'a doodle of a {}.',
|
||||
'a close-up photo of the {}.',
|
||||
'a photo of a {}.',
|
||||
'the origami {}.',
|
||||
'the {} in a video game.',
|
||||
'a sketch of a {}.',
|
||||
'a doodle of the {}.',
|
||||
'a origami {}.',
|
||||
'a low resolution photo of a {}.',
|
||||
'the toy {}.',
|
||||
'a rendition of the {}.',
|
||||
'a photo of the clean {}.',
|
||||
'a photo of a large {}.',
|
||||
'a rendition of a {}.',
|
||||
'a photo of a nice {}.',
|
||||
'a photo of a weird {}.',
|
||||
'a blurry photo of a {}.',
|
||||
'a cartoon {}.',
|
||||
'art of a {}.',
|
||||
'a sketch of the {}.',
|
||||
'a embroidered {}.',
|
||||
'a pixelated photo of a {}.',
|
||||
'itap of the {}.',
|
||||
'a jpeg corrupted photo of the {}.',
|
||||
'a good photo of a {}.',
|
||||
'a plushie {}.',
|
||||
'a photo of the nice {}.',
|
||||
'a photo of the small {}.',
|
||||
'a photo of the weird {}.',
|
||||
'the cartoon {}.',
|
||||
'art of the {}.',
|
||||
'a drawing of the {}.',
|
||||
'a photo of the large {}.',
|
||||
'a black and white photo of a {}.',
|
||||
'the plushie {}.',
|
||||
'a dark photo of a {}.',
|
||||
'itap of a {}.',
|
||||
'graffiti of the {}.',
|
||||
'a toy {}.',
|
||||
'itap of my {}.',
|
||||
'a photo of a cool {}.',
|
||||
'a photo of a small {}.',
|
||||
'a tattoo of the {}.',
|
||||
]
|
||||
|
||||
|
||||
def zeroshot_classifier(classnames, templates,model):
|
||||
with torch.no_grad():
|
||||
zeroshot_weights = []
|
||||
for classname in classnames:
|
||||
texts = [template.format(classname) for template in templates] #format with class
|
||||
texts = clip.tokenize(texts).cuda() #tokenize
|
||||
class_embeddings = model.encode_text(texts) #embed with text encoder
|
||||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||||
class_embedding = class_embeddings.mean(dim=0)
|
||||
class_embedding /= class_embedding.norm()
|
||||
zeroshot_weights.append(class_embedding)
|
||||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||||
return zeroshot_weights
|
||||
|
||||
|
||||
def GetDt(classnames,model):
|
||||
text_features=zeroshot_classifier(classnames, imagenet_templates,model).t()
|
||||
|
||||
dt=text_features[0]-text_features[1]
|
||||
dt=dt.cpu().numpy()
|
||||
|
||||
|
||||
print(np.linalg.norm(dt))
|
||||
dt=dt/np.linalg.norm(dt)
|
||||
return dt
|
||||
|
||||
|
||||
def GetBoundary(fs3,dt,M,threshold):
|
||||
tmp=np.dot(fs3,dt)
|
||||
|
||||
ds_imp=copy.copy(tmp)
|
||||
select=np.abs(tmp)<threshold
|
||||
num_c=np.sum(~select)
|
||||
|
||||
|
||||
ds_imp[select]=0
|
||||
tmp=np.abs(ds_imp).max()
|
||||
ds_imp/=tmp
|
||||
|
||||
boundary_tmp2=SplitS(ds_imp,M,if_std=True)
|
||||
print('num of channels being manipulated:',num_c)
|
||||
return boundary_tmp2,num_c
|
||||
|
||||
#%%
|
||||
if __name__ == "__main__":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model, preprocess = clip.load("ViT-B/32", device=device,jit=False)
|
||||
|
||||
# pls download the checkpoint from https://drive.google.com/file/d/1FlAb1rYa0r_--Zj_ML8e6shmaF28hQb5/view
|
||||
network_pkl='/cs/labs/danix/wuzongze/Gan_Manipulation/stylegan2/model/stylegan2-human-config-f.pkl'
|
||||
device = torch.device('cuda')
|
||||
M=Manipulator()
|
||||
M.device=device
|
||||
G=M.LoadModel(network_pkl,device)
|
||||
M.G=G
|
||||
M.SetGParameters()
|
||||
num_img=100_000
|
||||
M.GenerateS(num_img=num_img)
|
||||
M.GetCodeMS()
|
||||
np.set_printoptions(suppress=True)
|
||||
#%%
|
||||
file_path='./npy/human/'
|
||||
fs3=np.load(file_path+'fs3.npy')
|
||||
#%%
|
||||
img_indexs=np.arange(20)
|
||||
|
||||
dlatent_tmp=[tmp[img_indexs] for tmp in M.dlatents]
|
||||
M.num_images=len(img_indexs)
|
||||
#%%
|
||||
|
||||
paras=[
|
||||
['person', 'original', 0, 0],
|
||||
['woman', 'man', 0.2, 3],
|
||||
['person', 'person with T-shirt', 0.15, 4],
|
||||
['person', 'person with jeans', 0.15, 4],
|
||||
['person', 'person with jacket', 0.15, 4],
|
||||
]
|
||||
paras=np.array(paras)
|
||||
#%%
|
||||
|
||||
M.step=1
|
||||
|
||||
|
||||
imgs=[]
|
||||
all_b=[]
|
||||
for i in range(len(paras)):
|
||||
|
||||
neutral,target,beta,alpha=paras[i]
|
||||
beta=np.float32(beta)
|
||||
alpha=np.float32(alpha)
|
||||
M.alpha=[alpha]
|
||||
print()
|
||||
print(target)
|
||||
classnames=[target,neutral]
|
||||
dt=GetDt(classnames,model)
|
||||
boundary_tmp2,num_c=GetBoundary(fs3,dt,M,threshold=beta)
|
||||
all_b.append(boundary_tmp2)
|
||||
codes=M.MSCode(dlatent_tmp,boundary_tmp2)
|
||||
|
||||
out=M.GenerateImg(codes)
|
||||
imgs.append(out)
|
||||
|
||||
|
||||
imgs=np.concatenate(imgs,axis=1)
|
||||
M.step=imgs.shape[1]
|
||||
M.Vis('real','',imgs,colnames=list(paras[:,1]),rownames=img_indexs,viz_size=1024)
|
||||
|
||||
|
||||
|
||||
9
global_torch/dnnlib/__init__.py
Normal file
9
global_torch/dnnlib/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
from .util import EasyDict, make_cache_dir_path
|
||||
477
global_torch/dnnlib/util.py
Normal file
477
global_torch/dnnlib/util.py
Normal file
@ -0,0 +1,477 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Miscellaneous utility classes and functions."""
|
||||
|
||||
import ctypes
|
||||
import fnmatch
|
||||
import importlib
|
||||
import inspect
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import types
|
||||
import io
|
||||
import pickle
|
||||
import re
|
||||
import requests
|
||||
import html
|
||||
import hashlib
|
||||
import glob
|
||||
import tempfile
|
||||
import urllib
|
||||
import urllib.request
|
||||
import uuid
|
||||
|
||||
from distutils.util import strtobool
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
|
||||
# Util classes
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EasyDict(dict):
|
||||
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError:
|
||||
raise AttributeError(name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
self[name] = value
|
||||
|
||||
def __delattr__(self, name: str) -> None:
|
||||
del self[name]
|
||||
|
||||
|
||||
class Logger(object):
|
||||
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
||||
|
||||
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
||||
self.file = None
|
||||
|
||||
if file_name is not None:
|
||||
self.file = open(file_name, file_mode)
|
||||
|
||||
self.should_flush = should_flush
|
||||
self.stdout = sys.stdout
|
||||
self.stderr = sys.stderr
|
||||
|
||||
sys.stdout = self
|
||||
sys.stderr = self
|
||||
|
||||
def __enter__(self) -> "Logger":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def write(self, text: Union[str, bytes]) -> None:
|
||||
"""Write text to stdout (and a file) and optionally flush."""
|
||||
if isinstance(text, bytes):
|
||||
text = text.decode()
|
||||
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
||||
return
|
||||
|
||||
if self.file is not None:
|
||||
self.file.write(text)
|
||||
|
||||
self.stdout.write(text)
|
||||
|
||||
if self.should_flush:
|
||||
self.flush()
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush written text to both stdout and a file, if open."""
|
||||
if self.file is not None:
|
||||
self.file.flush()
|
||||
|
||||
self.stdout.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
||||
self.flush()
|
||||
|
||||
# if using multiple loggers, prevent closing in wrong order
|
||||
if sys.stdout is self:
|
||||
sys.stdout = self.stdout
|
||||
if sys.stderr is self:
|
||||
sys.stderr = self.stderr
|
||||
|
||||
if self.file is not None:
|
||||
self.file.close()
|
||||
self.file = None
|
||||
|
||||
|
||||
# Cache directories
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
_dnnlib_cache_dir = None
|
||||
|
||||
def set_cache_dir(path: str) -> None:
|
||||
global _dnnlib_cache_dir
|
||||
_dnnlib_cache_dir = path
|
||||
|
||||
def make_cache_dir_path(*paths: str) -> str:
|
||||
if _dnnlib_cache_dir is not None:
|
||||
return os.path.join(_dnnlib_cache_dir, *paths)
|
||||
if 'DNNLIB_CACHE_DIR' in os.environ:
|
||||
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
||||
if 'HOME' in os.environ:
|
||||
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
||||
if 'USERPROFILE' in os.environ:
|
||||
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
||||
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
||||
|
||||
# Small util functions
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def format_time(seconds: Union[int, float]) -> str:
|
||||
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
||||
s = int(np.rint(seconds))
|
||||
|
||||
if s < 60:
|
||||
return "{0}s".format(s)
|
||||
elif s < 60 * 60:
|
||||
return "{0}m {1:02}s".format(s // 60, s % 60)
|
||||
elif s < 24 * 60 * 60:
|
||||
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
||||
else:
|
||||
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
||||
|
||||
|
||||
def ask_yes_no(question: str) -> bool:
|
||||
"""Ask the user the question until the user inputs a valid answer."""
|
||||
while True:
|
||||
try:
|
||||
print("{0} [y/n]".format(question))
|
||||
return strtobool(input().lower())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def tuple_product(t: Tuple) -> Any:
|
||||
"""Calculate the product of the tuple elements."""
|
||||
result = 1
|
||||
|
||||
for v in t:
|
||||
result *= v
|
||||
|
||||
return result
|
||||
|
||||
|
||||
_str_to_ctype = {
|
||||
"uint8": ctypes.c_ubyte,
|
||||
"uint16": ctypes.c_uint16,
|
||||
"uint32": ctypes.c_uint32,
|
||||
"uint64": ctypes.c_uint64,
|
||||
"int8": ctypes.c_byte,
|
||||
"int16": ctypes.c_int16,
|
||||
"int32": ctypes.c_int32,
|
||||
"int64": ctypes.c_int64,
|
||||
"float32": ctypes.c_float,
|
||||
"float64": ctypes.c_double
|
||||
}
|
||||
|
||||
|
||||
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
||||
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
||||
type_str = None
|
||||
|
||||
if isinstance(type_obj, str):
|
||||
type_str = type_obj
|
||||
elif hasattr(type_obj, "__name__"):
|
||||
type_str = type_obj.__name__
|
||||
elif hasattr(type_obj, "name"):
|
||||
type_str = type_obj.name
|
||||
else:
|
||||
raise RuntimeError("Cannot infer type name from input")
|
||||
|
||||
assert type_str in _str_to_ctype.keys()
|
||||
|
||||
my_dtype = np.dtype(type_str)
|
||||
my_ctype = _str_to_ctype[type_str]
|
||||
|
||||
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
||||
|
||||
return my_dtype, my_ctype
|
||||
|
||||
|
||||
def is_pickleable(obj: Any) -> bool:
|
||||
try:
|
||||
with io.BytesIO() as stream:
|
||||
pickle.dump(obj, stream)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
# Functionality to import modules/objects by name, and call functions by name
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
||||
"""Searches for the underlying module behind the name to some python object.
|
||||
Returns the module and the object name (original name with module part removed)."""
|
||||
|
||||
# allow convenience shorthands, substitute them by full names
|
||||
obj_name = re.sub("^np.", "numpy.", obj_name)
|
||||
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
||||
|
||||
# list alternatives for (module_name, local_obj_name)
|
||||
parts = obj_name.split(".")
|
||||
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
||||
|
||||
# try each alternative in turn
|
||||
for module_name, local_obj_name in name_pairs:
|
||||
try:
|
||||
module = importlib.import_module(module_name) # may raise ImportError
|
||||
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
||||
return module, local_obj_name
|
||||
except:
|
||||
pass
|
||||
|
||||
# maybe some of the modules themselves contain errors?
|
||||
for module_name, _local_obj_name in name_pairs:
|
||||
try:
|
||||
importlib.import_module(module_name) # may raise ImportError
|
||||
except ImportError:
|
||||
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
||||
raise
|
||||
|
||||
# maybe the requested attribute is missing?
|
||||
for module_name, local_obj_name in name_pairs:
|
||||
try:
|
||||
module = importlib.import_module(module_name) # may raise ImportError
|
||||
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# we are out of luck, but we have no idea why
|
||||
raise ImportError(obj_name)
|
||||
|
||||
|
||||
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
||||
"""Traverses the object name and returns the last (rightmost) python object."""
|
||||
if obj_name == '':
|
||||
return module
|
||||
obj = module
|
||||
for part in obj_name.split("."):
|
||||
obj = getattr(obj, part)
|
||||
return obj
|
||||
|
||||
|
||||
def get_obj_by_name(name: str) -> Any:
|
||||
"""Finds the python object with the given name."""
|
||||
module, obj_name = get_module_from_obj_name(name)
|
||||
return get_obj_from_module(module, obj_name)
|
||||
|
||||
|
||||
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
||||
"""Finds the python object with the given name and calls it as a function."""
|
||||
assert func_name is not None
|
||||
func_obj = get_obj_by_name(func_name)
|
||||
assert callable(func_obj)
|
||||
return func_obj(*args, **kwargs)
|
||||
|
||||
|
||||
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
||||
"""Finds the python class with the given name and constructs it with the given arguments."""
|
||||
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
||||
|
||||
|
||||
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
||||
"""Get the directory path of the module containing the given object name."""
|
||||
module, _ = get_module_from_obj_name(obj_name)
|
||||
return os.path.dirname(inspect.getfile(module))
|
||||
|
||||
|
||||
def is_top_level_function(obj: Any) -> bool:
|
||||
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
||||
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
||||
|
||||
|
||||
def get_top_level_function_name(obj: Any) -> str:
|
||||
"""Return the fully-qualified name of a top-level function."""
|
||||
assert is_top_level_function(obj)
|
||||
module = obj.__module__
|
||||
if module == '__main__':
|
||||
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
||||
return module + "." + obj.__name__
|
||||
|
||||
|
||||
# File system helpers
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
||||
"""List all files recursively in a given directory while ignoring given file and directory names.
|
||||
Returns list of tuples containing both absolute and relative paths."""
|
||||
assert os.path.isdir(dir_path)
|
||||
base_name = os.path.basename(os.path.normpath(dir_path))
|
||||
|
||||
if ignores is None:
|
||||
ignores = []
|
||||
|
||||
result = []
|
||||
|
||||
for root, dirs, files in os.walk(dir_path, topdown=True):
|
||||
for ignore_ in ignores:
|
||||
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
||||
|
||||
# dirs need to be edited in-place
|
||||
for d in dirs_to_remove:
|
||||
dirs.remove(d)
|
||||
|
||||
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
||||
|
||||
absolute_paths = [os.path.join(root, f) for f in files]
|
||||
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
||||
|
||||
if add_base_to_relative:
|
||||
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
||||
|
||||
assert len(absolute_paths) == len(relative_paths)
|
||||
result += zip(absolute_paths, relative_paths)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
||||
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
||||
Will create all necessary directories."""
|
||||
for file in files:
|
||||
target_dir_name = os.path.dirname(file[1])
|
||||
|
||||
# will create all intermediate-level directories
|
||||
if not os.path.exists(target_dir_name):
|
||||
os.makedirs(target_dir_name)
|
||||
|
||||
shutil.copyfile(file[0], file[1])
|
||||
|
||||
|
||||
# URL helpers
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
||||
"""Determine whether the given object is a valid URL string."""
|
||||
if not isinstance(obj, str) or not "://" in obj:
|
||||
return False
|
||||
if allow_file_urls and obj.startswith('file://'):
|
||||
return True
|
||||
try:
|
||||
res = requests.compat.urlparse(obj)
|
||||
if not res.scheme or not res.netloc or not "." in res.netloc:
|
||||
return False
|
||||
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
||||
if not res.scheme or not res.netloc or not "." in res.netloc:
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
||||
"""Download the given URL and return a binary-mode file object to access the data."""
|
||||
assert num_attempts >= 1
|
||||
assert not (return_filename and (not cache))
|
||||
|
||||
# Doesn't look like an URL scheme so interpret it as a local filename.
|
||||
if not re.match('^[a-z]+://', url):
|
||||
return url if return_filename else open(url, "rb")
|
||||
|
||||
# Handle file URLs. This code handles unusual file:// patterns that
|
||||
# arise on Windows:
|
||||
#
|
||||
# file:///c:/foo.txt
|
||||
#
|
||||
# which would translate to a local '/c:/foo.txt' filename that's
|
||||
# invalid. Drop the forward slash for such pathnames.
|
||||
#
|
||||
# If you touch this code path, you should test it on both Linux and
|
||||
# Windows.
|
||||
#
|
||||
# Some internet resources suggest using urllib.request.url2pathname() but
|
||||
# but that converts forward slashes to backslashes and this causes
|
||||
# its own set of problems.
|
||||
if url.startswith('file://'):
|
||||
filename = urllib.parse.urlparse(url).path
|
||||
if re.match(r'^/[a-zA-Z]:', filename):
|
||||
filename = filename[1:]
|
||||
return filename if return_filename else open(filename, "rb")
|
||||
|
||||
assert is_url(url)
|
||||
|
||||
# Lookup from cache.
|
||||
if cache_dir is None:
|
||||
cache_dir = make_cache_dir_path('downloads')
|
||||
|
||||
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
||||
if cache:
|
||||
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
||||
if len(cache_files) == 1:
|
||||
filename = cache_files[0]
|
||||
return filename if return_filename else open(filename, "rb")
|
||||
|
||||
# Download.
|
||||
url_name = None
|
||||
url_data = None
|
||||
with requests.Session() as session:
|
||||
if verbose:
|
||||
print("Downloading %s ..." % url, end="", flush=True)
|
||||
for attempts_left in reversed(range(num_attempts)):
|
||||
try:
|
||||
with session.get(url) as res:
|
||||
res.raise_for_status()
|
||||
if len(res.content) == 0:
|
||||
raise IOError("No data received")
|
||||
|
||||
if len(res.content) < 8192:
|
||||
content_str = res.content.decode("utf-8")
|
||||
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
||||
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
||||
if len(links) == 1:
|
||||
url = requests.compat.urljoin(url, links[0])
|
||||
raise IOError("Google Drive virus checker nag")
|
||||
if "Google Drive - Quota exceeded" in content_str:
|
||||
raise IOError("Google Drive download quota exceeded -- please try again later")
|
||||
|
||||
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
||||
url_name = match[1] if match else url
|
||||
url_data = res.content
|
||||
if verbose:
|
||||
print(" done")
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except:
|
||||
if not attempts_left:
|
||||
if verbose:
|
||||
print(" failed")
|
||||
raise
|
||||
if verbose:
|
||||
print(".", end="", flush=True)
|
||||
|
||||
# Save to cache.
|
||||
if cache:
|
||||
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
||||
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
||||
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
with open(temp_file, "wb") as f:
|
||||
f.write(url_data)
|
||||
os.replace(temp_file, cache_file) # atomic
|
||||
if return_filename:
|
||||
return cache_file
|
||||
|
||||
# Return data as file object.
|
||||
assert not return_filename
|
||||
return io.BytesIO(url_data)
|
||||
99
global_torch/html/[6]_501_c.html
Normal file
99
global_torch/html/[6]_501_c.html
Normal file
File diff suppressed because one or more lines are too long
223
global_torch/html/real_.html
Normal file
223
global_torch/html/real_.html
Normal file
File diff suppressed because one or more lines are too long
326
global_torch/legacy.py
Normal file
326
global_torch/legacy.py
Normal file
@ -0,0 +1,326 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import click
|
||||
import pickle
|
||||
import re
|
||||
import copy
|
||||
import numpy as np
|
||||
import torch
|
||||
import dnnlib
|
||||
from torch_utils import misc
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def load_network_pkl(f, force_fp16=False):
|
||||
data = _LegacyUnpickler(f).load()
|
||||
|
||||
# Legacy TensorFlow pickle => convert.
|
||||
if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
|
||||
tf_G, tf_D, tf_Gs = data
|
||||
G = convert_tf_generator(tf_G)
|
||||
D = convert_tf_discriminator(tf_D)
|
||||
G_ema = convert_tf_generator(tf_Gs)
|
||||
data = dict(G=G, D=D, G_ema=G_ema)
|
||||
|
||||
# Add missing fields.
|
||||
if 'training_set_kwargs' not in data:
|
||||
data['training_set_kwargs'] = None
|
||||
if 'augment_pipe' not in data:
|
||||
data['augment_pipe'] = None
|
||||
|
||||
# Validate contents.
|
||||
assert isinstance(data['G'], torch.nn.Module)
|
||||
assert isinstance(data['D'], torch.nn.Module)
|
||||
assert isinstance(data['G_ema'], torch.nn.Module)
|
||||
assert isinstance(data['training_set_kwargs'], (dict, type(None)))
|
||||
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
|
||||
|
||||
# Force FP16.
|
||||
if force_fp16:
|
||||
for key in ['G', 'D', 'G_ema']:
|
||||
old = data[key]
|
||||
kwargs = copy.deepcopy(old.init_kwargs)
|
||||
if key.startswith('G'):
|
||||
kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
|
||||
kwargs.synthesis_kwargs.num_fp16_res = 4
|
||||
kwargs.synthesis_kwargs.conv_clamp = 256
|
||||
if key.startswith('D'):
|
||||
kwargs.num_fp16_res = 4
|
||||
kwargs.conv_clamp = 256
|
||||
if kwargs != old.init_kwargs:
|
||||
new = type(old)(**kwargs).eval().requires_grad_(False)
|
||||
misc.copy_params_and_buffers(old, new, require_all=True)
|
||||
data[key] = new
|
||||
return data
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _TFNetworkStub(dnnlib.EasyDict):
|
||||
pass
|
||||
|
||||
class _LegacyUnpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
if module == 'dnnlib.tflib.network' and name == 'Network':
|
||||
return _TFNetworkStub
|
||||
return super().find_class(module, name)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _collect_tf_params(tf_net):
|
||||
# pylint: disable=protected-access
|
||||
tf_params = dict()
|
||||
def recurse(prefix, tf_net):
|
||||
for name, value in tf_net.variables:
|
||||
tf_params[prefix + name] = value
|
||||
for name, comp in tf_net.components.items():
|
||||
recurse(prefix + name + '/', comp)
|
||||
recurse('', tf_net)
|
||||
return tf_params
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _populate_module_params(module, *patterns):
|
||||
for name, tensor in misc.named_params_and_buffers(module):
|
||||
found = False
|
||||
value = None
|
||||
for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
|
||||
match = re.fullmatch(pattern, name)
|
||||
if match:
|
||||
found = True
|
||||
if value_fn is not None:
|
||||
value = value_fn(*match.groups())
|
||||
break
|
||||
try:
|
||||
assert found
|
||||
if value is not None:
|
||||
tensor.copy_(torch.from_numpy(np.array(value)))
|
||||
except:
|
||||
print(name, list(tensor.shape))
|
||||
raise
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def convert_tf_generator(tf_G):
|
||||
if tf_G.version < 4:
|
||||
raise ValueError('TensorFlow pickle version too low')
|
||||
|
||||
# Collect kwargs.
|
||||
tf_kwargs = tf_G.static_kwargs
|
||||
known_kwargs = set()
|
||||
def kwarg(tf_name, default=None, none=None):
|
||||
known_kwargs.add(tf_name)
|
||||
val = tf_kwargs.get(tf_name, default)
|
||||
return val if val is not None else none
|
||||
|
||||
# Convert kwargs.
|
||||
kwargs = dnnlib.EasyDict(
|
||||
z_dim = kwarg('latent_size', 512),
|
||||
c_dim = kwarg('label_size', 0),
|
||||
w_dim = kwarg('dlatent_size', 512),
|
||||
img_resolution = kwarg('resolution', 1024),
|
||||
img_channels = kwarg('num_channels', 3),
|
||||
mapping_kwargs = dnnlib.EasyDict(
|
||||
num_layers = kwarg('mapping_layers', 8),
|
||||
embed_features = kwarg('label_fmaps', None),
|
||||
layer_features = kwarg('mapping_fmaps', None),
|
||||
activation = kwarg('mapping_nonlinearity', 'lrelu'),
|
||||
lr_multiplier = kwarg('mapping_lrmul', 0.01),
|
||||
w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
|
||||
),
|
||||
synthesis_kwargs = dnnlib.EasyDict(
|
||||
channel_base = kwarg('fmap_base', 16384) * 2,
|
||||
channel_max = kwarg('fmap_max', 512),
|
||||
num_fp16_res = kwarg('num_fp16_res', 0),
|
||||
conv_clamp = kwarg('conv_clamp', None),
|
||||
architecture = kwarg('architecture', 'skip'),
|
||||
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
||||
use_noise = kwarg('use_noise', True),
|
||||
activation = kwarg('nonlinearity', 'lrelu'),
|
||||
),
|
||||
)
|
||||
|
||||
# Check for unknown kwargs.
|
||||
kwarg('truncation_psi')
|
||||
kwarg('truncation_cutoff')
|
||||
kwarg('style_mixing_prob')
|
||||
kwarg('structure')
|
||||
if 'resolution_w' in tf_kwargs:
|
||||
tf_kwargs.pop('resolution_w', None)
|
||||
tf_kwargs.pop('resolution_h', None)
|
||||
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
||||
if len(unknown_kwargs) > 0:
|
||||
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
||||
|
||||
# Collect params.
|
||||
tf_params = _collect_tf_params(tf_G)
|
||||
for name, value in list(tf_params.items()):
|
||||
match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
|
||||
if match:
|
||||
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
||||
tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
|
||||
kwargs.synthesis.kwargs.architecture = 'orig'
|
||||
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
||||
|
||||
# Convert params.
|
||||
from training import networks
|
||||
G = networks.Generator(**kwargs).eval().requires_grad_(False)
|
||||
# pylint: disable=unnecessary-lambda
|
||||
_populate_module_params(G,
|
||||
r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
|
||||
r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
|
||||
r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
|
||||
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
|
||||
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
|
||||
r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
|
||||
r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
||||
r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
|
||||
r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
|
||||
r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
|
||||
r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
|
||||
r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
|
||||
r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
||||
r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
|
||||
r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
|
||||
r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
|
||||
r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
|
||||
r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
|
||||
r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
|
||||
r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
|
||||
r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
|
||||
r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
|
||||
r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
|
||||
r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
|
||||
r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
|
||||
r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
|
||||
r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
|
||||
r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
|
||||
r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
||||
r'.*\.resample_filter', None,
|
||||
)
|
||||
return G
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def convert_tf_discriminator(tf_D):
|
||||
if tf_D.version < 4:
|
||||
raise ValueError('TensorFlow pickle version too low')
|
||||
|
||||
# Collect kwargs.
|
||||
tf_kwargs = tf_D.static_kwargs
|
||||
known_kwargs = set()
|
||||
def kwarg(tf_name, default=None):
|
||||
known_kwargs.add(tf_name)
|
||||
return tf_kwargs.get(tf_name, default)
|
||||
|
||||
# Convert kwargs.
|
||||
kwargs = dnnlib.EasyDict(
|
||||
c_dim = kwarg('label_size', 0),
|
||||
img_resolution = kwarg('resolution', 1024),
|
||||
img_channels = kwarg('num_channels', 3),
|
||||
architecture = kwarg('architecture', 'resnet'),
|
||||
channel_base = kwarg('fmap_base', 16384) * 2,
|
||||
channel_max = kwarg('fmap_max', 512),
|
||||
num_fp16_res = kwarg('num_fp16_res', 0),
|
||||
conv_clamp = kwarg('conv_clamp', None),
|
||||
cmap_dim = kwarg('mapping_fmaps', None),
|
||||
block_kwargs = dnnlib.EasyDict(
|
||||
activation = kwarg('nonlinearity', 'lrelu'),
|
||||
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
||||
freeze_layers = kwarg('freeze_layers', 0),
|
||||
),
|
||||
mapping_kwargs = dnnlib.EasyDict(
|
||||
num_layers = kwarg('mapping_layers', 0),
|
||||
embed_features = kwarg('mapping_fmaps', None),
|
||||
layer_features = kwarg('mapping_fmaps', None),
|
||||
activation = kwarg('nonlinearity', 'lrelu'),
|
||||
lr_multiplier = kwarg('mapping_lrmul', 0.1),
|
||||
),
|
||||
epilogue_kwargs = dnnlib.EasyDict(
|
||||
mbstd_group_size = kwarg('mbstd_group_size', None),
|
||||
mbstd_num_channels = kwarg('mbstd_num_features', 1),
|
||||
activation = kwarg('nonlinearity', 'lrelu'),
|
||||
),
|
||||
)
|
||||
|
||||
# Check for unknown kwargs.
|
||||
kwarg('structure')
|
||||
if 'resolution_w' in tf_kwargs:
|
||||
tf_kwargs.pop('resolution_w', None)
|
||||
tf_kwargs.pop('resolution_h', None)
|
||||
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
||||
if len(unknown_kwargs) > 0:
|
||||
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
||||
|
||||
# Collect params.
|
||||
tf_params = _collect_tf_params(tf_D)
|
||||
for name, value in list(tf_params.items()):
|
||||
match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
|
||||
if match:
|
||||
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
||||
tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
|
||||
kwargs.architecture = 'orig'
|
||||
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
||||
|
||||
# Convert params.
|
||||
from training import networks
|
||||
D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
|
||||
# pylint: disable=unnecessary-lambda
|
||||
_populate_module_params(D,
|
||||
r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
|
||||
r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
|
||||
r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
|
||||
r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
|
||||
r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
|
||||
r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
|
||||
r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
|
||||
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
|
||||
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
|
||||
r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
||||
r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
|
||||
r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
|
||||
r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
|
||||
r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
|
||||
r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
|
||||
r'.*\.resample_filter', None,
|
||||
)
|
||||
return D
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@click.command()
|
||||
@click.option('--source', help='Input pickle', required=True, metavar='PATH')
|
||||
@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
|
||||
@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
|
||||
def convert_network_pickle(source, dest, force_fp16):
|
||||
"""Convert legacy network pickle into the native PyTorch format.
|
||||
|
||||
The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
|
||||
It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
|
||||
|
||||
Example:
|
||||
|
||||
\b
|
||||
python legacy.py \\
|
||||
--source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
|
||||
--dest=stylegan2-cat-config-f.pkl
|
||||
"""
|
||||
print(f'Loading "{source}"...')
|
||||
with dnnlib.util.open_url(source) as f:
|
||||
data = load_network_pkl(f, force_fp16=force_fp16)
|
||||
print(f'Saving "{dest}"...')
|
||||
with open(dest, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
print('Done.')
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert_network_pickle() # pylint: disable=no-value-for-parameter
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
383
global_torch/manipulate.py
Normal file
383
global_torch/manipulate.py
Normal file
@ -0,0 +1,383 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Mon Jul 19 21:03:58 2021
|
||||
|
||||
@author: wuzongze
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
import copy
|
||||
import os
|
||||
from time import perf_counter
|
||||
|
||||
import click
|
||||
import imageio
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
import dnnlib
|
||||
import legacy
|
||||
import pickle
|
||||
from visualizer import HtmlPageVisualizer
|
||||
|
||||
from torch_utils import misc
|
||||
import types
|
||||
from training.networks import SynthesisNetwork,SynthesisBlock,SynthesisLayer,ToRGBLayer
|
||||
|
||||
|
||||
def change_style_code(codes, layer, channel, step):
|
||||
codes[layer][:, channel] += step
|
||||
return codes
|
||||
|
||||
def Vis(bname,suffix,out,rownames=None,colnames=None,save_path=None,viz_size=256):
|
||||
|
||||
if save_path is None:
|
||||
save_path='./html/'
|
||||
|
||||
|
||||
num_images=out.shape[0]
|
||||
step=out.shape[1]
|
||||
|
||||
if colnames is None:
|
||||
colnames=[f'Step {i:02d}' for i in range(1, step + 1)]
|
||||
if rownames is None:
|
||||
rownames=[str(i) for i in range(num_images)]
|
||||
|
||||
|
||||
visualizer = HtmlPageVisualizer(
|
||||
num_rows=num_images, num_cols=step + 1, viz_size=viz_size)
|
||||
visualizer.set_headers(
|
||||
['Name'] +colnames)
|
||||
|
||||
for i in range(num_images):
|
||||
visualizer.set_cell(i, 0, text=rownames[i])
|
||||
|
||||
for i in range(num_images):
|
||||
for k in range(step):
|
||||
image=out[i,k,:,:,:]
|
||||
visualizer.set_cell(i, 1+k, image=image)
|
||||
|
||||
visualizer.save(save_path+bname+'_'+suffix+'.html')
|
||||
|
||||
def LoadModel(network_pkl,device):
|
||||
with dnnlib.util.open_url(network_pkl) as fp:
|
||||
G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
|
||||
|
||||
G.synthesis.forward=types.MethodType(SynthesisNetwork.forward,G.synthesis)
|
||||
G.synthesis.W2S=types.MethodType(SynthesisNetwork.W2S,G.synthesis)
|
||||
|
||||
for res in G.synthesis.block_resolutions:
|
||||
block = getattr(G.synthesis, f'b{res}')
|
||||
# print(block)
|
||||
block.forward=types.MethodType(SynthesisBlock.forward,block)
|
||||
|
||||
if res!=4:
|
||||
layer=block.conv0
|
||||
layer.forward=types.MethodType(SynthesisLayer.forward,layer)
|
||||
layer.name='conv0_resolution_'+str(res)
|
||||
|
||||
layer=block.conv1
|
||||
layer.forward=types.MethodType(SynthesisLayer.forward,layer)
|
||||
layer.name='conv1_resolution_'+str(res)
|
||||
|
||||
layer=block.torgb
|
||||
layer.forward=types.MethodType(ToRGBLayer.forward,layer)
|
||||
layer.name='toRGB_resolution_'+str(res)
|
||||
|
||||
|
||||
return G
|
||||
|
||||
|
||||
def S2List(encoded_styles):
|
||||
all_s=[]
|
||||
for name in encoded_styles.keys():
|
||||
tmp=encoded_styles[name].cpu().numpy()
|
||||
all_s.append(tmp)
|
||||
return all_s
|
||||
|
||||
|
||||
|
||||
class Manipulator():
|
||||
def __init__(self,dataset_name='ffhq'):
|
||||
|
||||
self.alpha=[0] #manipulation strength
|
||||
self.num_images=10
|
||||
self.img_index=0 #which image to start
|
||||
# self.viz_size=256
|
||||
self.manipulate_layers=None #which layer to manipulate, list
|
||||
self.truncation_psi=0.7
|
||||
self.truncation_cutoff=8
|
||||
|
||||
# self.G=LoadModel(self.model_path,self.model_name)
|
||||
|
||||
self.LoadModel=LoadModel
|
||||
self.Vis=Vis
|
||||
self.S2List=S2List
|
||||
|
||||
fmaps=[512, 512, 512, 512, 512, 256, 128, 64, 32]
|
||||
self.fmaps=np.repeat(fmaps,3)
|
||||
|
||||
|
||||
def GetSName(self):
|
||||
s_names=[]
|
||||
for res in self.G.synthesis.block_resolutions:
|
||||
if res==4:
|
||||
tmp=f'conv1_resolution_{res}'
|
||||
s_names.append(tmp)
|
||||
|
||||
tmp=f'toRGB_resolution_{res}'
|
||||
s_names.append(tmp)
|
||||
else:
|
||||
tmp=f'conv0_resolution_{res}'
|
||||
s_names.append(tmp)
|
||||
|
||||
tmp=f'conv1_resolution_{res}'
|
||||
s_names.append(tmp)
|
||||
|
||||
tmp=f'toRGB_resolution_{res}'
|
||||
s_names.append(tmp)
|
||||
|
||||
return s_names
|
||||
|
||||
def SL2D(self,tmp_code):
|
||||
encoded_styles={}
|
||||
for i in range(len(self.s_names)):
|
||||
encoded_styles[self.s_names[i]]=torch.from_numpy(tmp_code[i]).to(self.device)
|
||||
|
||||
return encoded_styles
|
||||
|
||||
|
||||
|
||||
def GenerateS(self,num_img=100):
|
||||
seed=5
|
||||
with torch.no_grad():
|
||||
z = torch.from_numpy(np.random.RandomState(seed).randn(num_img, self.G.z_dim)).to(self.device)
|
||||
ws = self.G.mapping(z=z,c=None,truncation_psi=self.truncation_psi,truncation_cutoff=self.truncation_cutoff)
|
||||
encoded_styles=self.G.synthesis.W2S(ws)
|
||||
# encoded_styles=encoded_styles.cpu().numpy()
|
||||
|
||||
self.dlatents=S2List(encoded_styles)
|
||||
|
||||
def GenerateImg(self,codes):
|
||||
|
||||
num_images,step=codes[0].shape[:2]
|
||||
out=np.zeros((num_images,step,self.img_size,self.img_size,3),dtype='uint8')
|
||||
for i in range(num_images):
|
||||
for k in range(step):
|
||||
|
||||
tmp_code=[]
|
||||
for m in range(len(self.s_names)):
|
||||
tmp=codes[m][i,k][None,:]
|
||||
tmp_code.append(tmp)
|
||||
|
||||
encoded_styles=self.SL2D(tmp_code)
|
||||
|
||||
with torch.no_grad():
|
||||
img = self.G.synthesis(None, encoded_styles=encoded_styles,noise_mode='const')
|
||||
img = (img + 1) * (255/2)
|
||||
img = img.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
|
||||
|
||||
|
||||
|
||||
if img.shape[1]==img.shape[0]:
|
||||
out[i,k,:,:,:]=img
|
||||
else:
|
||||
tmp=img.shape[1]
|
||||
tmp1=int((img.shape[0]-tmp)/2)
|
||||
out[i,k,:,tmp1:tmp1+tmp,:]=img
|
||||
return out
|
||||
|
||||
def ShowImg(self,num_img=10):
|
||||
|
||||
codes=[]
|
||||
for i in range(len(self.dlatents)):
|
||||
# print(i)
|
||||
tmp=self.dlatents[i][:num_img,None,:]
|
||||
codes.append(tmp)
|
||||
out=self.GenerateImg(codes)
|
||||
return out
|
||||
|
||||
def SetGParameters(self):
|
||||
self.num_layers=self.G.synthesis.num_ws
|
||||
self.img_size=self.G.synthesis.img_resolution
|
||||
self.s_names=self.GetSName()
|
||||
|
||||
self.img_size=self.G.synthesis.block_resolutions[-1]
|
||||
|
||||
self.mindexs=[0, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 21,23,24]
|
||||
|
||||
|
||||
|
||||
def MSCode(self,dlatent_tmp,boundary_tmp):
|
||||
|
||||
step=len(self.alpha)
|
||||
dlatent_tmp1=[tmp.reshape((self.num_images,-1)) for tmp in dlatent_tmp]
|
||||
dlatent_tmp2=[np.tile(tmp[:,None],(1,step,1)) for tmp in dlatent_tmp1] # (10, 7, 512)
|
||||
|
||||
l=np.array(self.alpha)
|
||||
l=l.reshape(
|
||||
[step if axis == 1 else 1 for axis in range(dlatent_tmp2[0].ndim)])
|
||||
|
||||
if type(self.manipulate_layers)==int:
|
||||
tmp=[self.manipulate_layers]
|
||||
elif type(self.manipulate_layers)==list:
|
||||
tmp=self.manipulate_layers
|
||||
elif self.manipulate_layers is None:
|
||||
tmp=np.arange(len(boundary_tmp))
|
||||
else:
|
||||
raise ValueError('manipulate_layers is wrong')
|
||||
|
||||
for i in tmp:
|
||||
dlatent_tmp2[i]+=l*boundary_tmp[i]
|
||||
|
||||
codes=[]
|
||||
for i in range(len(dlatent_tmp2)):
|
||||
tmp=list(dlatent_tmp[i].shape)
|
||||
tmp.insert(1,step)
|
||||
codes.append(dlatent_tmp2[i].reshape(tmp))
|
||||
return codes
|
||||
|
||||
|
||||
def EditOne(self,bname,dlatent_tmp=None):
|
||||
if dlatent_tmp==None:
|
||||
dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents]
|
||||
|
||||
boundary_tmp=[]
|
||||
for i in range(len(self.boundary)):
|
||||
tmp=self.boundary[i]
|
||||
if len(tmp)<=bname:
|
||||
boundary_tmp.append([])
|
||||
else:
|
||||
boundary_tmp.append(tmp[bname])
|
||||
|
||||
codes=self.MSCode(dlatent_tmp,boundary_tmp)
|
||||
|
||||
out=self.GenerateImg(codes)
|
||||
return codes,out
|
||||
|
||||
def EditOneC(self,cindex,dlatent_tmp=None):
|
||||
if dlatent_tmp==None:
|
||||
dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents]
|
||||
|
||||
boundary_tmp=[[] for i in range(len(self.dlatents))]
|
||||
|
||||
#'only manipulate 1 layer and one channel'
|
||||
assert len(self.manipulate_layers)==1
|
||||
|
||||
ml=self.manipulate_layers[0]
|
||||
tmp=dlatent_tmp[ml].shape[1] #ada
|
||||
tmp1=np.zeros(tmp)
|
||||
tmp1[cindex]=self.code_std[ml][cindex] #1
|
||||
boundary_tmp[ml]=tmp1
|
||||
|
||||
codes=self.MSCode(dlatent_tmp,boundary_tmp)
|
||||
out=self.GenerateImg(codes)
|
||||
return codes,out
|
||||
|
||||
def GetFindex(self,lindex,cindex,ignore_RGB=False):
|
||||
|
||||
if ignore_RGB:
|
||||
tmp=np.array(self.mindexs)<lindex
|
||||
tmp=np.sum(tmp)
|
||||
else:
|
||||
tmp=lindex
|
||||
findex=np.sum(self.fmaps[:tmp])+cindex
|
||||
return findex
|
||||
|
||||
def GetLCIndex(self,findex):
|
||||
l_p=[]
|
||||
cfmaps=np.cumsum(self.fmaps)
|
||||
for i in range(len(findex)):
|
||||
# i=-2
|
||||
tmp_index=findex[i]
|
||||
# importance_matrix.max(axis=0)
|
||||
# self.attrib_indices2
|
||||
tmp=tmp_index-cfmaps
|
||||
tmp=tmp[tmp>0]
|
||||
lindex=len(tmp)
|
||||
if lindex==0:
|
||||
cindex=tmp_index
|
||||
else:
|
||||
cindex=tmp[-1]
|
||||
|
||||
if cindex ==self.fmaps[lindex]:
|
||||
cindex=0
|
||||
lindex+=1
|
||||
# print(completeness.index[i],completeness.iloc[i,:].values,lindex,cindex)
|
||||
l_p.append([lindex,cindex])
|
||||
l_p=np.array(l_p)
|
||||
return l_p
|
||||
def GetLCIndex2(self,findex): #input findex without ToRGB
|
||||
fmaps_o=copy.copy(self.fmaps)
|
||||
mindexs=[0, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 21,23,24]
|
||||
self.fmaps=fmaps_o[mindexs]
|
||||
|
||||
l_p=self.GetLCIndex(findex)
|
||||
|
||||
l=l_p[:,0]
|
||||
l2=np.array(mindexs)[l]
|
||||
l_p[:,0]=l2
|
||||
self.fmaps=fmaps_o
|
||||
return l_p
|
||||
|
||||
def GetCodeMS(self):
|
||||
m=[]
|
||||
std=[]
|
||||
for i in range(len(self.dlatents)):
|
||||
tmp= self.dlatents[i]
|
||||
tmp_mean=tmp.mean(axis=0)
|
||||
tmp_std=tmp.std(axis=0)
|
||||
m.append(tmp_mean)
|
||||
std.append(tmp_std)
|
||||
|
||||
self.code_mean=m
|
||||
self.code_std=std
|
||||
# return m,std
|
||||
|
||||
|
||||
#%%
|
||||
if __name__ == "__main__":
|
||||
network_pkl='/cs/labs/danix/wuzongze/Gan_Manipulation/stylegan2/model/stylegan2-ffhq-config-f.pkl'
|
||||
device = torch.device('cuda')
|
||||
M=Manipulator()
|
||||
M.device=device
|
||||
G=M.LoadModel(network_pkl,device)
|
||||
M.G=G
|
||||
M.SetGParameters()
|
||||
num_img=100_000
|
||||
M.GenerateS(num_img=num_img)
|
||||
M.GetCodeMS()
|
||||
np.set_printoptions(suppress=True)
|
||||
|
||||
#%%
|
||||
M.alpha=[24,16,8,0,-8,-16,-24]
|
||||
M.step=len(M.alpha)
|
||||
M.img_index=0
|
||||
M.num_images=10
|
||||
lindex,bname=6,501
|
||||
# M.
|
||||
M.manipulate_layers=[lindex]
|
||||
codes,out=M.EditOneC(bname) #dlatent_tmp
|
||||
tmp=str(M.manipulate_layers)+'_'+str(bname)
|
||||
M.Vis(tmp,'c',out)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
BIN
global_torch/npy/ffhq/fs3.npy
Normal file
BIN
global_torch/npy/ffhq/fs3.npy
Normal file
Binary file not shown.
BIN
global_torch/npy/human/fs3.npy
Normal file
BIN
global_torch/npy/human/fs3.npy
Normal file
Binary file not shown.
9
global_torch/torch_utils/__init__.py
Normal file
9
global_torch/torch_utils/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
# empty
|
||||
126
global_torch/torch_utils/custom_ops.py
Normal file
126
global_torch/torch_utils/custom_ops.py
Normal file
@ -0,0 +1,126 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
import importlib
|
||||
import hashlib
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from torch.utils.file_baton import FileBaton
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Global options.
|
||||
|
||||
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Internal helper funcs.
|
||||
|
||||
def _find_compiler_bindir():
|
||||
patterns = [
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
||||
]
|
||||
for pattern in patterns:
|
||||
matches = sorted(glob.glob(pattern))
|
||||
if len(matches):
|
||||
return matches[-1]
|
||||
return None
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Main entry point for compiling and loading C++/CUDA plugins.
|
||||
|
||||
_cached_plugins = dict()
|
||||
|
||||
def get_plugin(module_name, sources, **build_kwargs):
|
||||
assert verbosity in ['none', 'brief', 'full']
|
||||
|
||||
# Already cached?
|
||||
if module_name in _cached_plugins:
|
||||
return _cached_plugins[module_name]
|
||||
|
||||
# Print status.
|
||||
if verbosity == 'full':
|
||||
print(f'Setting up PyTorch plugin "{module_name}"...')
|
||||
elif verbosity == 'brief':
|
||||
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
||||
|
||||
try: # pylint: disable=too-many-nested-blocks
|
||||
# Make sure we can find the necessary compiler binaries.
|
||||
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
||||
compiler_bindir = _find_compiler_bindir()
|
||||
if compiler_bindir is None:
|
||||
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
||||
os.environ['PATH'] += ';' + compiler_bindir
|
||||
|
||||
# Compile and load.
|
||||
verbose_build = (verbosity == 'full')
|
||||
|
||||
# Incremental build md5sum trickery. Copies all the input source files
|
||||
# into a cached build directory under a combined md5 digest of the input
|
||||
# source files. Copying is done only if the combined digest has changed.
|
||||
# This keeps input file timestamps and filenames the same as in previous
|
||||
# extension builds, allowing for fast incremental rebuilds.
|
||||
#
|
||||
# This optimization is done only in case all the source files reside in
|
||||
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
||||
# environment variable is set (we take this as a signal that the user
|
||||
# actually cares about this.)
|
||||
source_dirs_set = set(os.path.dirname(source) for source in sources)
|
||||
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
||||
all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
|
||||
|
||||
# Compute a combined hash digest for all source files in the same
|
||||
# custom op directory (usually .cu, .cpp, .py and .h files).
|
||||
hash_md5 = hashlib.md5()
|
||||
for src in all_source_files:
|
||||
with open(src, 'rb') as f:
|
||||
hash_md5.update(f.read())
|
||||
build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
||||
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
|
||||
|
||||
if not os.path.isdir(digest_build_dir):
|
||||
os.makedirs(digest_build_dir, exist_ok=True)
|
||||
baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
|
||||
if baton.try_acquire():
|
||||
try:
|
||||
for src in all_source_files:
|
||||
shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
|
||||
finally:
|
||||
baton.release()
|
||||
else:
|
||||
# Someone else is copying source files under the digest dir,
|
||||
# wait until done and continue.
|
||||
baton.wait()
|
||||
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
|
||||
torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
|
||||
verbose=verbose_build, sources=digest_sources, **build_kwargs)
|
||||
else:
|
||||
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
except:
|
||||
if verbosity == 'brief':
|
||||
print('Failed!')
|
||||
raise
|
||||
|
||||
# Print status and add to cache.
|
||||
if verbosity == 'full':
|
||||
print(f'Done setting up PyTorch plugin "{module_name}".')
|
||||
elif verbosity == 'brief':
|
||||
print('Done.')
|
||||
_cached_plugins[module_name] = module
|
||||
return module
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
262
global_torch/torch_utils/misc.py
Normal file
262
global_torch/torch_utils/misc.py
Normal file
@ -0,0 +1,262 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import re
|
||||
import contextlib
|
||||
import numpy as np
|
||||
import torch
|
||||
import warnings
|
||||
import dnnlib
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
||||
# same constant is used multiple times.
|
||||
|
||||
_constant_cache = dict()
|
||||
|
||||
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
||||
value = np.asarray(value)
|
||||
if shape is not None:
|
||||
shape = tuple(shape)
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if device is None:
|
||||
device = torch.device('cpu')
|
||||
if memory_format is None:
|
||||
memory_format = torch.contiguous_format
|
||||
|
||||
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
||||
tensor = _constant_cache.get(key, None)
|
||||
if tensor is None:
|
||||
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
||||
if shape is not None:
|
||||
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
||||
tensor = tensor.contiguous(memory_format=memory_format)
|
||||
_constant_cache[key] = tensor
|
||||
return tensor
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Replace NaN/Inf with specified numerical values.
|
||||
|
||||
try:
|
||||
nan_to_num = torch.nan_to_num # 1.8.0a0
|
||||
except AttributeError:
|
||||
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
||||
assert isinstance(input, torch.Tensor)
|
||||
if posinf is None:
|
||||
posinf = torch.finfo(input.dtype).max
|
||||
if neginf is None:
|
||||
neginf = torch.finfo(input.dtype).min
|
||||
assert nan == 0
|
||||
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Symbolic assert.
|
||||
|
||||
try:
|
||||
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
||||
except AttributeError:
|
||||
symbolic_assert = torch.Assert # 1.7.0
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Context manager to suppress known warnings in torch.jit.trace().
|
||||
|
||||
class suppress_tracer_warnings(warnings.catch_warnings):
|
||||
def __enter__(self):
|
||||
super().__enter__()
|
||||
warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
|
||||
return self
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Assert that the shape of a tensor matches the given list of integers.
|
||||
# None indicates that the size of a dimension is allowed to vary.
|
||||
# Performs symbolic assertion when used in torch.jit.trace().
|
||||
|
||||
def assert_shape(tensor, ref_shape):
|
||||
if tensor.ndim != len(ref_shape):
|
||||
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
||||
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
||||
if ref_size is None:
|
||||
pass
|
||||
elif isinstance(ref_size, torch.Tensor):
|
||||
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
||||
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
||||
elif isinstance(size, torch.Tensor):
|
||||
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
||||
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
||||
elif size != ref_size:
|
||||
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Function decorator that calls torch.autograd.profiler.record_function().
|
||||
|
||||
def profiled_function(fn):
|
||||
def decorator(*args, **kwargs):
|
||||
with torch.autograd.profiler.record_function(fn.__name__):
|
||||
return fn(*args, **kwargs)
|
||||
decorator.__name__ = fn.__name__
|
||||
return decorator
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
||||
# indefinitely, shuffling items as it goes.
|
||||
|
||||
class InfiniteSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
||||
assert len(dataset) > 0
|
||||
assert num_replicas > 0
|
||||
assert 0 <= rank < num_replicas
|
||||
assert 0 <= window_size <= 1
|
||||
super().__init__(dataset)
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.num_replicas = num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.window_size = window_size
|
||||
|
||||
def __iter__(self):
|
||||
order = np.arange(len(self.dataset))
|
||||
rnd = None
|
||||
window = 0
|
||||
if self.shuffle:
|
||||
rnd = np.random.RandomState(self.seed)
|
||||
rnd.shuffle(order)
|
||||
window = int(np.rint(order.size * self.window_size))
|
||||
|
||||
idx = 0
|
||||
while True:
|
||||
i = idx % order.size
|
||||
if idx % self.num_replicas == self.rank:
|
||||
yield order[i]
|
||||
if window >= 2:
|
||||
j = (i - rnd.randint(window)) % order.size
|
||||
order[i], order[j] = order[j], order[i]
|
||||
idx += 1
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Utilities for operating with torch.nn.Module parameters and buffers.
|
||||
|
||||
def params_and_buffers(module):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
return list(module.parameters()) + list(module.buffers())
|
||||
|
||||
def named_params_and_buffers(module):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
return list(module.named_parameters()) + list(module.named_buffers())
|
||||
|
||||
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
||||
assert isinstance(src_module, torch.nn.Module)
|
||||
assert isinstance(dst_module, torch.nn.Module)
|
||||
src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
|
||||
for name, tensor in named_params_and_buffers(dst_module):
|
||||
assert (name in src_tensors) or (not require_all)
|
||||
if name in src_tensors:
|
||||
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Context manager for easily enabling/disabling DistributedDataParallel
|
||||
# synchronization.
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ddp_sync(module, sync):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
||||
yield
|
||||
else:
|
||||
with module.no_sync():
|
||||
yield
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Check DistributedDataParallel consistency across processes.
|
||||
|
||||
def check_ddp_consistency(module, ignore_regex=None):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
for name, tensor in named_params_and_buffers(module):
|
||||
fullname = type(module).__name__ + '.' + name
|
||||
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
||||
continue
|
||||
tensor = tensor.detach()
|
||||
other = tensor.clone()
|
||||
torch.distributed.broadcast(tensor=other, src=0)
|
||||
assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Print summary table of module hierarchy.
|
||||
|
||||
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
assert not isinstance(module, torch.jit.ScriptModule)
|
||||
assert isinstance(inputs, (tuple, list))
|
||||
|
||||
# Register hooks.
|
||||
entries = []
|
||||
nesting = [0]
|
||||
def pre_hook(_mod, _inputs):
|
||||
nesting[0] += 1
|
||||
def post_hook(mod, _inputs, outputs):
|
||||
nesting[0] -= 1
|
||||
if nesting[0] <= max_nesting:
|
||||
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
||||
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
||||
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
||||
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
||||
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
||||
|
||||
# Run module.
|
||||
outputs = module(*inputs)
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# Identify unique outputs, parameters, and buffers.
|
||||
tensors_seen = set()
|
||||
for e in entries:
|
||||
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
||||
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
||||
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
||||
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
||||
|
||||
# Filter out redundant entries.
|
||||
if skip_redundant:
|
||||
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
||||
|
||||
# Construct table.
|
||||
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
||||
rows += [['---'] * len(rows[0])]
|
||||
param_total = 0
|
||||
buffer_total = 0
|
||||
submodule_names = {mod: name for name, mod in module.named_modules()}
|
||||
for e in entries:
|
||||
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
||||
param_size = sum(t.numel() for t in e.unique_params)
|
||||
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
||||
output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
|
||||
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
||||
rows += [[
|
||||
name + (':0' if len(e.outputs) >= 2 else ''),
|
||||
str(param_size) if param_size else '-',
|
||||
str(buffer_size) if buffer_size else '-',
|
||||
(output_shapes + ['-'])[0],
|
||||
(output_dtypes + ['-'])[0],
|
||||
]]
|
||||
for idx in range(1, len(e.outputs)):
|
||||
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
||||
param_total += param_size
|
||||
buffer_total += buffer_size
|
||||
rows += [['---'] * len(rows[0])]
|
||||
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
||||
|
||||
# Print table.
|
||||
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
||||
print()
|
||||
for row in rows:
|
||||
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
||||
print()
|
||||
return outputs
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
9
global_torch/torch_utils/ops/__init__.py
Normal file
9
global_torch/torch_utils/ops/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
# empty
|
||||
99
global_torch/torch_utils/ops/bias_act.cpp
Normal file
99
global_torch/torch_utils/ops/bias_act.cpp
Normal file
@ -0,0 +1,99 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "bias_act.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
||||
{
|
||||
if (x.dim() != y.dim())
|
||||
return false;
|
||||
for (int64_t i = 0; i < x.dim(); i++)
|
||||
{
|
||||
if (x.size(i) != y.size(i))
|
||||
return false;
|
||||
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
||||
{
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
||||
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
||||
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
||||
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
||||
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
||||
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
||||
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
||||
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
||||
|
||||
// Validate layout.
|
||||
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
||||
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
||||
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
||||
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
||||
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
||||
|
||||
// Create output tensor.
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
torch::Tensor y = torch::empty_like(x);
|
||||
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
||||
|
||||
// Initialize CUDA kernel parameters.
|
||||
bias_act_kernel_params p;
|
||||
p.x = x.data_ptr();
|
||||
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
||||
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
||||
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
||||
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
||||
p.y = y.data_ptr();
|
||||
p.grad = grad;
|
||||
p.act = act;
|
||||
p.alpha = alpha;
|
||||
p.gain = gain;
|
||||
p.clamp = clamp;
|
||||
p.sizeX = (int)x.numel();
|
||||
p.sizeB = (int)b.numel();
|
||||
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
||||
|
||||
// Choose CUDA kernel.
|
||||
void* kernel;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
||||
{
|
||||
kernel = choose_bias_act_kernel<scalar_t>(p);
|
||||
});
|
||||
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
||||
|
||||
// Launch CUDA kernel.
|
||||
p.loopX = 4;
|
||||
int blockSize = 4 * 32;
|
||||
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
||||
void* args[] = {&p};
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
return y;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("bias_act", &bias_act);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
173
global_torch/torch_utils/ops/bias_act.cu
Normal file
173
global_torch/torch_utils/ops/bias_act.cu
Normal file
@ -0,0 +1,173 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <c10/util/Half.h>
|
||||
#include "bias_act.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Helpers.
|
||||
|
||||
template <class T> struct InternalType;
|
||||
template <> struct InternalType<double> { typedef double scalar_t; };
|
||||
template <> struct InternalType<float> { typedef float scalar_t; };
|
||||
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel.
|
||||
|
||||
template <class T, int A>
|
||||
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
||||
{
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
int G = p.grad;
|
||||
scalar_t alpha = (scalar_t)p.alpha;
|
||||
scalar_t gain = (scalar_t)p.gain;
|
||||
scalar_t clamp = (scalar_t)p.clamp;
|
||||
scalar_t one = (scalar_t)1;
|
||||
scalar_t two = (scalar_t)2;
|
||||
scalar_t expRange = (scalar_t)80;
|
||||
scalar_t halfExpRange = (scalar_t)40;
|
||||
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
||||
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
||||
|
||||
// Loop over elements.
|
||||
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
||||
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
||||
{
|
||||
// Load.
|
||||
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
||||
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
||||
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
||||
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
||||
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
||||
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
||||
scalar_t y = 0;
|
||||
|
||||
// Apply bias.
|
||||
((G == 0) ? x : xref) += b;
|
||||
|
||||
// linear
|
||||
if (A == 1)
|
||||
{
|
||||
if (G == 0) y = x;
|
||||
if (G == 1) y = x;
|
||||
}
|
||||
|
||||
// relu
|
||||
if (A == 2)
|
||||
{
|
||||
if (G == 0) y = (x > 0) ? x : 0;
|
||||
if (G == 1) y = (yy > 0) ? x : 0;
|
||||
}
|
||||
|
||||
// lrelu
|
||||
if (A == 3)
|
||||
{
|
||||
if (G == 0) y = (x > 0) ? x : x * alpha;
|
||||
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
||||
}
|
||||
|
||||
// tanh
|
||||
if (A == 4)
|
||||
{
|
||||
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
||||
if (G == 1) y = x * (one - yy * yy);
|
||||
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
||||
}
|
||||
|
||||
// sigmoid
|
||||
if (A == 5)
|
||||
{
|
||||
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
||||
if (G == 1) y = x * yy * (one - yy);
|
||||
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
||||
}
|
||||
|
||||
// elu
|
||||
if (A == 6)
|
||||
{
|
||||
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
||||
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
||||
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
||||
}
|
||||
|
||||
// selu
|
||||
if (A == 7)
|
||||
{
|
||||
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
||||
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
||||
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
||||
}
|
||||
|
||||
// softplus
|
||||
if (A == 8)
|
||||
{
|
||||
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
||||
if (G == 1) y = x * (one - exp(-yy));
|
||||
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
||||
}
|
||||
|
||||
// swish
|
||||
if (A == 9)
|
||||
{
|
||||
if (G == 0)
|
||||
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
||||
else
|
||||
{
|
||||
scalar_t c = exp(xref);
|
||||
scalar_t d = c + one;
|
||||
if (G == 1)
|
||||
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
||||
else
|
||||
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
||||
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply gain.
|
||||
y *= gain * dy;
|
||||
|
||||
// Clamp.
|
||||
if (clamp >= 0)
|
||||
{
|
||||
if (G == 0)
|
||||
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
||||
else
|
||||
y = (yref > -clamp & yref < clamp) ? y : 0;
|
||||
}
|
||||
|
||||
// Store.
|
||||
((T*)p.y)[xi] = (T)y;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
||||
{
|
||||
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
||||
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
||||
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
||||
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
||||
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
||||
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
||||
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
||||
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
||||
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Template specializations.
|
||||
|
||||
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
||||
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
||||
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
38
global_torch/torch_utils/ops/bias_act.h
Normal file
38
global_torch/torch_utils/ops/bias_act.h
Normal file
@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel parameters.
|
||||
|
||||
struct bias_act_kernel_params
|
||||
{
|
||||
const void* x; // [sizeX]
|
||||
const void* b; // [sizeB] or NULL
|
||||
const void* xref; // [sizeX] or NULL
|
||||
const void* yref; // [sizeX] or NULL
|
||||
const void* dy; // [sizeX] or NULL
|
||||
void* y; // [sizeX]
|
||||
|
||||
int grad;
|
||||
int act;
|
||||
float alpha;
|
||||
float gain;
|
||||
float clamp;
|
||||
|
||||
int sizeX;
|
||||
int sizeB;
|
||||
int stepB;
|
||||
int loopX;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
212
global_torch/torch_utils/ops/bias_act.py
Normal file
212
global_torch/torch_utils/ops/bias_act.py
Normal file
@ -0,0 +1,212 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom PyTorch ops for efficient bias and activation."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
import numpy as np
|
||||
import torch
|
||||
import dnnlib
|
||||
import traceback
|
||||
|
||||
from .. import custom_ops
|
||||
from .. import misc
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
activation_funcs = {
|
||||
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
||||
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
||||
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
||||
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
||||
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
||||
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
||||
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
||||
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
||||
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
||||
}
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_inited = False
|
||||
_plugin = None
|
||||
_null_tensor = torch.empty([0])
|
||||
|
||||
def _init():
|
||||
global _inited, _plugin
|
||||
if not _inited:
|
||||
_inited = True
|
||||
sources = ['bias_act.cpp', 'bias_act.cu']
|
||||
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
||||
try:
|
||||
_plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
||||
except:
|
||||
warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
||||
return _plugin is not None
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
||||
r"""Fused bias and activation function.
|
||||
|
||||
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
||||
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
||||
the fused op is considerably more efficient than performing the same calculation
|
||||
using standard PyTorch ops. It supports first and second order gradients,
|
||||
but not third order gradients.
|
||||
|
||||
Args:
|
||||
x: Input activation tensor. Can be of any shape.
|
||||
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
||||
as `x`. The shape must be known, and it must match the dimension of `x`
|
||||
corresponding to `dim`.
|
||||
dim: The dimension in `x` corresponding to the elements of `b`.
|
||||
The value of `dim` is ignored if `b` is not specified.
|
||||
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
||||
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
||||
See `activation_funcs` for a full list. `None` is not allowed.
|
||||
alpha: Shape parameter for the activation function, or `None` to use the default.
|
||||
gain: Scaling factor for the output tensor, or `None` to use default.
|
||||
See `activation_funcs` for the default scaling of each activation function.
|
||||
If unsure, consider specifying 1.
|
||||
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
||||
the clamping (default).
|
||||
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
||||
|
||||
Returns:
|
||||
Tensor of the same shape and datatype as `x`.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert impl in ['ref', 'cuda']
|
||||
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
||||
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
||||
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
||||
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert clamp is None or clamp >= 0
|
||||
spec = activation_funcs[act]
|
||||
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
||||
gain = float(gain if gain is not None else spec.def_gain)
|
||||
clamp = float(clamp if clamp is not None else -1)
|
||||
|
||||
# Add bias.
|
||||
if b is not None:
|
||||
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
||||
assert 0 <= dim < x.ndim
|
||||
assert b.shape[0] == x.shape[dim]
|
||||
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
||||
|
||||
# Evaluate activation function.
|
||||
alpha = float(alpha)
|
||||
x = spec.func(x, alpha=alpha)
|
||||
|
||||
# Scale by gain.
|
||||
gain = float(gain)
|
||||
if gain != 1:
|
||||
x = x * gain
|
||||
|
||||
# Clamp.
|
||||
if clamp >= 0:
|
||||
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_bias_act_cuda_cache = dict()
|
||||
|
||||
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
||||
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
||||
"""
|
||||
# Parse arguments.
|
||||
assert clamp is None or clamp >= 0
|
||||
spec = activation_funcs[act]
|
||||
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
||||
gain = float(gain if gain is not None else spec.def_gain)
|
||||
clamp = float(clamp if clamp is not None else -1)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (dim, act, alpha, gain, clamp)
|
||||
if key in _bias_act_cuda_cache:
|
||||
return _bias_act_cuda_cache[key]
|
||||
|
||||
# Forward op.
|
||||
class BiasActCuda(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
||||
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
|
||||
x = x.contiguous(memory_format=ctx.memory_format)
|
||||
b = b.contiguous() if b is not None else _null_tensor
|
||||
y = x
|
||||
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
||||
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
ctx.save_for_backward(
|
||||
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
||||
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
||||
y if 'y' in spec.ref else _null_tensor)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy): # pylint: disable=arguments-differ
|
||||
dy = dy.contiguous(memory_format=ctx.memory_format)
|
||||
x, b, y = ctx.saved_tensors
|
||||
dx = None
|
||||
db = None
|
||||
|
||||
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
||||
dx = dy
|
||||
if act != 'linear' or gain != 1 or clamp >= 0:
|
||||
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
||||
|
||||
return dx, db
|
||||
|
||||
# Backward op.
|
||||
class BiasActCudaGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
||||
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
|
||||
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
ctx.save_for_backward(
|
||||
dy if spec.has_2nd_grad else _null_tensor,
|
||||
x, b, y)
|
||||
return dx
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
||||
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
||||
dy, x, b, y = ctx.saved_tensors
|
||||
d_dy = None
|
||||
d_x = None
|
||||
d_b = None
|
||||
d_y = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
||||
|
||||
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
||||
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
|
||||
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
||||
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
||||
|
||||
return d_dy, d_x, d_b, d_y
|
||||
|
||||
# Add to cache.
|
||||
_bias_act_cuda_cache[key] = BiasActCuda
|
||||
return BiasActCuda
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
170
global_torch/torch_utils/ops/conv2d_gradfix.py
Normal file
170
global_torch/torch_utils/ops/conv2d_gradfix.py
Normal file
@ -0,0 +1,170 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom replacement for `torch.nn.functional.conv2d` that supports
|
||||
arbitrarily high order gradients with zero performance penalty."""
|
||||
|
||||
import warnings
|
||||
import contextlib
|
||||
import torch
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
# pylint: disable=arguments-differ
|
||||
# pylint: disable=protected-access
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
enabled = False # Enable the custom op by setting this to true.
|
||||
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_weight_gradients():
|
||||
global weight_gradients_disabled
|
||||
old = weight_gradients_disabled
|
||||
weight_gradients_disabled = True
|
||||
yield
|
||||
weight_gradients_disabled = old
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if _should_use_custom_op(input):
|
||||
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
|
||||
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||
if _should_use_custom_op(input):
|
||||
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
|
||||
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _should_use_custom_op(input):
|
||||
assert isinstance(input, torch.Tensor)
|
||||
if (not enabled) or (not torch.backends.cudnn.enabled):
|
||||
return False
|
||||
if input.device.type != 'cuda':
|
||||
return False
|
||||
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
|
||||
return True
|
||||
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
|
||||
return False
|
||||
|
||||
def _tuple_of_ints(xs, ndim):
|
||||
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
||||
assert len(xs) == ndim
|
||||
assert all(isinstance(x, int) for x in xs)
|
||||
return xs
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_conv2d_gradfix_cache = dict()
|
||||
|
||||
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
||||
# Parse arguments.
|
||||
ndim = 2
|
||||
weight_shape = tuple(weight_shape)
|
||||
stride = _tuple_of_ints(stride, ndim)
|
||||
padding = _tuple_of_ints(padding, ndim)
|
||||
output_padding = _tuple_of_ints(output_padding, ndim)
|
||||
dilation = _tuple_of_ints(dilation, ndim)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
||||
if key in _conv2d_gradfix_cache:
|
||||
return _conv2d_gradfix_cache[key]
|
||||
|
||||
# Validate arguments.
|
||||
assert groups >= 1
|
||||
assert len(weight_shape) == ndim + 2
|
||||
assert all(stride[i] >= 1 for i in range(ndim))
|
||||
assert all(padding[i] >= 0 for i in range(ndim))
|
||||
assert all(dilation[i] >= 0 for i in range(ndim))
|
||||
if not transpose:
|
||||
assert all(output_padding[i] == 0 for i in range(ndim))
|
||||
else: # transpose
|
||||
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
||||
|
||||
# Helpers.
|
||||
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
def calc_output_padding(input_shape, output_shape):
|
||||
if transpose:
|
||||
return [0, 0]
|
||||
return [
|
||||
input_shape[i + 2]
|
||||
- (output_shape[i + 2] - 1) * stride[i]
|
||||
- (1 - 2 * padding[i])
|
||||
- dilation[i] * (weight_shape[i + 2] - 1)
|
||||
for i in range(ndim)
|
||||
]
|
||||
|
||||
# Forward & backward.
|
||||
class Conv2d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias):
|
||||
assert weight.shape == weight_shape
|
||||
if not transpose:
|
||||
output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
||||
else: # transpose
|
||||
output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
|
||||
ctx.save_for_backward(input, weight)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = None
|
||||
grad_weight = None
|
||||
grad_bias = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
||||
grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
|
||||
assert grad_input.shape == input.shape
|
||||
|
||||
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
||||
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
||||
assert grad_weight.shape == weight_shape
|
||||
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = grad_output.sum([0, 2, 3])
|
||||
|
||||
return grad_input, grad_weight, grad_bias
|
||||
|
||||
# Gradient with respect to the weights.
|
||||
class Conv2dGradWeight(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input):
|
||||
op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
|
||||
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
|
||||
grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
|
||||
assert grad_weight.shape == weight_shape
|
||||
ctx.save_for_backward(grad_output, input)
|
||||
return grad_weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad2_grad_weight):
|
||||
grad_output, input = ctx.saved_tensors
|
||||
grad2_grad_output = None
|
||||
grad2_input = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
|
||||
assert grad2_grad_output.shape == grad_output.shape
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
||||
grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
|
||||
assert grad2_input.shape == input.shape
|
||||
|
||||
return grad2_grad_output, grad2_input
|
||||
|
||||
_conv2d_gradfix_cache[key] = Conv2d
|
||||
return Conv2d
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
156
global_torch/torch_utils/ops/conv2d_resample.py
Normal file
156
global_torch/torch_utils/ops/conv2d_resample.py
Normal file
@ -0,0 +1,156 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""2D convolution with optional up/downsampling."""
|
||||
|
||||
import torch
|
||||
|
||||
from .. import misc
|
||||
from . import conv2d_gradfix
|
||||
from . import upfirdn2d
|
||||
from .upfirdn2d import _parse_padding
|
||||
from .upfirdn2d import _get_filter_size
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _get_weight_shape(w):
|
||||
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
||||
shape = [int(sz) for sz in w.shape]
|
||||
misc.assert_shape(w, shape)
|
||||
return shape
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
|
||||
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
||||
"""
|
||||
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
||||
|
||||
# Flip weight if requested.
|
||||
if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
||||
w = w.flip([2, 3])
|
||||
|
||||
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
|
||||
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
|
||||
if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
|
||||
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
|
||||
if out_channels <= 4 and groups == 1:
|
||||
in_shape = x.shape
|
||||
x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
|
||||
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
|
||||
else:
|
||||
x = x.to(memory_format=torch.contiguous_format)
|
||||
w = w.to(memory_format=torch.contiguous_format)
|
||||
x = conv2d_gradfix.conv2d(x, w, groups=groups)
|
||||
return x.to(memory_format=torch.channels_last)
|
||||
|
||||
# Otherwise => execute using conv2d_gradfix.
|
||||
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
|
||||
return op(x, w, stride=stride, padding=padding, groups=groups)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
|
||||
r"""2D convolution with optional up/downsampling.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape
|
||||
`[batch_size, in_channels, in_height, in_width]`.
|
||||
w: Weight tensor of shape
|
||||
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
||||
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
||||
calling upfirdn2d.setup_filter(). None = identity (default).
|
||||
up: Integer upsampling factor (default: 1).
|
||||
down: Integer downsampling factor (default: 1).
|
||||
padding: Padding with respect to the upsampled image. Can be a single number
|
||||
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
groups: Split input channels into N groups (default: 1).
|
||||
flip_weight: False = convolution, True = correlation (default: True).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
# Validate arguments.
|
||||
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
||||
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
||||
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
||||
assert isinstance(up, int) and (up >= 1)
|
||||
assert isinstance(down, int) and (down >= 1)
|
||||
assert isinstance(groups, int) and (groups >= 1)
|
||||
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
||||
fw, fh = _get_filter_size(f)
|
||||
px0, px1, py0, py1 = _parse_padding(padding)
|
||||
|
||||
# Adjust padding to account for up/downsampling.
|
||||
if up > 1:
|
||||
px0 += (fw + up - 1) // 2
|
||||
px1 += (fw - up) // 2
|
||||
py0 += (fh + up - 1) // 2
|
||||
py1 += (fh - up) // 2
|
||||
if down > 1:
|
||||
px0 += (fw - down + 1) // 2
|
||||
px1 += (fw - down) // 2
|
||||
py0 += (fh - down + 1) // 2
|
||||
py1 += (fh - down) // 2
|
||||
|
||||
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
||||
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||
return x
|
||||
|
||||
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
||||
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
||||
return x
|
||||
|
||||
# Fast path: downsampling only => use strided convolution.
|
||||
if down > 1 and up == 1:
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
||||
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
|
||||
return x
|
||||
|
||||
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
||||
if up > 1:
|
||||
if groups == 1:
|
||||
w = w.transpose(0, 1)
|
||||
else:
|
||||
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
||||
w = w.transpose(1, 2)
|
||||
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
||||
px0 -= kw - 1
|
||||
px1 -= kw - up
|
||||
py0 -= kh - 1
|
||||
py1 -= kh - up
|
||||
pxt = max(min(-px0, -px1), 0)
|
||||
pyt = max(min(-py0, -py1), 0)
|
||||
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
|
||||
if down > 1:
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
||||
return x
|
||||
|
||||
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
||||
if up == 1 and down == 1:
|
||||
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
||||
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
|
||||
|
||||
# Fallback: Generic reference implementation.
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||
if down > 1:
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
60
global_torch/torch_utils/ops/fma.py
Normal file
60
global_torch/torch_utils/ops/fma.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
|
||||
|
||||
import torch
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def fma(a, b, c): # => a * b + c
|
||||
return _FusedMultiplyAdd.apply(a, b, c)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
||||
out = torch.addcmul(c, a, b)
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.c_shape = c.shape
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout): # pylint: disable=arguments-differ
|
||||
a, b = ctx.saved_tensors
|
||||
c_shape = ctx.c_shape
|
||||
da = None
|
||||
db = None
|
||||
dc = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
da = _unbroadcast(dout * b, a.shape)
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
db = _unbroadcast(dout * a, b.shape)
|
||||
|
||||
if ctx.needs_input_grad[2]:
|
||||
dc = _unbroadcast(dout, c_shape)
|
||||
|
||||
return da, db, dc
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _unbroadcast(x, shape):
|
||||
extra_dims = x.ndim - len(shape)
|
||||
assert extra_dims >= 0
|
||||
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
||||
if len(dim):
|
||||
x = x.sum(dim=dim, keepdim=True)
|
||||
if extra_dims:
|
||||
x = x.reshape(-1, *x.shape[extra_dims+1:])
|
||||
assert x.shape == shape
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
34
global_torch/torch_utils/ops/fused_act.py
Normal file
34
global_torch/torch_utils/ops/fused_act.py
Normal file
@ -0,0 +1,34 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
module_path = os.path.dirname(__file__)
|
||||
|
||||
|
||||
|
||||
class FusedLeakyReLU(nn.Module):
|
||||
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
||||
super().__init__()
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(channel))
|
||||
self.negative_slope = negative_slope
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input):
|
||||
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
||||
|
||||
|
||||
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
||||
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
||||
input = input.cuda()
|
||||
return (
|
||||
F.leaky_relu(
|
||||
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
|
||||
)
|
||||
* scale
|
||||
)
|
||||
|
||||
83
global_torch/torch_utils/ops/grid_sample_gradfix.py
Normal file
83
global_torch/torch_utils/ops/grid_sample_gradfix.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom replacement for `torch.nn.functional.grid_sample` that
|
||||
supports arbitrarily high order gradients between the input and output.
|
||||
Only works on 2D images and assumes
|
||||
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
|
||||
|
||||
import warnings
|
||||
import torch
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
# pylint: disable=arguments-differ
|
||||
# pylint: disable=protected-access
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
enabled = False # Enable the custom op by setting this to true.
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def grid_sample(input, grid):
|
||||
if _should_use_custom_op():
|
||||
return _GridSample2dForward.apply(input, grid)
|
||||
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _should_use_custom_op():
|
||||
if not enabled:
|
||||
return False
|
||||
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
|
||||
return True
|
||||
warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
|
||||
return False
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _GridSample2dForward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, grid):
|
||||
assert input.ndim == 4
|
||||
assert grid.ndim == 4
|
||||
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
ctx.save_for_backward(input, grid)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, grid = ctx.saved_tensors
|
||||
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
|
||||
return grad_input, grad_grid
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _GridSample2dBackward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input, grid):
|
||||
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
|
||||
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
||||
ctx.save_for_backward(grid)
|
||||
return grad_input, grad_grid
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
||||
_ = grad2_grad_grid # unused
|
||||
grid, = ctx.saved_tensors
|
||||
grad2_grad_output = None
|
||||
grad2_input = None
|
||||
grad2_grid = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
|
||||
|
||||
assert not ctx.needs_input_grad[2]
|
||||
return grad2_grad_output, grad2_input, grad2_grid
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
103
global_torch/torch_utils/ops/upfirdn2d.cpp
Normal file
103
global_torch/torch_utils/ops/upfirdn2d.cpp
Normal file
@ -0,0 +1,103 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "upfirdn2d.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
|
||||
{
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
|
||||
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
|
||||
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
|
||||
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
||||
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
|
||||
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
|
||||
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
|
||||
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
|
||||
|
||||
// Create output tensor.
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
|
||||
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
|
||||
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
|
||||
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
|
||||
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
|
||||
|
||||
// Initialize CUDA kernel parameters.
|
||||
upfirdn2d_kernel_params p;
|
||||
p.x = x.data_ptr();
|
||||
p.f = f.data_ptr<float>();
|
||||
p.y = y.data_ptr();
|
||||
p.up = make_int2(upx, upy);
|
||||
p.down = make_int2(downx, downy);
|
||||
p.pad0 = make_int2(padx0, pady0);
|
||||
p.flip = (flip) ? 1 : 0;
|
||||
p.gain = gain;
|
||||
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
||||
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
|
||||
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
|
||||
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
|
||||
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
||||
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
|
||||
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
|
||||
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
|
||||
|
||||
// Choose CUDA kernel.
|
||||
upfirdn2d_kernel_spec spec;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
||||
{
|
||||
spec = choose_upfirdn2d_kernel<scalar_t>(p);
|
||||
});
|
||||
|
||||
// Set looping options.
|
||||
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
|
||||
p.loopMinor = spec.loopMinor;
|
||||
p.loopX = spec.loopX;
|
||||
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
|
||||
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
|
||||
|
||||
// Compute grid size.
|
||||
dim3 blockSize, gridSize;
|
||||
if (spec.tileOutW < 0) // large
|
||||
{
|
||||
blockSize = dim3(4, 32, 1);
|
||||
gridSize = dim3(
|
||||
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
|
||||
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
|
||||
p.launchMajor);
|
||||
}
|
||||
else // small
|
||||
{
|
||||
blockSize = dim3(256, 1, 1);
|
||||
gridSize = dim3(
|
||||
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
|
||||
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
|
||||
p.launchMajor);
|
||||
}
|
||||
|
||||
// Launch CUDA kernel.
|
||||
void* args[] = {&p};
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
return y;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("upfirdn2d", &upfirdn2d);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
350
global_torch/torch_utils/ops/upfirdn2d.cu
Normal file
350
global_torch/torch_utils/ops/upfirdn2d.cu
Normal file
@ -0,0 +1,350 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <c10/util/Half.h>
|
||||
#include "upfirdn2d.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Helpers.
|
||||
|
||||
template <class T> struct InternalType;
|
||||
template <> struct InternalType<double> { typedef double scalar_t; };
|
||||
template <> struct InternalType<float> { typedef float scalar_t; };
|
||||
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
||||
|
||||
static __device__ __forceinline__ int floor_div(int a, int b)
|
||||
{
|
||||
int t = 1 - a / b;
|
||||
return (a + t * b) / b - t;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Generic CUDA implementation for large filters.
|
||||
|
||||
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
|
||||
{
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
|
||||
// Calculate thread index.
|
||||
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int outY = minorBase / p.launchMinor;
|
||||
minorBase -= outY * p.launchMinor;
|
||||
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
||||
int majorBase = blockIdx.z * p.loopMajor;
|
||||
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
|
||||
return;
|
||||
|
||||
// Setup Y receptive field.
|
||||
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
|
||||
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
|
||||
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
|
||||
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
|
||||
if (p.flip)
|
||||
filterY = p.filterSize.y - 1 - filterY;
|
||||
|
||||
// Loop over major, minor, and X.
|
||||
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
||||
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
|
||||
{
|
||||
int nc = major * p.sizeMinor + minor;
|
||||
int n = nc / p.inSize.z;
|
||||
int c = nc - n * p.inSize.z;
|
||||
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
|
||||
{
|
||||
// Setup X receptive field.
|
||||
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
|
||||
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
|
||||
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
|
||||
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
|
||||
if (p.flip)
|
||||
filterX = p.filterSize.x - 1 - filterX;
|
||||
|
||||
// Initialize pointers.
|
||||
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
||||
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
|
||||
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
|
||||
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
|
||||
|
||||
// Inner loop.
|
||||
scalar_t v = 0;
|
||||
for (int y = 0; y < h; y++)
|
||||
{
|
||||
for (int x = 0; x < w; x++)
|
||||
{
|
||||
v += (scalar_t)(*xp) * (scalar_t)(*fp);
|
||||
xp += p.inStride.x;
|
||||
fp += filterStepX;
|
||||
}
|
||||
xp += p.inStride.y - w * p.inStride.x;
|
||||
fp += filterStepY - w * filterStepX;
|
||||
}
|
||||
|
||||
// Store result.
|
||||
v *= p.gain;
|
||||
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Specialized CUDA implementation for small filters.
|
||||
|
||||
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
|
||||
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
|
||||
{
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
|
||||
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
|
||||
__shared__ volatile scalar_t sf[filterH][filterW];
|
||||
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
|
||||
|
||||
// Calculate tile index.
|
||||
int minorBase = blockIdx.x;
|
||||
int tileOutY = minorBase / p.launchMinor;
|
||||
minorBase -= tileOutY * p.launchMinor;
|
||||
minorBase *= loopMinor;
|
||||
tileOutY *= tileOutH;
|
||||
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
||||
int majorBase = blockIdx.z * p.loopMajor;
|
||||
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
|
||||
return;
|
||||
|
||||
// Load filter (flipped).
|
||||
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
|
||||
{
|
||||
int fy = tapIdx / filterW;
|
||||
int fx = tapIdx - fy * filterW;
|
||||
scalar_t v = 0;
|
||||
if (fx < p.filterSize.x & fy < p.filterSize.y)
|
||||
{
|
||||
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
|
||||
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
|
||||
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
|
||||
}
|
||||
sf[fy][fx] = v;
|
||||
}
|
||||
|
||||
// Loop over major and X.
|
||||
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
||||
{
|
||||
int baseNC = major * p.sizeMinor + minorBase;
|
||||
int n = baseNC / p.inSize.z;
|
||||
int baseC = baseNC - n * p.inSize.z;
|
||||
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
|
||||
{
|
||||
// Load input pixels.
|
||||
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
|
||||
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
|
||||
int tileInX = floor_div(tileMidX, upx);
|
||||
int tileInY = floor_div(tileMidY, upy);
|
||||
__syncthreads();
|
||||
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
|
||||
{
|
||||
int relC = inIdx;
|
||||
int relInX = relC / loopMinor;
|
||||
int relInY = relInX / tileInW;
|
||||
relC -= relInX * loopMinor;
|
||||
relInX -= relInY * tileInW;
|
||||
int c = baseC + relC;
|
||||
int inX = tileInX + relInX;
|
||||
int inY = tileInY + relInY;
|
||||
scalar_t v = 0;
|
||||
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
|
||||
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
||||
sx[relInY][relInX][relC] = v;
|
||||
}
|
||||
|
||||
// Loop over output pixels.
|
||||
__syncthreads();
|
||||
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
|
||||
{
|
||||
int relC = outIdx;
|
||||
int relOutX = relC / loopMinor;
|
||||
int relOutY = relOutX / tileOutW;
|
||||
relC -= relOutX * loopMinor;
|
||||
relOutX -= relOutY * tileOutW;
|
||||
int c = baseC + relC;
|
||||
int outX = tileOutX + relOutX;
|
||||
int outY = tileOutY + relOutY;
|
||||
|
||||
// Setup receptive field.
|
||||
int midX = tileMidX + relOutX * downx;
|
||||
int midY = tileMidY + relOutY * downy;
|
||||
int inX = floor_div(midX, upx);
|
||||
int inY = floor_div(midY, upy);
|
||||
int relInX = inX - tileInX;
|
||||
int relInY = inY - tileInY;
|
||||
int filterX = (inX + 1) * upx - midX - 1; // flipped
|
||||
int filterY = (inY + 1) * upy - midY - 1; // flipped
|
||||
|
||||
// Inner loop.
|
||||
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
|
||||
{
|
||||
scalar_t v = 0;
|
||||
#pragma unroll
|
||||
for (int y = 0; y < filterH / upy; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < filterW / upx; x++)
|
||||
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
|
||||
v *= p.gain;
|
||||
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
|
||||
{
|
||||
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
|
||||
|
||||
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
|
||||
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
|
||||
|
||||
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
||||
{
|
||||
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
|
||||
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
||||
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
|
||||
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
||||
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
|
||||
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
||||
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
||||
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
||||
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
||||
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
||||
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
||||
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
||||
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
||||
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
||||
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
||||
}
|
||||
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
||||
{
|
||||
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
|
||||
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
||||
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
||||
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
||||
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
||||
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
||||
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
||||
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
||||
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
||||
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
||||
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
||||
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
||||
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
||||
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
||||
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
||||
}
|
||||
if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
||||
{
|
||||
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
|
||||
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
||||
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
||||
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
|
||||
}
|
||||
if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
||||
{
|
||||
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
|
||||
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
||||
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
||||
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
|
||||
}
|
||||
if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
||||
{
|
||||
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
||||
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
||||
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
||||
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
||||
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
||||
}
|
||||
if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
||||
{
|
||||
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
||||
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
||||
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
||||
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
||||
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
||||
}
|
||||
if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
||||
{
|
||||
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
||||
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
||||
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
||||
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
||||
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
||||
}
|
||||
if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
||||
{
|
||||
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
||||
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
||||
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
||||
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
||||
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
||||
}
|
||||
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
|
||||
{
|
||||
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
|
||||
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
|
||||
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
|
||||
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
|
||||
}
|
||||
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
|
||||
{
|
||||
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
|
||||
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
|
||||
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
|
||||
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
|
||||
}
|
||||
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
|
||||
{
|
||||
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
|
||||
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
|
||||
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
|
||||
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
|
||||
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
|
||||
}
|
||||
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
|
||||
{
|
||||
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
|
||||
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
|
||||
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
|
||||
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
|
||||
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
|
||||
}
|
||||
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
|
||||
{
|
||||
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
|
||||
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
|
||||
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
|
||||
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
|
||||
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
|
||||
}
|
||||
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
|
||||
{
|
||||
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
|
||||
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
|
||||
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
|
||||
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
|
||||
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
|
||||
}
|
||||
return spec;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Template specializations.
|
||||
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
59
global_torch/torch_utils/ops/upfirdn2d.h
Normal file
59
global_torch/torch_utils/ops/upfirdn2d.h
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel parameters.
|
||||
|
||||
struct upfirdn2d_kernel_params
|
||||
{
|
||||
const void* x;
|
||||
const float* f;
|
||||
void* y;
|
||||
|
||||
int2 up;
|
||||
int2 down;
|
||||
int2 pad0;
|
||||
int flip;
|
||||
float gain;
|
||||
|
||||
int4 inSize; // [width, height, channel, batch]
|
||||
int4 inStride;
|
||||
int2 filterSize; // [width, height]
|
||||
int2 filterStride;
|
||||
int4 outSize; // [width, height, channel, batch]
|
||||
int4 outStride;
|
||||
int sizeMinor;
|
||||
int sizeMajor;
|
||||
|
||||
int loopMinor;
|
||||
int loopMajor;
|
||||
int loopX;
|
||||
int launchMinor;
|
||||
int launchMajor;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel specialization.
|
||||
|
||||
struct upfirdn2d_kernel_spec
|
||||
{
|
||||
void* kernel;
|
||||
int tileOutW;
|
||||
int tileOutH;
|
||||
int loopMinor;
|
||||
int loopX;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
384
global_torch/torch_utils/ops/upfirdn2d.py
Normal file
384
global_torch/torch_utils/ops/upfirdn2d.py
Normal file
@ -0,0 +1,384 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom PyTorch ops for efficient resampling of 2D images."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
import numpy as np
|
||||
import torch
|
||||
import traceback
|
||||
|
||||
from .. import custom_ops
|
||||
from .. import misc
|
||||
from . import conv2d_gradfix
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_inited = False
|
||||
_plugin = None
|
||||
|
||||
def _init():
|
||||
global _inited, _plugin
|
||||
if not _inited:
|
||||
sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
|
||||
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
||||
try:
|
||||
_plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
||||
except:
|
||||
warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
||||
return _plugin is not None
|
||||
|
||||
def _parse_scaling(scaling):
|
||||
if isinstance(scaling, int):
|
||||
scaling = [scaling, scaling]
|
||||
assert isinstance(scaling, (list, tuple))
|
||||
assert all(isinstance(x, int) for x in scaling)
|
||||
sx, sy = scaling
|
||||
assert sx >= 1 and sy >= 1
|
||||
return sx, sy
|
||||
|
||||
def _parse_padding(padding):
|
||||
if isinstance(padding, int):
|
||||
padding = [padding, padding]
|
||||
assert isinstance(padding, (list, tuple))
|
||||
assert all(isinstance(x, int) for x in padding)
|
||||
if len(padding) == 2:
|
||||
padx, pady = padding
|
||||
padding = [padx, padx, pady, pady]
|
||||
padx0, padx1, pady0, pady1 = padding
|
||||
return padx0, padx1, pady0, pady1
|
||||
|
||||
def _get_filter_size(f):
|
||||
if f is None:
|
||||
return 1, 1
|
||||
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
||||
fw = f.shape[-1]
|
||||
fh = f.shape[0]
|
||||
with misc.suppress_tracer_warnings():
|
||||
fw = int(fw)
|
||||
fh = int(fh)
|
||||
misc.assert_shape(f, [fh, fw][:f.ndim])
|
||||
assert fw >= 1 and fh >= 1
|
||||
return fw, fh
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
||||
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
||||
|
||||
Args:
|
||||
f: Torch tensor, numpy array, or python list of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable),
|
||||
`[]` (impulse), or
|
||||
`None` (identity).
|
||||
device: Result device (default: cpu).
|
||||
normalize: Normalize the filter so that it retains the magnitude
|
||||
for constant input signal (DC)? (default: True).
|
||||
flip_filter: Flip the filter? (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
separable: Return a separable filter? (default: select automatically).
|
||||
|
||||
Returns:
|
||||
Float32 tensor of the shape
|
||||
`[filter_height, filter_width]` (non-separable) or
|
||||
`[filter_taps]` (separable).
|
||||
"""
|
||||
# Validate.
|
||||
if f is None:
|
||||
f = 1
|
||||
f = torch.as_tensor(f, dtype=torch.float32)
|
||||
assert f.ndim in [0, 1, 2]
|
||||
assert f.numel() > 0
|
||||
if f.ndim == 0:
|
||||
f = f[np.newaxis]
|
||||
|
||||
# Separable?
|
||||
if separable is None:
|
||||
separable = (f.ndim == 1 and f.numel() >= 8)
|
||||
if f.ndim == 1 and not separable:
|
||||
f = f.ger(f)
|
||||
assert f.ndim == (1 if separable else 2)
|
||||
|
||||
# Apply normalize, flip, gain, and device.
|
||||
if normalize:
|
||||
f /= f.sum()
|
||||
if flip_filter:
|
||||
f = f.flip(list(range(f.ndim)))
|
||||
f = f * (gain ** (f.ndim / 2))
|
||||
f = f.to(device=device)
|
||||
return f
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
||||
|
||||
Performs the following sequence of operations for each channel:
|
||||
|
||||
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
||||
|
||||
2. Pad the image with the specified number of zeros on each side (`padding`).
|
||||
Negative padding corresponds to cropping the image.
|
||||
|
||||
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
||||
so that the footprint of all output pixels lies within the input image.
|
||||
|
||||
4. Downsample the image by keeping every Nth pixel (`down`).
|
||||
|
||||
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
||||
The fused op is considerably more efficient than performing the same calculation
|
||||
using standard PyTorch ops. It supports gradients of arbitrary order.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
up: Integer upsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
down: Integer downsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
padding: Padding with respect to the upsampled image. Can be a single number
|
||||
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert impl in ['ref', 'cuda']
|
||||
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
||||
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
||||
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
||||
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
||||
"""
|
||||
# Validate arguments.
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
if f is None:
|
||||
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
||||
assert f.dtype == torch.float32 and not f.requires_grad
|
||||
batch_size, num_channels, in_height, in_width = x.shape
|
||||
upx, upy = _parse_scaling(up)
|
||||
downx, downy = _parse_scaling(down)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
|
||||
# Upsample by inserting zeros.
|
||||
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
||||
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
||||
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
||||
|
||||
# Pad or crop.
|
||||
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
||||
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
||||
|
||||
# Setup filter.
|
||||
f = f * (gain ** (f.ndim / 2))
|
||||
f = f.to(x.dtype)
|
||||
if not flip_filter:
|
||||
f = f.flip(list(range(f.ndim)))
|
||||
|
||||
# Convolve with the filter.
|
||||
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
||||
if f.ndim == 4:
|
||||
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
||||
else:
|
||||
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
||||
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
||||
|
||||
# Downsample by throwing away pixels.
|
||||
x = x[:, :, ::downy, ::downx]
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_upfirdn2d_cuda_cache = dict()
|
||||
|
||||
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
|
||||
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
|
||||
"""
|
||||
# Parse arguments.
|
||||
upx, upy = _parse_scaling(up)
|
||||
downx, downy = _parse_scaling(down)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
||||
if key in _upfirdn2d_cuda_cache:
|
||||
return _upfirdn2d_cuda_cache[key]
|
||||
|
||||
# Forward op.
|
||||
class Upfirdn2dCuda(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, f): # pylint: disable=arguments-differ
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
if f is None:
|
||||
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
||||
y = x
|
||||
if f.ndim == 2:
|
||||
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
||||
else:
|
||||
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
|
||||
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
|
||||
ctx.save_for_backward(f)
|
||||
ctx.x_shape = x.shape
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy): # pylint: disable=arguments-differ
|
||||
f, = ctx.saved_tensors
|
||||
_, _, ih, iw = ctx.x_shape
|
||||
_, _, oh, ow = dy.shape
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
fw - padx0 - 1,
|
||||
iw * upx - ow * downx + padx0 - upx + 1,
|
||||
fh - pady0 - 1,
|
||||
ih * upy - oh * downy + pady0 - upy + 1,
|
||||
]
|
||||
dx = None
|
||||
df = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
|
||||
|
||||
assert not ctx.needs_input_grad[1]
|
||||
return dx, df
|
||||
|
||||
# Add to cache.
|
||||
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
|
||||
return Upfirdn2dCuda
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Filter a batch of 2D images using the given 2D FIR filter.
|
||||
|
||||
By default, the result is padded so that its shape matches the input.
|
||||
User-specified padding is applied on top of that, with negative values
|
||||
indicating cropping. Pixels outside the image are assumed to be zero.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
padding: Padding with respect to the output. Can be a single number or a
|
||||
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
padx0 + fw // 2,
|
||||
padx1 + (fw - 1) // 2,
|
||||
pady0 + fh // 2,
|
||||
pady1 + (fh - 1) // 2,
|
||||
]
|
||||
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
||||
|
||||
By default, the result is padded so that its shape is a multiple of the input.
|
||||
User-specified padding is applied on top of that, with negative values
|
||||
indicating cropping. Pixels outside the image are assumed to be zero.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
up: Integer upsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
padding: Padding with respect to the output. Can be a single number or a
|
||||
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
upx, upy = _parse_scaling(up)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
padx0 + (fw + upx - 1) // 2,
|
||||
padx1 + (fw - upx) // 2,
|
||||
pady0 + (fh + upy - 1) // 2,
|
||||
pady1 + (fh - upy) // 2,
|
||||
]
|
||||
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
||||
|
||||
By default, the result is padded so that its shape is a fraction of the input.
|
||||
User-specified padding is applied on top of that, with negative values
|
||||
indicating cropping. Pixels outside the image are assumed to be zero.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
down: Integer downsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
padding: Padding with respect to the input. Can be a single number or a
|
||||
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
downx, downy = _parse_scaling(down)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
padx0 + (fw - downx + 1) // 2,
|
||||
padx1 + (fw - downx) // 2,
|
||||
pady0 + (fh - downy + 1) // 2,
|
||||
pady1 + (fh - downy) // 2,
|
||||
]
|
||||
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
251
global_torch/torch_utils/persistence.py
Normal file
251
global_torch/torch_utils/persistence.py
Normal file
@ -0,0 +1,251 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Facilities for pickling Python code alongside other data.
|
||||
|
||||
The pickled code is automatically imported into a separate Python module
|
||||
during unpickling. This way, any previously exported pickles will remain
|
||||
usable even if the original code is no longer available, or if the current
|
||||
version of the code is not consistent with what was originally pickled."""
|
||||
|
||||
import sys
|
||||
import pickle
|
||||
import io
|
||||
import inspect
|
||||
import copy
|
||||
import uuid
|
||||
import types
|
||||
import dnnlib
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_version = 6 # internal version number
|
||||
_decorators = set() # {decorator_class, ...}
|
||||
_import_hooks = [] # [hook_function, ...]
|
||||
_module_to_src_dict = dict() # {module: src, ...}
|
||||
_src_to_module_dict = dict() # {src: module, ...}
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def persistent_class(orig_class):
|
||||
r"""Class decorator that extends a given class to save its source code
|
||||
when pickled.
|
||||
|
||||
Example:
|
||||
|
||||
from torch_utils import persistence
|
||||
|
||||
@persistence.persistent_class
|
||||
class MyNetwork(torch.nn.Module):
|
||||
def __init__(self, num_inputs, num_outputs):
|
||||
super().__init__()
|
||||
self.fc = MyLayer(num_inputs, num_outputs)
|
||||
...
|
||||
|
||||
@persistence.persistent_class
|
||||
class MyLayer(torch.nn.Module):
|
||||
...
|
||||
|
||||
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
|
||||
source code alongside other internal state (e.g., parameters, buffers,
|
||||
and submodules). This way, any previously exported pickle will remain
|
||||
usable even if the class definitions have been modified or are no
|
||||
longer available.
|
||||
|
||||
The decorator saves the source code of the entire Python module
|
||||
containing the decorated class. It does *not* save the source code of
|
||||
any imported modules. Thus, the imported modules must be available
|
||||
during unpickling, also including `torch_utils.persistence` itself.
|
||||
|
||||
It is ok to call functions defined in the same module from the
|
||||
decorated class. However, if the decorated class depends on other
|
||||
classes defined in the same module, they must be decorated as well.
|
||||
This is illustrated in the above example in the case of `MyLayer`.
|
||||
|
||||
It is also possible to employ the decorator just-in-time before
|
||||
calling the constructor. For example:
|
||||
|
||||
cls = MyLayer
|
||||
if want_to_make_it_persistent:
|
||||
cls = persistence.persistent_class(cls)
|
||||
layer = cls(num_inputs, num_outputs)
|
||||
|
||||
As an additional feature, the decorator also keeps track of the
|
||||
arguments that were used to construct each instance of the decorated
|
||||
class. The arguments can be queried via `obj.init_args` and
|
||||
`obj.init_kwargs`, and they are automatically pickled alongside other
|
||||
object state. A typical use case is to first unpickle a previous
|
||||
instance of a persistent class, and then upgrade it to use the latest
|
||||
version of the source code:
|
||||
|
||||
with open('old_pickle.pkl', 'rb') as f:
|
||||
old_net = pickle.load(f)
|
||||
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
|
||||
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
|
||||
"""
|
||||
assert isinstance(orig_class, type)
|
||||
if is_persistent(orig_class):
|
||||
return orig_class
|
||||
|
||||
assert orig_class.__module__ in sys.modules
|
||||
orig_module = sys.modules[orig_class.__module__]
|
||||
orig_module_src = _module_to_src(orig_module)
|
||||
|
||||
class Decorator(orig_class):
|
||||
_orig_module_src = orig_module_src
|
||||
_orig_class_name = orig_class.__name__
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._init_args = copy.deepcopy(args)
|
||||
self._init_kwargs = copy.deepcopy(kwargs)
|
||||
assert orig_class.__name__ in orig_module.__dict__
|
||||
_check_pickleable(self.__reduce__())
|
||||
|
||||
@property
|
||||
def init_args(self):
|
||||
return copy.deepcopy(self._init_args)
|
||||
|
||||
@property
|
||||
def init_kwargs(self):
|
||||
return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
|
||||
|
||||
def __reduce__(self):
|
||||
fields = list(super().__reduce__())
|
||||
fields += [None] * max(3 - len(fields), 0)
|
||||
if fields[0] is not _reconstruct_persistent_obj:
|
||||
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
|
||||
fields[0] = _reconstruct_persistent_obj # reconstruct func
|
||||
fields[1] = (meta,) # reconstruct args
|
||||
fields[2] = None # state dict
|
||||
return tuple(fields)
|
||||
|
||||
Decorator.__name__ = orig_class.__name__
|
||||
_decorators.add(Decorator)
|
||||
return Decorator
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def is_persistent(obj):
|
||||
r"""Test whether the given object or class is persistent, i.e.,
|
||||
whether it will save its source code when pickled.
|
||||
"""
|
||||
try:
|
||||
if obj in _decorators:
|
||||
return True
|
||||
except TypeError:
|
||||
pass
|
||||
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def import_hook(hook):
|
||||
r"""Register an import hook that is called whenever a persistent object
|
||||
is being unpickled. A typical use case is to patch the pickled source
|
||||
code to avoid errors and inconsistencies when the API of some imported
|
||||
module has changed.
|
||||
|
||||
The hook should have the following signature:
|
||||
|
||||
hook(meta) -> modified meta
|
||||
|
||||
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
|
||||
|
||||
type: Type of the persistent object, e.g. `'class'`.
|
||||
version: Internal version number of `torch_utils.persistence`.
|
||||
module_src Original source code of the Python module.
|
||||
class_name: Class name in the original Python module.
|
||||
state: Internal state of the object.
|
||||
|
||||
Example:
|
||||
|
||||
@persistence.import_hook
|
||||
def wreck_my_network(meta):
|
||||
if meta.class_name == 'MyNetwork':
|
||||
print('MyNetwork is being imported. I will wreck it!')
|
||||
meta.module_src = meta.module_src.replace("True", "False")
|
||||
return meta
|
||||
"""
|
||||
assert callable(hook)
|
||||
_import_hooks.append(hook)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _reconstruct_persistent_obj(meta):
|
||||
r"""Hook that is called internally by the `pickle` module to unpickle
|
||||
a persistent object.
|
||||
"""
|
||||
meta = dnnlib.EasyDict(meta)
|
||||
meta.state = dnnlib.EasyDict(meta.state)
|
||||
for hook in _import_hooks:
|
||||
meta = hook(meta)
|
||||
assert meta is not None
|
||||
|
||||
assert meta.version == _version
|
||||
module = _src_to_module(meta.module_src)
|
||||
|
||||
assert meta.type == 'class'
|
||||
orig_class = module.__dict__[meta.class_name]
|
||||
decorator_class = persistent_class(orig_class)
|
||||
obj = decorator_class.__new__(decorator_class)
|
||||
|
||||
setstate = getattr(obj, '__setstate__', None)
|
||||
if callable(setstate):
|
||||
setstate(meta.state) # pylint: disable=not-callable
|
||||
else:
|
||||
obj.__dict__.update(meta.state)
|
||||
return obj
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _module_to_src(module):
|
||||
r"""Query the source code of a given Python module.
|
||||
"""
|
||||
src = _module_to_src_dict.get(module, None)
|
||||
if src is None:
|
||||
src = inspect.getsource(module)
|
||||
_module_to_src_dict[module] = src
|
||||
_src_to_module_dict[src] = module
|
||||
return src
|
||||
|
||||
def _src_to_module(src):
|
||||
r"""Get or create a Python module for the given source code.
|
||||
"""
|
||||
module = _src_to_module_dict.get(src, None)
|
||||
if module is None:
|
||||
module_name = "_imported_module_" + uuid.uuid4().hex
|
||||
module = types.ModuleType(module_name)
|
||||
sys.modules[module_name] = module
|
||||
_module_to_src_dict[module] = src
|
||||
_src_to_module_dict[src] = module
|
||||
exec(src, module.__dict__) # pylint: disable=exec-used
|
||||
return module
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _check_pickleable(obj):
|
||||
r"""Check that the given object is pickleable, raising an exception if
|
||||
it is not. This function is expected to be considerably more efficient
|
||||
than actually pickling the object.
|
||||
"""
|
||||
def recurse(obj):
|
||||
if isinstance(obj, (list, tuple, set)):
|
||||
return [recurse(x) for x in obj]
|
||||
if isinstance(obj, dict):
|
||||
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
||||
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
||||
return None # Python primitive types are pickleable.
|
||||
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
|
||||
return None # NumPy arrays and PyTorch tensors are pickleable.
|
||||
if is_persistent(obj):
|
||||
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
||||
return obj
|
||||
with io.BytesIO() as f:
|
||||
pickle.dump(recurse(obj), f)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
268
global_torch/torch_utils/training_stats.py
Normal file
268
global_torch/torch_utils/training_stats.py
Normal file
@ -0,0 +1,268 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Facilities for reporting and collecting training statistics across
|
||||
multiple processes and devices. The interface is designed to minimize
|
||||
synchronization overhead as well as the amount of boilerplate in user
|
||||
code."""
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
import torch
|
||||
import dnnlib
|
||||
|
||||
from . import misc
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
|
||||
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
|
||||
_counter_dtype = torch.float64 # Data type to use for the internal counters.
|
||||
_rank = 0 # Rank of the current process.
|
||||
_sync_device = None # Device to use for multiprocess communication. None = single-process.
|
||||
_sync_called = False # Has _sync() been called yet?
|
||||
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
|
||||
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def init_multiprocessing(rank, sync_device):
|
||||
r"""Initializes `torch_utils.training_stats` for collecting statistics
|
||||
across multiple processes.
|
||||
|
||||
This function must be called after
|
||||
`torch.distributed.init_process_group()` and before `Collector.update()`.
|
||||
The call is not necessary if multi-process collection is not needed.
|
||||
|
||||
Args:
|
||||
rank: Rank of the current process.
|
||||
sync_device: PyTorch device to use for inter-process
|
||||
communication, or None to disable multi-process
|
||||
collection. Typically `torch.device('cuda', rank)`.
|
||||
"""
|
||||
global _rank, _sync_device
|
||||
assert not _sync_called
|
||||
_rank = rank
|
||||
_sync_device = sync_device
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def report(name, value):
|
||||
r"""Broadcasts the given set of scalars to all interested instances of
|
||||
`Collector`, across device and process boundaries.
|
||||
|
||||
This function is expected to be extremely cheap and can be safely
|
||||
called from anywhere in the training loop, loss function, or inside a
|
||||
`torch.nn.Module`.
|
||||
|
||||
Warning: The current implementation expects the set of unique names to
|
||||
be consistent across processes. Please make sure that `report()` is
|
||||
called at least once for each unique name by each process, and in the
|
||||
same order. If a given process has no scalars to broadcast, it can do
|
||||
`report(name, [])` (empty list).
|
||||
|
||||
Args:
|
||||
name: Arbitrary string specifying the name of the statistic.
|
||||
Averages are accumulated separately for each unique name.
|
||||
value: Arbitrary set of scalars. Can be a list, tuple,
|
||||
NumPy array, PyTorch tensor, or Python scalar.
|
||||
|
||||
Returns:
|
||||
The same `value` that was passed in.
|
||||
"""
|
||||
if name not in _counters:
|
||||
_counters[name] = dict()
|
||||
|
||||
elems = torch.as_tensor(value)
|
||||
if elems.numel() == 0:
|
||||
return value
|
||||
|
||||
elems = elems.detach().flatten().to(_reduce_dtype)
|
||||
moments = torch.stack([
|
||||
torch.ones_like(elems).sum(),
|
||||
elems.sum(),
|
||||
elems.square().sum(),
|
||||
])
|
||||
assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
||||
moments = moments.to(_counter_dtype)
|
||||
|
||||
device = moments.device
|
||||
if device not in _counters[name]:
|
||||
_counters[name][device] = torch.zeros_like(moments)
|
||||
_counters[name][device].add_(moments)
|
||||
return value
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def report0(name, value):
|
||||
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
|
||||
but ignores any scalars provided by the other processes.
|
||||
See `report()` for further details.
|
||||
"""
|
||||
report(name, value if _rank == 0 else [])
|
||||
return value
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class Collector:
|
||||
r"""Collects the scalars broadcasted by `report()` and `report0()` and
|
||||
computes their long-term averages (mean and standard deviation) over
|
||||
user-defined periods of time.
|
||||
|
||||
The averages are first collected into internal counters that are not
|
||||
directly visible to the user. They are then copied to the user-visible
|
||||
state as a result of calling `update()` and can then be queried using
|
||||
`mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
|
||||
internal counters for the next round, so that the user-visible state
|
||||
effectively reflects averages collected between the last two calls to
|
||||
`update()`.
|
||||
|
||||
Args:
|
||||
regex: Regular expression defining which statistics to
|
||||
collect. The default is to collect everything.
|
||||
keep_previous: Whether to retain the previous averages if no
|
||||
scalars were collected on a given round
|
||||
(default: True).
|
||||
"""
|
||||
def __init__(self, regex='.*', keep_previous=True):
|
||||
self._regex = re.compile(regex)
|
||||
self._keep_previous = keep_previous
|
||||
self._cumulative = dict()
|
||||
self._moments = dict()
|
||||
self.update()
|
||||
self._moments.clear()
|
||||
|
||||
def names(self):
|
||||
r"""Returns the names of all statistics broadcasted so far that
|
||||
match the regular expression specified at construction time.
|
||||
"""
|
||||
return [name for name in _counters if self._regex.fullmatch(name)]
|
||||
|
||||
def update(self):
|
||||
r"""Copies current values of the internal counters to the
|
||||
user-visible state and resets them for the next round.
|
||||
|
||||
If `keep_previous=True` was specified at construction time, the
|
||||
operation is skipped for statistics that have received no scalars
|
||||
since the last update, retaining their previous averages.
|
||||
|
||||
This method performs a number of GPU-to-CPU transfers and one
|
||||
`torch.distributed.all_reduce()`. It is intended to be called
|
||||
periodically in the main training loop, typically once every
|
||||
N training steps.
|
||||
"""
|
||||
if not self._keep_previous:
|
||||
self._moments.clear()
|
||||
for name, cumulative in _sync(self.names()):
|
||||
if name not in self._cumulative:
|
||||
self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
||||
delta = cumulative - self._cumulative[name]
|
||||
self._cumulative[name].copy_(cumulative)
|
||||
if float(delta[0]) != 0:
|
||||
self._moments[name] = delta
|
||||
|
||||
def _get_delta(self, name):
|
||||
r"""Returns the raw moments that were accumulated for the given
|
||||
statistic between the last two calls to `update()`, or zero if
|
||||
no scalars were collected.
|
||||
"""
|
||||
assert self._regex.fullmatch(name)
|
||||
if name not in self._moments:
|
||||
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
||||
return self._moments[name]
|
||||
|
||||
def num(self, name):
|
||||
r"""Returns the number of scalars that were accumulated for the given
|
||||
statistic between the last two calls to `update()`, or zero if
|
||||
no scalars were collected.
|
||||
"""
|
||||
delta = self._get_delta(name)
|
||||
return int(delta[0])
|
||||
|
||||
def mean(self, name):
|
||||
r"""Returns the mean of the scalars that were accumulated for the
|
||||
given statistic between the last two calls to `update()`, or NaN if
|
||||
no scalars were collected.
|
||||
"""
|
||||
delta = self._get_delta(name)
|
||||
if int(delta[0]) == 0:
|
||||
return float('nan')
|
||||
return float(delta[1] / delta[0])
|
||||
|
||||
def std(self, name):
|
||||
r"""Returns the standard deviation of the scalars that were
|
||||
accumulated for the given statistic between the last two calls to
|
||||
`update()`, or NaN if no scalars were collected.
|
||||
"""
|
||||
delta = self._get_delta(name)
|
||||
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
||||
return float('nan')
|
||||
if int(delta[0]) == 1:
|
||||
return float(0)
|
||||
mean = float(delta[1] / delta[0])
|
||||
raw_var = float(delta[2] / delta[0])
|
||||
return np.sqrt(max(raw_var - np.square(mean), 0))
|
||||
|
||||
def as_dict(self):
|
||||
r"""Returns the averages accumulated between the last two calls to
|
||||
`update()` as an `dnnlib.EasyDict`. The contents are as follows:
|
||||
|
||||
dnnlib.EasyDict(
|
||||
NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
|
||||
...
|
||||
)
|
||||
"""
|
||||
stats = dnnlib.EasyDict()
|
||||
for name in self.names():
|
||||
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
||||
return stats
|
||||
|
||||
def __getitem__(self, name):
|
||||
r"""Convenience getter.
|
||||
`collector[name]` is a synonym for `collector.mean(name)`.
|
||||
"""
|
||||
return self.mean(name)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _sync(names):
|
||||
r"""Synchronize the global cumulative counters across devices and
|
||||
processes. Called internally by `Collector.update()`.
|
||||
"""
|
||||
if len(names) == 0:
|
||||
return []
|
||||
global _sync_called
|
||||
_sync_called = True
|
||||
|
||||
# Collect deltas within current rank.
|
||||
deltas = []
|
||||
device = _sync_device if _sync_device is not None else torch.device('cpu')
|
||||
for name in names:
|
||||
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
||||
for counter in _counters[name].values():
|
||||
delta.add_(counter.to(device))
|
||||
counter.copy_(torch.zeros_like(counter))
|
||||
deltas.append(delta)
|
||||
deltas = torch.stack(deltas)
|
||||
|
||||
# Sum deltas across ranks.
|
||||
if _sync_device is not None:
|
||||
torch.distributed.all_reduce(deltas)
|
||||
|
||||
# Update cumulative values.
|
||||
deltas = deltas.cpu()
|
||||
for idx, name in enumerate(names):
|
||||
if name not in _cumulative:
|
||||
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
||||
_cumulative[name].add_(deltas[idx])
|
||||
|
||||
# Return name-value pairs.
|
||||
return [(name, _cumulative[name]) for name in names]
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
9
global_torch/training/__init__.py
Normal file
9
global_torch/training/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
# empty
|
||||
809
global_torch/training/networks.py
Normal file
809
global_torch/training/networks.py
Normal file
@ -0,0 +1,809 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_utils import misc
|
||||
from torch_utils import persistence
|
||||
from torch_utils.ops import conv2d_resample
|
||||
from torch_utils.ops import upfirdn2d
|
||||
from torch_utils.ops import bias_act
|
||||
from torch_utils.ops import fma
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def normalize_2nd_moment(x, dim=1, eps=1e-8):
|
||||
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def modulated_conv2d(
|
||||
x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
|
||||
weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
|
||||
styles, # Modulation coefficients of shape [batch_size, in_channels].
|
||||
noise = None, # Optional noise tensor to add to the output activations.
|
||||
up = 1, # Integer upsampling factor.
|
||||
down = 1, # Integer downsampling factor.
|
||||
padding = 0, # Padding with respect to the upsampled image.
|
||||
resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
|
||||
demodulate = True, # Apply weight demodulation?
|
||||
flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
|
||||
fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
|
||||
):
|
||||
batch_size = x.shape[0]
|
||||
out_channels, in_channels, kh, kw = weight.shape
|
||||
misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
|
||||
misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
|
||||
misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
|
||||
|
||||
# Pre-normalize inputs to avoid FP16 overflow.
|
||||
if x.dtype == torch.float16 and demodulate:
|
||||
weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
|
||||
styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
|
||||
|
||||
# Calculate per-sample weights and demodulation coefficients.
|
||||
w = None
|
||||
dcoefs = None
|
||||
if demodulate or fused_modconv:
|
||||
w = weight.unsqueeze(0) # [NOIkk]
|
||||
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
|
||||
if demodulate:
|
||||
dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
|
||||
if demodulate and fused_modconv:
|
||||
w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
|
||||
|
||||
# Execute by scaling the activations before and after the convolution.
|
||||
if not fused_modconv:
|
||||
x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
||||
x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
|
||||
if demodulate and noise is not None:
|
||||
x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
|
||||
elif demodulate:
|
||||
x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
||||
elif noise is not None:
|
||||
x = x.add_(noise.to(x.dtype))
|
||||
return x
|
||||
|
||||
# Execute as one fused op using grouped convolution.
|
||||
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
||||
batch_size = int(batch_size)
|
||||
misc.assert_shape(x, [batch_size, in_channels, None, None])
|
||||
x = x.reshape(1, -1, *x.shape[2:])
|
||||
w = w.reshape(-1, in_channels, kh, kw)
|
||||
x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
|
||||
x = x.reshape(batch_size, -1, *x.shape[2:])
|
||||
if noise is not None:
|
||||
x = x.add_(noise)
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class FullyConnectedLayer(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_features, # Number of input features.
|
||||
out_features, # Number of output features.
|
||||
bias = True, # Apply additive bias before the activation function?
|
||||
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
||||
lr_multiplier = 1, # Learning rate multiplier.
|
||||
bias_init = 0, # Initial value for the additive bias.
|
||||
):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
|
||||
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
|
||||
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
||||
self.bias_gain = lr_multiplier
|
||||
|
||||
def forward(self, x):
|
||||
w = self.weight.to(x.dtype) * self.weight_gain
|
||||
b = self.bias
|
||||
if b is not None:
|
||||
b = b.to(x.dtype)
|
||||
if self.bias_gain != 1:
|
||||
b = b * self.bias_gain
|
||||
|
||||
if self.activation == 'linear' and b is not None:
|
||||
x = torch.addmm(b.unsqueeze(0), x, w.t())
|
||||
else:
|
||||
x = x.matmul(w.t())
|
||||
x = bias_act.bias_act(x, b, act=self.activation)
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class Conv2dLayer(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels, # Number of input channels.
|
||||
out_channels, # Number of output channels.
|
||||
kernel_size, # Width and height of the convolution kernel.
|
||||
bias = True, # Apply additive bias before the activation function?
|
||||
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
||||
up = 1, # Integer upsampling factor.
|
||||
down = 1, # Integer downsampling factor.
|
||||
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
||||
conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
|
||||
channels_last = False, # Expect the input to have memory_format=channels_last?
|
||||
trainable = True, # Update the weights of this layer during training?
|
||||
):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.conv_clamp = conv_clamp
|
||||
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
||||
self.padding = kernel_size // 2
|
||||
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
||||
self.act_gain = bias_act.activation_funcs[activation].def_gain
|
||||
|
||||
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
||||
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
|
||||
bias = torch.zeros([out_channels]) if bias else None
|
||||
if trainable:
|
||||
self.weight = torch.nn.Parameter(weight)
|
||||
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
||||
else:
|
||||
self.register_buffer('weight', weight)
|
||||
if bias is not None:
|
||||
self.register_buffer('bias', bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x, gain=1):
|
||||
w = self.weight * self.weight_gain
|
||||
b = self.bias.to(x.dtype) if self.bias is not None else None
|
||||
flip_weight = (self.up == 1) # slightly faster
|
||||
x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
|
||||
|
||||
act_gain = self.act_gain * gain
|
||||
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
||||
x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class MappingNetwork(torch.nn.Module):
|
||||
def __init__(self,
|
||||
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
||||
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
||||
w_dim, # Intermediate latent (W) dimensionality.
|
||||
num_ws, # Number of intermediate latents to output, None = do not broadcast.
|
||||
num_layers = 8, # Number of mapping layers.
|
||||
embed_features = None, # Label embedding dimensionality, None = same as w_dim.
|
||||
layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
|
||||
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
||||
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
|
||||
w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.
|
||||
):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
self.c_dim = c_dim
|
||||
self.w_dim = w_dim
|
||||
self.num_ws = num_ws
|
||||
self.num_layers = num_layers
|
||||
self.w_avg_beta = w_avg_beta
|
||||
|
||||
if embed_features is None:
|
||||
embed_features = w_dim
|
||||
if c_dim == 0:
|
||||
embed_features = 0
|
||||
if layer_features is None:
|
||||
layer_features = w_dim
|
||||
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
|
||||
|
||||
if c_dim > 0:
|
||||
self.embed = FullyConnectedLayer(c_dim, embed_features)
|
||||
for idx in range(num_layers):
|
||||
in_features = features_list[idx]
|
||||
out_features = features_list[idx + 1]
|
||||
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
|
||||
setattr(self, f'fc{idx}', layer)
|
||||
|
||||
if num_ws is not None and w_avg_beta is not None:
|
||||
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
||||
|
||||
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
|
||||
# Embed, normalize, and concat inputs.
|
||||
x = None
|
||||
with torch.autograd.profiler.record_function('input'):
|
||||
if self.z_dim > 0:
|
||||
misc.assert_shape(z, [None, self.z_dim])
|
||||
x = normalize_2nd_moment(z.to(torch.float32))
|
||||
if self.c_dim > 0:
|
||||
misc.assert_shape(c, [None, self.c_dim])
|
||||
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
|
||||
x = torch.cat([x, y], dim=1) if x is not None else y
|
||||
|
||||
# Main layers.
|
||||
for idx in range(self.num_layers):
|
||||
layer = getattr(self, f'fc{idx}')
|
||||
x = layer(x)
|
||||
|
||||
# Update moving average of W.
|
||||
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
|
||||
with torch.autograd.profiler.record_function('update_w_avg'):
|
||||
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
||||
|
||||
# Broadcast.
|
||||
if self.num_ws is not None:
|
||||
with torch.autograd.profiler.record_function('broadcast'):
|
||||
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
||||
|
||||
# Apply truncation.
|
||||
if truncation_psi != 1:
|
||||
with torch.autograd.profiler.record_function('truncate'):
|
||||
assert self.w_avg_beta is not None
|
||||
if self.num_ws is None or truncation_cutoff is None:
|
||||
x = self.w_avg.lerp(x, truncation_psi)
|
||||
else:
|
||||
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class SynthesisLayer(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels, # Number of input channels.
|
||||
out_channels, # Number of output channels.
|
||||
w_dim, # Intermediate latent (W) dimensionality.
|
||||
resolution, # Resolution of this layer.
|
||||
kernel_size = 3, # Convolution kernel size.
|
||||
up = 1, # Integer upsampling factor.
|
||||
use_noise = True, # Enable noise input?
|
||||
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
||||
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
||||
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
||||
channels_last = False, # Use channels_last format for the weights?
|
||||
name = ''
|
||||
):
|
||||
super().__init__()
|
||||
self.resolution = resolution
|
||||
self.up = up
|
||||
self.use_noise = use_noise
|
||||
self.activation = activation
|
||||
self.conv_clamp = conv_clamp
|
||||
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
||||
self.padding = kernel_size // 2
|
||||
self.act_gain = bias_act.activation_funcs[activation].def_gain
|
||||
self.name = name
|
||||
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
||||
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
||||
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
||||
if use_noise:
|
||||
self.register_buffer('noise_const', torch.randn([resolution, resolution]))
|
||||
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
|
||||
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
||||
print(f"name:{name} Resolution: {resolution}, InC: {in_channels}, OutC:{out_channels}, w_dim: {w_dim}")
|
||||
|
||||
def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1, encoded_styles=None):
|
||||
assert noise_mode in ['random', 'const', 'none']
|
||||
in_resolution = self.resolution // self.up
|
||||
# misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution]) # not need to be squre
|
||||
if encoded_styles is None:
|
||||
styles = self.affine(w)
|
||||
else:
|
||||
styles = encoded_styles[self.name]
|
||||
|
||||
noise = None
|
||||
if self.use_noise and noise_mode == 'random':
|
||||
noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
|
||||
if self.use_noise and noise_mode == 'const':
|
||||
noise = self.noise_const * self.noise_strength
|
||||
|
||||
flip_weight = (self.up == 1) # slightly faster
|
||||
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
|
||||
padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
|
||||
|
||||
act_gain = self.act_gain * gain
|
||||
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
||||
x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class ToRGBLayer(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False, name=''):
|
||||
super().__init__()
|
||||
self.conv_clamp = conv_clamp
|
||||
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
||||
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
||||
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
||||
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
||||
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
||||
self.name = name
|
||||
print(f"name:{name} InC: {in_channels}, OutC:{out_channels}, w_dim: {w_dim}")
|
||||
|
||||
|
||||
def forward(self, x, w, fused_modconv=True, encoded_styles=None):
|
||||
if encoded_styles is None:
|
||||
styles = self.affine(w) #* self.weight_gain
|
||||
|
||||
else:
|
||||
styles = encoded_styles[self.name]
|
||||
tmp_s=styles* self.weight_gain
|
||||
|
||||
x = modulated_conv2d(x=x, weight=self.weight, styles=tmp_s, demodulate=False, fused_modconv=fused_modconv)
|
||||
x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class SynthesisBlock(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels, # Number of input channels, 0 = first block.
|
||||
out_channels, # Number of output channels.
|
||||
w_dim, # Intermediate latent (W) dimensionality.
|
||||
resolution, # Resolution of this block.
|
||||
img_channels, # Number of output color channels.
|
||||
is_last, # Is this the last block?
|
||||
architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
|
||||
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
||||
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
||||
use_fp16 = False, # Use FP16 for this block?
|
||||
fp16_channels_last = False, # Use channels-last memory format with FP16?
|
||||
**layer_kwargs, # Arguments for SynthesisLayer.
|
||||
):
|
||||
assert architecture in ['orig', 'skip', 'resnet']
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.w_dim = w_dim
|
||||
self.resolution = resolution
|
||||
self.img_channels = img_channels
|
||||
self.is_last = is_last
|
||||
self.architecture = architecture
|
||||
self.use_fp16 = use_fp16
|
||||
self.channels_last = (use_fp16 and fp16_channels_last)
|
||||
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
||||
self.num_conv = 0
|
||||
self.num_torgb = 0
|
||||
|
||||
|
||||
if in_channels == 0:
|
||||
self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
|
||||
|
||||
if in_channels != 0:
|
||||
self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
|
||||
resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, name=f'conv0_resolution_{resolution}', **layer_kwargs)
|
||||
self.num_conv += 1
|
||||
|
||||
self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
|
||||
conv_clamp=conv_clamp, channels_last=self.channels_last, name=f'conv1_resolution_{resolution}', **layer_kwargs)
|
||||
self.num_conv += 1
|
||||
|
||||
if is_last or architecture == 'skip':
|
||||
self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
|
||||
conv_clamp=conv_clamp, channels_last=self.channels_last, name=f'toRGB_resolution_{resolution}')
|
||||
self.num_torgb += 1
|
||||
|
||||
if in_channels != 0 and architecture == 'resnet':
|
||||
self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
|
||||
resample_filter=resample_filter, channels_last=self.channels_last)
|
||||
|
||||
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, encoded_styles=None, **layer_kwargs):
|
||||
|
||||
class NoneIter:
|
||||
def __init__(self):
|
||||
pass
|
||||
def __iter__(self):
|
||||
return self
|
||||
def __next__(self):
|
||||
return None
|
||||
|
||||
if encoded_styles is None:
|
||||
misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
|
||||
w_iter = iter(ws.unbind(dim=1))
|
||||
else:
|
||||
w_iter = iter(NoneIter())
|
||||
|
||||
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
||||
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
||||
if fused_modconv is None:
|
||||
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
||||
fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
|
||||
|
||||
# Input.
|
||||
if self.in_channels == 0:
|
||||
x = self.const.to(dtype=dtype, memory_format=memory_format)
|
||||
if encoded_styles is None:
|
||||
x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
|
||||
else:
|
||||
x = x.unsqueeze(0).repeat([encoded_styles['conv1_resolution_4'].shape[0], 1, 1, 1])
|
||||
else:
|
||||
# misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) # not need to be squre
|
||||
x = x.to(dtype=dtype, memory_format=memory_format)
|
||||
|
||||
# Main layers.
|
||||
if self.in_channels == 0:
|
||||
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, encoded_styles=encoded_styles, **layer_kwargs)
|
||||
elif self.architecture == 'resnet':
|
||||
y = self.skip(x, gain=np.sqrt(0.5))
|
||||
x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, encoded_styles=encoded_styles, **layer_kwargs)
|
||||
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, encoded_styles=encoded_styles, gain=np.sqrt(0.5), **layer_kwargs)
|
||||
x = y.add_(x)
|
||||
else:
|
||||
x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, encoded_styles=encoded_styles, **layer_kwargs)
|
||||
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, encoded_styles=encoded_styles, **layer_kwargs)
|
||||
|
||||
# ToRGB.
|
||||
if img is not None:
|
||||
# misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) ## not need to be squre
|
||||
img = upfirdn2d.upsample2d(img, self.resample_filter)
|
||||
if self.is_last or self.architecture == 'skip':
|
||||
y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv, encoded_styles=encoded_styles, )
|
||||
y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
||||
img = img.add_(y) if img is not None else y
|
||||
|
||||
assert x.dtype == dtype
|
||||
assert img is None or img.dtype == torch.float32
|
||||
return x, img
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class SynthesisNetwork(torch.nn.Module):
|
||||
def __init__(self,
|
||||
w_dim, # Intermediate latent (W) dimensionality.
|
||||
img_resolution, # Output image resolution.
|
||||
img_channels, # Number of color channels.
|
||||
channel_base = 32768, # Overall multiplier for the number of channels.
|
||||
channel_max = 512, # Maximum number of channels in any layer.
|
||||
num_fp16_res = 0, # Use FP16 for the N highest resolutions.
|
||||
**block_kwargs, # Arguments for SynthesisBlock.
|
||||
):
|
||||
assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
|
||||
super().__init__()
|
||||
self.w_dim = w_dim
|
||||
self.img_resolution = img_resolution
|
||||
self.img_resolution_log2 = int(np.log2(img_resolution))
|
||||
self.img_channels = img_channels
|
||||
self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
|
||||
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
|
||||
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
||||
|
||||
self.num_ws = 0
|
||||
for res in self.block_resolutions:
|
||||
in_channels = channels_dict[res // 2] if res > 4 else 0
|
||||
out_channels = channels_dict[res]
|
||||
use_fp16 = (res >= fp16_resolution)
|
||||
is_last = (res == self.img_resolution)
|
||||
block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
|
||||
img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
|
||||
self.num_ws += block.num_conv
|
||||
if is_last:
|
||||
self.num_ws += block.num_torgb
|
||||
setattr(self, f'b{res}', block)
|
||||
|
||||
def forward(self, ws, encoded_styles=None, **block_kwargs):
|
||||
if encoded_styles is None:
|
||||
block_ws = []
|
||||
with torch.autograd.profiler.record_function('split_ws'):
|
||||
misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
|
||||
ws = ws.to(torch.float32)
|
||||
w_idx = 0
|
||||
for res in self.block_resolutions:
|
||||
block = getattr(self, f'b{res}')
|
||||
block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
|
||||
w_idx += block.num_conv
|
||||
|
||||
x = img = None
|
||||
for res, cur_ws in zip(self.block_resolutions, block_ws):
|
||||
block = getattr(self, f'b{res}')
|
||||
x, img = block(x, img, cur_ws, encoded_styles=encoded_styles, **block_kwargs)
|
||||
else:
|
||||
x = img = None
|
||||
for res in self.block_resolutions:
|
||||
block = getattr(self, f'b{res}')
|
||||
x, img = block(x, img, None, encoded_styles=encoded_styles, **block_kwargs)
|
||||
return img
|
||||
|
||||
def W2S(self,ws):
|
||||
|
||||
i=0
|
||||
encoded_styles={}
|
||||
for res in self.block_resolutions:
|
||||
block = getattr(self, f'b{res}')
|
||||
if res==4:
|
||||
s=block.conv1.affine(ws[:,i])
|
||||
encoded_styles[f'conv1_resolution_{res}'] =s
|
||||
i+=1
|
||||
s=block.torgb.affine(ws[:,i]) #* block.torgb.weight_gain
|
||||
encoded_styles[f'toRGB_resolution_{res}'] =s
|
||||
# i+=1
|
||||
else:
|
||||
# print(res,i)
|
||||
s=block.conv0.affine(ws[:,i])
|
||||
encoded_styles[f'conv0_resolution_{res}'] =s
|
||||
i+=1
|
||||
# print(res,i)
|
||||
s=block.conv1.affine(ws[:,i])
|
||||
encoded_styles[f'conv1_resolution_{res}'] =s
|
||||
i+=1
|
||||
# toRGB and next layer conv0 use the same w
|
||||
s=block.torgb.affine(ws[:,i])#* block.torgb.weight_gain
|
||||
encoded_styles[f'toRGB_resolution_{res}'] =s
|
||||
# i+=1
|
||||
# print(i)
|
||||
|
||||
|
||||
|
||||
|
||||
return encoded_styles
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self,
|
||||
z_dim, # Input latent (Z) dimensionality.
|
||||
c_dim, # Conditioning label (C) dimensionality.
|
||||
w_dim, # Intermediate latent (W) dimensionality.
|
||||
img_resolution, # Output resolution.
|
||||
img_channels, # Number of output color channels.
|
||||
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
||||
synthesis_kwargs = {}, # Arguments for SynthesisNetwork.
|
||||
):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
self.c_dim = c_dim
|
||||
self.w_dim = w_dim
|
||||
self.img_resolution = img_resolution
|
||||
self.img_channels = img_channels
|
||||
self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
|
||||
self.num_ws = self.synthesis.num_ws
|
||||
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
|
||||
|
||||
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, encoded_styles=None, **synthesis_kwargs):
|
||||
if encoded_styles is None:
|
||||
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
|
||||
else:
|
||||
ws = None
|
||||
img = self.synthesis(ws, encoded_styles=encoded_styles, **synthesis_kwargs)
|
||||
return img
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class DiscriminatorBlock(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels, # Number of input channels, 0 = first block.
|
||||
tmp_channels, # Number of intermediate channels.
|
||||
out_channels, # Number of output channels.
|
||||
resolution, # Resolution of this block.
|
||||
img_channels, # Number of input color channels.
|
||||
first_layer_idx, # Index of the first layer.
|
||||
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
||||
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
||||
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
||||
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
||||
use_fp16 = False, # Use FP16 for this block?
|
||||
fp16_channels_last = False, # Use channels-last memory format with FP16?
|
||||
freeze_layers = 0, # Freeze-D: Number of layers to freeze.
|
||||
):
|
||||
assert in_channels in [0, tmp_channels]
|
||||
assert architecture in ['orig', 'skip', 'resnet']
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.resolution = resolution
|
||||
self.img_channels = img_channels
|
||||
self.first_layer_idx = first_layer_idx
|
||||
self.architecture = architecture
|
||||
self.use_fp16 = use_fp16
|
||||
self.channels_last = (use_fp16 and fp16_channels_last)
|
||||
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
||||
|
||||
self.num_layers = 0
|
||||
def trainable_gen():
|
||||
while True:
|
||||
layer_idx = self.first_layer_idx + self.num_layers
|
||||
trainable = (layer_idx >= freeze_layers)
|
||||
self.num_layers += 1
|
||||
yield trainable
|
||||
trainable_iter = trainable_gen()
|
||||
|
||||
if in_channels == 0 or architecture == 'skip':
|
||||
self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation,
|
||||
trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
|
||||
|
||||
self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
|
||||
trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
|
||||
|
||||
self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
|
||||
trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)
|
||||
|
||||
if architecture == 'resnet':
|
||||
self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
|
||||
trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
|
||||
|
||||
def forward(self, x, img, force_fp32=False):
|
||||
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
||||
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
||||
|
||||
# Input.
|
||||
if x is not None:
|
||||
misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
|
||||
x = x.to(dtype=dtype, memory_format=memory_format)
|
||||
|
||||
# FromRGB.
|
||||
if self.in_channels == 0 or self.architecture == 'skip':
|
||||
misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
|
||||
img = img.to(dtype=dtype, memory_format=memory_format)
|
||||
y = self.fromrgb(img)
|
||||
x = x + y if x is not None else y
|
||||
img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
|
||||
|
||||
# Main layers.
|
||||
if self.architecture == 'resnet':
|
||||
y = self.skip(x, gain=np.sqrt(0.5))
|
||||
x = self.conv0(x)
|
||||
x = self.conv1(x, gain=np.sqrt(0.5))
|
||||
x = y.add_(x)
|
||||
else:
|
||||
x = self.conv0(x)
|
||||
x = self.conv1(x)
|
||||
|
||||
assert x.dtype == dtype
|
||||
return x, img
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class MinibatchStdLayer(torch.nn.Module):
|
||||
def __init__(self, group_size, num_channels=1):
|
||||
super().__init__()
|
||||
self.group_size = group_size
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, x):
|
||||
N, C, H, W = x.shape
|
||||
with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants
|
||||
G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
|
||||
F = self.num_channels
|
||||
c = C // F
|
||||
|
||||
y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
|
||||
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
|
||||
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
|
||||
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
|
||||
y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
|
||||
y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
|
||||
y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
|
||||
x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class DiscriminatorEpilogue(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels, # Number of input channels.
|
||||
cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
|
||||
resolution, # Resolution of this block.
|
||||
img_channels, # Number of input color channels.
|
||||
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
||||
mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
||||
mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
||||
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
||||
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
||||
):
|
||||
assert architecture in ['orig', 'skip', 'resnet']
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.cmap_dim = cmap_dim
|
||||
self.resolution = resolution
|
||||
self.img_channels = img_channels
|
||||
self.architecture = architecture
|
||||
|
||||
if architecture == 'skip':
|
||||
self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)
|
||||
self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
|
||||
self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
|
||||
self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation)
|
||||
self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)
|
||||
|
||||
def forward(self, x, img, cmap, force_fp32=False):
|
||||
misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]
|
||||
_ = force_fp32 # unused
|
||||
dtype = torch.float32
|
||||
memory_format = torch.contiguous_format
|
||||
|
||||
# FromRGB.
|
||||
x = x.to(dtype=dtype, memory_format=memory_format)
|
||||
if self.architecture == 'skip':
|
||||
misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
|
||||
img = img.to(dtype=dtype, memory_format=memory_format)
|
||||
x = x + self.fromrgb(img)
|
||||
|
||||
# Main layers.
|
||||
if self.mbstd is not None:
|
||||
x = self.mbstd(x)
|
||||
x = self.conv(x)
|
||||
x = self.fc(x.flatten(1))
|
||||
x = self.out(x)
|
||||
|
||||
# Conditioning.
|
||||
if self.cmap_dim > 0:
|
||||
misc.assert_shape(cmap, [None, self.cmap_dim])
|
||||
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
||||
|
||||
assert x.dtype == dtype
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class Discriminator(torch.nn.Module):
|
||||
def __init__(self,
|
||||
c_dim, # Conditioning label (C) dimensionality.
|
||||
img_resolution, # Input resolution.
|
||||
img_channels, # Number of input color channels.
|
||||
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
||||
channel_base = 32768, # Overall multiplier for the number of channels.
|
||||
channel_max = 512, # Maximum number of channels in any layer.
|
||||
num_fp16_res = 0, # Use FP16 for the N highest resolutions.
|
||||
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
||||
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
|
||||
block_kwargs = {}, # Arguments for DiscriminatorBlock.
|
||||
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
||||
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
|
||||
):
|
||||
super().__init__()
|
||||
self.c_dim = c_dim
|
||||
self.img_resolution = img_resolution
|
||||
self.img_resolution_log2 = int(np.log2(img_resolution))
|
||||
self.img_channels = img_channels
|
||||
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
|
||||
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
|
||||
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
||||
|
||||
if cmap_dim is None:
|
||||
cmap_dim = channels_dict[4]
|
||||
if c_dim == 0:
|
||||
cmap_dim = 0
|
||||
|
||||
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
|
||||
cur_layer_idx = 0
|
||||
for res in self.block_resolutions:
|
||||
in_channels = channels_dict[res] if res < img_resolution else 0
|
||||
tmp_channels = channels_dict[res]
|
||||
out_channels = channels_dict[res // 2]
|
||||
use_fp16 = (res >= fp16_resolution)
|
||||
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
|
||||
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
|
||||
setattr(self, f'b{res}', block)
|
||||
cur_layer_idx += block.num_layers
|
||||
if c_dim > 0:
|
||||
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
|
||||
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
|
||||
|
||||
def forward(self, img, c, **block_kwargs):
|
||||
x = None
|
||||
for res in self.block_resolutions:
|
||||
block = getattr(self, f'b{res}')
|
||||
x, img = block(x, img, **block_kwargs)
|
||||
|
||||
cmap = None
|
||||
if self.c_dim > 0:
|
||||
cmap = self.mapping(None, c)
|
||||
x = self.b4(x, img, cmap)
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
605
global_torch/visualizer.py
Normal file
605
global_torch/visualizer.py
Normal file
@ -0,0 +1,605 @@
|
||||
# python 3.7
|
||||
"""Utility functions for visualizing results on html page."""
|
||||
|
||||
import base64
|
||||
import os.path
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
'get_grid_shape', 'get_blank_image', 'load_image', 'save_image',
|
||||
'resize_image', 'add_text_to_image', 'fuse_images', 'HtmlPageVisualizer',
|
||||
'VideoReader', 'VideoWriter', 'adjust_pixel_range'
|
||||
]
|
||||
|
||||
|
||||
def adjust_pixel_range(images, min_val=-1.0, max_val=1.0, channel_order='NCHW'):
|
||||
"""Adjusts the pixel range of the input images.
|
||||
|
||||
This function assumes the input array (image batch) is with shape [batch_size,
|
||||
channel, height, width] if `channel_order = NCHW`, or with shape [batch_size,
|
||||
height, width] if `channel_order = NHWC`. The returned images are with shape
|
||||
[batch_size, height, width, channel] and pixel range [0, 255].
|
||||
|
||||
NOTE: The channel order of output images will remain the same as the input.
|
||||
|
||||
Args:
|
||||
images: Input images to adjust pixel range.
|
||||
min_val: Min value of the input images. (default: -1.0)
|
||||
max_val: Max value of the input images. (default: 1.0)
|
||||
channel_order: Channel order of the input array. (default: NCHW)
|
||||
|
||||
Returns:
|
||||
The postprocessed images with dtype `numpy.uint8` and range [0, 255].
|
||||
|
||||
Raises:
|
||||
ValueError: If the input `images` are not with type `numpy.ndarray` or the
|
||||
shape is invalid according to `channel_order`.
|
||||
"""
|
||||
if not isinstance(images, np.ndarray):
|
||||
raise ValueError(f'Images should be with type `numpy.ndarray`!')
|
||||
|
||||
channel_order = channel_order.upper()
|
||||
if channel_order not in ['NCHW', 'NHWC']:
|
||||
raise ValueError(f'Invalid channel order `{channel_order}`!')
|
||||
|
||||
if images.ndim != 4:
|
||||
raise ValueError(f'Input images are expected to be with shape `NCHW` or '
|
||||
f'`NHWC`, but `{images.shape}` is received!')
|
||||
if channel_order == 'NCHW' and images.shape[1] not in [1, 3]:
|
||||
raise ValueError(f'Input images should have 1 or 3 channels under `NCHW` '
|
||||
f'channel order!')
|
||||
if channel_order == 'NHWC' and images.shape[3] not in [1, 3]:
|
||||
raise ValueError(f'Input images should have 1 or 3 channels under `NHWC` '
|
||||
f'channel order!')
|
||||
|
||||
images = images.astype(np.float32)
|
||||
images = (images - min_val) * 255 / (max_val - min_val)
|
||||
images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
|
||||
if channel_order == 'NCHW':
|
||||
images = images.transpose(0, 2, 3, 1)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def get_grid_shape(size, row=0, col=0, is_portrait=False):
|
||||
"""Gets the shape of a grid based on the size.
|
||||
|
||||
This function makes greatest effort on making the output grid square if
|
||||
neither `row` nor `col` is set. If `is_portrait` is set as `False`, the height
|
||||
will always be equal to or smaller than the width. For example, if input
|
||||
`size = 16`, output shape will be `(4, 4)`; if input `size = 15`, output shape
|
||||
will be (3, 5). Otherwise, the height will always be equal to or larger than
|
||||
the width.
|
||||
|
||||
Args:
|
||||
size: Size (height * width) of the target grid.
|
||||
is_portrait: Whether to return a portrait size of a landscape size.
|
||||
(default: False)
|
||||
|
||||
Returns:
|
||||
A two-element tuple, representing height and width respectively.
|
||||
"""
|
||||
assert isinstance(size, int)
|
||||
assert isinstance(row, int)
|
||||
assert isinstance(col, int)
|
||||
if size == 0:
|
||||
return (0, 0)
|
||||
|
||||
if row > 0 and col > 0 and row * col != size:
|
||||
row = 0
|
||||
col = 0
|
||||
|
||||
if row > 0 and size % row == 0:
|
||||
return (row, size // row)
|
||||
if col > 0 and size % col == 0:
|
||||
return (size // col, col)
|
||||
|
||||
row = int(np.sqrt(size))
|
||||
while row > 0:
|
||||
if size % row == 0:
|
||||
col = size // row
|
||||
break
|
||||
row = row - 1
|
||||
|
||||
return (col, row) if is_portrait else (row, col)
|
||||
|
||||
|
||||
def get_blank_image(height, width, channels=3, is_black=True):
|
||||
"""Gets a blank image, either white of black.
|
||||
|
||||
NOTE: This function will always return an image with `RGB` channel order for
|
||||
color image and pixel range [0, 255].
|
||||
|
||||
Args:
|
||||
height: Height of the returned image.
|
||||
width: Width of the returned image.
|
||||
channels: Number of channels. (default: 3)
|
||||
is_black: Whether to return a black image or white image. (default: True)
|
||||
"""
|
||||
shape = (height, width, channels)
|
||||
if is_black:
|
||||
return np.zeros(shape, dtype=np.uint8)
|
||||
return np.ones(shape, dtype=np.uint8) * 255
|
||||
|
||||
|
||||
def load_image(path):
|
||||
"""Loads an image from disk.
|
||||
|
||||
NOTE: This function will always return an image with `RGB` channel order for
|
||||
color image and pixel range [0, 255].
|
||||
|
||||
Args:
|
||||
path: Path to load the image from.
|
||||
|
||||
Returns:
|
||||
An image with dtype `np.ndarray` or `None` if input `path` does not exist.
|
||||
"""
|
||||
if not os.path.isfile(path):
|
||||
return None
|
||||
|
||||
image = cv2.imread(path)
|
||||
return image[:, :, ::-1]
|
||||
|
||||
|
||||
def save_image(path, image):
|
||||
"""Saves an image to disk.
|
||||
|
||||
NOTE: The input image (if colorful) is assumed to be with `RGB` channel order
|
||||
and pixel range [0, 255].
|
||||
|
||||
Args:
|
||||
path: Path to save the image to.
|
||||
image: Image to save.
|
||||
"""
|
||||
if image is None:
|
||||
return
|
||||
|
||||
assert len(image.shape) == 3 and image.shape[2] in [1, 3]
|
||||
cv2.imwrite(path, image[:, :, ::-1])
|
||||
|
||||
|
||||
def resize_image(image, *args, **kwargs):
|
||||
"""Resizes image.
|
||||
|
||||
This is a wrap of `cv2.resize()`.
|
||||
|
||||
NOTE: THe channel order of the input image will not be changed.
|
||||
|
||||
Args:
|
||||
image: Image to resize.
|
||||
"""
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
assert image.ndim == 3 and image.shape[2] in [1, 3]
|
||||
image = cv2.resize(image, *args, **kwargs)
|
||||
if image.ndim == 2:
|
||||
return image[:, :, np.newaxis]
|
||||
return image
|
||||
|
||||
|
||||
def add_text_to_image(image,
|
||||
text='',
|
||||
position=None,
|
||||
font=cv2.FONT_HERSHEY_TRIPLEX,
|
||||
font_size=1.0,
|
||||
line_type=cv2.LINE_8,
|
||||
line_width=1,
|
||||
color=(255, 255, 255)):
|
||||
"""Overlays text on given image.
|
||||
|
||||
NOTE: The input image is assumed to be with `RGB` channel order.
|
||||
|
||||
Args:
|
||||
image: The image to overlay text on.
|
||||
text: Text content to overlay on the image. (default: '')
|
||||
position: Target position (bottom-left corner) to add text. If not set,
|
||||
center of the image will be used by default. (default: None)
|
||||
font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX)
|
||||
font_size: Font size of the text added. (default: 1.0)
|
||||
line_type: Line type used to depict the text. (default: cv2.LINE_8)
|
||||
line_width: Line width used to depict the text. (default: 1)
|
||||
color: Color of the text added in `RGB` channel order. (default:
|
||||
(255, 255, 255))
|
||||
|
||||
Returns:
|
||||
An image with target text overlayed on.
|
||||
"""
|
||||
if image is None or not text:
|
||||
return image
|
||||
|
||||
cv2.putText(img=image,
|
||||
text=text,
|
||||
org=position,
|
||||
fontFace=font,
|
||||
fontScale=font_size,
|
||||
color=color,
|
||||
thickness=line_width,
|
||||
lineType=line_type,
|
||||
bottomLeftOrigin=False)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def fuse_images(images,
|
||||
image_size=None,
|
||||
row=0,
|
||||
col=0,
|
||||
is_row_major=True,
|
||||
is_portrait=False,
|
||||
row_spacing=0,
|
||||
col_spacing=0,
|
||||
border_left=0,
|
||||
border_right=0,
|
||||
border_top=0,
|
||||
border_bottom=0,
|
||||
black_background=True):
|
||||
"""Fuses a collection of images into an entire image.
|
||||
|
||||
Args:
|
||||
images: A collection of images to fuse. Should be with shape [num, height,
|
||||
width, channels].
|
||||
image_size: Int or two-element tuple. This field is used to resize the image
|
||||
before fusing. `None` disables resizing. (default: None)
|
||||
row: Number of rows used for image fusion. If not set, this field will be
|
||||
automatically assigned based on `col` and total number of images.
|
||||
(default: None)
|
||||
col: Number of columns used for image fusion. If not set, this field will be
|
||||
automatically assigned based on `row` and total number of images.
|
||||
(default: None)
|
||||
is_row_major: Whether the input images should be arranged row-major or
|
||||
column-major. (default: True)
|
||||
is_portrait: Only active when both `row` and `col` should be assigned
|
||||
automatically. (default: False)
|
||||
row_spacing: Space between rows. (default: 0)
|
||||
col_spacing: Space between columns. (default: 0)
|
||||
border_left: Width of left border. (default: 0)
|
||||
border_right: Width of right border. (default: 0)
|
||||
border_top: Width of top border. (default: 0)
|
||||
border_bottom: Width of bottom border. (default: 0)
|
||||
|
||||
Returns:
|
||||
The fused image.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input `images` is not with shape [num, height, width,
|
||||
width].
|
||||
"""
|
||||
if images is None:
|
||||
return images
|
||||
|
||||
if not images.ndim == 4:
|
||||
raise ValueError(f'Input `images` should be with shape [num, height, '
|
||||
f'width, channels], but {images.shape} is received!')
|
||||
|
||||
num, image_height, image_width, channels = images.shape
|
||||
if image_size is not None:
|
||||
if isinstance(image_size, int):
|
||||
image_size = (image_size, image_size)
|
||||
assert isinstance(image_size, (list, tuple)) and len(image_size) == 2
|
||||
width, height = image_size
|
||||
else:
|
||||
height, width = image_height, image_width
|
||||
row, col = get_grid_shape(num, row=row, col=col, is_portrait=is_portrait)
|
||||
fused_height = (
|
||||
height * row + row_spacing * (row - 1) + border_top + border_bottom)
|
||||
fused_width = (
|
||||
width * col + col_spacing * (col - 1) + border_left + border_right)
|
||||
fused_image = get_blank_image(
|
||||
fused_height, fused_width, channels=channels, is_black=black_background)
|
||||
images = images.reshape(row, col, image_height, image_width, channels)
|
||||
if not is_row_major:
|
||||
images = images.transpose(1, 0, 2, 3, 4)
|
||||
|
||||
for i in range(row):
|
||||
y = border_top + i * (height + row_spacing)
|
||||
for j in range(col):
|
||||
x = border_left + j * (width + col_spacing)
|
||||
if image_size is not None:
|
||||
image = cv2.resize(images[i, j], image_size)
|
||||
else:
|
||||
image = images[i, j]
|
||||
fused_image[y:y + height, x:x + width] = image
|
||||
|
||||
return fused_image
|
||||
|
||||
|
||||
def get_sortable_html_header(column_name_list, sort_by_ascending=False):
|
||||
"""Gets header for sortable html page.
|
||||
|
||||
Basically, the html page contains a sortable table, where user can sort the
|
||||
rows by a particular column by clicking the column head.
|
||||
|
||||
Example:
|
||||
|
||||
column_name_list = [name_1, name_2, name_3]
|
||||
header = get_sortable_html_header(column_name_list)
|
||||
footer = get_sortable_html_footer()
|
||||
sortable_table = ...
|
||||
html_page = header + sortable_table + footer
|
||||
|
||||
Args:
|
||||
column_name_list: List of column header names.
|
||||
sort_by_ascending: Default sorting order. If set as `True`, the html page
|
||||
will be sorted by ascending order when the header is clicked for the first
|
||||
time.
|
||||
|
||||
Returns:
|
||||
A string, which represents for the header for a sortable html page.
|
||||
"""
|
||||
header = '\n'.join([
|
||||
'<script type="text/javascript">',
|
||||
'var column_idx;',
|
||||
'var sort_by_ascending = ' + str(sort_by_ascending).lower() + ';',
|
||||
'',
|
||||
'function sorting(tbody, column_idx){',
|
||||
' this.column_idx = column_idx;',
|
||||
' Array.from(tbody.rows)',
|
||||
' .sort(compareCells)',
|
||||
' .forEach(function(row) { tbody.appendChild(row); })',
|
||||
' sort_by_ascending = !sort_by_ascending;',
|
||||
'}',
|
||||
'',
|
||||
'function compareCells(row_a, row_b) {',
|
||||
' var val_a = row_a.cells[column_idx].innerText;',
|
||||
' var val_b = row_b.cells[column_idx].innerText;',
|
||||
' var flag = sort_by_ascending ? 1 : -1;',
|
||||
' return flag * (val_a > val_b ? 1 : -1);',
|
||||
'}',
|
||||
'</script>',
|
||||
'',
|
||||
'<html>',
|
||||
'',
|
||||
'<head>',
|
||||
'<style>',
|
||||
' table {',
|
||||
' border-spacing: 0;',
|
||||
' border: 1px solid black;',
|
||||
' }',
|
||||
' th {',
|
||||
' cursor: pointer;',
|
||||
' }',
|
||||
' th, td {',
|
||||
' text-align: left;',
|
||||
' vertical-align: middle;',
|
||||
' border-collapse: collapse;',
|
||||
' border: 0.5px solid black;',
|
||||
' padding: 8px;',
|
||||
' }',
|
||||
' tr:nth-child(even) {',
|
||||
' background-color: #d2d2d2;',
|
||||
' }',
|
||||
'</style>',
|
||||
'</head>',
|
||||
'',
|
||||
'<body>',
|
||||
'',
|
||||
'<table>',
|
||||
'<thead>',
|
||||
'<tr>',
|
||||
''])
|
||||
for idx, column_name in enumerate(column_name_list):
|
||||
header += f' <th onclick="sorting(tbody, {idx})">{column_name}</th>\n'
|
||||
header += '</tr>\n'
|
||||
header += '</thead>\n'
|
||||
header += '<tbody id="tbody">\n'
|
||||
|
||||
return header
|
||||
|
||||
|
||||
def get_sortable_html_footer():
|
||||
"""Gets footer for sortable html page.
|
||||
|
||||
Check function `get_sortable_html_header()` for more details.
|
||||
"""
|
||||
return '</tbody>\n</table>\n\n</body>\n</html>\n'
|
||||
|
||||
|
||||
def encode_image_to_html_str(image, image_size=None):
|
||||
"""Encodes an image to html language.
|
||||
|
||||
Args:
|
||||
image: The input image to encode. Should be with `RGB` channel order.
|
||||
image_size: Int or two-element tuple. This field is used to resize the image
|
||||
before encoding. `None` disables resizing. (default: None)
|
||||
|
||||
Returns:
|
||||
A string which represents the encoded image.
|
||||
"""
|
||||
if image is None:
|
||||
return ''
|
||||
|
||||
assert len(image.shape) == 3 and image.shape[2] in [1, 3]
|
||||
|
||||
# Change channel order to `BGR`, which is opencv-friendly.
|
||||
image = image[:, :, ::-1]
|
||||
|
||||
# Resize the image if needed.
|
||||
if image_size is not None:
|
||||
if isinstance(image_size, int):
|
||||
image_size = (image_size, image_size)
|
||||
assert isinstance(image_size, (list, tuple)) and len(image_size) == 2
|
||||
image = cv2.resize(image, image_size)
|
||||
|
||||
# Encode the image to html-format string.
|
||||
encoded_image = cv2.imencode(".jpg", image)[1].tostring()
|
||||
encoded_image_base64 = base64.b64encode(encoded_image).decode('utf-8')
|
||||
html_str = f'<img src="data:image/jpeg;base64, {encoded_image_base64}"/>'
|
||||
|
||||
return html_str
|
||||
|
||||
|
||||
class HtmlPageVisualizer(object):
|
||||
"""Defines the html page visualizer.
|
||||
|
||||
This class can be used to visualize image results as html page. Basically, it
|
||||
is based on an html-format sorted table with helper functions
|
||||
`get_sortable_html_header()`, `get_sortable_html_footer()`, and
|
||||
`encode_image_to_html_str()`. To simplify the usage, specifying the following
|
||||
fields is enough to create a visualization page:
|
||||
|
||||
(1) num_rows: Number of rows of the table (header-row exclusive).
|
||||
(2) num_cols: Number of columns of the table.
|
||||
(3) header contents (optional): Title of each column.
|
||||
|
||||
NOTE: `grid_size` can be used to assign `num_rows` and `num_cols`
|
||||
automatically.
|
||||
|
||||
Example:
|
||||
|
||||
html = HtmlPageVisualizer(num_rows, num_cols)
|
||||
html.set_headers([...])
|
||||
for i in range(num_rows):
|
||||
for j in range(num_cols):
|
||||
html.set_cell(i, j, text=..., image=...)
|
||||
html.save('visualize.html')
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_rows=0,
|
||||
num_cols=0,
|
||||
grid_size=0,
|
||||
is_portrait=False,
|
||||
viz_size=None):
|
||||
if grid_size > 0:
|
||||
num_rows, num_cols = get_grid_shape(
|
||||
grid_size, row=num_rows, col=num_cols, is_portrait=is_portrait)
|
||||
assert num_rows > 0 and num_cols > 0
|
||||
|
||||
self.num_rows = num_rows
|
||||
self.num_cols = num_cols
|
||||
self.viz_size = viz_size
|
||||
self.headers = ['' for _ in range(self.num_cols)]
|
||||
self.cells = [[{
|
||||
'text': '',
|
||||
'image': '',
|
||||
} for _ in range(self.num_cols)] for _ in range(self.num_rows)]
|
||||
|
||||
def set_header(self, column_idx, content):
|
||||
"""Sets the content of a particular header by column index."""
|
||||
self.headers[column_idx] = content
|
||||
|
||||
def set_headers(self, contents):
|
||||
"""Sets the contents of all headers."""
|
||||
if isinstance(contents, str):
|
||||
contents = [contents]
|
||||
assert isinstance(contents, (list, tuple))
|
||||
assert len(contents) == self.num_cols
|
||||
for column_idx, content in enumerate(contents):
|
||||
self.set_header(column_idx, content)
|
||||
|
||||
def set_cell(self, row_idx, column_idx, text='', image=None):
|
||||
"""Sets the content of a particular cell.
|
||||
|
||||
Basically, a cell contains some text as well as an image. Both text and
|
||||
image can be empty.
|
||||
|
||||
Args:
|
||||
row_idx: Row index of the cell to edit.
|
||||
column_idx: Column index of the cell to edit.
|
||||
text: Text to add into the target cell.
|
||||
image: Image to show in the target cell. Should be with `RGB` channel
|
||||
order.
|
||||
"""
|
||||
self.cells[row_idx][column_idx]['text'] = text
|
||||
self.cells[row_idx][column_idx]['image'] = encode_image_to_html_str(
|
||||
image, self.viz_size)
|
||||
|
||||
def save(self, save_path):
|
||||
"""Saves the html page."""
|
||||
html = ''
|
||||
for i in range(self.num_rows):
|
||||
html += f'<tr>\n'
|
||||
for j in range(self.num_cols):
|
||||
text = self.cells[i][j]['text']
|
||||
image = self.cells[i][j]['image']
|
||||
if text:
|
||||
html += f' <td>{text}<br><br>{image}</td>\n'
|
||||
else:
|
||||
html += f' <td>{image}</td>\n'
|
||||
html += f'</tr>\n'
|
||||
|
||||
header = get_sortable_html_header(self.headers)
|
||||
footer = get_sortable_html_footer()
|
||||
|
||||
with open(save_path, 'w') as f:
|
||||
f.write(header + html + footer)
|
||||
|
||||
|
||||
class VideoReader(object):
|
||||
"""Defines the video reader.
|
||||
|
||||
This class can be used to read frames from a given video.
|
||||
"""
|
||||
|
||||
def __init__(self, path):
|
||||
"""Initializes the video reader by loading the video from disk."""
|
||||
if not os.path.isfile(path):
|
||||
raise ValueError(f'Video `{path}` does not exist!')
|
||||
|
||||
self.path = path
|
||||
self.video = cv2.VideoCapture(path)
|
||||
assert self.video.isOpened()
|
||||
self.position = 0
|
||||
|
||||
self.length = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
self.frame_height = int(self.video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
self.frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
self.fps = self.video.get(cv2.CAP_PROP_FPS)
|
||||
|
||||
def __del__(self):
|
||||
"""Releases the opened video."""
|
||||
self.video.release()
|
||||
|
||||
def read(self, position=None):
|
||||
"""Reads a certain frame.
|
||||
|
||||
NOTE: The returned frame is assumed to be with `RGB` channel order.
|
||||
|
||||
Args:
|
||||
position: Optional. If set, the reader will read frames from the exact
|
||||
position. Otherwise, the reader will read next frames. (default: None)
|
||||
"""
|
||||
if position is not None and position < self.length:
|
||||
self.video.set(cv2.CAP_PROP_POS_FRAMES, position)
|
||||
self.position = position
|
||||
|
||||
success, frame = self.video.read()
|
||||
self.position = self.position + 1
|
||||
|
||||
return frame[:, :, ::-1] if success else None
|
||||
|
||||
|
||||
class VideoWriter(object):
|
||||
"""Defines the video writer.
|
||||
|
||||
This class can be used to create a video.
|
||||
|
||||
NOTE: `.avi` and `DIVX` is the most recommended codec format since it does not
|
||||
rely on other dependencies.
|
||||
"""
|
||||
|
||||
def __init__(self, path, frame_height, frame_width, fps=24, codec='DIVX'):
|
||||
"""Creates the video writer."""
|
||||
self.path = path
|
||||
self.frame_height = frame_height
|
||||
self.frame_width = frame_width
|
||||
self.fps = fps
|
||||
self.codec = codec
|
||||
|
||||
self.video = cv2.VideoWriter(filename=path,
|
||||
fourcc=cv2.VideoWriter_fourcc(*codec),
|
||||
fps=fps,
|
||||
frameSize=(frame_width, frame_height))
|
||||
|
||||
def __del__(self):
|
||||
"""Releases the opened video."""
|
||||
self.video.release()
|
||||
|
||||
def write(self, frame):
|
||||
"""Writes a target frame.
|
||||
|
||||
NOTE: The input frame is assumed to be with `RGB` channel order.
|
||||
"""
|
||||
self.video.write(frame[:, :, ::-1])
|
||||
BIN
latents_test/example_celebs.pt
Normal file
BIN
latents_test/example_celebs.pt
Normal file
Binary file not shown.
21
licenses/LICENSE-CLIP
Normal file
21
licenses/LICENSE-CLIP
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 OpenAI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
21
licenses/LICENSE-stylegan2-pytorch
Normal file
21
licenses/LICENSE-stylegan2-pytorch
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2019 Kim Seonghyeon
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
0
models/facial_recognition/__init__.py
Normal file
0
models/facial_recognition/__init__.py
Normal file
119
models/facial_recognition/helpers.py
Normal file
119
models/facial_recognition/helpers.py
Normal file
@ -0,0 +1,119 @@
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
||||
|
||||
"""
|
||||
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
"""
|
||||
|
||||
|
||||
class Flatten(Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
def l2_norm(input, axis=1):
|
||||
norm = torch.norm(input, 2, axis, True)
|
||||
output = torch.div(input, norm)
|
||||
return output
|
||||
|
||||
|
||||
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
||||
""" A named tuple describing a ResNet block. """
|
||||
|
||||
|
||||
def get_block(in_channel, depth, num_units, stride=2):
|
||||
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
||||
|
||||
|
||||
def get_blocks(num_layers):
|
||||
if num_layers == 50:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=4),
|
||||
get_block(in_channel=128, depth=256, num_units=14),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 100:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=13),
|
||||
get_block(in_channel=128, depth=256, num_units=30),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 152:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=8),
|
||||
get_block(in_channel=128, depth=256, num_units=36),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
else:
|
||||
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
||||
return blocks
|
||||
|
||||
|
||||
class SEModule(Module):
|
||||
def __init__(self, channels, reduction):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2d(1)
|
||||
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
||||
self.relu = ReLU(inplace=True)
|
||||
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
||||
self.sigmoid = Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return module_input * x
|
||||
|
||||
|
||||
class bottleneck_IR(Module):
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth)
|
||||
)
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
||||
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
|
||||
|
||||
class bottleneck_IR_SE(Module):
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR_SE, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth)
|
||||
)
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
||||
PReLU(depth),
|
||||
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
BatchNorm2d(depth),
|
||||
SEModule(depth, 16)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
86
models/facial_recognition/model_irse.py
Normal file
86
models/facial_recognition/model_irse.py
Normal file
@ -0,0 +1,86 @@
|
||||
import sys
|
||||
sys.path.append('/home/ly/StyleCLIP-main/models/facial_recognition')
|
||||
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
||||
from helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
||||
|
||||
"""
|
||||
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
"""
|
||||
|
||||
|
||||
class Backbone(Module):
|
||||
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
||||
super(Backbone, self).__init__()
|
||||
assert input_size in [112, 224], "input_size should be 112 or 224"
|
||||
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
||||
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
||||
blocks = get_blocks(num_layers)
|
||||
if mode == 'ir':
|
||||
unit_module = bottleneck_IR
|
||||
elif mode == 'ir_se':
|
||||
unit_module = bottleneck_IR_SE
|
||||
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
||||
BatchNorm2d(64),
|
||||
PReLU(64))
|
||||
if input_size == 112:
|
||||
self.output_layer = Sequential(BatchNorm2d(512),
|
||||
Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
Linear(512 * 7 * 7, 512),
|
||||
BatchNorm1d(512, affine=affine))
|
||||
else:
|
||||
self.output_layer = Sequential(BatchNorm2d(512),
|
||||
Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
Linear(512 * 14 * 14, 512),
|
||||
BatchNorm1d(512, affine=affine))
|
||||
|
||||
modules = []
|
||||
for block in blocks:
|
||||
for bottleneck in block:
|
||||
modules.append(unit_module(bottleneck.in_channel,
|
||||
bottleneck.depth,
|
||||
bottleneck.stride))
|
||||
self.body = Sequential(*modules)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.body(x)
|
||||
x = self.output_layer(x)
|
||||
return l2_norm(x)
|
||||
|
||||
|
||||
def IR_50(input_size):
|
||||
"""Constructs a ir-50 model."""
|
||||
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_101(input_size):
|
||||
"""Constructs a ir-101 model."""
|
||||
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_152(input_size):
|
||||
"""Constructs a ir-152 model."""
|
||||
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_50(input_size):
|
||||
"""Constructs a ir_se-50 model."""
|
||||
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_101(input_size):
|
||||
"""Constructs a ir_se-101 model."""
|
||||
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_152(input_size):
|
||||
"""Constructs a ir_se-152 model."""
|
||||
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
0
models/stylegan2/__init__.py
Normal file
0
models/stylegan2/__init__.py
Normal file
715
models/stylegan2/model.py
Normal file
715
models/stylegan2/model.py
Normal file
@ -0,0 +1,715 @@
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
#normalizes了特征向量的每个元素到单位长度附近,阻止了信号幅度signal magnitudes导致的在训练过程中逐步失控的风险。
|
||||
def forward(self, input):
|
||||
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
|
||||
def make_kernel(k):
|
||||
k = torch.tensor(k, dtype=torch.float32)
|
||||
|
||||
if k.ndim == 1:
|
||||
k = k[None, :] * k[:, None]
|
||||
|
||||
k /= k.sum()
|
||||
|
||||
return k
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, kernel, factor=2):
|
||||
super().__init__()
|
||||
|
||||
self.factor = factor
|
||||
kernel = make_kernel(kernel) * (factor ** 2)
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
p = kernel.shape[0] - factor
|
||||
|
||||
pad0 = (p + 1) // 2 + factor - 1
|
||||
pad1 = p // 2
|
||||
|
||||
self.pad = (pad0, pad1)
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, kernel, factor=2):
|
||||
super().__init__()
|
||||
|
||||
self.factor = factor
|
||||
kernel = make_kernel(kernel)
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
p = kernel.shape[0] - factor
|
||||
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
self.pad = (pad0, pad1)
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self, kernel, pad, upsample_factor=1):
|
||||
super().__init__()
|
||||
|
||||
kernel = make_kernel(kernel)
|
||||
|
||||
if upsample_factor > 1:
|
||||
kernel = kernel * (upsample_factor ** 2)
|
||||
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
self.pad = pad
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
def __init__(
|
||||
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
||||
)
|
||||
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channel))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
out = F.conv2d(
|
||||
input,
|
||||
self.weight * self.scale,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
||||
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
||||
)
|
||||
|
||||
#定义了一个线性激活层
|
||||
class EqualLinear(nn.Module):
|
||||
def __init__(
|
||||
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
||||
self.lr_mul = lr_mul
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
if self.activation:
|
||||
out = F.linear(input, self.weight * self.scale)
|
||||
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
||||
|
||||
else:
|
||||
out = F.linear(
|
||||
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
||||
)
|
||||
|
||||
|
||||
class ScaledLeakyReLU(nn.Module):
|
||||
def __init__(self, negative_slope=0.2):
|
||||
super().__init__()
|
||||
|
||||
self.negative_slope = negative_slope
|
||||
|
||||
def forward(self, input):
|
||||
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
||||
|
||||
return out * math.sqrt(2)
|
||||
|
||||
|
||||
class ModulatedConv2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
style_dim,
|
||||
demodulate=True,
|
||||
upsample=False,
|
||||
#给卷积核乘以放缩参数
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.eps = 1e-8
|
||||
self.kernel_size = kernel_size
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
self.upsample = upsample
|
||||
self.downsample = downsample
|
||||
|
||||
if upsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2 + factor - 1
|
||||
pad1 = p // 2 + 1
|
||||
|
||||
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
||||
|
||||
fan_in = in_channel * kernel_size ** 2
|
||||
self.scale = 1 / math.sqrt(fan_in)
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
||||
)
|
||||
|
||||
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
||||
|
||||
self.demodulate = demodulate
|
||||
|
||||
def __repr__(self):
|
||||
#返回了模型的各个参数的字符串
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
|
||||
f'upsample={self.upsample}, downsample={self.downsample})'
|
||||
)
|
||||
|
||||
def forward(self, input, style, input_is_stylespace=False):
|
||||
batch, in_channel, height, width = input.shape
|
||||
|
||||
if not input_is_stylespace:
|
||||
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
||||
weight = self.scale * self.weight * style
|
||||
|
||||
#对权重进行解调
|
||||
if self.demodulate:
|
||||
#类似标准差计算,平方求和再反平方,目的是计算每个权重向量的解调因子
|
||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
||||
#demod是一个解调因子矩阵,通过demod.view()将其形状调整为与权重矩阵相同的形状,以便进行逐元素的相乘操作。
|
||||
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
||||
|
||||
weight = weight.view(
|
||||
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
||||
)
|
||||
|
||||
if self.upsample:
|
||||
input = input.view(1, batch * in_channel, height, width)
|
||||
weight = weight.view(
|
||||
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
||||
)
|
||||
weight = weight.transpose(1, 2).reshape(
|
||||
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
||||
)
|
||||
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
||||
_, _, height, width = out.shape
|
||||
out = out.view(batch, self.out_channel, height, width)
|
||||
out = self.blur(out)
|
||||
|
||||
elif self.downsample:
|
||||
input = self.blur(input)
|
||||
_, _, height, width = input.shape
|
||||
input = input.view(1, batch * in_channel, height, width)
|
||||
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
||||
_, _, height, width = out.shape
|
||||
out = out.view(batch, self.out_channel, height, width)
|
||||
|
||||
else:
|
||||
input = input.view(1, batch * in_channel, height, width)
|
||||
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
||||
_, _, height, width = out.shape
|
||||
out = out.view(batch, self.out_channel, height, width)
|
||||
|
||||
return out, style
|
||||
|
||||
# 用噪声 ( noise ) 来影响头发丝、皱纹、肤色等细节部分。
|
||||
class NoiseInjection(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, image, noise=None):
|
||||
if noise is None:
|
||||
batch, _, height, width = image.shape
|
||||
noise = image.new_empty(batch, 1, height, width).normal_()
|
||||
|
||||
return image + self.weight * noise
|
||||
|
||||
|
||||
class ConstantInput(nn.Module):
|
||||
def __init__(self, channel, size=4):
|
||||
super().__init__()
|
||||
|
||||
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
||||
|
||||
def forward(self, input):
|
||||
batch = input.shape[0]
|
||||
out = self.input.repeat(batch, 1, 1, 1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class StyledConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
style_dim,
|
||||
upsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
demodulate=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv = ModulatedConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
style_dim,
|
||||
upsample=upsample,
|
||||
blur_kernel=blur_kernel,
|
||||
demodulate=demodulate,
|
||||
)
|
||||
|
||||
self.noise = NoiseInjection()
|
||||
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
||||
# self.activate = ScaledLeakyReLU(0.2)
|
||||
self.activate = FusedLeakyReLU(out_channel)
|
||||
|
||||
def forward(self, input, style, noise=None, input_is_stylespace=False):
|
||||
out, style = self.conv(input, style, input_is_stylespace=input_is_stylespace)
|
||||
out = self.noise(out, noise=noise)
|
||||
# out = out + self.bias
|
||||
out = self.activate(out)
|
||||
|
||||
return out, style
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
if upsample:
|
||||
self.upsample = Upsample(blur_kernel)
|
||||
|
||||
#ToRGB层不进行demodulate处理
|
||||
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
||||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
||||
|
||||
def forward(self, input, style, skip=None, input_is_stylespace=False):
|
||||
out, style = self.conv(input, style, input_is_stylespace=input_is_stylespace)
|
||||
out = out + self.bias
|
||||
|
||||
if skip is not None:
|
||||
skip = self.upsample(skip)
|
||||
|
||||
out = out + skip
|
||||
|
||||
return out, style
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
style_dim,
|
||||
n_mlp,
|
||||
channel_multiplier=2,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
lr_mlp=0.01,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.size = size
|
||||
|
||||
self.style_dim = style_dim
|
||||
|
||||
layers = [PixelNorm()]
|
||||
|
||||
for i in range(n_mlp):
|
||||
layers.append(
|
||||
EqualLinear(
|
||||
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
|
||||
)
|
||||
)
|
||||
|
||||
self.style = nn.Sequential(*layers)
|
||||
|
||||
self.channels = {
|
||||
4: 512,
|
||||
8: 512,
|
||||
16: 512,
|
||||
32: 512,
|
||||
64: 256 * channel_multiplier,
|
||||
128: 128 * channel_multiplier,
|
||||
256: 64 * channel_multiplier,
|
||||
512: 32 * channel_multiplier,
|
||||
1024: 16 * channel_multiplier,
|
||||
}
|
||||
|
||||
self.input = ConstantInput(self.channels[4])
|
||||
self.conv1 = StyledConv(
|
||||
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
||||
)
|
||||
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
||||
|
||||
self.log_size = int(math.log(size, 2)) #log(1024,2) = 10
|
||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.upsamples = nn.ModuleList()
|
||||
self.to_rgbs = nn.ModuleList()
|
||||
self.noises = nn.Module()
|
||||
|
||||
in_channel = self.channels[4]
|
||||
|
||||
for layer_idx in range(self.num_layers):
|
||||
res = (layer_idx + 5) // 2
|
||||
shape = [1, 1, 2 ** res, 2 ** res]
|
||||
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channel = self.channels[2 ** i]
|
||||
|
||||
self.convs.append(
|
||||
StyledConv(
|
||||
in_channel,
|
||||
out_channel,
|
||||
3,
|
||||
style_dim,
|
||||
upsample=True,
|
||||
blur_kernel=blur_kernel,
|
||||
)
|
||||
)
|
||||
|
||||
self.convs.append(
|
||||
StyledConv(
|
||||
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
||||
)
|
||||
)
|
||||
|
||||
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
||||
|
||||
in_channel = out_channel
|
||||
# w+ repeat的倍数,例如1024计算为18,实际上就是上采样层1+8*2+1,因为第一层只需要一个style最后又多了一层to_rgb用了style,其中8个block每个上采样层之前均要加入两次style
|
||||
self.n_latent = self.log_size * 2 - 2
|
||||
|
||||
|
||||
def make_noise(self):
|
||||
device = self.input.input.device
|
||||
|
||||
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
for _ in range(2):
|
||||
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
||||
|
||||
return noises
|
||||
|
||||
def mean_latent(self, n_latent):
|
||||
latent_in = torch.randn(
|
||||
n_latent, self.style_dim, device=self.input.input.device
|
||||
)
|
||||
latent = self.style(latent_in).mean(0, keepdim=True)
|
||||
|
||||
return latent
|
||||
|
||||
def get_latent(self, input):
|
||||
return self.style(input)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
styles,
|
||||
return_latents=False,
|
||||
inject_index=None,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
input_is_latent=False,
|
||||
input_is_stylespace=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
):
|
||||
if not input_is_latent and not input_is_stylespace:
|
||||
styles = [self.style(s) for s in styles]
|
||||
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers
|
||||
else:
|
||||
noise = [
|
||||
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
|
||||
]
|
||||
|
||||
if truncation < 1 and not input_is_stylespace:
|
||||
style_t = []
|
||||
|
||||
for style in styles:
|
||||
style_t.append(
|
||||
truncation_latent + truncation * (style - truncation_latent)
|
||||
)
|
||||
|
||||
styles = style_t
|
||||
|
||||
if input_is_stylespace:
|
||||
latent = styles[0]
|
||||
elif len(styles) < 2:
|
||||
inject_index = self.n_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
|
||||
else:
|
||||
latent = styles[0]
|
||||
|
||||
else:
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.n_latent - 1)
|
||||
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
||||
|
||||
latent = torch.cat([latent, latent2], 1)
|
||||
|
||||
|
||||
style_vector = []
|
||||
|
||||
if not input_is_stylespace:
|
||||
out = self.input(latent)
|
||||
# print('laten:',latent.shape) # torch.Size([1, 18, 512])
|
||||
out, out_style = self.conv1(out, latent[:, 0], noise=noise[0])
|
||||
style_vector.append(out_style)
|
||||
|
||||
skip, out_style = self.to_rgb1(out, latent[:, 1])
|
||||
style_vector.append(out_style)
|
||||
|
||||
i = 1
|
||||
else:
|
||||
out = self.input(latent[0])
|
||||
out, out_style = self.conv1(out, latent[0], noise=noise[0], input_is_stylespace=input_is_stylespace)
|
||||
style_vector.append(out_style)
|
||||
|
||||
skip, out_style = self.to_rgb1(out, latent[1], input_is_stylespace=input_is_stylespace)
|
||||
style_vector.append(out_style)
|
||||
|
||||
i = 2
|
||||
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
||||
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
||||
):
|
||||
if not input_is_stylespace:
|
||||
out, out_style1 = conv1(out, latent[:, i], noise=noise1)
|
||||
out, out_style2 = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip, rgb_style = to_rgb(out, latent[:, i + 2], skip)
|
||||
|
||||
style_vector.extend([out_style1, out_style2, rgb_style])
|
||||
|
||||
i += 2
|
||||
else:
|
||||
out, out_style1 = conv1(out, latent[i], noise=noise1, input_is_stylespace=input_is_stylespace)
|
||||
out, out_style2 = conv2(out, latent[i + 1], noise=noise2, input_is_stylespace=input_is_stylespace)
|
||||
skip, rgb_style = to_rgb(out, latent[i + 2], skip, input_is_stylespace=input_is_stylespace)
|
||||
|
||||
style_vector.extend([out_style1, out_style2, rgb_style])
|
||||
|
||||
i += 3
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent, style_vector
|
||||
|
||||
else:
|
||||
return image, None
|
||||
|
||||
|
||||
class ConvLayer(nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
bias=True,
|
||||
activate=True,
|
||||
):
|
||||
layers = []
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
||||
|
||||
stride = 2
|
||||
self.padding = 0
|
||||
|
||||
else:
|
||||
stride = 1
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
layers.append(
|
||||
EqualConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
padding=self.padding,
|
||||
stride=stride,
|
||||
bias=bias and not activate,
|
||||
)
|
||||
)
|
||||
|
||||
if activate:
|
||||
if bias:
|
||||
layers.append(FusedLeakyReLU(out_channel))
|
||||
|
||||
else:
|
||||
layers.append(ScaledLeakyReLU(0.2))
|
||||
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
||||
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
||||
|
||||
self.skip = ConvLayer(
|
||||
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv1(input)
|
||||
out = self.conv2(out)
|
||||
|
||||
skip = self.skip(input)
|
||||
out = (out + skip) / math.sqrt(2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
channels = {
|
||||
4: 512,
|
||||
8: 512,
|
||||
16: 512,
|
||||
32: 512,
|
||||
64: 256 * channel_multiplier,
|
||||
128: 128 * channel_multiplier,
|
||||
256: 64 * channel_multiplier,
|
||||
512: 32 * channel_multiplier,
|
||||
1024: 16 * channel_multiplier,
|
||||
}
|
||||
|
||||
convs = [ConvLayer(3, channels[size], 1)]
|
||||
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
in_channel = channels[size]
|
||||
|
||||
#这里代码是8个大残差block,让feature map大小从1024到4
|
||||
for i in range(log_size, 2, -1):
|
||||
out_channel = channels[2 ** (i - 1)]
|
||||
|
||||
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
||||
|
||||
in_channel = out_channel
|
||||
|
||||
self.convs = nn.Sequential(*convs)
|
||||
|
||||
self.stddev_group = 4
|
||||
self.stddev_feat = 1
|
||||
|
||||
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
||||
self.final_linear = nn.Sequential(
|
||||
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
|
||||
EqualLinear(channels[4], 1),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.convs(input)
|
||||
|
||||
batch, channel, height, width = out.shape
|
||||
group = min(batch, self.stddev_group)
|
||||
stddev = out.view(
|
||||
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
||||
)
|
||||
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
||||
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
||||
stddev = stddev.repeat(group, 1, height, width)
|
||||
out = torch.cat([out, stddev], 1)
|
||||
|
||||
out = self.final_conv(out)
|
||||
|
||||
out = out.view(batch, -1)
|
||||
out = self.final_linear(out)
|
||||
|
||||
return out
|
||||
|
||||
2
models/stylegan2/op/__init__.py
Normal file
2
models/stylegan2/op/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
||||
from .upfirdn2d import upfirdn2d
|
||||
BIN
models/stylegan2/op/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
models/stylegan2/op/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
models/stylegan2/op/__pycache__/fused_act.cpython-310.pyc
Normal file
BIN
models/stylegan2/op/__pycache__/fused_act.cpython-310.pyc
Normal file
Binary file not shown.
BIN
models/stylegan2/op/__pycache__/upfirdn2d.cpython-310.pyc
Normal file
BIN
models/stylegan2/op/__pycache__/upfirdn2d.cpython-310.pyc
Normal file
Binary file not shown.
40
models/stylegan2/op/fused_act.py
Normal file
40
models/stylegan2/op/fused_act.py
Normal file
@ -0,0 +1,40 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
module_path = os.path.dirname(__file__)
|
||||
|
||||
|
||||
|
||||
class FusedLeakyReLU(nn.Module):
|
||||
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
||||
super().__init__()
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(channel))
|
||||
self.negative_slope = negative_slope
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input):
|
||||
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
||||
|
||||
|
||||
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
||||
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
||||
input = input.cuda()
|
||||
if input.ndim == 3:
|
||||
return (
|
||||
F.leaky_relu(
|
||||
input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope
|
||||
)
|
||||
* scale #增益值,激活函数里的 gain(torch中scale) 是一个增益值,增益值是指的非线性函数稳态时输入幅度与输出幅度的比值,通常被用来乘在激活函数之后使激活函数更加稳定。
|
||||
)
|
||||
else:
|
||||
return (
|
||||
F.leaky_relu(
|
||||
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
|
||||
)
|
||||
* scale
|
||||
)
|
||||
|
||||
60
models/stylegan2/op/upfirdn2d.py
Normal file
60
models/stylegan2/op/upfirdn2d.py
Normal file
@ -0,0 +1,60 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
module_path = os.path.dirname(__file__)
|
||||
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
out = upfirdn2d_native(
|
||||
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def upfirdn2d_native(
|
||||
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
||||
):
|
||||
_, channel, in_h, in_w = input.shape
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = input.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(
|
||||
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
||||
)
|
||||
out = out[
|
||||
:,
|
||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
||||
:,
|
||||
]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape(
|
||||
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
||||
)
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
out = out[:, ::down_y, ::down_x, :]
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
||||
9
models/stylegan3/dnnlib/__init__.py
Normal file
9
models/stylegan3/dnnlib/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
from .util import EasyDict, make_cache_dir_path
|
||||
BIN
models/stylegan3/dnnlib/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
models/stylegan3/dnnlib/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
models/stylegan3/dnnlib/__pycache__/util.cpython-310.pyc
Normal file
BIN
models/stylegan3/dnnlib/__pycache__/util.cpython-310.pyc
Normal file
Binary file not shown.
491
models/stylegan3/dnnlib/util.py
Normal file
491
models/stylegan3/dnnlib/util.py
Normal file
@ -0,0 +1,491 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Miscellaneous utility classes and functions."""
|
||||
|
||||
import ctypes
|
||||
import fnmatch
|
||||
import importlib
|
||||
import inspect
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import types
|
||||
import io
|
||||
import pickle
|
||||
import re
|
||||
import requests
|
||||
import html
|
||||
import hashlib
|
||||
import glob
|
||||
import tempfile
|
||||
import urllib
|
||||
import urllib.request
|
||||
import uuid
|
||||
|
||||
from distutils.util import strtobool
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
|
||||
# Util classes
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EasyDict(dict):
|
||||
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError:
|
||||
raise AttributeError(name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
self[name] = value
|
||||
|
||||
def __delattr__(self, name: str) -> None:
|
||||
del self[name]
|
||||
|
||||
|
||||
class Logger(object):
|
||||
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
||||
|
||||
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
||||
self.file = None
|
||||
|
||||
if file_name is not None:
|
||||
self.file = open(file_name, file_mode)
|
||||
|
||||
self.should_flush = should_flush
|
||||
self.stdout = sys.stdout
|
||||
self.stderr = sys.stderr
|
||||
|
||||
sys.stdout = self
|
||||
sys.stderr = self
|
||||
|
||||
def __enter__(self) -> "Logger":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def write(self, text: Union[str, bytes]) -> None:
|
||||
"""Write text to stdout (and a file) and optionally flush."""
|
||||
if isinstance(text, bytes):
|
||||
text = text.decode()
|
||||
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
||||
return
|
||||
|
||||
if self.file is not None:
|
||||
self.file.write(text)
|
||||
|
||||
self.stdout.write(text)
|
||||
|
||||
if self.should_flush:
|
||||
self.flush()
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush written text to both stdout and a file, if open."""
|
||||
if self.file is not None:
|
||||
self.file.flush()
|
||||
|
||||
self.stdout.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
||||
self.flush()
|
||||
|
||||
# if using multiple loggers, prevent closing in wrong order
|
||||
if sys.stdout is self:
|
||||
sys.stdout = self.stdout
|
||||
if sys.stderr is self:
|
||||
sys.stderr = self.stderr
|
||||
|
||||
if self.file is not None:
|
||||
self.file.close()
|
||||
self.file = None
|
||||
|
||||
|
||||
# Cache directories
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
_dnnlib_cache_dir = None
|
||||
|
||||
def set_cache_dir(path: str) -> None:
|
||||
global _dnnlib_cache_dir
|
||||
_dnnlib_cache_dir = path
|
||||
|
||||
def make_cache_dir_path(*paths: str) -> str:
|
||||
if _dnnlib_cache_dir is not None:
|
||||
return os.path.join(_dnnlib_cache_dir, *paths)
|
||||
if 'DNNLIB_CACHE_DIR' in os.environ:
|
||||
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
||||
if 'HOME' in os.environ:
|
||||
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
||||
if 'USERPROFILE' in os.environ:
|
||||
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
||||
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
||||
|
||||
# Small util functions
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def format_time(seconds: Union[int, float]) -> str:
|
||||
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
||||
s = int(np.rint(seconds))
|
||||
|
||||
if s < 60:
|
||||
return "{0}s".format(s)
|
||||
elif s < 60 * 60:
|
||||
return "{0}m {1:02}s".format(s // 60, s % 60)
|
||||
elif s < 24 * 60 * 60:
|
||||
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
||||
else:
|
||||
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
||||
|
||||
|
||||
def format_time_brief(seconds: Union[int, float]) -> str:
|
||||
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
||||
s = int(np.rint(seconds))
|
||||
|
||||
if s < 60:
|
||||
return "{0}s".format(s)
|
||||
elif s < 60 * 60:
|
||||
return "{0}m {1:02}s".format(s // 60, s % 60)
|
||||
elif s < 24 * 60 * 60:
|
||||
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
||||
else:
|
||||
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
||||
|
||||
|
||||
def ask_yes_no(question: str) -> bool:
|
||||
"""Ask the user the question until the user inputs a valid answer."""
|
||||
while True:
|
||||
try:
|
||||
print("{0} [y/n]".format(question))
|
||||
return strtobool(input().lower())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def tuple_product(t: Tuple) -> Any:
|
||||
"""Calculate the product of the tuple elements."""
|
||||
result = 1
|
||||
|
||||
for v in t:
|
||||
result *= v
|
||||
|
||||
return result
|
||||
|
||||
|
||||
_str_to_ctype = {
|
||||
"uint8": ctypes.c_ubyte,
|
||||
"uint16": ctypes.c_uint16,
|
||||
"uint32": ctypes.c_uint32,
|
||||
"uint64": ctypes.c_uint64,
|
||||
"int8": ctypes.c_byte,
|
||||
"int16": ctypes.c_int16,
|
||||
"int32": ctypes.c_int32,
|
||||
"int64": ctypes.c_int64,
|
||||
"float32": ctypes.c_float,
|
||||
"float64": ctypes.c_double
|
||||
}
|
||||
|
||||
|
||||
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
||||
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
||||
type_str = None
|
||||
|
||||
if isinstance(type_obj, str):
|
||||
type_str = type_obj
|
||||
elif hasattr(type_obj, "__name__"):
|
||||
type_str = type_obj.__name__
|
||||
elif hasattr(type_obj, "name"):
|
||||
type_str = type_obj.name
|
||||
else:
|
||||
raise RuntimeError("Cannot infer type name from input")
|
||||
|
||||
assert type_str in _str_to_ctype.keys()
|
||||
|
||||
my_dtype = np.dtype(type_str)
|
||||
my_ctype = _str_to_ctype[type_str]
|
||||
|
||||
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
||||
|
||||
return my_dtype, my_ctype
|
||||
|
||||
|
||||
def is_pickleable(obj: Any) -> bool:
|
||||
try:
|
||||
with io.BytesIO() as stream:
|
||||
pickle.dump(obj, stream)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
# Functionality to import modules/objects by name, and call functions by name
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
||||
"""Searches for the underlying module behind the name to some python object.
|
||||
Returns the module and the object name (original name with module part removed)."""
|
||||
|
||||
# allow convenience shorthands, substitute them by full names
|
||||
obj_name = re.sub("^np.", "numpy.", obj_name)
|
||||
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
||||
|
||||
# list alternatives for (module_name, local_obj_name)
|
||||
parts = obj_name.split(".")
|
||||
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
||||
|
||||
# try each alternative in turn
|
||||
for module_name, local_obj_name in name_pairs:
|
||||
try:
|
||||
module = importlib.import_module(module_name) # may raise ImportError
|
||||
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
||||
return module, local_obj_name
|
||||
except:
|
||||
pass
|
||||
|
||||
# maybe some of the modules themselves contain errors?
|
||||
for module_name, _local_obj_name in name_pairs:
|
||||
try:
|
||||
importlib.import_module(module_name) # may raise ImportError
|
||||
except ImportError:
|
||||
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
||||
raise
|
||||
|
||||
# maybe the requested attribute is missing?
|
||||
for module_name, local_obj_name in name_pairs:
|
||||
try:
|
||||
module = importlib.import_module(module_name) # may raise ImportError
|
||||
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# we are out of luck, but we have no idea why
|
||||
raise ImportError(obj_name)
|
||||
|
||||
|
||||
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
||||
"""Traverses the object name and returns the last (rightmost) python object."""
|
||||
if obj_name == '':
|
||||
return module
|
||||
obj = module
|
||||
for part in obj_name.split("."):
|
||||
obj = getattr(obj, part)
|
||||
return obj
|
||||
|
||||
|
||||
def get_obj_by_name(name: str) -> Any:
|
||||
"""Finds the python object with the given name."""
|
||||
module, obj_name = get_module_from_obj_name(name)
|
||||
return get_obj_from_module(module, obj_name)
|
||||
|
||||
|
||||
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
||||
"""Finds the python object with the given name and calls it as a function."""
|
||||
assert func_name is not None
|
||||
func_obj = get_obj_by_name(func_name)
|
||||
assert callable(func_obj)
|
||||
return func_obj(*args, **kwargs)
|
||||
|
||||
|
||||
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
||||
"""Finds the python class with the given name and constructs it with the given arguments."""
|
||||
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
||||
|
||||
|
||||
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
||||
"""Get the directory path of the module containing the given object name."""
|
||||
module, _ = get_module_from_obj_name(obj_name)
|
||||
return os.path.dirname(inspect.getfile(module))
|
||||
|
||||
|
||||
def is_top_level_function(obj: Any) -> bool:
|
||||
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
||||
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
||||
|
||||
|
||||
def get_top_level_function_name(obj: Any) -> str:
|
||||
"""Return the fully-qualified name of a top-level function."""
|
||||
assert is_top_level_function(obj)
|
||||
module = obj.__module__
|
||||
if module == '__main__':
|
||||
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
||||
return module + "." + obj.__name__
|
||||
|
||||
|
||||
# File system helpers
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
||||
"""List all files recursively in a given directory while ignoring given file and directory names.
|
||||
Returns list of tuples containing both absolute and relative paths."""
|
||||
assert os.path.isdir(dir_path)
|
||||
base_name = os.path.basename(os.path.normpath(dir_path))
|
||||
|
||||
if ignores is None:
|
||||
ignores = []
|
||||
|
||||
result = []
|
||||
|
||||
for root, dirs, files in os.walk(dir_path, topdown=True):
|
||||
for ignore_ in ignores:
|
||||
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
||||
|
||||
# dirs need to be edited in-place
|
||||
for d in dirs_to_remove:
|
||||
dirs.remove(d)
|
||||
|
||||
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
||||
|
||||
absolute_paths = [os.path.join(root, f) for f in files]
|
||||
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
||||
|
||||
if add_base_to_relative:
|
||||
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
||||
|
||||
assert len(absolute_paths) == len(relative_paths)
|
||||
result += zip(absolute_paths, relative_paths)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
||||
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
||||
Will create all necessary directories."""
|
||||
for file in files:
|
||||
target_dir_name = os.path.dirname(file[1])
|
||||
|
||||
# will create all intermediate-level directories
|
||||
if not os.path.exists(target_dir_name):
|
||||
os.makedirs(target_dir_name)
|
||||
|
||||
shutil.copyfile(file[0], file[1])
|
||||
|
||||
|
||||
# URL helpers
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
||||
"""Determine whether the given object is a valid URL string."""
|
||||
if not isinstance(obj, str) or not "://" in obj:
|
||||
return False
|
||||
if allow_file_urls and obj.startswith('file://'):
|
||||
return True
|
||||
try:
|
||||
res = requests.compat.urlparse(obj)
|
||||
if not res.scheme or not res.netloc or not "." in res.netloc:
|
||||
return False
|
||||
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
||||
if not res.scheme or not res.netloc or not "." in res.netloc:
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
||||
"""Download the given URL and return a binary-mode file object to access the data."""
|
||||
assert num_attempts >= 1
|
||||
assert not (return_filename and (not cache))
|
||||
|
||||
# Doesn't look like an URL scheme so interpret it as a local filename.
|
||||
if not re.match('^[a-z]+://', url):
|
||||
return url if return_filename else open(url, "rb")
|
||||
|
||||
# Handle file URLs. This code handles unusual file:// patterns that
|
||||
# arise on Windows:
|
||||
#
|
||||
# file:///c:/foo.txt
|
||||
#
|
||||
# which would translate to a local '/c:/foo.txt' filename that's
|
||||
# invalid. Drop the forward slash for such pathnames.
|
||||
#
|
||||
# If you touch this code path, you should test it on both Linux and
|
||||
# Windows.
|
||||
#
|
||||
# Some internet resources suggest using urllib.request.url2pathname() but
|
||||
# but that converts forward slashes to backslashes and this causes
|
||||
# its own set of problems.
|
||||
if url.startswith('file://'):
|
||||
filename = urllib.parse.urlparse(url).path
|
||||
if re.match(r'^/[a-zA-Z]:', filename):
|
||||
filename = filename[1:]
|
||||
return filename if return_filename else open(filename, "rb")
|
||||
|
||||
assert is_url(url)
|
||||
|
||||
# Lookup from cache.
|
||||
if cache_dir is None:
|
||||
cache_dir = make_cache_dir_path('downloads')
|
||||
|
||||
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
||||
if cache:
|
||||
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
||||
if len(cache_files) == 1:
|
||||
filename = cache_files[0]
|
||||
return filename if return_filename else open(filename, "rb")
|
||||
|
||||
# Download.
|
||||
url_name = None
|
||||
url_data = None
|
||||
with requests.Session() as session:
|
||||
if verbose:
|
||||
print("Downloading %s ..." % url, end="", flush=True)
|
||||
for attempts_left in reversed(range(num_attempts)):
|
||||
try:
|
||||
with session.get(url) as res:
|
||||
res.raise_for_status()
|
||||
if len(res.content) == 0:
|
||||
raise IOError("No data received")
|
||||
|
||||
if len(res.content) < 8192:
|
||||
content_str = res.content.decode("utf-8")
|
||||
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
||||
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
||||
if len(links) == 1:
|
||||
url = requests.compat.urljoin(url, links[0])
|
||||
raise IOError("Google Drive virus checker nag")
|
||||
if "Google Drive - Quota exceeded" in content_str:
|
||||
raise IOError("Google Drive download quota exceeded -- please try again later")
|
||||
|
||||
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
||||
url_name = match[1] if match else url
|
||||
url_data = res.content
|
||||
if verbose:
|
||||
print(" done")
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except:
|
||||
if not attempts_left:
|
||||
if verbose:
|
||||
print(" failed")
|
||||
raise
|
||||
if verbose:
|
||||
print(".", end="", flush=True)
|
||||
|
||||
# Save to cache.
|
||||
if cache:
|
||||
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
||||
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
||||
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
with open(temp_file, "wb") as f:
|
||||
f.write(url_data)
|
||||
os.replace(temp_file, cache_file) # atomic
|
||||
if return_filename:
|
||||
return cache_file
|
||||
|
||||
# Return data as file object.
|
||||
assert not return_filename
|
||||
return io.BytesIO(url_data)
|
||||
529
models/stylegan3/model_3.py
Normal file
529
models/stylegan3/model_3.py
Normal file
@ -0,0 +1,529 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Generator architecture from the paper
|
||||
"Alias-Free Generative Adversarial Networks"."""
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
import scipy.optimize
|
||||
import torch
|
||||
from torch_utils import misc
|
||||
from torch_utils import persistence
|
||||
from torch_utils.ops import conv2d_gradfix
|
||||
from torch_utils.ops import filtered_lrelu
|
||||
from torch_utils.ops import bias_act
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def modulated_conv2d(
|
||||
x, # Input tensor: [batch_size, in_channels, in_height, in_width]
|
||||
w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width]
|
||||
s, # Style tensor: [batch_size, in_channels]
|
||||
demodulate = True, # Apply weight demodulation?
|
||||
padding = 0, # Padding: int or [padH, padW]
|
||||
input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels]
|
||||
):
|
||||
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
||||
batch_size = int(x.shape[0])
|
||||
out_channels, in_channels, kh, kw = w.shape
|
||||
misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk]
|
||||
misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
|
||||
misc.assert_shape(s, [batch_size, in_channels]) # [NI]
|
||||
|
||||
# Pre-normalize inputs.
|
||||
if demodulate:
|
||||
w = w * w.square().mean([1,2,3], keepdim=True).rsqrt()
|
||||
s = s * s.square().mean().rsqrt()
|
||||
|
||||
# Modulate weights.
|
||||
w = w.unsqueeze(0) # [NOIkk]
|
||||
w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
|
||||
|
||||
# Demodulate weights.
|
||||
if demodulate:
|
||||
dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
|
||||
w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk]
|
||||
|
||||
# Apply input scaling.
|
||||
if input_gain is not None:
|
||||
input_gain = input_gain.expand(batch_size, in_channels) # [NI]
|
||||
w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
|
||||
|
||||
# Execute as one fused op using grouped convolution.
|
||||
x = x.reshape(1, -1, *x.shape[2:])
|
||||
w = w.reshape(-1, in_channels, kh, kw)
|
||||
x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size)
|
||||
x = x.reshape(batch_size, -1, *x.shape[2:])
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class FullyConnectedLayer(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_features, # Number of input features.
|
||||
out_features, # Number of output features.
|
||||
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
||||
bias = True, # Apply additive bias before the activation function?
|
||||
lr_multiplier = 1, # Learning rate multiplier.
|
||||
weight_init = 1, # Initial standard deviation of the weight tensor.
|
||||
bias_init = 0, # Initial value of the additive bias.
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.activation = activation
|
||||
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier))
|
||||
bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])
|
||||
self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None
|
||||
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
||||
self.bias_gain = lr_multiplier
|
||||
|
||||
def forward(self, x):
|
||||
w = self.weight.to(x.dtype) * self.weight_gain
|
||||
b = self.bias
|
||||
if b is not None:
|
||||
b = b.to(x.dtype)
|
||||
if self.bias_gain != 1:
|
||||
b = b * self.bias_gain
|
||||
if self.activation == 'linear' and b is not None:
|
||||
x = torch.addmm(b.unsqueeze(0), x, w.t())
|
||||
else:
|
||||
x = x.matmul(w.t())
|
||||
x = bias_act.bias_act(x, b, act=self.activation)
|
||||
return x
|
||||
|
||||
def extra_repr(self):
|
||||
return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class MappingNetwork(torch.nn.Module):
|
||||
def __init__(self,
|
||||
z_dim, # Input latent (Z) dimensionality.
|
||||
c_dim, # Conditioning label (C) dimensionality, 0 = no labels.
|
||||
w_dim, # Intermediate latent (W) dimensionality.
|
||||
num_ws, # Number of intermediate latents to output.
|
||||
num_layers = 2, # Number of mapping layers.
|
||||
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
|
||||
w_avg_beta = 0.998, # Decay for tracking the moving average of W during training.
|
||||
):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
self.c_dim = c_dim
|
||||
self.w_dim = w_dim
|
||||
self.num_ws = num_ws
|
||||
self.num_layers = num_layers
|
||||
self.w_avg_beta = w_avg_beta
|
||||
|
||||
# Construct layers.
|
||||
self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None
|
||||
features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers
|
||||
for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
|
||||
layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
|
||||
setattr(self, f'fc{idx}', layer)
|
||||
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
||||
|
||||
def forward(self, z, c=0, truncation_psi=1, truncation_cutoff=None, update_emas=False):
|
||||
#将传入的z由list改为tensor 好像改得不对,还是别改把
|
||||
# z = torch.tensor( [item.cpu().detach().numpy() for item in z] )
|
||||
misc.assert_shape(z, [None, self.z_dim])
|
||||
if truncation_cutoff is None:
|
||||
truncation_cutoff = self.num_ws
|
||||
|
||||
# Embed, normalize, and concatenate inputs.
|
||||
x = z.to(torch.float32)
|
||||
x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
|
||||
if self.c_dim > 0:
|
||||
misc.assert_shape(c, [None, self.c_dim])
|
||||
y = self.embed(c.to(torch.float32))
|
||||
y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt()
|
||||
x = torch.cat([x, y], dim=1) if x is not None else y
|
||||
|
||||
# Execute layers.
|
||||
for idx in range(self.num_layers):
|
||||
x = getattr(self, f'fc{idx}')(x)
|
||||
|
||||
# Update moving average of W.
|
||||
if update_emas:
|
||||
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
||||
|
||||
# Broadcast and apply truncation.
|
||||
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
||||
if truncation_psi != 1:
|
||||
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
|
||||
return x
|
||||
|
||||
def extra_repr(self):
|
||||
return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class SynthesisInput(torch.nn.Module):
|
||||
def __init__(self,
|
||||
w_dim, # Intermediate latent (W) dimensionality.
|
||||
channels, # Number of output channels.
|
||||
size, # Output spatial size: int or [width, height].
|
||||
sampling_rate, # Output sampling rate.
|
||||
bandwidth, # Output bandwidth.
|
||||
):
|
||||
super().__init__()
|
||||
self.w_dim = w_dim
|
||||
self.channels = channels
|
||||
self.size = np.broadcast_to(np.asarray(size), [2])
|
||||
self.sampling_rate = sampling_rate
|
||||
self.bandwidth = bandwidth
|
||||
|
||||
# Draw random frequencies from uniform 2D disc.
|
||||
freqs = torch.randn([self.channels, 2])
|
||||
radii = freqs.square().sum(dim=1, keepdim=True).sqrt()
|
||||
freqs /= radii * radii.square().exp().pow(0.25)
|
||||
freqs *= bandwidth
|
||||
phases = torch.rand([self.channels]) - 0.5
|
||||
|
||||
# Setup parameters and buffers.
|
||||
self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels]))
|
||||
self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0])
|
||||
self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image.
|
||||
self.register_buffer('freqs', freqs)
|
||||
self.register_buffer('phases', phases)
|
||||
|
||||
def forward(self, w):
|
||||
# Introduce batch dimension.
|
||||
transforms = self.transform.unsqueeze(0) # [batch, row, col]
|
||||
freqs = self.freqs.unsqueeze(0) # [batch, channel, xy]
|
||||
phases = self.phases.unsqueeze(0) # [batch, channel]
|
||||
|
||||
# Apply learned transformation.
|
||||
t = self.affine(w) # t = (r_c, r_s, t_x, t_y)
|
||||
t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y)
|
||||
m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image.
|
||||
m_r[:, 0, 0] = t[:, 0] # r'_c
|
||||
m_r[:, 0, 1] = -t[:, 1] # r'_s
|
||||
m_r[:, 1, 0] = t[:, 1] # r'_s
|
||||
m_r[:, 1, 1] = t[:, 0] # r'_c
|
||||
m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image.
|
||||
m_t[:, 0, 2] = -t[:, 2] # t'_x
|
||||
m_t[:, 1, 2] = -t[:, 3] # t'_y
|
||||
transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform.
|
||||
|
||||
# Transform frequencies.
|
||||
phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2)
|
||||
freqs = freqs @ transforms[:, :2, :2]
|
||||
|
||||
# Dampen out-of-band frequencies that may occur due to the user-specified transform.
|
||||
amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1)
|
||||
|
||||
# Construct sampling grid.
|
||||
theta = torch.eye(2, 3, device=w.device)
|
||||
theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate
|
||||
theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate
|
||||
grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False)
|
||||
|
||||
# Compute Fourier features.
|
||||
x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel]
|
||||
x = x + phases.unsqueeze(1).unsqueeze(2)
|
||||
x = torch.sin(x * (np.pi * 2))
|
||||
x = x * amplitudes.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Apply trainable mapping.
|
||||
weight = self.weight / np.sqrt(self.channels)
|
||||
x = x @ weight.t()
|
||||
|
||||
# Ensure correct shape.
|
||||
x = x.permute(0, 3, 1, 2) # [batch, channel, height, width]
|
||||
misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])])
|
||||
return x
|
||||
|
||||
def extra_repr(self):
|
||||
return '\n'.join([
|
||||
f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},',
|
||||
f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}'])
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class SynthesisLayer(torch.nn.Module):
|
||||
def __init__(self,
|
||||
w_dim, # Intermediate latent (W) dimensionality.
|
||||
is_torgb, # Is this the final ToRGB layer?
|
||||
is_critically_sampled, # Does this layer use critical sampling?
|
||||
use_fp16, # Does this layer use FP16?
|
||||
|
||||
# Input & output specifications.
|
||||
in_channels, # Number of input channels.
|
||||
out_channels, # Number of output channels.
|
||||
in_size, # Input spatial size: int or [width, height].
|
||||
out_size, # Output spatial size: int or [width, height].
|
||||
in_sampling_rate, # Input sampling rate (s).
|
||||
out_sampling_rate, # Output sampling rate (s).
|
||||
in_cutoff, # Input cutoff frequency (f_c).
|
||||
out_cutoff, # Output cutoff frequency (f_c).
|
||||
in_half_width, # Input transition band half-width (f_h).
|
||||
out_half_width, # Output Transition band half-width (f_h).
|
||||
|
||||
# Hyperparameters.
|
||||
conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer.
|
||||
filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling.
|
||||
lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer.
|
||||
use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers.
|
||||
conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping.
|
||||
magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes.
|
||||
):
|
||||
super().__init__()
|
||||
self.w_dim = w_dim
|
||||
self.is_torgb = is_torgb
|
||||
self.is_critically_sampled = is_critically_sampled
|
||||
self.use_fp16 = use_fp16
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.in_size = np.broadcast_to(np.asarray(in_size), [2])
|
||||
self.out_size = np.broadcast_to(np.asarray(out_size), [2])
|
||||
self.in_sampling_rate = in_sampling_rate
|
||||
self.out_sampling_rate = out_sampling_rate
|
||||
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling)
|
||||
self.in_cutoff = in_cutoff
|
||||
self.out_cutoff = out_cutoff
|
||||
self.in_half_width = in_half_width
|
||||
self.out_half_width = out_half_width
|
||||
self.conv_kernel = 1 if is_torgb else conv_kernel
|
||||
self.conv_clamp = conv_clamp
|
||||
self.magnitude_ema_beta = magnitude_ema_beta
|
||||
|
||||
# Setup parameters and buffers.
|
||||
self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1)
|
||||
self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel]))
|
||||
self.bias = torch.nn.Parameter(torch.zeros([self.out_channels]))
|
||||
self.register_buffer('magnitude_ema', torch.ones([]))
|
||||
|
||||
# Design upsampling filter.
|
||||
self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate))
|
||||
assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate
|
||||
self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1
|
||||
self.register_buffer('up_filter', self.design_lowpass_filter(
|
||||
numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate))
|
||||
|
||||
# Design downsampling filter.
|
||||
self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate))
|
||||
assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate
|
||||
self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1
|
||||
self.down_radial = use_radial_filters and not self.is_critically_sampled
|
||||
self.register_buffer('down_filter', self.design_lowpass_filter(
|
||||
numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial))
|
||||
|
||||
# Compute padding.
|
||||
pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling.
|
||||
pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling.
|
||||
pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters.
|
||||
pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3).
|
||||
pad_hi = pad_total - pad_lo
|
||||
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])]
|
||||
|
||||
def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False):
|
||||
assert noise_mode in ['random', 'const', 'none'] # unused
|
||||
misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])])
|
||||
misc.assert_shape(w, [x.shape[0], self.w_dim])
|
||||
|
||||
# Track input magnitude.
|
||||
if update_emas:
|
||||
with torch.autograd.profiler.record_function('update_magnitude_ema'):
|
||||
magnitude_cur = x.detach().to(torch.float32).square().mean()
|
||||
self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta))
|
||||
input_gain = self.magnitude_ema.rsqrt()
|
||||
|
||||
# Execute affine layer.
|
||||
styles = self.affine(w)
|
||||
if self.is_torgb:
|
||||
weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2))
|
||||
styles = styles * weight_gain
|
||||
|
||||
# Execute modulated conv2d.
|
||||
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
|
||||
x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles,
|
||||
padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain)
|
||||
|
||||
# Execute bias, filtered leaky ReLU, and clamping.
|
||||
gain = 1 if self.is_torgb else np.sqrt(2)
|
||||
slope = 1 if self.is_torgb else 0.2
|
||||
x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype),
|
||||
up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp)
|
||||
|
||||
# Ensure correct shape and dtype.
|
||||
misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])])
|
||||
assert x.dtype == dtype
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False):
|
||||
assert numtaps >= 1
|
||||
|
||||
# Identity filter.
|
||||
if numtaps == 1:
|
||||
return None
|
||||
|
||||
# Separable Kaiser low-pass filter.
|
||||
if not radial:
|
||||
f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs)
|
||||
return torch.as_tensor(f, dtype=torch.float32)
|
||||
|
||||
# Radially symmetric jinc-based filter.
|
||||
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs
|
||||
r = np.hypot(*np.meshgrid(x, x))
|
||||
f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r)
|
||||
beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2)))
|
||||
w = np.kaiser(numtaps, beta)
|
||||
f *= np.outer(w, w)
|
||||
f /= np.sum(f)
|
||||
return torch.as_tensor(f, dtype=torch.float32)
|
||||
|
||||
def extra_repr(self):
|
||||
return '\n'.join([
|
||||
f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},',
|
||||
f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},',
|
||||
f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},',
|
||||
f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},',
|
||||
f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},',
|
||||
f'in_size={list(self.in_size)}, out_size={list(self.out_size)},',
|
||||
f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}'])
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class SynthesisNetwork(torch.nn.Module):
|
||||
def __init__(self,
|
||||
w_dim, # Intermediate latent (W) dimensionality. 512
|
||||
img_resolution, # Output image resolution. 1024
|
||||
img_channels, # Number of color channels. 3
|
||||
channel_base = 32768, # Overall multiplier for the number of channels.通道总体倍增因子
|
||||
channel_max = 512, # Maximum number of channels in any layer.
|
||||
num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB.
|
||||
num_critical = 2, # Number of critically sampled layers at the end.
|
||||
first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}).
|
||||
first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}).
|
||||
last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff.
|
||||
margin_size = 10, # Number of additional pixels outside the image.
|
||||
output_scale = 0.25, # Scale factor for the output image.
|
||||
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
|
||||
**layer_kwargs, # Arguments for SynthesisLayer.
|
||||
):
|
||||
super().__init__()
|
||||
self.w_dim = w_dim
|
||||
self.num_ws = num_layers + 2
|
||||
self.img_resolution = img_resolution
|
||||
self.img_channels = img_channels
|
||||
self.num_layers = num_layers
|
||||
self.num_critical = num_critical
|
||||
self.margin_size = margin_size
|
||||
self.output_scale = output_scale
|
||||
self.num_fp16_res = num_fp16_res
|
||||
|
||||
# Geometric progression of layer cutoffs and min. stopbands.
|
||||
last_cutoff = self.img_resolution / 2 # f_{c,N}
|
||||
last_stopband = last_cutoff * last_stopband_rel # f_{t,N}
|
||||
exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1)
|
||||
cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i] [ 2. 3.1748021 5.0396842 8. 12.69920842, 20.1587368 32. 50.79683366 80.63494719 128., 203.18733465 322.53978877 512. 512. 512. ]
|
||||
stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i]
|
||||
|
||||
# Compute remaining layer parameters.
|
||||
sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i]
|
||||
half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i]
|
||||
sizes = sampling_rates + self.margin_size * 2
|
||||
sizes[-2:] = self.img_resolution
|
||||
channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max))
|
||||
channels[-1] = self.img_channels
|
||||
|
||||
# Construct layers.
|
||||
self.input = SynthesisInput(
|
||||
w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]), #sizes:[ 36. 36. 52. 52. 84. 148. 148. 276. 276. 532. 1044. 1044., 1044. 1024. 1024.]
|
||||
sampling_rate=sampling_rates[0], bandwidth=cutoffs[0]) #sampling_rates :[ 16. 16. 32. 32. 64. 128. 128. 256. 256. 512. 1024. 1024., 1024. 1024. 1024.]
|
||||
self.layer_names = []
|
||||
for idx in range(self.num_layers + 1):
|
||||
prev = max(idx - 1, 0)
|
||||
is_torgb = (idx == self.num_layers)
|
||||
is_critically_sampled = (idx >= self.num_layers - self.num_critical)
|
||||
use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution)
|
||||
layer = SynthesisLayer(
|
||||
w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16,
|
||||
in_channels=int(channels[prev]), out_channels= int(channels[idx]),
|
||||
in_size=int(sizes[prev]), out_size=int(sizes[idx]),
|
||||
in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]),
|
||||
in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx],
|
||||
in_half_width=half_widths[prev], out_half_width=half_widths[idx],
|
||||
**layer_kwargs)
|
||||
name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}'
|
||||
setattr(self, name, layer)
|
||||
self.layer_names.append(name)
|
||||
|
||||
def forward(self, ws, **layer_kwargs):
|
||||
misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
|
||||
ws = ws.to(torch.float32).unbind(dim=1)
|
||||
|
||||
# Execute layers.
|
||||
x = self.input(ws[0])
|
||||
for name, w in zip(self.layer_names, ws[1:]):
|
||||
x = getattr(self, name)(x, w, **layer_kwargs)
|
||||
if self.output_scale != 1:
|
||||
x = x * self.output_scale
|
||||
|
||||
# Ensure correct shape and dtype.
|
||||
misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution])
|
||||
x = x.to(torch.float32)
|
||||
return x
|
||||
|
||||
def extra_repr(self):
|
||||
return '\n'.join([
|
||||
f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
|
||||
f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
|
||||
f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},',
|
||||
f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}'])
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@persistence.persistent_class
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self,
|
||||
z_dim, # Input latent (Z) dimensionality.
|
||||
c_dim, # Conditioning label (C) dimensionality.
|
||||
w_dim, # Intermediate latent (W) dimensionality.
|
||||
img_resolution, # Output resolution.
|
||||
img_channels, # Number of output color channels.
|
||||
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
||||
**synthesis_kwargs, # Arguments for SynthesisNetwork.
|
||||
):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim #512
|
||||
self.c_dim = c_dim #0
|
||||
self.w_dim = w_dim #512
|
||||
self.img_resolution = img_resolution
|
||||
self.img_channels = img_channels
|
||||
self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
|
||||
self.num_ws = self.synthesis.num_ws #16
|
||||
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
|
||||
|
||||
# def mean_latent(self, n_latent):
|
||||
# latent_in = torch.randn(
|
||||
# #此处的style_dim应与w_dim对应
|
||||
# n_latent, self.w_dim, device=self.synthesis.input.weight.device
|
||||
# )
|
||||
# latent = self.synthesis.styles(latent_in).mean(0, keepdim=True)
|
||||
#
|
||||
# return latent
|
||||
|
||||
def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
|
||||
# print("-----------------------------------")
|
||||
# print(z)
|
||||
# print("-----------------------------------")
|
||||
ws = self.mapping(z, c = None, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
|
||||
img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
|
||||
return img
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
267
models/stylegan3/run_optimization3.py
Normal file
267
models/stylegan3/run_optimization3.py
Normal file
@ -0,0 +1,267 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import torchvision
|
||||
from torch import optim
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import clip
|
||||
|
||||
|
||||
class CLIPLoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, opts):
|
||||
super(CLIPLoss, self).__init__()
|
||||
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
|
||||
self.upsample = torch.nn.Upsample(scale_factor=7)
|
||||
self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
|
||||
|
||||
def forward(self, image, text):
|
||||
image = self.avg_pool(self.upsample(image))
|
||||
similarity = 1 - self.model(image, text)[0] / 100
|
||||
return similarity
|
||||
|
||||
|
||||
from torch import nn
|
||||
import sys
|
||||
sys.path.append('/home/ly/StyleCLIP-main/models/facial_recognition')
|
||||
from model_irse import Backbone
|
||||
|
||||
|
||||
class IDLoss(nn.Module):
|
||||
def __init__(self, opts):
|
||||
super(IDLoss, self).__init__()
|
||||
print('Loading ResNet ArcFace')
|
||||
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
||||
self.facenet.load_state_dict(torch.load(opts.ir_se50_weights))
|
||||
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
||||
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
||||
self.facenet.eval()
|
||||
self.facenet.cuda()
|
||||
self.opts = opts
|
||||
|
||||
def extract_feats(self, x):
|
||||
if x.shape[2] != 256:
|
||||
x = self.pool(x)
|
||||
x = x[:, :, 35:223, 32:220] # Crop interesting region
|
||||
x = self.face_pool(x)
|
||||
x_feats = self.facenet(x)
|
||||
return x_feats
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
n_samples = y.shape[0]
|
||||
y_feats = self.extract_feats(y) # Otherwise use the feature from there
|
||||
y_hat_feats = self.extract_feats(y_hat)
|
||||
y_feats = y_feats.detach()
|
||||
loss = 0
|
||||
sim_improvement = 0
|
||||
count = 0
|
||||
for i in range(n_samples):
|
||||
diff_target = y_hat_feats[i].dot(y_feats[i])
|
||||
loss += 1 - diff_target
|
||||
count += 1
|
||||
|
||||
return loss / count, sim_improvement / count
|
||||
sys.path.append('/home/ly/StyleCLIP-main/mapper/training')
|
||||
from train_utils import STYLESPACE_DIMENSIONS
|
||||
from model_3 import Generator
|
||||
from model_3 import SynthesisNetwork
|
||||
from model_3 import SynthesisLayer
|
||||
|
||||
|
||||
sys.path.append('/home/ly/StyleCLIP-main')
|
||||
from utils import ensure_checkpoint_exists
|
||||
|
||||
STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in list(range(1, len(STYLESPACE_DIMENSIONS), 3))]
|
||||
|
||||
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
|
||||
lr_ramp = min(1, (1 - t) / rampdown)
|
||||
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
|
||||
lr_ramp = lr_ramp * min(1, t / rampup)
|
||||
|
||||
return initial_lr * lr_ramp
|
||||
|
||||
|
||||
def main(args):
|
||||
ensure_checkpoint_exists(args.ckpt)
|
||||
# 把描述加载进clip预训练模型里面去
|
||||
text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
|
||||
# print('text_input是: ', text_inputs)
|
||||
#tokenizer clip分词的机制 依据规则
|
||||
#以及词汇表的总量
|
||||
'''
|
||||
--description "a person with purple hair"
|
||||
tensor([[49406, 320, 2533, 593, 5496, 2225, 49407, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0]], device='cuda:0',
|
||||
dtype=torch.int32)
|
||||
--description "a person with red hair"
|
||||
tensor([[49406, 320, 2533, 593, 736, 2225, 49407, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0]], device='cuda:0',
|
||||
dtype=torch.int32)
|
||||
'''
|
||||
|
||||
os.makedirs(args.results_dir, exist_ok=True)
|
||||
#改成stylegan3的输入
|
||||
|
||||
# with open('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl', 'rb') as f:
|
||||
# G = pickle.load(f)['G_ema'].cuda() # torch.nn.Module
|
||||
# z = torch.randn([1, G.z_dim]).cuda() # latent codes
|
||||
# c = None # class labels (not used in this example)
|
||||
# img = G(z, c) # NCHW, float32, dynamic range [-1, +1], no truncation
|
||||
|
||||
# g_ema = Generator(512, 0, 512,args.stylegan_size, 3) #512,0,512,1024,3
|
||||
# with open('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl', 'rb') as f:
|
||||
#stylegan3-r-ffhqu-1024x1024.pkl 生成图片的效果欠佳 别用
|
||||
#stylegan3-t-ffhq-1024x1024.pkl 生成效果一般 loss值较好
|
||||
#stylegan3-r-ffhq-1024x1024.pkl 折中
|
||||
#stylegan3-t-ffhqu-1024x1024.pkl 生成图片可以 loss较差
|
||||
with open('/home/ly/StyleCLIP-main/pretrained_models/stylegan3-t-ffhq-1024x1024.pkl', 'rb') as f: #stylespace_dimensions [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 256, 256, 128, 128, 128, 64, 64, 64, 32, 32]
|
||||
# new_p = pickle.load(f)
|
||||
# print(new_p)
|
||||
# print("new_p")
|
||||
# print(new_p.keys())
|
||||
# G_ema.load_state_dict(pickle.load(f)['G_ema'].cuda(), strict=False) 这种方式模型加载不进来
|
||||
g_ema = pickle.load(f)['G_ema'].cuda() # torch.nn.Module 这种方式推演三百步的图片平均要4分钟
|
||||
z = torch.randn([1, g_ema.z_dim]).cuda() # latent codes
|
||||
c = None # class labels (not used in this example)
|
||||
#g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
|
||||
# 将模型对象设置为评估模式
|
||||
g_ema.eval()
|
||||
#更改cuda卡号
|
||||
g_ema = g_ema.cuda()
|
||||
# device = torch.cuda.current_device()
|
||||
# print('cuda:',device)
|
||||
mean_latent = torch.randn([1, g_ema.z_dim]).cuda()
|
||||
torch.save(mean_latent,'/home/ly/StyleCLIP-main/pretrained_models/latent_code/style3.pt')
|
||||
# print('mean_latent: ', mean_latent)
|
||||
|
||||
if args.latent_path:
|
||||
latent_code_init = torch.load(args.latent_path).cuda()
|
||||
# elif args.mode == "edit":
|
||||
# latent_code_init_not_trunc = torch.randn(1, 512).cuda()
|
||||
# with torch.no_grad():
|
||||
# _, latent_code_init, _ = g_ema([latent_code_init_not_trunc], return_latents=True,
|
||||
# truncation=args.truncation, truncation_latent=mean_latent)
|
||||
else:
|
||||
# latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1) #在维度1上重复18次
|
||||
latent_code_init = mean_latent.detach().clone()
|
||||
# def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
|
||||
with torch.no_grad():
|
||||
print("mean_latent ", mean_latent.shape)
|
||||
# img_orig, _ = g_ema([latent_code_init], c, input_is_latent=True, randomize_noise=False)
|
||||
img_orig = g_ema(latent_code_init, c)
|
||||
|
||||
if args.work_in_stylespace:
|
||||
with torch.no_grad():
|
||||
_, _, latent_code_init = g_ema([latent_code_init], input_is_latent=True, return_latents=True)
|
||||
latent = [s.detach().clone() for s in latent_code_init]
|
||||
for c, s in enumerate(latent):
|
||||
if c in STYLESPACE_INDICES_WITHOUT_TORGB:
|
||||
s.requires_grad = True
|
||||
else:
|
||||
latent = latent_code_init.detach().clone()
|
||||
latent.requires_grad = True
|
||||
|
||||
clip_loss = CLIPLoss(args)
|
||||
id_loss = IDLoss(args)
|
||||
|
||||
if args.work_in_stylespace:
|
||||
optimizer = optim.Adam(latent, lr=args.lr)
|
||||
else:
|
||||
optimizer = optim.Adam([latent], lr=args.lr)
|
||||
|
||||
pbar = tqdm(range(args.step))
|
||||
|
||||
for i in pbar:
|
||||
t = i / args.step
|
||||
lr = get_lr(t, args.lr)
|
||||
optimizer.param_groups[0]["lr"] = lr
|
||||
|
||||
img_gen = g_ema(latent,c)
|
||||
|
||||
c_loss = clip_loss(img_gen, text_inputs)
|
||||
|
||||
if args.id_lambda > 0:
|
||||
#身份损失
|
||||
i_loss = id_loss(img_gen, img_orig)[0]
|
||||
else:
|
||||
i_loss = 0
|
||||
|
||||
if args.mode == "edit":
|
||||
if args.work_in_stylespace:
|
||||
l2_loss = sum([((latent_code_init[c] - latent[c]) ** 2).sum() for c in range(len(latent_code_init))])
|
||||
else:
|
||||
#与潜在空间的L2距离
|
||||
l2_loss = ((latent_code_init - latent) ** 2).sum()
|
||||
loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss
|
||||
else:
|
||||
loss = c_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
pbar.set_description(
|
||||
(
|
||||
f"loss: {loss.item():.4f};"
|
||||
)
|
||||
)
|
||||
if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0:
|
||||
with torch.no_grad():
|
||||
img_gen = g_ema(latent, c)
|
||||
|
||||
torchvision.utils.save_image(img_gen, f"results/stygan3Clip/{str(i).zfill(5)}.jpg", normalize=True, range=(-1, 1))
|
||||
|
||||
if args.mode == "edit":
|
||||
final_result = torch.cat([img_orig, img_gen])
|
||||
else:
|
||||
final_result = img_gen
|
||||
|
||||
return final_result
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--description", type=str, default="a person with purple hair", help="the text that guides the editing/generation")
|
||||
parser.add_argument("--ckpt", type=str, default="../pretrained_models/stylegan2-ffhq-config-f.pt", help="pretrained StyleGAN2 weights")
|
||||
parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution")
|
||||
parser.add_argument("--lr_rampup", type=float, default=0.05)
|
||||
parser.add_argument("--lr", type=float, default=0.1)
|
||||
parser.add_argument("--step", type=int, default=300, help="number of optimization steps")
|
||||
parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"], help="choose between edit an image an generate a free one")
|
||||
parser.add_argument("--l2_lambda", type=float, default=0.008, help="weight of the latent distance (used for editing only)")
|
||||
parser.add_argument("--id_lambda", type=float, default=0.000, help="weight of id loss (used for editing only)")
|
||||
parser.add_argument("--latent_path", type=str, default=None, help="starts the optimization from the given latent code if provided. Otherwose, starts from"
|
||||
"the mean latent in a free generation, and from a random one in editing. "
|
||||
"Expects a .pt format")
|
||||
parser.add_argument("--truncation", type=float, default=1, help="used only for the initial latent vector, and only when a latent code path is"
|
||||
"not provided")
|
||||
parser.add_argument('--work_in_stylespace', default=False, action='store_true')
|
||||
parser.add_argument("--save_intermediate_image_every", type=int, default=20, help="if > 0 then saves intermidate results during the optimization")
|
||||
parser.add_argument("--results_dir", type=str, default="results")
|
||||
parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str,
|
||||
help="Path to facial recognition network used in ID loss")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
result_image = main(args)
|
||||
|
||||
torchvision.utils.save_image(result_image.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"), normalize=True, scale_each=True, range=(-1, 1))
|
||||
|
||||
|
||||
194
models/stylegan3/show_pkl.py
Normal file
194
models/stylegan3/show_pkl.py
Normal file
@ -0,0 +1,194 @@
|
||||
# show_pkl.py
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
import torch
|
||||
sys.path.append('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils')
|
||||
|
||||
#
|
||||
# path = '/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl' # path='/root/……/aus_openface.pkl' pkl文件所在路径
|
||||
#
|
||||
# f = open(path, 'rb')
|
||||
# data = pickle.load(f)
|
||||
#
|
||||
# print(data)
|
||||
# print(len(data))
|
||||
# print(data.shape)
|
||||
|
||||
with open('/home/ly/StyleCLIP-main/models/stylegan3/torch_utils/stylegan3-r-afhqv2-512x512.pkl', 'rb') as f:
|
||||
G = pickle.load(f)['G_ema'].cuda() # torch.nn.Module
|
||||
z = torch.randn([1, G.z_dim]).cuda() # latent codes
|
||||
c = None # class labels (not used in this example)
|
||||
img = G(z, c) # NCHW, float32, dynamic range [-1, +1], no truncation
|
||||
print(G)
|
||||
|
||||
|
||||
#输出
|
||||
# Generator(
|
||||
# (synthesis): SynthesisNetwork(
|
||||
# w_dim=512, num_ws=16,
|
||||
# img_resolution=512, img_channels=3,
|
||||
# num_layers=14, num_critical=2,
|
||||
# margin_size=10, num_fp16_res=4
|
||||
# (input): SynthesisInput(
|
||||
# w_dim=512, channels=1024, size=[36, 36],
|
||||
# sampling_rate=16, bandwidth=2
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=4, activation=linear)
|
||||
# )
|
||||
# (L0_36_1024): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=False, use_fp16=False,
|
||||
# in_sampling_rate=16, out_sampling_rate=16,
|
||||
# in_cutoff=2, out_cutoff=2,
|
||||
# in_half_width=6, out_half_width=6,
|
||||
# in_size=[36, 36], out_size=[36, 36],
|
||||
# in_channels=1024, out_channels=1024
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
|
||||
# )
|
||||
# (L1_36_1024): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=False, use_fp16=False,
|
||||
# in_sampling_rate=16, out_sampling_rate=16,
|
||||
# in_cutoff=2, out_cutoff=2.99661,
|
||||
# in_half_width=6, out_half_width=5.00339,
|
||||
# in_size=[36, 36], out_size=[36, 36],
|
||||
# in_channels=1024, out_channels=1024
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
|
||||
# )
|
||||
# (L2_52_1024): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=False, use_fp16=False,
|
||||
# in_sampling_rate=16, out_sampling_rate=32,
|
||||
# in_cutoff=2.99661, out_cutoff=4.48985,
|
||||
# in_half_width=5.00339, out_half_width=11.5102,
|
||||
# in_size=[36, 36], out_size=[52, 52],
|
||||
# in_channels=1024, out_channels=1024
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
|
||||
# )
|
||||
# (L3_52_1024): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=False, use_fp16=False,
|
||||
# in_sampling_rate=32, out_sampling_rate=32,
|
||||
# in_cutoff=4.48985, out_cutoff=6.72717,
|
||||
# in_half_width=11.5102, out_half_width=9.27283,
|
||||
# in_size=[52, 52], out_size=[52, 52],
|
||||
# in_channels=1024, out_channels=1024
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
|
||||
# )
|
||||
# (L4_84_1024): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=False, use_fp16=True,
|
||||
# in_sampling_rate=32, out_sampling_rate=64,
|
||||
# in_cutoff=6.72717, out_cutoff=10.0794,
|
||||
# in_half_width=9.27283, out_half_width=21.9206,
|
||||
# in_size=[52, 52], out_size=[84, 84],
|
||||
# in_channels=1024, out_channels=1024
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
|
||||
# )
|
||||
# (L5_84_1024): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=False, use_fp16=True,
|
||||
# in_sampling_rate=64, out_sampling_rate=64,
|
||||
# in_cutoff=10.0794, out_cutoff=15.102,
|
||||
# in_half_width=21.9206, out_half_width=16.898,
|
||||
# in_size=[84, 84], out_size=[84, 84],
|
||||
# in_channels=1024, out_channels=1024
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
|
||||
# )
|
||||
# (L6_148_1024): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=False, use_fp16=True,
|
||||
# in_sampling_rate=64, out_sampling_rate=128,
|
||||
# in_cutoff=15.102, out_cutoff=22.6274,
|
||||
# in_half_width=16.898, out_half_width=41.3726,
|
||||
# in_size=[84, 84], out_size=[148, 148],
|
||||
# in_channels=1024, out_channels=1024
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
|
||||
# )
|
||||
# (L7_148_967): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=False, use_fp16=True,
|
||||
# in_sampling_rate=128, out_sampling_rate=128,
|
||||
# in_cutoff=22.6274, out_cutoff=33.9028,
|
||||
# in_half_width=41.3726, out_half_width=30.0972,
|
||||
# in_size=[148, 148], out_size=[148, 148],
|
||||
# in_channels=1024, out_channels=967
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
|
||||
# )
|
||||
# (L8_276_645): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=False, use_fp16=True,
|
||||
# in_sampling_rate=128, out_sampling_rate=256,
|
||||
# in_cutoff=33.9028, out_cutoff=50.7968,
|
||||
# in_half_width=30.0972, out_half_width=77.2032,
|
||||
# in_size=[148, 148], out_size=[276, 276],
|
||||
# in_channels=967, out_channels=645
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=967, activation=linear)
|
||||
# )
|
||||
# (L9_276_431): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=False, use_fp16=True,
|
||||
# in_sampling_rate=256, out_sampling_rate=256,
|
||||
# in_cutoff=50.7968, out_cutoff=76.1093,
|
||||
# in_half_width=77.2032, out_half_width=51.8907,
|
||||
# in_size=[276, 276], out_size=[276, 276],
|
||||
# in_channels=645, out_channels=431
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=645, activation=linear)
|
||||
# )
|
||||
# (L10_532_287): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=False, use_fp16=True,
|
||||
# in_sampling_rate=256, out_sampling_rate=512,
|
||||
# in_cutoff=76.1093, out_cutoff=114.035,
|
||||
# in_half_width=51.8907, out_half_width=141.965,
|
||||
# in_size=[276, 276], out_size=[532, 532],
|
||||
# in_channels=431, out_channels=287
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=431, activation=linear)
|
||||
# )
|
||||
# (L11_532_192): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=False, use_fp16=True,
|
||||
# in_sampling_rate=512, out_sampling_rate=512,
|
||||
# in_cutoff=114.035, out_cutoff=170.86,
|
||||
# in_half_width=141.965, out_half_width=85.1405,
|
||||
# in_size=[532, 532], out_size=[532, 532],
|
||||
# in_channels=287, out_channels=192
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=287, activation=linear)
|
||||
# )
|
||||
# (L12_532_128): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=True, use_fp16=True,
|
||||
# in_sampling_rate=512, out_sampling_rate=512,
|
||||
# in_cutoff=170.86, out_cutoff=256,
|
||||
# in_half_width=85.1405, out_half_width=59.173,
|
||||
# in_size=[532, 532], out_size=[532, 532],
|
||||
# in_channels=192, out_channels=128
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=192, activation=linear)
|
||||
# )
|
||||
# (L13_512_128): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=False,
|
||||
# is_critically_sampled=True, use_fp16=True,
|
||||
# in_sampling_rate=512, out_sampling_rate=512,
|
||||
# in_cutoff=256, out_cutoff=256,
|
||||
# in_half_width=59.173, out_half_width=59.173,
|
||||
# in_size=[532, 532], out_size=[512, 512],
|
||||
# in_channels=128, out_channels=128
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=128, activation=linear)
|
||||
# )
|
||||
# (L14_512_3): SynthesisLayer(
|
||||
# w_dim=512, is_torgb=True,
|
||||
# is_critically_sampled=True, use_fp16=True,
|
||||
# in_sampling_rate=512, out_sampling_rate=512,
|
||||
# in_cutoff=256, out_cutoff=256,
|
||||
# in_half_width=59.173, out_half_width=59.173,
|
||||
# in_size=[512, 512], out_size=[512, 512],
|
||||
# in_channels=128, out_channels=3
|
||||
# (affine): FullyConnectedLayer(in_features=512, out_features=128, activation=linear)
|
||||
# )
|
||||
# )
|
||||
# (mapping): MappingNetwork(
|
||||
# z_dim=512, c_dim=0, w_dim=512, num_ws=16
|
||||
# (fc0): FullyConnectedLayer(in_features=512, out_features=512, activation=lrelu)
|
||||
# (fc1): FullyConnectedLayer(in_features=512, out_features=512, activation=lrelu)
|
||||
# )
|
||||
# )
|
||||
37
models/stylegan3/test001_s3.py
Normal file
37
models/stylegan3/test001_s3.py
Normal file
@ -0,0 +1,37 @@
|
||||
import torchvision
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
from run_optimization3 import main
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# parser.add_argument("--description", type=str, default="a person with purple hair",
|
||||
parser.add_argument("--description", type=str, default="a person with purple hair",
|
||||
help="the text that guides the editing/generation")
|
||||
parser.add_argument("--ckpt", type=str, default="/home/ly/StyleCLIP-main/pretrained_models/stylegan3-r-ffhqu-1024x1024.pkl",
|
||||
help="pretrained StyleGAN3 weights")
|
||||
parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution")
|
||||
parser.add_argument("--lr_rampup", type=float, default=0.05)
|
||||
parser.add_argument("--lr", type=float, default=0.1)
|
||||
parser.add_argument("--step", type=int, default=300, help="number of optimization steps")
|
||||
parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"],
|
||||
help="choose between edit an image an generate a free one")
|
||||
parser.add_argument("--l2_lambda", type=float, default=0.008,
|
||||
help="weight of the latent distance (used for editing only)")
|
||||
parser.add_argument("--latent_path", type=str, default=None, #"/home/ly/StyleCLIP-main/latents_test/example_celebs.pt"
|
||||
help="starts the optimization from the given latent code if provided. Otherwise, starts from"
|
||||
"the mean latent in a free generation, and from a random one in editing. "
|
||||
"Expects a .pt format")
|
||||
parser.add_argument("--truncation", type=float, default=0.5,
|
||||
help="used only for the initial latent vector, and only when a latent code path is"
|
||||
"not provided")
|
||||
parser.add_argument("--save_intermediate_image_every", type=int, default=20,
|
||||
help="if > 0 then saves intermidate results during the optimization")
|
||||
parser.add_argument("--results_dir", type=str, default="/home/ly/StyleCLIP-main/results/stygan3Clip")
|
||||
parser.add_argument('--work_in_stylespace', default=False, action='store_true', help="trains a mapper in S instead of W+")
|
||||
parser.add_argument('--ir_se50_weights', default='/home/ly/StyleCLIP-main/pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss")
|
||||
parser.add_argument('--id_lambda', default=0.10, type=float, help='ID loss multiplier factor')
|
||||
|
||||
args = vars(parser.parse_args())
|
||||
result_image = main(Namespace(**args))
|
||||
torchvision.utils.save_image(result_image.detach().cpu(), f"/home/ly/StyleCLIP-main/results/stygan3Clip/final_result.png", normalize=True, scale_each=True,
|
||||
range=(-1, 1))
|
||||
9
models/stylegan3/torch_utils/__init__.py
Normal file
9
models/stylegan3/torch_utils/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
# empty
|
||||
157
models/stylegan3/torch_utils/custom_ops.py
Normal file
157
models/stylegan3/torch_utils/custom_ops.py
Normal file
@ -0,0 +1,157 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import glob
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
from torch.utils.file_baton import FileBaton
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Global options.
|
||||
|
||||
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Internal helper funcs.
|
||||
|
||||
def _find_compiler_bindir():
|
||||
patterns = [
|
||||
'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files*/Microsoft Visual Studio */vc/bin',
|
||||
]
|
||||
for pattern in patterns:
|
||||
matches = sorted(glob.glob(pattern))
|
||||
if len(matches):
|
||||
return matches[-1]
|
||||
return None
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _get_mangled_gpu_name():
|
||||
name = torch.cuda.get_device_name().lower()
|
||||
out = []
|
||||
for c in name:
|
||||
if re.match('[a-z0-9_-]+', c):
|
||||
out.append(c)
|
||||
else:
|
||||
out.append('-')
|
||||
return ''.join(out)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Main entry point for compiling and loading C++/CUDA plugins.
|
||||
|
||||
_cached_plugins = dict()
|
||||
|
||||
def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
|
||||
assert verbosity in ['none', 'brief', 'full']
|
||||
if headers is None:
|
||||
headers = []
|
||||
if source_dir is not None:
|
||||
sources = [os.path.join(source_dir, fname) for fname in sources]
|
||||
headers = [os.path.join(source_dir, fname) for fname in headers]
|
||||
|
||||
# Already cached?
|
||||
if module_name in _cached_plugins:
|
||||
return _cached_plugins[module_name]
|
||||
|
||||
# Print status.
|
||||
if verbosity == 'full':
|
||||
print(f'Setting up PyTorch plugin "{module_name}"...')
|
||||
elif verbosity == 'brief':
|
||||
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
||||
verbose_build = (verbosity == 'full')
|
||||
|
||||
# Compile and load.
|
||||
try: # pylint: disable=too-many-nested-blocks
|
||||
# Make sure we can find the necessary compiler binaries.
|
||||
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
||||
compiler_bindir = _find_compiler_bindir()
|
||||
if compiler_bindir is None:
|
||||
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
||||
os.environ['PATH'] += ';' + compiler_bindir
|
||||
|
||||
# Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
|
||||
# break the build or unnecessarily restrict what's available to nvcc.
|
||||
# Unset it to let nvcc decide based on what's available on the
|
||||
# machine.
|
||||
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
|
||||
|
||||
# Incremental build md5sum trickery. Copies all the input source files
|
||||
# into a cached build directory under a combined md5 digest of the input
|
||||
# source files. Copying is done only if the combined digest has changed.
|
||||
# This keeps input file timestamps and filenames the same as in previous
|
||||
# extension builds, allowing for fast incremental rebuilds.
|
||||
#
|
||||
# This optimization is done only in case all the source files reside in
|
||||
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
||||
# environment variable is set (we take this as a signal that the user
|
||||
# actually cares about this.)
|
||||
#
|
||||
# EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
|
||||
# around the *.cu dependency bug in ninja config.
|
||||
#
|
||||
all_source_files = sorted(sources + headers)
|
||||
all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
|
||||
if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
||||
|
||||
# Compute combined hash digest for all source files.
|
||||
hash_md5 = hashlib.md5()
|
||||
for src in all_source_files:
|
||||
with open(src, 'rb') as f:
|
||||
hash_md5.update(f.read())
|
||||
|
||||
# Select cached build directory name.
|
||||
source_digest = hash_md5.hexdigest()
|
||||
build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
||||
cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
|
||||
|
||||
if not os.path.isdir(cached_build_dir):
|
||||
tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
|
||||
os.makedirs(tmpdir)
|
||||
for src in all_source_files:
|
||||
shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
|
||||
try:
|
||||
os.replace(tmpdir, cached_build_dir) # atomic
|
||||
except OSError:
|
||||
# source directory already exists, delete tmpdir and its contents.
|
||||
shutil.rmtree(tmpdir)
|
||||
if not os.path.isdir(cached_build_dir): raise
|
||||
|
||||
# Compile.
|
||||
cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
|
||||
torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
|
||||
verbose=verbose_build, sources=cached_sources, **build_kwargs)
|
||||
else:
|
||||
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
||||
|
||||
# Load.
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
except:
|
||||
if verbosity == 'brief':
|
||||
print('Failed!')
|
||||
raise
|
||||
|
||||
# Print status and add to cache dict.
|
||||
if verbosity == 'full':
|
||||
print(f'Done setting up PyTorch plugin "{module_name}".')
|
||||
elif verbosity == 'brief':
|
||||
print('Done.')
|
||||
_cached_plugins[module_name] = module
|
||||
return module
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
267
models/stylegan3/torch_utils/misc.py
Normal file
267
models/stylegan3/torch_utils/misc.py
Normal file
@ -0,0 +1,267 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import re
|
||||
import contextlib
|
||||
import numpy as np
|
||||
import torch
|
||||
import warnings
|
||||
import dnnlib
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
||||
# same constant is used multiple times.
|
||||
|
||||
_constant_cache = dict()
|
||||
|
||||
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
||||
value = np.asarray(value)
|
||||
if shape is not None:
|
||||
shape = tuple(shape)
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if device is None:
|
||||
device = torch.device('cpu')
|
||||
if memory_format is None:
|
||||
memory_format = torch.contiguous_format
|
||||
|
||||
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
||||
tensor = _constant_cache.get(key, None)
|
||||
if tensor is None:
|
||||
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
||||
if shape is not None:
|
||||
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
||||
tensor = tensor.contiguous(memory_format=memory_format)
|
||||
_constant_cache[key] = tensor
|
||||
return tensor
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Replace NaN/Inf with specified numerical values.
|
||||
|
||||
try:
|
||||
nan_to_num = torch.nan_to_num # 1.8.0a0
|
||||
except AttributeError:
|
||||
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
||||
assert isinstance(input, torch.Tensor)
|
||||
if posinf is None:
|
||||
posinf = torch.finfo(input.dtype).max
|
||||
if neginf is None:
|
||||
neginf = torch.finfo(input.dtype).min
|
||||
assert nan == 0
|
||||
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Symbolic assert.
|
||||
|
||||
try:
|
||||
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
||||
except AttributeError:
|
||||
symbolic_assert = torch.Assert # 1.7.0
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Context manager to temporarily suppress known warnings in torch.jit.trace().
|
||||
# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_tracer_warnings():
|
||||
flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
|
||||
warnings.filters.insert(0, flt)
|
||||
yield
|
||||
warnings.filters.remove(flt)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Assert that the shape of a tensor matches the given list of integers.
|
||||
# None indicates that the size of a dimension is allowed to vary.
|
||||
# Performs symbolic assertion when used in torch.jit.trace().
|
||||
|
||||
def assert_shape(tensor, ref_shape):
|
||||
#使用ndim报错:AttributeError: 'list' object has no attribute 'ndim'
|
||||
if tensor.ndim != len(ref_shape):
|
||||
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
||||
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
||||
if ref_size is None:
|
||||
pass
|
||||
elif isinstance(ref_size, torch.Tensor):
|
||||
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
||||
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
||||
elif isinstance(size, torch.Tensor):
|
||||
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
||||
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
||||
elif size != ref_size:
|
||||
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Function decorator that calls torch.autograd.profiler.record_function().
|
||||
|
||||
def profiled_function(fn):
|
||||
def decorator(*args, **kwargs):
|
||||
with torch.autograd.profiler.record_function(fn.__name__):
|
||||
return fn(*args, **kwargs)
|
||||
decorator.__name__ = fn.__name__
|
||||
return decorator
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
||||
# indefinitely, shuffling items as it goes.
|
||||
|
||||
class InfiniteSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
||||
assert len(dataset) > 0
|
||||
assert num_replicas > 0
|
||||
assert 0 <= rank < num_replicas
|
||||
assert 0 <= window_size <= 1
|
||||
super().__init__(dataset)
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.num_replicas = num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.window_size = window_size
|
||||
|
||||
def __iter__(self):
|
||||
order = np.arange(len(self.dataset))
|
||||
rnd = None
|
||||
window = 0
|
||||
if self.shuffle:
|
||||
rnd = np.random.RandomState(self.seed)
|
||||
rnd.shuffle(order)
|
||||
window = int(np.rint(order.size * self.window_size))
|
||||
|
||||
idx = 0
|
||||
while True:
|
||||
i = idx % order.size
|
||||
if idx % self.num_replicas == self.rank:
|
||||
yield order[i]
|
||||
if window >= 2:
|
||||
j = (i - rnd.randint(window)) % order.size
|
||||
order[i], order[j] = order[j], order[i]
|
||||
idx += 1
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Utilities for operating with torch.nn.Module parameters and buffers.
|
||||
|
||||
def params_and_buffers(module):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
return list(module.parameters()) + list(module.buffers())
|
||||
|
||||
def named_params_and_buffers(module):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
return list(module.named_parameters()) + list(module.named_buffers())
|
||||
|
||||
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
||||
assert isinstance(src_module, torch.nn.Module)
|
||||
assert isinstance(dst_module, torch.nn.Module)
|
||||
src_tensors = dict(named_params_and_buffers(src_module))
|
||||
for name, tensor in named_params_and_buffers(dst_module):
|
||||
assert (name in src_tensors) or (not require_all)
|
||||
if name in src_tensors:
|
||||
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Context manager for easily enabling/disabling DistributedDataParallel
|
||||
# synchronization.
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ddp_sync(module, sync):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
||||
yield
|
||||
else:
|
||||
with module.no_sync():
|
||||
yield
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Check DistributedDataParallel consistency across processes.
|
||||
|
||||
def check_ddp_consistency(module, ignore_regex=None):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
for name, tensor in named_params_and_buffers(module):
|
||||
fullname = type(module).__name__ + '.' + name
|
||||
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
||||
continue
|
||||
tensor = tensor.detach()
|
||||
if tensor.is_floating_point():
|
||||
tensor = nan_to_num(tensor)
|
||||
other = tensor.clone()
|
||||
torch.distributed.broadcast(tensor=other, src=0)
|
||||
assert (tensor == other).all(), fullname
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Print summary table of module hierarchy.
|
||||
|
||||
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
assert not isinstance(module, torch.jit.ScriptModule)
|
||||
assert isinstance(inputs, (tuple, list))
|
||||
|
||||
# Register hooks.
|
||||
entries = []
|
||||
nesting = [0]
|
||||
def pre_hook(_mod, _inputs):
|
||||
nesting[0] += 1
|
||||
def post_hook(mod, _inputs, outputs):
|
||||
nesting[0] -= 1
|
||||
if nesting[0] <= max_nesting:
|
||||
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
||||
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
||||
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
||||
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
||||
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
||||
|
||||
# Run module.
|
||||
outputs = module(*inputs)
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# Identify unique outputs, parameters, and buffers.
|
||||
tensors_seen = set()
|
||||
for e in entries:
|
||||
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
||||
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
||||
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
||||
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
||||
|
||||
# Filter out redundant entries.
|
||||
if skip_redundant:
|
||||
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
||||
|
||||
# Construct table.
|
||||
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
||||
rows += [['---'] * len(rows[0])]
|
||||
param_total = 0
|
||||
buffer_total = 0
|
||||
submodule_names = {mod: name for name, mod in module.named_modules()}
|
||||
for e in entries:
|
||||
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
||||
param_size = sum(t.numel() for t in e.unique_params)
|
||||
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
||||
output_shapes = [str(list(t.shape)) for t in e.outputs]
|
||||
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
||||
rows += [[
|
||||
name + (':0' if len(e.outputs) >= 2 else ''),
|
||||
str(param_size) if param_size else '-',
|
||||
str(buffer_size) if buffer_size else '-',
|
||||
(output_shapes + ['-'])[0],
|
||||
(output_dtypes + ['-'])[0],
|
||||
]]
|
||||
for idx in range(1, len(e.outputs)):
|
||||
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
||||
param_total += param_size
|
||||
buffer_total += buffer_size
|
||||
rows += [['---'] * len(rows[0])]
|
||||
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
||||
|
||||
# Print table.
|
||||
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
||||
print()
|
||||
for row in rows:
|
||||
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
||||
print()
|
||||
return outputs
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
9
models/stylegan3/torch_utils/ops/__init__.py
Normal file
9
models/stylegan3/torch_utils/ops/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
# empty
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
models/stylegan3/torch_utils/ops/__pycache__/fma.cpython-310.pyc
Normal file
BIN
models/stylegan3/torch_utils/ops/__pycache__/fma.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
99
models/stylegan3/torch_utils/ops/bias_act.cpp
Normal file
99
models/stylegan3/torch_utils/ops/bias_act.cpp
Normal file
@ -0,0 +1,99 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "bias_act.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
||||
{
|
||||
if (x.dim() != y.dim())
|
||||
return false;
|
||||
for (int64_t i = 0; i < x.dim(); i++)
|
||||
{
|
||||
if (x.size(i) != y.size(i))
|
||||
return false;
|
||||
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
||||
{
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
||||
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
||||
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
||||
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
||||
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
||||
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
||||
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
||||
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
||||
|
||||
// Validate layout.
|
||||
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
||||
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
||||
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
||||
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
||||
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
||||
|
||||
// Create output tensor.
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
torch::Tensor y = torch::empty_like(x);
|
||||
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
||||
|
||||
// Initialize CUDA kernel parameters.
|
||||
bias_act_kernel_params p;
|
||||
p.x = x.data_ptr();
|
||||
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
||||
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
||||
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
||||
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
||||
p.y = y.data_ptr();
|
||||
p.grad = grad;
|
||||
p.act = act;
|
||||
p.alpha = alpha;
|
||||
p.gain = gain;
|
||||
p.clamp = clamp;
|
||||
p.sizeX = (int)x.numel();
|
||||
p.sizeB = (int)b.numel();
|
||||
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
||||
|
||||
// Choose CUDA kernel.
|
||||
void* kernel;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
||||
{
|
||||
kernel = choose_bias_act_kernel<scalar_t>(p);
|
||||
});
|
||||
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
||||
|
||||
// Launch CUDA kernel.
|
||||
p.loopX = 4;
|
||||
int blockSize = 4 * 32;
|
||||
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
||||
void* args[] = {&p};
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
return y;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("bias_act", &bias_act);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
173
models/stylegan3/torch_utils/ops/bias_act.cu
Normal file
173
models/stylegan3/torch_utils/ops/bias_act.cu
Normal file
@ -0,0 +1,173 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <c10/util/Half.h>
|
||||
#include "bias_act.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Helpers.
|
||||
|
||||
template <class T> struct InternalType;
|
||||
template <> struct InternalType<double> { typedef double scalar_t; };
|
||||
template <> struct InternalType<float> { typedef float scalar_t; };
|
||||
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel.
|
||||
|
||||
template <class T, int A>
|
||||
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
||||
{
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
int G = p.grad;
|
||||
scalar_t alpha = (scalar_t)p.alpha;
|
||||
scalar_t gain = (scalar_t)p.gain;
|
||||
scalar_t clamp = (scalar_t)p.clamp;
|
||||
scalar_t one = (scalar_t)1;
|
||||
scalar_t two = (scalar_t)2;
|
||||
scalar_t expRange = (scalar_t)80;
|
||||
scalar_t halfExpRange = (scalar_t)40;
|
||||
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
||||
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
||||
|
||||
// Loop over elements.
|
||||
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
||||
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
||||
{
|
||||
// Load.
|
||||
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
||||
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
||||
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
||||
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
||||
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
||||
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
||||
scalar_t y = 0;
|
||||
|
||||
// Apply bias.
|
||||
((G == 0) ? x : xref) += b;
|
||||
|
||||
// linear
|
||||
if (A == 1)
|
||||
{
|
||||
if (G == 0) y = x;
|
||||
if (G == 1) y = x;
|
||||
}
|
||||
|
||||
// relu
|
||||
if (A == 2)
|
||||
{
|
||||
if (G == 0) y = (x > 0) ? x : 0;
|
||||
if (G == 1) y = (yy > 0) ? x : 0;
|
||||
}
|
||||
|
||||
// lrelu
|
||||
if (A == 3)
|
||||
{
|
||||
if (G == 0) y = (x > 0) ? x : x * alpha;
|
||||
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
||||
}
|
||||
|
||||
// tanh
|
||||
if (A == 4)
|
||||
{
|
||||
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
||||
if (G == 1) y = x * (one - yy * yy);
|
||||
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
||||
}
|
||||
|
||||
// sigmoid
|
||||
if (A == 5)
|
||||
{
|
||||
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
||||
if (G == 1) y = x * yy * (one - yy);
|
||||
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
||||
}
|
||||
|
||||
// elu
|
||||
if (A == 6)
|
||||
{
|
||||
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
||||
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
||||
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
||||
}
|
||||
|
||||
// selu
|
||||
if (A == 7)
|
||||
{
|
||||
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
||||
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
||||
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
||||
}
|
||||
|
||||
// softplus
|
||||
if (A == 8)
|
||||
{
|
||||
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
||||
if (G == 1) y = x * (one - exp(-yy));
|
||||
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
||||
}
|
||||
|
||||
// swish
|
||||
if (A == 9)
|
||||
{
|
||||
if (G == 0)
|
||||
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
||||
else
|
||||
{
|
||||
scalar_t c = exp(xref);
|
||||
scalar_t d = c + one;
|
||||
if (G == 1)
|
||||
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
||||
else
|
||||
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
||||
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply gain.
|
||||
y *= gain * dy;
|
||||
|
||||
// Clamp.
|
||||
if (clamp >= 0)
|
||||
{
|
||||
if (G == 0)
|
||||
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
||||
else
|
||||
y = (yref > -clamp & yref < clamp) ? y : 0;
|
||||
}
|
||||
|
||||
// Store.
|
||||
((T*)p.y)[xi] = (T)y;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
||||
{
|
||||
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
||||
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
||||
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
||||
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
||||
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
||||
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
||||
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
||||
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
||||
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Template specializations.
|
||||
|
||||
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
||||
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
||||
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
38
models/stylegan3/torch_utils/ops/bias_act.h
Normal file
38
models/stylegan3/torch_utils/ops/bias_act.h
Normal file
@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel parameters.
|
||||
|
||||
struct bias_act_kernel_params
|
||||
{
|
||||
const void* x; // [sizeX]
|
||||
const void* b; // [sizeB] or NULL
|
||||
const void* xref; // [sizeX] or NULL
|
||||
const void* yref; // [sizeX] or NULL
|
||||
const void* dy; // [sizeX] or NULL
|
||||
void* y; // [sizeX]
|
||||
|
||||
int grad;
|
||||
int act;
|
||||
float alpha;
|
||||
float gain;
|
||||
float clamp;
|
||||
|
||||
int sizeX;
|
||||
int sizeB;
|
||||
int stepB;
|
||||
int loopX;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
209
models/stylegan3/torch_utils/ops/bias_act.py
Normal file
209
models/stylegan3/torch_utils/ops/bias_act.py
Normal file
@ -0,0 +1,209 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom PyTorch ops for efficient bias and activation."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import dnnlib
|
||||
|
||||
from .. import custom_ops
|
||||
from .. import misc
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
activation_funcs = {
|
||||
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
||||
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
||||
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
||||
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
||||
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
||||
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
||||
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
||||
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
||||
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
||||
}
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_plugin = None
|
||||
_null_tensor = torch.empty([0])
|
||||
|
||||
def _init():
|
||||
global _plugin
|
||||
if _plugin is None:
|
||||
_plugin = custom_ops.get_plugin(
|
||||
module_name='bias_act_plugin',
|
||||
sources=['bias_act.cpp', 'bias_act.cu'],
|
||||
headers=['bias_act.h'],
|
||||
source_dir=os.path.dirname(__file__),
|
||||
extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
|
||||
)
|
||||
return True
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
||||
r"""Fused bias and activation function.
|
||||
|
||||
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
||||
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
||||
the fused op is considerably more efficient than performing the same calculation
|
||||
using standard PyTorch ops. It supports first and second order gradients,
|
||||
but not third order gradients.
|
||||
|
||||
Args:
|
||||
x: Input activation tensor. Can be of any shape.
|
||||
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
||||
as `x`. The shape must be known, and it must match the dimension of `x`
|
||||
corresponding to `dim`.
|
||||
dim: The dimension in `x` corresponding to the elements of `b`.
|
||||
The value of `dim` is ignored if `b` is not specified.
|
||||
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
||||
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
||||
See `activation_funcs` for a full list. `None` is not allowed.
|
||||
alpha: Shape parameter for the activation function, or `None` to use the default.
|
||||
gain: Scaling factor for the output tensor, or `None` to use default.
|
||||
See `activation_funcs` for the default scaling of each activation function.
|
||||
If unsure, consider specifying 1.
|
||||
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
||||
the clamping (default).
|
||||
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
||||
|
||||
Returns:
|
||||
Tensor of the same shape and datatype as `x`.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert impl in ['ref', 'cuda']
|
||||
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
||||
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
||||
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
||||
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert clamp is None or clamp >= 0
|
||||
spec = activation_funcs[act]
|
||||
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
||||
gain = float(gain if gain is not None else spec.def_gain)
|
||||
clamp = float(clamp if clamp is not None else -1)
|
||||
|
||||
# Add bias.
|
||||
if b is not None:
|
||||
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
||||
assert 0 <= dim < x.ndim
|
||||
assert b.shape[0] == x.shape[dim]
|
||||
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
||||
|
||||
# Evaluate activation function.
|
||||
alpha = float(alpha)
|
||||
x = spec.func(x, alpha=alpha)
|
||||
|
||||
# Scale by gain.
|
||||
gain = float(gain)
|
||||
if gain != 1:
|
||||
x = x * gain
|
||||
|
||||
# Clamp.
|
||||
if clamp >= 0:
|
||||
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_bias_act_cuda_cache = dict()
|
||||
|
||||
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
||||
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
||||
"""
|
||||
# Parse arguments.
|
||||
assert clamp is None or clamp >= 0
|
||||
spec = activation_funcs[act]
|
||||
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
||||
gain = float(gain if gain is not None else spec.def_gain)
|
||||
clamp = float(clamp if clamp is not None else -1)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (dim, act, alpha, gain, clamp)
|
||||
if key in _bias_act_cuda_cache:
|
||||
return _bias_act_cuda_cache[key]
|
||||
|
||||
# Forward op.
|
||||
class BiasActCuda(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
||||
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
|
||||
x = x.contiguous(memory_format=ctx.memory_format)
|
||||
b = b.contiguous() if b is not None else _null_tensor
|
||||
y = x
|
||||
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
||||
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
ctx.save_for_backward(
|
||||
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
||||
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
||||
y if 'y' in spec.ref else _null_tensor)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy): # pylint: disable=arguments-differ
|
||||
dy = dy.contiguous(memory_format=ctx.memory_format)
|
||||
x, b, y = ctx.saved_tensors
|
||||
dx = None
|
||||
db = None
|
||||
|
||||
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
||||
dx = dy
|
||||
if act != 'linear' or gain != 1 or clamp >= 0:
|
||||
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
||||
|
||||
return dx, db
|
||||
|
||||
# Backward op.
|
||||
class BiasActCudaGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
||||
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
|
||||
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
ctx.save_for_backward(
|
||||
dy if spec.has_2nd_grad else _null_tensor,
|
||||
x, b, y)
|
||||
return dx
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
||||
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
||||
dy, x, b, y = ctx.saved_tensors
|
||||
d_dy = None
|
||||
d_x = None
|
||||
d_b = None
|
||||
d_y = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
||||
|
||||
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
||||
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
|
||||
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
||||
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
||||
|
||||
return d_dy, d_x, d_b, d_y
|
||||
|
||||
# Add to cache.
|
||||
_bias_act_cuda_cache[key] = BiasActCuda
|
||||
return BiasActCuda
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
203
models/stylegan3/torch_utils/ops/conv2d_gradfix.py
Normal file
203
models/stylegan3/torch_utils/ops/conv2d_gradfix.py
Normal file
@ -0,0 +1,203 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom replacement for `torch.nn.functional.conv2d` that supports
|
||||
arbitrarily high order gradients with zero performance penalty."""
|
||||
|
||||
import contextlib
|
||||
import torch
|
||||
from pkg_resources import parse_version
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
# pylint: disable=arguments-differ
|
||||
# pylint: disable=protected-access
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
enabled = False # Enable the custom op by setting this to true.
|
||||
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
|
||||
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_weight_gradients(disable=True):
|
||||
global weight_gradients_disabled
|
||||
old = weight_gradients_disabled
|
||||
if disable:
|
||||
weight_gradients_disabled = True
|
||||
yield
|
||||
weight_gradients_disabled = old
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if _should_use_custom_op(input):
|
||||
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
|
||||
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||
if _should_use_custom_op(input):
|
||||
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
|
||||
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _should_use_custom_op(input):
|
||||
assert isinstance(input, torch.Tensor)
|
||||
if (not enabled) or (not torch.backends.cudnn.enabled):
|
||||
return False
|
||||
if _use_pytorch_1_11_api:
|
||||
# The work-around code doesn't work on PyTorch 1.11.0 onwards
|
||||
return False
|
||||
if input.device.type != 'cuda':
|
||||
return False
|
||||
return True
|
||||
|
||||
def _tuple_of_ints(xs, ndim):
|
||||
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
||||
assert len(xs) == ndim
|
||||
assert all(isinstance(x, int) for x in xs)
|
||||
return xs
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_conv2d_gradfix_cache = dict()
|
||||
_null_tensor = torch.empty([0])
|
||||
|
||||
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
||||
# Parse arguments.
|
||||
ndim = 2
|
||||
weight_shape = tuple(weight_shape)
|
||||
stride = _tuple_of_ints(stride, ndim)
|
||||
padding = _tuple_of_ints(padding, ndim)
|
||||
output_padding = _tuple_of_ints(output_padding, ndim)
|
||||
dilation = _tuple_of_ints(dilation, ndim)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
||||
if key in _conv2d_gradfix_cache:
|
||||
return _conv2d_gradfix_cache[key]
|
||||
|
||||
# Validate arguments.
|
||||
assert groups >= 1
|
||||
assert len(weight_shape) == ndim + 2
|
||||
assert all(stride[i] >= 1 for i in range(ndim))
|
||||
assert all(padding[i] >= 0 for i in range(ndim))
|
||||
assert all(dilation[i] >= 0 for i in range(ndim))
|
||||
if not transpose:
|
||||
assert all(output_padding[i] == 0 for i in range(ndim))
|
||||
else: # transpose
|
||||
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
||||
|
||||
# Helpers.
|
||||
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
def calc_output_padding(input_shape, output_shape):
|
||||
if transpose:
|
||||
return [0, 0]
|
||||
return [
|
||||
input_shape[i + 2]
|
||||
- (output_shape[i + 2] - 1) * stride[i]
|
||||
- (1 - 2 * padding[i])
|
||||
- dilation[i] * (weight_shape[i + 2] - 1)
|
||||
for i in range(ndim)
|
||||
]
|
||||
|
||||
# Forward & backward.
|
||||
class Conv2d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias):
|
||||
assert weight.shape == weight_shape
|
||||
ctx.save_for_backward(
|
||||
input if weight.requires_grad else _null_tensor,
|
||||
weight if input.requires_grad else _null_tensor,
|
||||
)
|
||||
ctx.input_shape = input.shape
|
||||
|
||||
# Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
|
||||
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
|
||||
a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
|
||||
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
|
||||
c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
|
||||
c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
|
||||
c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
|
||||
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
|
||||
|
||||
# General case => cuDNN.
|
||||
if transpose:
|
||||
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
|
||||
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
input_shape = ctx.input_shape
|
||||
grad_input = None
|
||||
grad_weight = None
|
||||
grad_bias = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
|
||||
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
|
||||
grad_input = op.apply(grad_output, weight, None)
|
||||
assert grad_input.shape == input_shape
|
||||
|
||||
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
||||
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
||||
assert grad_weight.shape == weight_shape
|
||||
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = grad_output.sum([0, 2, 3])
|
||||
|
||||
return grad_input, grad_weight, grad_bias
|
||||
|
||||
# Gradient with respect to the weights.
|
||||
class Conv2dGradWeight(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input):
|
||||
ctx.save_for_backward(
|
||||
grad_output if input.requires_grad else _null_tensor,
|
||||
input if grad_output.requires_grad else _null_tensor,
|
||||
)
|
||||
ctx.grad_output_shape = grad_output.shape
|
||||
ctx.input_shape = input.shape
|
||||
|
||||
# Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
|
||||
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
|
||||
a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
|
||||
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
|
||||
c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
|
||||
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
|
||||
|
||||
# General case => cuDNN.
|
||||
name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
|
||||
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
|
||||
return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad2_grad_weight):
|
||||
grad_output, input = ctx.saved_tensors
|
||||
grad_output_shape = ctx.grad_output_shape
|
||||
input_shape = ctx.input_shape
|
||||
grad2_grad_output = None
|
||||
grad2_input = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
|
||||
assert grad2_grad_output.shape == grad_output_shape
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
|
||||
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
|
||||
grad2_input = op.apply(grad_output, grad2_grad_weight, None)
|
||||
assert grad2_input.shape == input_shape
|
||||
|
||||
return grad2_grad_output, grad2_input
|
||||
|
||||
_conv2d_gradfix_cache[key] = Conv2d
|
||||
return Conv2d
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
143
models/stylegan3/torch_utils/ops/conv2d_resample.py
Normal file
143
models/stylegan3/torch_utils/ops/conv2d_resample.py
Normal file
@ -0,0 +1,143 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""2D convolution with optional up/downsampling."""
|
||||
|
||||
import torch
|
||||
|
||||
from .. import misc
|
||||
from . import conv2d_gradfix
|
||||
from . import upfirdn2d
|
||||
from .upfirdn2d import _parse_padding
|
||||
from .upfirdn2d import _get_filter_size
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _get_weight_shape(w):
|
||||
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
||||
shape = [int(sz) for sz in w.shape]
|
||||
misc.assert_shape(w, shape)
|
||||
return shape
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
|
||||
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
||||
"""
|
||||
_out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
|
||||
|
||||
# Flip weight if requested.
|
||||
# Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
||||
if not flip_weight and (kw > 1 or kh > 1):
|
||||
w = w.flip([2, 3])
|
||||
|
||||
# Execute using conv2d_gradfix.
|
||||
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
|
||||
return op(x, w, stride=stride, padding=padding, groups=groups)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
|
||||
r"""2D convolution with optional up/downsampling.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape
|
||||
`[batch_size, in_channels, in_height, in_width]`.
|
||||
w: Weight tensor of shape
|
||||
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
||||
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
||||
calling upfirdn2d.setup_filter(). None = identity (default).
|
||||
up: Integer upsampling factor (default: 1).
|
||||
down: Integer downsampling factor (default: 1).
|
||||
padding: Padding with respect to the upsampled image. Can be a single number
|
||||
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
groups: Split input channels into N groups (default: 1).
|
||||
flip_weight: False = convolution, True = correlation (default: True).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
# Validate arguments.
|
||||
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
||||
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
||||
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
||||
assert isinstance(up, int) and (up >= 1)
|
||||
assert isinstance(down, int) and (down >= 1)
|
||||
assert isinstance(groups, int) and (groups >= 1)
|
||||
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
||||
fw, fh = _get_filter_size(f)
|
||||
px0, px1, py0, py1 = _parse_padding(padding)
|
||||
|
||||
# Adjust padding to account for up/downsampling.
|
||||
if up > 1:
|
||||
px0 += (fw + up - 1) // 2
|
||||
px1 += (fw - up) // 2
|
||||
py0 += (fh + up - 1) // 2
|
||||
py1 += (fh - up) // 2
|
||||
if down > 1:
|
||||
px0 += (fw - down + 1) // 2
|
||||
px1 += (fw - down) // 2
|
||||
py0 += (fh - down + 1) // 2
|
||||
py1 += (fh - down) // 2
|
||||
|
||||
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
||||
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||
return x
|
||||
|
||||
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
||||
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
||||
return x
|
||||
|
||||
# Fast path: downsampling only => use strided convolution.
|
||||
if down > 1 and up == 1:
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
||||
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
|
||||
return x
|
||||
|
||||
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
||||
if up > 1:
|
||||
if groups == 1:
|
||||
w = w.transpose(0, 1)
|
||||
else:
|
||||
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
||||
w = w.transpose(1, 2)
|
||||
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
||||
px0 -= kw - 1
|
||||
px1 -= kw - up
|
||||
py0 -= kh - 1
|
||||
py1 -= kh - up
|
||||
pxt = max(min(-px0, -px1), 0)
|
||||
pyt = max(min(-py0, -py1), 0)
|
||||
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
|
||||
if down > 1:
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
||||
return x
|
||||
|
||||
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
||||
if up == 1 and down == 1:
|
||||
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
||||
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
|
||||
|
||||
# Fallback: Generic reference implementation.
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||
if down > 1:
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
300
models/stylegan3/torch_utils/ops/filtered_lrelu.cpp
Normal file
300
models/stylegan3/torch_utils/ops/filtered_lrelu.cpp
Normal file
@ -0,0 +1,300 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "filtered_lrelu.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
|
||||
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
|
||||
int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
|
||||
{
|
||||
// Set CUDA device.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
|
||||
TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
|
||||
TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
|
||||
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
|
||||
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
||||
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(x.numel() > 0, "x is empty");
|
||||
TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
|
||||
TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
|
||||
TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
|
||||
TORCH_CHECK(fu.numel() > 0, "fu is empty");
|
||||
TORCH_CHECK(fd.numel() > 0, "fd is empty");
|
||||
TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
|
||||
TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
|
||||
|
||||
// Figure out how much shared memory is available on the device.
|
||||
int maxSharedBytes = 0;
|
||||
AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
|
||||
int sharedKB = maxSharedBytes >> 10;
|
||||
|
||||
// Populate enough launch parameters to check if a CUDA kernel exists.
|
||||
filtered_lrelu_kernel_params p;
|
||||
p.up = up;
|
||||
p.down = down;
|
||||
p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
|
||||
p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
|
||||
filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
|
||||
if (!test_spec.exec)
|
||||
{
|
||||
// No kernel found - return empty tensors and indicate missing kernel with return code of -1.
|
||||
return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
|
||||
}
|
||||
|
||||
// Input/output element size.
|
||||
int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
|
||||
|
||||
// Input sizes.
|
||||
int64_t xw = (int)x.size(3);
|
||||
int64_t xh = (int)x.size(2);
|
||||
int64_t fut_w = (int)fu.size(-1) - 1;
|
||||
int64_t fut_h = (int)fu.size(0) - 1;
|
||||
int64_t fdt_w = (int)fd.size(-1) - 1;
|
||||
int64_t fdt_h = (int)fd.size(0) - 1;
|
||||
|
||||
// Logical size of upsampled buffer.
|
||||
int64_t cw = xw * up + (px0 + px1) - fut_w;
|
||||
int64_t ch = xh * up + (py0 + py1) - fut_h;
|
||||
TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
|
||||
TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
|
||||
|
||||
// Compute output size and allocate.
|
||||
int64_t yw = (cw - fdt_w + (down - 1)) / down;
|
||||
int64_t yh = (ch - fdt_h + (down - 1)) / down;
|
||||
TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
|
||||
TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
|
||||
torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
|
||||
|
||||
// Allocate sign tensor.
|
||||
torch::Tensor so;
|
||||
torch::Tensor s = si;
|
||||
bool readSigns = !!s.numel();
|
||||
int64_t sw_active = 0; // Active width of sign tensor.
|
||||
if (writeSigns)
|
||||
{
|
||||
sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
|
||||
int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
|
||||
int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
|
||||
TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
|
||||
s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
|
||||
}
|
||||
else if (readSigns)
|
||||
sw_active = s.size(3) << 2;
|
||||
|
||||
// Validate sign tensor if in use.
|
||||
if (readSigns || writeSigns)
|
||||
{
|
||||
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
|
||||
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
|
||||
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
|
||||
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
|
||||
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
|
||||
TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
|
||||
}
|
||||
|
||||
// Populate rest of CUDA kernel parameters.
|
||||
p.x = x.data_ptr();
|
||||
p.y = y.data_ptr();
|
||||
p.b = b.data_ptr();
|
||||
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
|
||||
p.fu = fu.data_ptr<float>();
|
||||
p.fd = fd.data_ptr<float>();
|
||||
p.pad0 = make_int2(px0, py0);
|
||||
p.gain = gain;
|
||||
p.slope = slope;
|
||||
p.clamp = clamp;
|
||||
p.flip = (flip_filters) ? 1 : 0;
|
||||
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
||||
p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
||||
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
|
||||
p.sOfs = make_int2(sx, sy);
|
||||
p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
|
||||
|
||||
// x, y, b strides are in bytes.
|
||||
p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
|
||||
p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
|
||||
p.bStride = sz * b.stride(0);
|
||||
|
||||
// fu, fd strides are in elements.
|
||||
p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
|
||||
p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
|
||||
|
||||
// Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
|
||||
bool index64b = false;
|
||||
if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
|
||||
if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
|
||||
if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
|
||||
if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
|
||||
if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
|
||||
if (s.numel() > INT_MAX) index64b = true;
|
||||
|
||||
// Choose CUDA kernel.
|
||||
filtered_lrelu_kernel_spec spec = { 0 };
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
|
||||
{
|
||||
if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
|
||||
{
|
||||
// Choose kernel based on index type, datatype and sign read/write modes.
|
||||
if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
|
||||
else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
|
||||
else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
|
||||
else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
|
||||
else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
|
||||
else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
|
||||
}
|
||||
});
|
||||
TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
|
||||
|
||||
// Launch CUDA kernel.
|
||||
void* args[] = {&p};
|
||||
int bx = spec.numWarps * 32;
|
||||
int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
|
||||
int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
|
||||
int gz = p.yShape.z * p.yShape.w;
|
||||
|
||||
// Repeat multiple horizontal tiles in a CTA?
|
||||
if (spec.xrep)
|
||||
{
|
||||
p.tilesXrep = spec.xrep;
|
||||
p.tilesXdim = gx;
|
||||
|
||||
gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
|
||||
std::swap(gx, gy);
|
||||
}
|
||||
else
|
||||
{
|
||||
p.tilesXrep = 0;
|
||||
p.tilesXdim = 0;
|
||||
}
|
||||
|
||||
// Launch filter setup kernel.
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
|
||||
// Copy kernels to constant memory.
|
||||
if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
|
||||
else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
|
||||
else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
|
||||
|
||||
// Set cache and shared memory configurations for main kernel.
|
||||
AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
|
||||
if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
|
||||
AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
|
||||
AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
|
||||
|
||||
// Launch main kernel.
|
||||
const int maxSubGz = 65535; // CUDA maximum for block z dimension.
|
||||
for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
|
||||
{
|
||||
p.blockZofs = zofs;
|
||||
int subGz = std::min(maxSubGz, gz - zofs);
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
|
||||
}
|
||||
|
||||
// Done.
|
||||
return std::make_tuple(y, so, 0);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
|
||||
{
|
||||
// Set CUDA device.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
||||
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(x.numel() > 0, "x is empty");
|
||||
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
|
||||
|
||||
// Output signs if we don't have sign input.
|
||||
torch::Tensor so;
|
||||
torch::Tensor s = si;
|
||||
bool readSigns = !!s.numel();
|
||||
if (writeSigns)
|
||||
{
|
||||
int64_t sw = x.size(3);
|
||||
sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
|
||||
s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
// Validate sign tensor if in use.
|
||||
if (readSigns || writeSigns)
|
||||
{
|
||||
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
|
||||
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
|
||||
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
|
||||
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
|
||||
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
|
||||
TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
|
||||
}
|
||||
|
||||
// Initialize CUDA kernel parameters.
|
||||
filtered_lrelu_act_kernel_params p;
|
||||
p.x = x.data_ptr();
|
||||
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
|
||||
p.gain = gain;
|
||||
p.slope = slope;
|
||||
p.clamp = clamp;
|
||||
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
||||
p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
|
||||
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
|
||||
p.sOfs = make_int2(sx, sy);
|
||||
|
||||
// Choose CUDA kernel.
|
||||
void* func = 0;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
|
||||
{
|
||||
if (writeSigns)
|
||||
func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
|
||||
else if (readSigns)
|
||||
func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
|
||||
else
|
||||
func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
|
||||
});
|
||||
TORCH_CHECK(func, "internal error - CUDA kernel not found");
|
||||
|
||||
// Launch CUDA kernel.
|
||||
void* args[] = {&p};
|
||||
int bx = 128; // 4 warps per block.
|
||||
|
||||
// Logical size of launch = writeSigns ? p.s : p.x
|
||||
uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
|
||||
uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
|
||||
uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
|
||||
gx = (gx - 1) / bx + 1;
|
||||
|
||||
// Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
|
||||
const uint32_t gmax = 65535;
|
||||
gy = std::min(gy, gmax);
|
||||
gz = std::min(gz, gmax);
|
||||
|
||||
// Launch.
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
return so;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
|
||||
m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
1284
models/stylegan3/torch_utils/ops/filtered_lrelu.cu
Normal file
1284
models/stylegan3/torch_utils/ops/filtered_lrelu.cu
Normal file
File diff suppressed because it is too large
Load Diff
90
models/stylegan3/torch_utils/ops/filtered_lrelu.h
Normal file
90
models/stylegan3/torch_utils/ops/filtered_lrelu.h
Normal file
@ -0,0 +1,90 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel parameters.
|
||||
|
||||
struct filtered_lrelu_kernel_params
|
||||
{
|
||||
// These parameters decide which kernel to use.
|
||||
int up; // upsampling ratio (1, 2, 4)
|
||||
int down; // downsampling ratio (1, 2, 4)
|
||||
int2 fuShape; // [size, 1] | [size, size]
|
||||
int2 fdShape; // [size, 1] | [size, size]
|
||||
|
||||
int _dummy; // Alignment.
|
||||
|
||||
// Rest of the parameters.
|
||||
const void* x; // Input tensor.
|
||||
void* y; // Output tensor.
|
||||
const void* b; // Bias tensor.
|
||||
unsigned char* s; // Sign tensor in/out. NULL if unused.
|
||||
const float* fu; // Upsampling filter.
|
||||
const float* fd; // Downsampling filter.
|
||||
|
||||
int2 pad0; // Left/top padding.
|
||||
float gain; // Additional gain factor.
|
||||
float slope; // Leaky ReLU slope on negative side.
|
||||
float clamp; // Clamp after nonlinearity.
|
||||
int flip; // Filter kernel flip for gradient computation.
|
||||
|
||||
int tilesXdim; // Original number of horizontal output tiles.
|
||||
int tilesXrep; // Number of horizontal tiles per CTA.
|
||||
int blockZofs; // Block z offset to support large minibatch, channel dimensions.
|
||||
|
||||
int4 xShape; // [width, height, channel, batch]
|
||||
int4 yShape; // [width, height, channel, batch]
|
||||
int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
|
||||
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
|
||||
int swLimit; // Active width of sign tensor in bytes.
|
||||
|
||||
longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
|
||||
longlong4 yStride; //
|
||||
int64_t bStride; //
|
||||
longlong3 fuStride; //
|
||||
longlong3 fdStride; //
|
||||
};
|
||||
|
||||
struct filtered_lrelu_act_kernel_params
|
||||
{
|
||||
void* x; // Input/output, modified in-place.
|
||||
unsigned char* s; // Sign tensor in/out. NULL if unused.
|
||||
|
||||
float gain; // Additional gain factor.
|
||||
float slope; // Leaky ReLU slope on negative side.
|
||||
float clamp; // Clamp after nonlinearity.
|
||||
|
||||
int4 xShape; // [width, height, channel, batch]
|
||||
longlong4 xStride; // Input/output tensor strides, same order as in shape.
|
||||
int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
|
||||
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel specialization.
|
||||
|
||||
struct filtered_lrelu_kernel_spec
|
||||
{
|
||||
void* setup; // Function for filter kernel setup.
|
||||
void* exec; // Function for main operation.
|
||||
int2 tileOut; // Width/height of launch tile.
|
||||
int numWarps; // Number of warps per thread block, determines launch block size.
|
||||
int xrep; // For processing multiple horizontal tiles per thread block.
|
||||
int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
|
||||
template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
274
models/stylegan3/torch_utils/ops/filtered_lrelu.py
Normal file
274
models/stylegan3/torch_utils/ops/filtered_lrelu.py
Normal file
@ -0,0 +1,274 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from .. import custom_ops
|
||||
from .. import misc
|
||||
from . import upfirdn2d
|
||||
from . import bias_act
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_plugin = None
|
||||
|
||||
def _init():
|
||||
global _plugin
|
||||
if _plugin is None:
|
||||
_plugin = custom_ops.get_plugin(
|
||||
module_name='filtered_lrelu_plugin',
|
||||
sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
|
||||
headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
|
||||
source_dir=os.path.dirname(__file__),
|
||||
extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
|
||||
)
|
||||
return True
|
||||
|
||||
def _get_filter_size(f):
|
||||
if f is None:
|
||||
return 1, 1
|
||||
assert isinstance(f, torch.Tensor)
|
||||
assert 1 <= f.ndim <= 2
|
||||
return f.shape[-1], f.shape[0] # width, height
|
||||
|
||||
def _parse_padding(padding):
|
||||
if isinstance(padding, int):
|
||||
padding = [padding, padding]
|
||||
assert isinstance(padding, (list, tuple))
|
||||
assert all(isinstance(x, (int, np.integer)) for x in padding)
|
||||
padding = [int(x) for x in padding]
|
||||
if len(padding) == 2:
|
||||
px, py = padding
|
||||
padding = [px, px, py, py]
|
||||
px0, px1, py0, py1 = padding
|
||||
return px0, px1, py0, py1
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
|
||||
r"""Filtered leaky ReLU for a batch of 2D images.
|
||||
|
||||
Performs the following sequence of operations for each channel:
|
||||
|
||||
1. Add channel-specific bias if provided (`b`).
|
||||
|
||||
2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
||||
|
||||
3. Pad the image with the specified number of zeros on each side (`padding`).
|
||||
Negative padding corresponds to cropping the image.
|
||||
|
||||
4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
|
||||
so that the footprint of all output pixels lies within the input image.
|
||||
|
||||
5. Multiply each value by the provided gain factor (`gain`).
|
||||
|
||||
6. Apply leaky ReLU activation function to each value.
|
||||
|
||||
7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
|
||||
|
||||
8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
|
||||
it so that the footprint of all output pixels lies within the input image.
|
||||
|
||||
9. Downsample the image by keeping every Nth pixel (`down`).
|
||||
|
||||
The fused op is considerably more efficient than performing the same calculation
|
||||
using standard PyTorch ops. It supports gradients of arbitrary order.
|
||||
|
||||
Args:
|
||||
x: Float32/float16/float64 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
fu: Float32 upsampling FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
fd: Float32 downsampling FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
||||
as `x`. The length of vector must must match the channel dimension of `x`.
|
||||
up: Integer upsampling factor (default: 1).
|
||||
down: Integer downsampling factor. (default: 1).
|
||||
padding: Padding with respect to the upsampled image. Can be a single number
|
||||
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
|
||||
slope: Slope on the negative side of leaky ReLU (default: 0.2).
|
||||
clamp: Maximum magnitude for leaky ReLU output (default: None).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert impl in ['ref', 'cuda']
|
||||
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
||||
return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
|
||||
return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
|
||||
"""Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
|
||||
existing `upfirdn2n()` and `bias_act()` ops.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
fu_w, fu_h = _get_filter_size(fu)
|
||||
fd_w, fd_h = _get_filter_size(fd)
|
||||
if b is not None:
|
||||
assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
|
||||
misc.assert_shape(b, [x.shape[1]])
|
||||
assert isinstance(up, int) and up >= 1
|
||||
assert isinstance(down, int) and down >= 1
|
||||
px0, px1, py0, py1 = _parse_padding(padding)
|
||||
assert gain == float(gain) and gain > 0
|
||||
assert slope == float(slope) and slope >= 0
|
||||
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
|
||||
|
||||
# Calculate output size.
|
||||
batch_size, channels, in_h, in_w = x.shape
|
||||
in_dtype = x.dtype
|
||||
out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
|
||||
out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
|
||||
|
||||
# Compute using existing ops.
|
||||
x = bias_act.bias_act(x=x, b=b) # Apply bias.
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
|
||||
x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
|
||||
|
||||
# Check output shape & dtype.
|
||||
misc.assert_shape(x, [batch_size, channels, out_h, out_w])
|
||||
assert x.dtype == in_dtype
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_filtered_lrelu_cuda_cache = dict()
|
||||
|
||||
def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
|
||||
"""Fast CUDA implementation of `filtered_lrelu()` using custom ops.
|
||||
"""
|
||||
assert isinstance(up, int) and up >= 1
|
||||
assert isinstance(down, int) and down >= 1
|
||||
px0, px1, py0, py1 = _parse_padding(padding)
|
||||
assert gain == float(gain) and gain > 0
|
||||
gain = float(gain)
|
||||
assert slope == float(slope) and slope >= 0
|
||||
slope = float(slope)
|
||||
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
|
||||
clamp = float(clamp if clamp is not None else 'inf')
|
||||
|
||||
# Lookup from cache.
|
||||
key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
|
||||
if key in _filtered_lrelu_cuda_cache:
|
||||
return _filtered_lrelu_cuda_cache[key]
|
||||
|
||||
# Forward op.
|
||||
class FilteredLReluCuda(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
|
||||
# Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
|
||||
if fu is None:
|
||||
fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
if fd is None:
|
||||
fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
assert 1 <= fu.ndim <= 2
|
||||
assert 1 <= fd.ndim <= 2
|
||||
|
||||
# Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
|
||||
if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
|
||||
fu = fu.square()[None]
|
||||
if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
|
||||
fd = fd.square()[None]
|
||||
|
||||
# Missing sign input tensor.
|
||||
if si is None:
|
||||
si = torch.empty([0])
|
||||
|
||||
# Missing bias tensor.
|
||||
if b is None:
|
||||
b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
|
||||
|
||||
# Construct internal sign tensor only if gradients are needed.
|
||||
write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
|
||||
|
||||
# Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
|
||||
strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
|
||||
if any(a < b for a, b in zip(strides[:-1], strides[1:])):
|
||||
warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
|
||||
|
||||
# Call C++/Cuda plugin if datatype is supported.
|
||||
if x.dtype in [torch.float16, torch.float32]:
|
||||
if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
|
||||
warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
|
||||
y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
|
||||
else:
|
||||
return_code = -1
|
||||
|
||||
# No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
|
||||
# only the bit-packed sign tensor is retained for gradient computation.
|
||||
if return_code < 0:
|
||||
warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
|
||||
|
||||
y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
|
||||
y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
|
||||
so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
|
||||
y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
|
||||
|
||||
# Prepare for gradient computation.
|
||||
ctx.save_for_backward(fu, fd, (si if si.numel() else so))
|
||||
ctx.x_shape = x.shape
|
||||
ctx.y_shape = y.shape
|
||||
ctx.s_ofs = sx, sy
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy): # pylint: disable=arguments-differ
|
||||
fu, fd, si = ctx.saved_tensors
|
||||
_, _, xh, xw = ctx.x_shape
|
||||
_, _, yh, yw = ctx.y_shape
|
||||
sx, sy = ctx.s_ofs
|
||||
dx = None # 0
|
||||
dfu = None; assert not ctx.needs_input_grad[1]
|
||||
dfd = None; assert not ctx.needs_input_grad[2]
|
||||
db = None # 3
|
||||
dsi = None; assert not ctx.needs_input_grad[4]
|
||||
dsx = None; assert not ctx.needs_input_grad[5]
|
||||
dsy = None; assert not ctx.needs_input_grad[6]
|
||||
|
||||
if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
|
||||
pp = [
|
||||
(fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
|
||||
xw * up - yw * down + px0 - (up - 1),
|
||||
(fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
|
||||
xh * up - yh * down + py0 - (up - 1),
|
||||
]
|
||||
gg = gain * (up ** 2) / (down ** 2)
|
||||
ff = (not flip_filter)
|
||||
sx = sx - (fu.shape[-1] - 1) + px0
|
||||
sy = sy - (fu.shape[0] - 1) + py0
|
||||
dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
|
||||
|
||||
if ctx.needs_input_grad[3]:
|
||||
db = dx.sum([0, 2, 3])
|
||||
|
||||
return dx, dfu, dfd, db, dsi, dsx, dsy
|
||||
|
||||
# Add to cache.
|
||||
_filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
|
||||
return FilteredLReluCuda
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
27
models/stylegan3/torch_utils/ops/filtered_lrelu_ns.cu
Normal file
27
models/stylegan3/torch_utils/ops/filtered_lrelu_ns.cu
Normal file
@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include "filtered_lrelu.cu"
|
||||
|
||||
// Template/kernel specializations for no signs mode (no gradients required).
|
||||
|
||||
// Full op, 32-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Full op, 64-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Activation/signs only for generic variant. 64-bit indexing.
|
||||
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
|
||||
|
||||
// Copy filters to constant memory.
|
||||
template cudaError_t copy_filters<false, false>(cudaStream_t stream);
|
||||
27
models/stylegan3/torch_utils/ops/filtered_lrelu_rd.cu
Normal file
27
models/stylegan3/torch_utils/ops/filtered_lrelu_rd.cu
Normal file
@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include "filtered_lrelu.cu"
|
||||
|
||||
// Template/kernel specializations for sign read mode.
|
||||
|
||||
// Full op, 32-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Full op, 64-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Activation/signs only for generic variant. 64-bit indexing.
|
||||
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
|
||||
|
||||
// Copy filters to constant memory.
|
||||
template cudaError_t copy_filters<false, true>(cudaStream_t stream);
|
||||
27
models/stylegan3/torch_utils/ops/filtered_lrelu_wr.cu
Normal file
27
models/stylegan3/torch_utils/ops/filtered_lrelu_wr.cu
Normal file
@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include "filtered_lrelu.cu"
|
||||
|
||||
// Template/kernel specializations for sign write mode.
|
||||
|
||||
// Full op, 32-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Full op, 64-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Activation/signs only for generic variant. 64-bit indexing.
|
||||
template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
|
||||
|
||||
// Copy filters to constant memory.
|
||||
template cudaError_t copy_filters<true, false>(cudaStream_t stream);
|
||||
60
models/stylegan3/torch_utils/ops/fma.py
Normal file
60
models/stylegan3/torch_utils/ops/fma.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
|
||||
|
||||
import torch
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def fma(a, b, c): # => a * b + c
|
||||
return _FusedMultiplyAdd.apply(a, b, c)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
||||
out = torch.addcmul(c, a, b)
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.c_shape = c.shape
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout): # pylint: disable=arguments-differ
|
||||
a, b = ctx.saved_tensors
|
||||
c_shape = ctx.c_shape
|
||||
da = None
|
||||
db = None
|
||||
dc = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
da = _unbroadcast(dout * b, a.shape)
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
db = _unbroadcast(dout * a, b.shape)
|
||||
|
||||
if ctx.needs_input_grad[2]:
|
||||
dc = _unbroadcast(dout, c_shape)
|
||||
|
||||
return da, db, dc
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _unbroadcast(x, shape):
|
||||
extra_dims = x.ndim - len(shape)
|
||||
assert extra_dims >= 0
|
||||
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
||||
if len(dim):
|
||||
x = x.sum(dim=dim, keepdim=True)
|
||||
if extra_dims:
|
||||
x = x.reshape(-1, *x.shape[extra_dims+1:])
|
||||
assert x.shape == shape
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
86
models/stylegan3/torch_utils/ops/grid_sample_gradfix.py
Normal file
86
models/stylegan3/torch_utils/ops/grid_sample_gradfix.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom replacement for `torch.nn.functional.grid_sample` that
|
||||
supports arbitrarily high order gradients between the input and output.
|
||||
Only works on 2D images and assumes
|
||||
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
|
||||
|
||||
import torch
|
||||
from pkg_resources import parse_version
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
# pylint: disable=arguments-differ
|
||||
# pylint: disable=protected-access
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
enabled = False # Enable the custom op by setting this to true.
|
||||
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
|
||||
_use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def grid_sample(input, grid):
|
||||
if _should_use_custom_op():
|
||||
return _GridSample2dForward.apply(input, grid)
|
||||
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _should_use_custom_op():
|
||||
return enabled
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _GridSample2dForward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, grid):
|
||||
assert input.ndim == 4
|
||||
assert grid.ndim == 4
|
||||
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
ctx.save_for_backward(input, grid)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, grid = ctx.saved_tensors
|
||||
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
|
||||
return grad_input, grad_grid
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _GridSample2dBackward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input, grid):
|
||||
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
|
||||
if _use_pytorch_1_12_api:
|
||||
op = op[0]
|
||||
if _use_pytorch_1_11_api:
|
||||
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
|
||||
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
|
||||
else:
|
||||
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
||||
ctx.save_for_backward(grid)
|
||||
return grad_input, grad_grid
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
||||
_ = grad2_grad_grid # unused
|
||||
grid, = ctx.saved_tensors
|
||||
grad2_grad_output = None
|
||||
grad2_input = None
|
||||
grad2_grid = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
|
||||
|
||||
assert not ctx.needs_input_grad[2]
|
||||
return grad2_grad_output, grad2_input, grad2_grid
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
107
models/stylegan3/torch_utils/ops/upfirdn2d.cpp
Normal file
107
models/stylegan3/torch_utils/ops/upfirdn2d.cpp
Normal file
@ -0,0 +1,107 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "upfirdn2d.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
|
||||
{
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
|
||||
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
|
||||
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
|
||||
TORCH_CHECK(x.numel() > 0, "x has zero size");
|
||||
TORCH_CHECK(f.numel() > 0, "f has zero size");
|
||||
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
||||
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
|
||||
TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
|
||||
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
|
||||
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
|
||||
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
|
||||
|
||||
// Create output tensor.
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
|
||||
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
|
||||
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
|
||||
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
|
||||
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
|
||||
TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
|
||||
|
||||
// Initialize CUDA kernel parameters.
|
||||
upfirdn2d_kernel_params p;
|
||||
p.x = x.data_ptr();
|
||||
p.f = f.data_ptr<float>();
|
||||
p.y = y.data_ptr();
|
||||
p.up = make_int2(upx, upy);
|
||||
p.down = make_int2(downx, downy);
|
||||
p.pad0 = make_int2(padx0, pady0);
|
||||
p.flip = (flip) ? 1 : 0;
|
||||
p.gain = gain;
|
||||
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
||||
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
|
||||
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
|
||||
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
|
||||
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
||||
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
|
||||
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
|
||||
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
|
||||
|
||||
// Choose CUDA kernel.
|
||||
upfirdn2d_kernel_spec spec;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
||||
{
|
||||
spec = choose_upfirdn2d_kernel<scalar_t>(p);
|
||||
});
|
||||
|
||||
// Set looping options.
|
||||
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
|
||||
p.loopMinor = spec.loopMinor;
|
||||
p.loopX = spec.loopX;
|
||||
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
|
||||
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
|
||||
|
||||
// Compute grid size.
|
||||
dim3 blockSize, gridSize;
|
||||
if (spec.tileOutW < 0) // large
|
||||
{
|
||||
blockSize = dim3(4, 32, 1);
|
||||
gridSize = dim3(
|
||||
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
|
||||
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
|
||||
p.launchMajor);
|
||||
}
|
||||
else // small
|
||||
{
|
||||
blockSize = dim3(256, 1, 1);
|
||||
gridSize = dim3(
|
||||
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
|
||||
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
|
||||
p.launchMajor);
|
||||
}
|
||||
|
||||
// Launch CUDA kernel.
|
||||
void* args[] = {&p};
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
return y;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("upfirdn2d", &upfirdn2d);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
384
models/stylegan3/torch_utils/ops/upfirdn2d.cu
Normal file
384
models/stylegan3/torch_utils/ops/upfirdn2d.cu
Normal file
@ -0,0 +1,384 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <c10/util/Half.h>
|
||||
#include "upfirdn2d.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Helpers.
|
||||
|
||||
template <class T> struct InternalType;
|
||||
template <> struct InternalType<double> { typedef double scalar_t; };
|
||||
template <> struct InternalType<float> { typedef float scalar_t; };
|
||||
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
||||
|
||||
static __device__ __forceinline__ int floor_div(int a, int b)
|
||||
{
|
||||
int t = 1 - a / b;
|
||||
return (a + t * b) / b - t;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Generic CUDA implementation for large filters.
|
||||
|
||||
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
|
||||
{
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
|
||||
// Calculate thread index.
|
||||
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int outY = minorBase / p.launchMinor;
|
||||
minorBase -= outY * p.launchMinor;
|
||||
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
||||
int majorBase = blockIdx.z * p.loopMajor;
|
||||
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
|
||||
return;
|
||||
|
||||
// Setup Y receptive field.
|
||||
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
|
||||
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
|
||||
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
|
||||
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
|
||||
if (p.flip)
|
||||
filterY = p.filterSize.y - 1 - filterY;
|
||||
|
||||
// Loop over major, minor, and X.
|
||||
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
||||
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
|
||||
{
|
||||
int nc = major * p.sizeMinor + minor;
|
||||
int n = nc / p.inSize.z;
|
||||
int c = nc - n * p.inSize.z;
|
||||
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
|
||||
{
|
||||
// Setup X receptive field.
|
||||
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
|
||||
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
|
||||
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
|
||||
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
|
||||
if (p.flip)
|
||||
filterX = p.filterSize.x - 1 - filterX;
|
||||
|
||||
// Initialize pointers.
|
||||
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
||||
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
|
||||
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
|
||||
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
|
||||
|
||||
// Inner loop.
|
||||
scalar_t v = 0;
|
||||
for (int y = 0; y < h; y++)
|
||||
{
|
||||
for (int x = 0; x < w; x++)
|
||||
{
|
||||
v += (scalar_t)(*xp) * (scalar_t)(*fp);
|
||||
xp += p.inStride.x;
|
||||
fp += filterStepX;
|
||||
}
|
||||
xp += p.inStride.y - w * p.inStride.x;
|
||||
fp += filterStepY - w * filterStepX;
|
||||
}
|
||||
|
||||
// Store result.
|
||||
v *= p.gain;
|
||||
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Specialized CUDA implementation for small filters.
|
||||
|
||||
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
|
||||
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
|
||||
{
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
|
||||
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
|
||||
__shared__ volatile scalar_t sf[filterH][filterW];
|
||||
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
|
||||
|
||||
// Calculate tile index.
|
||||
int minorBase = blockIdx.x;
|
||||
int tileOutY = minorBase / p.launchMinor;
|
||||
minorBase -= tileOutY * p.launchMinor;
|
||||
minorBase *= loopMinor;
|
||||
tileOutY *= tileOutH;
|
||||
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
||||
int majorBase = blockIdx.z * p.loopMajor;
|
||||
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
|
||||
return;
|
||||
|
||||
// Load filter (flipped).
|
||||
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
|
||||
{
|
||||
int fy = tapIdx / filterW;
|
||||
int fx = tapIdx - fy * filterW;
|
||||
scalar_t v = 0;
|
||||
if (fx < p.filterSize.x & fy < p.filterSize.y)
|
||||
{
|
||||
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
|
||||
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
|
||||
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
|
||||
}
|
||||
sf[fy][fx] = v;
|
||||
}
|
||||
|
||||
// Loop over major and X.
|
||||
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
||||
{
|
||||
int baseNC = major * p.sizeMinor + minorBase;
|
||||
int n = baseNC / p.inSize.z;
|
||||
int baseC = baseNC - n * p.inSize.z;
|
||||
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
|
||||
{
|
||||
// Load input pixels.
|
||||
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
|
||||
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
|
||||
int tileInX = floor_div(tileMidX, upx);
|
||||
int tileInY = floor_div(tileMidY, upy);
|
||||
__syncthreads();
|
||||
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
|
||||
{
|
||||
int relC = inIdx;
|
||||
int relInX = relC / loopMinor;
|
||||
int relInY = relInX / tileInW;
|
||||
relC -= relInX * loopMinor;
|
||||
relInX -= relInY * tileInW;
|
||||
int c = baseC + relC;
|
||||
int inX = tileInX + relInX;
|
||||
int inY = tileInY + relInY;
|
||||
scalar_t v = 0;
|
||||
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
|
||||
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
||||
sx[relInY][relInX][relC] = v;
|
||||
}
|
||||
|
||||
// Loop over output pixels.
|
||||
__syncthreads();
|
||||
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
|
||||
{
|
||||
int relC = outIdx;
|
||||
int relOutX = relC / loopMinor;
|
||||
int relOutY = relOutX / tileOutW;
|
||||
relC -= relOutX * loopMinor;
|
||||
relOutX -= relOutY * tileOutW;
|
||||
int c = baseC + relC;
|
||||
int outX = tileOutX + relOutX;
|
||||
int outY = tileOutY + relOutY;
|
||||
|
||||
// Setup receptive field.
|
||||
int midX = tileMidX + relOutX * downx;
|
||||
int midY = tileMidY + relOutY * downy;
|
||||
int inX = floor_div(midX, upx);
|
||||
int inY = floor_div(midY, upy);
|
||||
int relInX = inX - tileInX;
|
||||
int relInY = inY - tileInY;
|
||||
int filterX = (inX + 1) * upx - midX - 1; // flipped
|
||||
int filterY = (inY + 1) * upy - midY - 1; // flipped
|
||||
|
||||
// Inner loop.
|
||||
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
|
||||
{
|
||||
scalar_t v = 0;
|
||||
#pragma unroll
|
||||
for (int y = 0; y < filterH / upy; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < filterW / upx; x++)
|
||||
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
|
||||
v *= p.gain;
|
||||
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
|
||||
{
|
||||
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
|
||||
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
|
||||
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
|
||||
|
||||
// No up/downsampling.
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
||||
}
|
||||
|
||||
// 2x upsampling.
|
||||
if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
|
||||
}
|
||||
if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
||||
}
|
||||
|
||||
// 2x downsampling.
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
|
||||
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
|
||||
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
|
||||
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
|
||||
}
|
||||
|
||||
// 4x upsampling.
|
||||
if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};
|
||||
}
|
||||
if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};
|
||||
}
|
||||
|
||||
// 4x downsampling (inefficient).
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};
|
||||
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};
|
||||
}
|
||||
return spec;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Template specializations.
|
||||
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
59
models/stylegan3/torch_utils/ops/upfirdn2d.h
Normal file
59
models/stylegan3/torch_utils/ops/upfirdn2d.h
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel parameters.
|
||||
|
||||
struct upfirdn2d_kernel_params
|
||||
{
|
||||
const void* x;
|
||||
const float* f;
|
||||
void* y;
|
||||
|
||||
int2 up;
|
||||
int2 down;
|
||||
int2 pad0;
|
||||
int flip;
|
||||
float gain;
|
||||
|
||||
int4 inSize; // [width, height, channel, batch]
|
||||
int4 inStride;
|
||||
int2 filterSize; // [width, height]
|
||||
int2 filterStride;
|
||||
int4 outSize; // [width, height, channel, batch]
|
||||
int4 outStride;
|
||||
int sizeMinor;
|
||||
int sizeMajor;
|
||||
|
||||
int loopMinor;
|
||||
int loopMajor;
|
||||
int loopX;
|
||||
int launchMinor;
|
||||
int launchMajor;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel specialization.
|
||||
|
||||
struct upfirdn2d_kernel_spec
|
||||
{
|
||||
void* kernel;
|
||||
int tileOutW;
|
||||
int tileOutH;
|
||||
int loopMinor;
|
||||
int loopX;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
389
models/stylegan3/torch_utils/ops/upfirdn2d.py
Normal file
389
models/stylegan3/torch_utils/ops/upfirdn2d.py
Normal file
@ -0,0 +1,389 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom PyTorch ops for efficient resampling of 2D images."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .. import custom_ops
|
||||
from .. import misc
|
||||
from . import conv2d_gradfix
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_plugin = None
|
||||
|
||||
def _init():
|
||||
global _plugin
|
||||
if _plugin is None:
|
||||
_plugin = custom_ops.get_plugin(
|
||||
module_name='upfirdn2d_plugin',
|
||||
sources=['upfirdn2d.cpp', 'upfirdn2d.cu'],
|
||||
headers=['upfirdn2d.h'],
|
||||
source_dir=os.path.dirname(__file__),
|
||||
extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
|
||||
)
|
||||
return True
|
||||
|
||||
def _parse_scaling(scaling):
|
||||
if isinstance(scaling, int):
|
||||
scaling = [scaling, scaling]
|
||||
assert isinstance(scaling, (list, tuple))
|
||||
assert all(isinstance(x, int) for x in scaling)
|
||||
sx, sy = scaling
|
||||
assert sx >= 1 and sy >= 1
|
||||
return sx, sy
|
||||
|
||||
def _parse_padding(padding):
|
||||
if isinstance(padding, int):
|
||||
padding = [padding, padding]
|
||||
assert isinstance(padding, (list, tuple))
|
||||
assert all(isinstance(x, int) for x in padding)
|
||||
if len(padding) == 2:
|
||||
padx, pady = padding
|
||||
padding = [padx, padx, pady, pady]
|
||||
padx0, padx1, pady0, pady1 = padding
|
||||
return padx0, padx1, pady0, pady1
|
||||
|
||||
def _get_filter_size(f):
|
||||
if f is None:
|
||||
return 1, 1
|
||||
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
||||
fw = f.shape[-1]
|
||||
fh = f.shape[0]
|
||||
with misc.suppress_tracer_warnings():
|
||||
fw = int(fw)
|
||||
fh = int(fh)
|
||||
misc.assert_shape(f, [fh, fw][:f.ndim])
|
||||
assert fw >= 1 and fh >= 1
|
||||
return fw, fh
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
||||
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
||||
|
||||
Args:
|
||||
f: Torch tensor, numpy array, or python list of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable),
|
||||
`[]` (impulse), or
|
||||
`None` (identity).
|
||||
device: Result device (default: cpu).
|
||||
normalize: Normalize the filter so that it retains the magnitude
|
||||
for constant input signal (DC)? (default: True).
|
||||
flip_filter: Flip the filter? (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
separable: Return a separable filter? (default: select automatically).
|
||||
|
||||
Returns:
|
||||
Float32 tensor of the shape
|
||||
`[filter_height, filter_width]` (non-separable) or
|
||||
`[filter_taps]` (separable).
|
||||
"""
|
||||
# Validate.
|
||||
if f is None:
|
||||
f = 1
|
||||
f = torch.as_tensor(f, dtype=torch.float32)
|
||||
assert f.ndim in [0, 1, 2]
|
||||
assert f.numel() > 0
|
||||
if f.ndim == 0:
|
||||
f = f[np.newaxis]
|
||||
|
||||
# Separable?
|
||||
if separable is None:
|
||||
separable = (f.ndim == 1 and f.numel() >= 8)
|
||||
if f.ndim == 1 and not separable:
|
||||
f = f.ger(f)
|
||||
assert f.ndim == (1 if separable else 2)
|
||||
|
||||
# Apply normalize, flip, gain, and device.
|
||||
if normalize:
|
||||
f /= f.sum()
|
||||
if flip_filter:
|
||||
f = f.flip(list(range(f.ndim)))
|
||||
f = f * (gain ** (f.ndim / 2))
|
||||
f = f.to(device=device)
|
||||
return f
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
||||
|
||||
Performs the following sequence of operations for each channel:
|
||||
|
||||
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
||||
|
||||
2. Pad the image with the specified number of zeros on each side (`padding`).
|
||||
Negative padding corresponds to cropping the image.
|
||||
|
||||
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
||||
so that the footprint of all output pixels lies within the input image.
|
||||
|
||||
4. Downsample the image by keeping every Nth pixel (`down`).
|
||||
|
||||
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
||||
The fused op is considerably more efficient than performing the same calculation
|
||||
using standard PyTorch ops. It supports gradients of arbitrary order.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
up: Integer upsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
down: Integer downsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
padding: Padding with respect to the upsampled image. Can be a single number
|
||||
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert impl in ['ref', 'cuda']
|
||||
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
||||
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
||||
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
||||
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
||||
"""
|
||||
# Validate arguments.
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
if f is None:
|
||||
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
||||
assert f.dtype == torch.float32 and not f.requires_grad
|
||||
batch_size, num_channels, in_height, in_width = x.shape
|
||||
upx, upy = _parse_scaling(up)
|
||||
downx, downy = _parse_scaling(down)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
|
||||
# Check that upsampled buffer is not smaller than the filter.
|
||||
upW = in_width * upx + padx0 + padx1
|
||||
upH = in_height * upy + pady0 + pady1
|
||||
assert upW >= f.shape[-1] and upH >= f.shape[0]
|
||||
|
||||
# Upsample by inserting zeros.
|
||||
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
||||
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
||||
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
||||
|
||||
# Pad or crop.
|
||||
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
||||
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
||||
|
||||
# Setup filter.
|
||||
f = f * (gain ** (f.ndim / 2))
|
||||
f = f.to(x.dtype)
|
||||
if not flip_filter:
|
||||
f = f.flip(list(range(f.ndim)))
|
||||
|
||||
# Convolve with the filter.
|
||||
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
||||
if f.ndim == 4:
|
||||
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
||||
else:
|
||||
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
||||
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
||||
|
||||
# Downsample by throwing away pixels.
|
||||
x = x[:, :, ::downy, ::downx]
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_upfirdn2d_cuda_cache = dict()
|
||||
|
||||
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
|
||||
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
|
||||
"""
|
||||
# Parse arguments.
|
||||
upx, upy = _parse_scaling(up)
|
||||
downx, downy = _parse_scaling(down)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
||||
if key in _upfirdn2d_cuda_cache:
|
||||
return _upfirdn2d_cuda_cache[key]
|
||||
|
||||
# Forward op.
|
||||
class Upfirdn2dCuda(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, f): # pylint: disable=arguments-differ
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
if f is None:
|
||||
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
if f.ndim == 1 and f.shape[0] == 1:
|
||||
f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1.
|
||||
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
||||
y = x
|
||||
if f.ndim == 2:
|
||||
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
||||
else:
|
||||
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0)
|
||||
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain)
|
||||
ctx.save_for_backward(f)
|
||||
ctx.x_shape = x.shape
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy): # pylint: disable=arguments-differ
|
||||
f, = ctx.saved_tensors
|
||||
_, _, ih, iw = ctx.x_shape
|
||||
_, _, oh, ow = dy.shape
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
fw - padx0 - 1,
|
||||
iw * upx - ow * downx + padx0 - upx + 1,
|
||||
fh - pady0 - 1,
|
||||
ih * upy - oh * downy + pady0 - upy + 1,
|
||||
]
|
||||
dx = None
|
||||
df = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
|
||||
|
||||
assert not ctx.needs_input_grad[1]
|
||||
return dx, df
|
||||
|
||||
# Add to cache.
|
||||
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
|
||||
return Upfirdn2dCuda
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Filter a batch of 2D images using the given 2D FIR filter.
|
||||
|
||||
By default, the result is padded so that its shape matches the input.
|
||||
User-specified padding is applied on top of that, with negative values
|
||||
indicating cropping. Pixels outside the image are assumed to be zero.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
padding: Padding with respect to the output. Can be a single number or a
|
||||
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
padx0 + fw // 2,
|
||||
padx1 + (fw - 1) // 2,
|
||||
pady0 + fh // 2,
|
||||
pady1 + (fh - 1) // 2,
|
||||
]
|
||||
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
||||
|
||||
By default, the result is padded so that its shape is a multiple of the input.
|
||||
User-specified padding is applied on top of that, with negative values
|
||||
indicating cropping. Pixels outside the image are assumed to be zero.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
up: Integer upsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
padding: Padding with respect to the output. Can be a single number or a
|
||||
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
upx, upy = _parse_scaling(up)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
padx0 + (fw + upx - 1) // 2,
|
||||
padx1 + (fw - upx) // 2,
|
||||
pady0 + (fh + upy - 1) // 2,
|
||||
pady1 + (fh - upy) // 2,
|
||||
]
|
||||
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
||||
|
||||
By default, the result is padded so that its shape is a fraction of the input.
|
||||
User-specified padding is applied on top of that, with negative values
|
||||
indicating cropping. Pixels outside the image are assumed to be zero.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
down: Integer downsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
padding: Padding with respect to the input. Can be a single number or a
|
||||
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
downx, downy = _parse_scaling(down)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
padx0 + (fw - downx + 1) // 2,
|
||||
padx1 + (fw - downx) // 2,
|
||||
pady0 + (fh - downy + 1) // 2,
|
||||
pady1 + (fh - downy) // 2,
|
||||
]
|
||||
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
251
models/stylegan3/torch_utils/persistence.py
Normal file
251
models/stylegan3/torch_utils/persistence.py
Normal file
@ -0,0 +1,251 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Facilities for pickling Python code alongside other data.
|
||||
|
||||
The pickled code is automatically imported into a separate Python module
|
||||
during unpickling. This way, any previously exported pickles will remain
|
||||
usable even if the original code is no longer available, or if the current
|
||||
version of the code is not consistent with what was originally pickled."""
|
||||
|
||||
import sys
|
||||
import pickle
|
||||
import io
|
||||
import inspect
|
||||
import copy
|
||||
import uuid
|
||||
import types
|
||||
import dnnlib
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_version = 6 # internal version number
|
||||
_decorators = set() # {decorator_class, ...}
|
||||
_import_hooks = [] # [hook_function, ...]
|
||||
_module_to_src_dict = dict() # {module: src, ...}
|
||||
_src_to_module_dict = dict() # {src: module, ...}
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def persistent_class(orig_class):
|
||||
r"""Class decorator that extends a given class to save its source code
|
||||
when pickled.
|
||||
|
||||
Example:
|
||||
|
||||
from torch_utils import persistence
|
||||
|
||||
@persistence.persistent_class
|
||||
class MyNetwork(torch.nn.Module):
|
||||
def __init__(self, num_inputs, num_outputs):
|
||||
super().__init__()
|
||||
self.fc = MyLayer(num_inputs, num_outputs)
|
||||
...
|
||||
|
||||
@persistence.persistent_class
|
||||
class MyLayer(torch.nn.Module):
|
||||
...
|
||||
|
||||
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
|
||||
source code alongside other internal state (e.g., parameters, buffers,
|
||||
and submodules). This way, any previously exported pickle will remain
|
||||
usable even if the class definitions have been modified or are no
|
||||
longer available.
|
||||
|
||||
The decorator saves the source code of the entire Python module
|
||||
containing the decorated class. It does *not* save the source code of
|
||||
any imported modules. Thus, the imported modules must be available
|
||||
during unpickling, also including `torch_utils.persistence` itself.
|
||||
|
||||
It is ok to call functions defined in the same module from the
|
||||
decorated class. However, if the decorated class depends on other
|
||||
classes defined in the same module, they must be decorated as well.
|
||||
This is illustrated in the above example in the case of `MyLayer`.
|
||||
|
||||
It is also possible to employ the decorator just-in-time before
|
||||
calling the constructor. For example:
|
||||
|
||||
cls = MyLayer
|
||||
if want_to_make_it_persistent:
|
||||
cls = persistence.persistent_class(cls)
|
||||
layer = cls(num_inputs, num_outputs)
|
||||
|
||||
As an additional feature, the decorator also keeps track of the
|
||||
arguments that were used to construct each instance of the decorated
|
||||
class. The arguments can be queried via `obj.init_args` and
|
||||
`obj.init_kwargs`, and they are automatically pickled alongside other
|
||||
object state. A typical use case is to first unpickle a previous
|
||||
instance of a persistent class, and then upgrade it to use the latest
|
||||
version of the source code:
|
||||
|
||||
with open('old_pickle.pkl', 'rb') as f:
|
||||
old_net = pickle.load(f)
|
||||
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
|
||||
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
|
||||
"""
|
||||
assert isinstance(orig_class, type)
|
||||
if is_persistent(orig_class):
|
||||
return orig_class
|
||||
|
||||
assert orig_class.__module__ in sys.modules
|
||||
orig_module = sys.modules[orig_class.__module__]
|
||||
orig_module_src = _module_to_src(orig_module)
|
||||
|
||||
class Decorator(orig_class):
|
||||
_orig_module_src = orig_module_src
|
||||
_orig_class_name = orig_class.__name__
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._init_args = copy.deepcopy(args)
|
||||
self._init_kwargs = copy.deepcopy(kwargs)
|
||||
assert orig_class.__name__ in orig_module.__dict__
|
||||
_check_pickleable(self.__reduce__())
|
||||
|
||||
@property
|
||||
def init_args(self):
|
||||
return copy.deepcopy(self._init_args)
|
||||
|
||||
@property
|
||||
def init_kwargs(self):
|
||||
return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
|
||||
|
||||
def __reduce__(self):
|
||||
fields = list(super().__reduce__())
|
||||
fields += [None] * max(3 - len(fields), 0)
|
||||
if fields[0] is not _reconstruct_persistent_obj:
|
||||
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
|
||||
fields[0] = _reconstruct_persistent_obj # reconstruct func
|
||||
fields[1] = (meta,) # reconstruct args
|
||||
fields[2] = None # state dict
|
||||
return tuple(fields)
|
||||
|
||||
Decorator.__name__ = orig_class.__name__
|
||||
_decorators.add(Decorator)
|
||||
return Decorator
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def is_persistent(obj):
|
||||
r"""Test whether the given object or class is persistent, i.e.,
|
||||
whether it will save its source code when pickled.
|
||||
"""
|
||||
try:
|
||||
if obj in _decorators:
|
||||
return True
|
||||
except TypeError:
|
||||
pass
|
||||
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def import_hook(hook):
|
||||
r"""Register an import hook that is called whenever a persistent object
|
||||
is being unpickled. A typical use case is to patch the pickled source
|
||||
code to avoid errors and inconsistencies when the API of some imported
|
||||
module has changed.
|
||||
|
||||
The hook should have the following signature:
|
||||
|
||||
hook(meta) -> modified meta
|
||||
|
||||
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
|
||||
|
||||
type: Type of the persistent object, e.g. `'class'`.
|
||||
version: Internal version number of `torch_utils.persistence`.
|
||||
module_src Original source code of the Python module.
|
||||
class_name: Class name in the original Python module.
|
||||
state: Internal state of the object.
|
||||
|
||||
Example:
|
||||
|
||||
@persistence.import_hook
|
||||
def wreck_my_network(meta):
|
||||
if meta.class_name == 'MyNetwork':
|
||||
print('MyNetwork is being imported. I will wreck it!')
|
||||
meta.module_src = meta.module_src.replace("True", "False")
|
||||
return meta
|
||||
"""
|
||||
assert callable(hook)
|
||||
_import_hooks.append(hook)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _reconstruct_persistent_obj(meta):
|
||||
r"""Hook that is called internally by the `pickle` module to unpickle
|
||||
a persistent object.
|
||||
"""
|
||||
meta = dnnlib.EasyDict(meta)
|
||||
meta.state = dnnlib.EasyDict(meta.state)
|
||||
for hook in _import_hooks:
|
||||
meta = hook(meta)
|
||||
assert meta is not None
|
||||
|
||||
assert meta.version == _version
|
||||
module = _src_to_module(meta.module_src)
|
||||
|
||||
assert meta.type == 'class'
|
||||
orig_class = module.__dict__[meta.class_name]
|
||||
decorator_class = persistent_class(orig_class)
|
||||
obj = decorator_class.__new__(decorator_class)
|
||||
|
||||
setstate = getattr(obj, '__setstate__', None)
|
||||
if callable(setstate):
|
||||
setstate(meta.state) # pylint: disable=not-callable
|
||||
else:
|
||||
obj.__dict__.update(meta.state)
|
||||
return obj
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _module_to_src(module):
|
||||
r"""Query the source code of a given Python module.
|
||||
"""
|
||||
src = _module_to_src_dict.get(module, None)
|
||||
if src is None:
|
||||
src = inspect.getsource(module)
|
||||
_module_to_src_dict[module] = src
|
||||
_src_to_module_dict[src] = module
|
||||
return src
|
||||
|
||||
def _src_to_module(src):
|
||||
r"""Get or create a Python module for the given source code.
|
||||
"""
|
||||
module = _src_to_module_dict.get(src, None)
|
||||
if module is None:
|
||||
module_name = "_imported_module_" + uuid.uuid4().hex
|
||||
module = types.ModuleType(module_name)
|
||||
sys.modules[module_name] = module
|
||||
_module_to_src_dict[module] = src
|
||||
_src_to_module_dict[src] = module
|
||||
exec(src, module.__dict__) # pylint: disable=exec-used
|
||||
return module
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _check_pickleable(obj):
|
||||
r"""Check that the given object is pickleable, raising an exception if
|
||||
it is not. This function is expected to be considerably more efficient
|
||||
than actually pickling the object.
|
||||
"""
|
||||
def recurse(obj):
|
||||
if isinstance(obj, (list, tuple, set)):
|
||||
return [recurse(x) for x in obj]
|
||||
if isinstance(obj, dict):
|
||||
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
||||
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
||||
return None # Python primitive types are pickleable.
|
||||
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
|
||||
return None # NumPy arrays and PyTorch tensors are pickleable.
|
||||
if is_persistent(obj):
|
||||
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
||||
return obj
|
||||
with io.BytesIO() as f:
|
||||
pickle.dump(recurse(obj), f)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
268
models/stylegan3/torch_utils/training_stats.py
Normal file
268
models/stylegan3/torch_utils/training_stats.py
Normal file
@ -0,0 +1,268 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Facilities for reporting and collecting training statistics across
|
||||
multiple processes and devices. The interface is designed to minimize
|
||||
synchronization overhead as well as the amount of boilerplate in user
|
||||
code."""
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
import torch
|
||||
import dnnlib
|
||||
|
||||
from . import misc
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
|
||||
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
|
||||
_counter_dtype = torch.float64 # Data type to use for the internal counters.
|
||||
_rank = 0 # Rank of the current process.
|
||||
_sync_device = None # Device to use for multiprocess communication. None = single-process.
|
||||
_sync_called = False # Has _sync() been called yet?
|
||||
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
|
||||
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def init_multiprocessing(rank, sync_device):
|
||||
r"""Initializes `torch_utils.training_stats` for collecting statistics
|
||||
across multiple processes.
|
||||
|
||||
This function must be called after
|
||||
`torch.distributed.init_process_group()` and before `Collector.update()`.
|
||||
The call is not necessary if multi-process collection is not needed.
|
||||
|
||||
Args:
|
||||
rank: Rank of the current process.
|
||||
sync_device: PyTorch device to use for inter-process
|
||||
communication, or None to disable multi-process
|
||||
collection. Typically `torch.device('cuda', rank)`.
|
||||
"""
|
||||
global _rank, _sync_device
|
||||
assert not _sync_called
|
||||
_rank = rank
|
||||
_sync_device = sync_device
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def report(name, value):
|
||||
r"""Broadcasts the given set of scalars to all interested instances of
|
||||
`Collector`, across device and process boundaries.
|
||||
|
||||
This function is expected to be extremely cheap and can be safely
|
||||
called from anywhere in the training loop, loss function, or inside a
|
||||
`torch.nn.Module`.
|
||||
|
||||
Warning: The current implementation expects the set of unique names to
|
||||
be consistent across processes. Please make sure that `report()` is
|
||||
called at least once for each unique name by each process, and in the
|
||||
same order. If a given process has no scalars to broadcast, it can do
|
||||
`report(name, [])` (empty list).
|
||||
|
||||
Args:
|
||||
name: Arbitrary string specifying the name of the statistic.
|
||||
Averages are accumulated separately for each unique name.
|
||||
value: Arbitrary set of scalars. Can be a list, tuple,
|
||||
NumPy array, PyTorch tensor, or Python scalar.
|
||||
|
||||
Returns:
|
||||
The same `value` that was passed in.
|
||||
"""
|
||||
if name not in _counters:
|
||||
_counters[name] = dict()
|
||||
|
||||
elems = torch.as_tensor(value)
|
||||
if elems.numel() == 0:
|
||||
return value
|
||||
|
||||
elems = elems.detach().flatten().to(_reduce_dtype)
|
||||
moments = torch.stack([
|
||||
torch.ones_like(elems).sum(),
|
||||
elems.sum(),
|
||||
elems.square().sum(),
|
||||
])
|
||||
assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
||||
moments = moments.to(_counter_dtype)
|
||||
|
||||
device = moments.device
|
||||
if device not in _counters[name]:
|
||||
_counters[name][device] = torch.zeros_like(moments)
|
||||
_counters[name][device].add_(moments)
|
||||
return value
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def report0(name, value):
|
||||
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
|
||||
but ignores any scalars provided by the other processes.
|
||||
See `report()` for further details.
|
||||
"""
|
||||
report(name, value if _rank == 0 else [])
|
||||
return value
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class Collector:
|
||||
r"""Collects the scalars broadcasted by `report()` and `report0()` and
|
||||
computes their long-term averages (mean and standard deviation) over
|
||||
user-defined periods of time.
|
||||
|
||||
The averages are first collected into internal counters that are not
|
||||
directly visible to the user. They are then copied to the user-visible
|
||||
state as a result of calling `update()` and can then be queried using
|
||||
`mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
|
||||
internal counters for the next round, so that the user-visible state
|
||||
effectively reflects averages collected between the last two calls to
|
||||
`update()`.
|
||||
|
||||
Args:
|
||||
regex: Regular expression defining which statistics to
|
||||
collect. The default is to collect everything.
|
||||
keep_previous: Whether to retain the previous averages if no
|
||||
scalars were collected on a given round
|
||||
(default: True).
|
||||
"""
|
||||
def __init__(self, regex='.*', keep_previous=True):
|
||||
self._regex = re.compile(regex)
|
||||
self._keep_previous = keep_previous
|
||||
self._cumulative = dict()
|
||||
self._moments = dict()
|
||||
self.update()
|
||||
self._moments.clear()
|
||||
|
||||
def names(self):
|
||||
r"""Returns the names of all statistics broadcasted so far that
|
||||
match the regular expression specified at construction time.
|
||||
"""
|
||||
return [name for name in _counters if self._regex.fullmatch(name)]
|
||||
|
||||
def update(self):
|
||||
r"""Copies current values of the internal counters to the
|
||||
user-visible state and resets them for the next round.
|
||||
|
||||
If `keep_previous=True` was specified at construction time, the
|
||||
operation is skipped for statistics that have received no scalars
|
||||
since the last update, retaining their previous averages.
|
||||
|
||||
This method performs a number of GPU-to-CPU transfers and one
|
||||
`torch.distributed.all_reduce()`. It is intended to be called
|
||||
periodically in the main training loop, typically once every
|
||||
N training steps.
|
||||
"""
|
||||
if not self._keep_previous:
|
||||
self._moments.clear()
|
||||
for name, cumulative in _sync(self.names()):
|
||||
if name not in self._cumulative:
|
||||
self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
||||
delta = cumulative - self._cumulative[name]
|
||||
self._cumulative[name].copy_(cumulative)
|
||||
if float(delta[0]) != 0:
|
||||
self._moments[name] = delta
|
||||
|
||||
def _get_delta(self, name):
|
||||
r"""Returns the raw moments that were accumulated for the given
|
||||
statistic between the last two calls to `update()`, or zero if
|
||||
no scalars were collected.
|
||||
"""
|
||||
assert self._regex.fullmatch(name)
|
||||
if name not in self._moments:
|
||||
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
||||
return self._moments[name]
|
||||
|
||||
def num(self, name):
|
||||
r"""Returns the number of scalars that were accumulated for the given
|
||||
statistic between the last two calls to `update()`, or zero if
|
||||
no scalars were collected.
|
||||
"""
|
||||
delta = self._get_delta(name)
|
||||
return int(delta[0])
|
||||
|
||||
def mean(self, name):
|
||||
r"""Returns the mean of the scalars that were accumulated for the
|
||||
given statistic between the last two calls to `update()`, or NaN if
|
||||
no scalars were collected.
|
||||
"""
|
||||
delta = self._get_delta(name)
|
||||
if int(delta[0]) == 0:
|
||||
return float('nan')
|
||||
return float(delta[1] / delta[0])
|
||||
|
||||
def std(self, name):
|
||||
r"""Returns the standard deviation of the scalars that were
|
||||
accumulated for the given statistic between the last two calls to
|
||||
`update()`, or NaN if no scalars were collected.
|
||||
"""
|
||||
delta = self._get_delta(name)
|
||||
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
||||
return float('nan')
|
||||
if int(delta[0]) == 1:
|
||||
return float(0)
|
||||
mean = float(delta[1] / delta[0])
|
||||
raw_var = float(delta[2] / delta[0])
|
||||
return np.sqrt(max(raw_var - np.square(mean), 0))
|
||||
|
||||
def as_dict(self):
|
||||
r"""Returns the averages accumulated between the last two calls to
|
||||
`update()` as an `dnnlib.EasyDict`. The contents are as follows:
|
||||
|
||||
dnnlib.EasyDict(
|
||||
NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
|
||||
...
|
||||
)
|
||||
"""
|
||||
stats = dnnlib.EasyDict()
|
||||
for name in self.names():
|
||||
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
||||
return stats
|
||||
|
||||
def __getitem__(self, name):
|
||||
r"""Convenience getter.
|
||||
`collector[name]` is a synonym for `collector.mean(name)`.
|
||||
"""
|
||||
return self.mean(name)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _sync(names):
|
||||
r"""Synchronize the global cumulative counters across devices and
|
||||
processes. Called internally by `Collector.update()`.
|
||||
"""
|
||||
if len(names) == 0:
|
||||
return []
|
||||
global _sync_called
|
||||
_sync_called = True
|
||||
|
||||
# Collect deltas within current rank.
|
||||
deltas = []
|
||||
device = _sync_device if _sync_device is not None else torch.device('cpu')
|
||||
for name in names:
|
||||
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
||||
for counter in _counters[name].values():
|
||||
delta.add_(counter.to(device))
|
||||
counter.copy_(torch.zeros_like(counter))
|
||||
deltas.append(delta)
|
||||
deltas = torch.stack(deltas)
|
||||
|
||||
# Sum deltas across ranks.
|
||||
if _sync_device is not None:
|
||||
torch.distributed.all_reduce(deltas)
|
||||
|
||||
# Update cumulative values.
|
||||
deltas = deltas.cpu()
|
||||
for idx, name in enumerate(names):
|
||||
if name not in _cumulative:
|
||||
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
||||
_cumulative[name].add_(deltas[idx])
|
||||
|
||||
# Return name-value pairs.
|
||||
return [(name, _cumulative[name]) for name in names]
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
181
optimization/run_optimization.py
Normal file
181
optimization/run_optimization.py
Normal file
@ -0,0 +1,181 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import optim
|
||||
from tqdm import tqdm
|
||||
|
||||
from criteria.clip_loss import CLIPLoss
|
||||
from criteria.id_loss import IDLoss
|
||||
from mapper.training.train_utils import STYLESPACE_DIMENSIONS
|
||||
from models.stylegan2.model import Generator
|
||||
import clip
|
||||
from utils import ensure_checkpoint_exists
|
||||
|
||||
STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in list(range(1, len(STYLESPACE_DIMENSIONS), 3))]
|
||||
|
||||
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
|
||||
lr_ramp = min(1, (1 - t) / rampdown)
|
||||
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
|
||||
lr_ramp = lr_ramp * min(1, t / rampup)
|
||||
|
||||
return initial_lr * lr_ramp
|
||||
|
||||
|
||||
def main(args):
|
||||
ensure_checkpoint_exists(args.ckpt)
|
||||
# 把描述加载进clip预训练模型里面去
|
||||
text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
|
||||
# print('text_input是: ', text_inputs)
|
||||
'''
|
||||
--description "a person with purple hair"
|
||||
tensor([[49406, 320, 2533, 593, 5496, 2225, 49407, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0]], device='cuda:0',
|
||||
dtype=torch.int32)
|
||||
--description "a person with red hair"
|
||||
tensor([[49406, 320, 2533, 593, 736, 2225, 49407, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0]], device='cuda:0',
|
||||
dtype=torch.int32)
|
||||
'''
|
||||
|
||||
os.makedirs(args.results_dir, exist_ok=True)
|
||||
|
||||
g_ema = Generator(args.stylegan_size, 512, 8)
|
||||
g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
|
||||
# 将模型对象设置为评估模式
|
||||
g_ema.eval()
|
||||
#更改cuda卡号
|
||||
g_ema = g_ema.cuda()
|
||||
# device = torch.cuda.current_device()
|
||||
# print('cuda:',device)
|
||||
mean_latent = g_ema.mean_latent(4096)
|
||||
# print('mean_latent: ', mean_latent.shape ) #[1,512]
|
||||
|
||||
|
||||
if args.latent_path:
|
||||
latent_code_init = torch.load(args.latent_path).cuda()
|
||||
with torch.no_grad():
|
||||
_, latent_code_init, _ = g_ema([latent_code_init], return_latents=True,
|
||||
truncation=args.truncation, truncation_latent=mean_latent)
|
||||
elif args.mode == "edit":
|
||||
latent_code_init_not_trunc = torch.randn(1, 512).cuda()
|
||||
with torch.no_grad():
|
||||
_, latent_code_init, _ = g_ema([latent_code_init_not_trunc], return_latents=True,
|
||||
truncation=args.truncation, truncation_latent=mean_latent)
|
||||
else:
|
||||
latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)
|
||||
print(latent_code_init) #在维度1上重复18次 torch.Size([1, 18, 512])
|
||||
with torch.no_grad():
|
||||
img_orig, _ = g_ema([latent_code_init], input_is_latent=True, randomize_noise=False)
|
||||
|
||||
if args.work_in_stylespace:
|
||||
with torch.no_grad():
|
||||
_, _, latent_code_init = g_ema([latent_code_init], input_is_latent=True, return_latents=True)
|
||||
latent = [s.detach().clone() for s in latent_code_init]
|
||||
for c, s in enumerate(latent):
|
||||
if c in STYLESPACE_INDICES_WITHOUT_TORGB:
|
||||
s.requires_grad = True
|
||||
else:
|
||||
latent = latent_code_init.detach().clone()
|
||||
latent.requires_grad = True
|
||||
|
||||
clip_loss = CLIPLoss(args)
|
||||
id_loss = IDLoss(args)
|
||||
|
||||
if args.work_in_stylespace:
|
||||
optimizer = optim.Adam(latent, lr=args.lr)
|
||||
else:
|
||||
optimizer = optim.Adam([latent], lr=args.lr)
|
||||
|
||||
pbar = tqdm(range(args.step))
|
||||
|
||||
for i in pbar:
|
||||
t = i / args.step
|
||||
lr = get_lr(t, args.lr)
|
||||
optimizer.param_groups[0]["lr"] = lr
|
||||
|
||||
img_gen, _ = g_ema([latent], input_is_latent=True, randomize_noise=False, input_is_stylespace=args.work_in_stylespace)
|
||||
|
||||
c_loss = clip_loss(img_gen, text_inputs)
|
||||
|
||||
if args.id_lambda > 0:
|
||||
i_loss = id_loss(img_gen, img_orig)[0]
|
||||
else:
|
||||
i_loss = 0
|
||||
|
||||
if args.mode == "edit":
|
||||
if args.work_in_stylespace:
|
||||
l2_loss = sum([((latent_code_init[c] - latent[c]) ** 2).sum() for c in range(len(latent_code_init))])
|
||||
else:
|
||||
l2_loss = ((latent_code_init - latent) ** 2).sum()
|
||||
loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss
|
||||
else:
|
||||
loss = c_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
pbar.set_description(
|
||||
(
|
||||
f"loss: {loss.item():.4f};"
|
||||
)
|
||||
)
|
||||
if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0:
|
||||
with torch.no_grad():
|
||||
img_gen, _ = g_ema([latent], input_is_latent=True, randomize_noise=False, input_is_stylespace=args.work_in_stylespace)
|
||||
|
||||
torchvision.utils.save_image(img_gen, f"results/{str(i).zfill(5)}.jpg", normalize=True, range=(-1, 1))
|
||||
|
||||
if args.mode == "edit":
|
||||
final_result = torch.cat([img_orig, img_gen])
|
||||
else:
|
||||
final_result = img_gen
|
||||
|
||||
return final_result
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--description", type=str, default="a person with purple hair", help="the text that guides the editing/generation")
|
||||
parser.add_argument("--ckpt", type=str, default="../pretrained_models/stylegan2-ffhq-config-f.pt", help="pretrained StyleGAN2 weights")
|
||||
parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution")
|
||||
parser.add_argument("--lr_rampup", type=float, default=0.05)
|
||||
parser.add_argument("--lr", type=float, default=0.1)
|
||||
parser.add_argument("--step", type=int, default=300, help="number of optimization steps")
|
||||
parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"], help="choose between edit an image an generate a free one")
|
||||
parser.add_argument("--l2_lambda", type=float, default=0.008, help="weight of the latent distance (used for editing only)")
|
||||
parser.add_argument("--id_lambda", type=float, default=0.000, help="weight of id loss (used for editing only)")
|
||||
parser.add_argument("--latent_path", type=str, default=None, help="starts the optimization from the given latent code if provided. Otherwose, starts from"
|
||||
"the mean latent in a free generation, and from a random one in editing. "
|
||||
"Expects a .pt format")
|
||||
parser.add_argument("--truncation", type=float, default=0.7, help="used only for the initial latent vector, and only when a latent code path is"
|
||||
"not provided")
|
||||
parser.add_argument('--work_in_stylespace', default=False, action='store_true')
|
||||
parser.add_argument("--save_intermediate_image_every", type=int, default=20, help="if > 0 then saves intermidate results during the optimization")
|
||||
parser.add_argument("--results_dir", type=str, default="results")
|
||||
parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str,
|
||||
help="Path to facial recognition network used in ID loss")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
result_image = main(args)
|
||||
|
||||
torchvision.utils.save_image(result_image.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"), normalize=True, scale_each=True, range=(-1, 1))
|
||||
|
||||
|
||||
36
test001.py
Normal file
36
test001.py
Normal file
@ -0,0 +1,36 @@
|
||||
import torchvision
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
from optimization.run_optimization import main
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--description", type=str, default="a person with purple hair",
|
||||
help="the text that guides the editing/generation")
|
||||
parser.add_argument("--ckpt", type=str, default="./pretrained_models/stylegan2-ffhq-config-f.pt",
|
||||
help="pretrained StyleGAN2 weights")
|
||||
parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution")
|
||||
parser.add_argument("--lr_rampup", type=float, default=0.05)
|
||||
parser.add_argument("--lr", type=float, default=0.1)
|
||||
parser.add_argument("--step", type=int, default=300, help="number of optimization steps")
|
||||
parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"],
|
||||
help="choose between edit an image an generate a free one")
|
||||
parser.add_argument("--l2_lambda", type=float, default=0.008,
|
||||
help="weight of the latent distance (used for editing only)")
|
||||
parser.add_argument("--latent_path", type=str, default="/home/ly/StyleCLIP-main/pretrained_models/latent_code/style3.pt",
|
||||
help="starts the optimization from the given latent code if provided. Otherwise, starts from"
|
||||
"the mean latent in a free generation, and from a random one in editing. "
|
||||
"Expects a .pt format")
|
||||
parser.add_argument("--truncation", type=float, default=0.7,
|
||||
help="used only for the initial latent vector, and only when a latent code path is"
|
||||
"not provided")
|
||||
parser.add_argument("--save_intermediate_image_every", type=int, default=20,
|
||||
help="if > 0 then saves intermidate results during the optimization")
|
||||
parser.add_argument("--results_dir", type=str, default="results")
|
||||
parser.add_argument('--work_in_stylespace', default=False, action='store_true', help="trains a mapper in S instead of W+")
|
||||
parser.add_argument('--ir_se50_weights', default='pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss")
|
||||
parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor')
|
||||
|
||||
args = vars(parser.parse_args())
|
||||
result_image = main(Namespace(**args))
|
||||
torchvision.utils.save_image(result_image.detach().cpu(), f"results/final_result.png", normalize=True, scale_each=True,
|
||||
range=(-1, 1))
|
||||
27
test002.py
Normal file
27
test002.py
Normal file
@ -0,0 +1,27 @@
|
||||
import torchvision
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
from PIL import Image
|
||||
|
||||
from utils import ensure_checkpoint_exists
|
||||
from mapper.scripts.inference import run
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--exp_dir', default="./results", type=str, help='Path to experiment output directory')
|
||||
parser.add_argument('--checkpoint_path', default="./pretrained_models/mapper/purple_hair.pt", type=str,
|
||||
help='Path to model checkpoint')
|
||||
parser.add_argument('--couple_outputs', default=True, action='store_true',
|
||||
help='Whether to also save inputs + outputs side-by-side')
|
||||
parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use')
|
||||
parser.add_argument('--no_coarse_mapper', default=False, action="store_true")
|
||||
parser.add_argument('--no_medium_mapper', default=False, action="store_true")
|
||||
parser.add_argument('--no_fine_mapper', default=False, action="store_true")
|
||||
parser.add_argument('--stylegan_size', default=1024, type=int)
|
||||
parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
|
||||
parser.add_argument('--latents_test_path', default="./latents_test/example_celebs.pt", type=str,
|
||||
help="The latents for the validation")
|
||||
parser.add_argument('--test_workers', default=0, type=int, help='Number of test/inference dataloader workers')
|
||||
parser.add_argument('--n_images', type=int, default=None, help='Number of images to output. If None, run on all data')
|
||||
|
||||
args = vars(parser.parse_args())
|
||||
run(Namespace(**args))
|
||||
49
utils.py
Normal file
49
utils.py
Normal file
@ -0,0 +1,49 @@
|
||||
import os
|
||||
|
||||
|
||||
google_drive_paths = {
|
||||
"stylegan2-ffhq-config-f.pt": "https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT",
|
||||
|
||||
"mapper/pretrained/afro.pt": "https://drive.google.com/uc?id=1i5vAqo4z0I-Yon3FNft_YZOq7ClWayQJ",
|
||||
"mapper/pretrained/angry.pt": "https://drive.google.com/uc?id=1g82HEH0jFDrcbCtn3M22gesWKfzWV_ma",
|
||||
"mapper/pretrained/beyonce.pt": "https://drive.google.com/uc?id=1KJTc-h02LXs4zqCyo7pzCp0iWeO6T9fz",
|
||||
"mapper/pretrained/bobcut.pt": "https://drive.google.com/uc?id=1IvyqjZzKS-vNdq_OhwapAcwrxgLAY8UF",
|
||||
"mapper/pretrained/bowlcut.pt": "https://drive.google.com/uc?id=1xwdxI2YCewSt05dEHgkpmmzoauPjEnnZ",
|
||||
"mapper/pretrained/curly_hair.pt": "https://drive.google.com/uc?id=1xZ7fFB12Ci6rUbUfaHPpo44xUFzpWQ6M",
|
||||
"mapper/pretrained/depp.pt": "https://drive.google.com/uc?id=1FPiJkvFPG_y-bFanxLLP91wUKuy-l3IV",
|
||||
"mapper/pretrained/hilary_clinton.pt": "https://drive.google.com/uc?id=1X7U2zj2lt0KFifIsTfOOzVZXqYyCWVll",
|
||||
"mapper/pretrained/mohawk.pt": "https://drive.google.com/uc?id=1oMMPc8iQZ7dhyWavZ7VNWLwzf9aX4C09",
|
||||
"mapper/pretrained/purple_hair.pt": "https://drive.google.com/uc?id=14H0CGXWxePrrKIYmZnDD2Ccs65EEww75",
|
||||
"mapper/pretrained/surprised.pt": "https://drive.google.com/uc?id=1F-mPrhO-UeWrV1QYMZck63R43aLtPChI",
|
||||
"mapper/pretrained/taylor_swift.pt": "https://drive.google.com/uc?id=10jHuHsKKJxuf3N0vgQbX_SMEQgFHDrZa",
|
||||
"mapper/pretrained/trump.pt": "https://drive.google.com/uc?id=14v8D0uzy4tOyfBU3ca9T0AzTt3v-dNyh",
|
||||
"mapper/pretrained/zuckerberg.pt": "https://drive.google.com/uc?id=1NjDcMUL8G-pO3i_9N6EPpQNXeMc3Ar1r",
|
||||
|
||||
"example_celebs.pt": "https://drive.google.com/uc?id=1VL3lP4avRhz75LxSza6jgDe-pHd2veQG"
|
||||
}
|
||||
|
||||
|
||||
def ensure_checkpoint_exists(model_weights_filename):
|
||||
if not os.path.isfile(model_weights_filename) and (
|
||||
model_weights_filename in google_drive_paths
|
||||
):
|
||||
gdrive_url = google_drive_paths[model_weights_filename]
|
||||
try:
|
||||
from gdown import download as drive_download
|
||||
|
||||
drive_download(gdrive_url, model_weights_filename, quiet=False)
|
||||
except ModuleNotFoundError:
|
||||
print(
|
||||
"gdown module not found.",
|
||||
"pip3 install gdown or, manually download the checkpoint file:",
|
||||
gdrive_url
|
||||
)
|
||||
|
||||
if not os.path.isfile(model_weights_filename) and (
|
||||
model_weights_filename not in google_drive_paths
|
||||
):
|
||||
print(
|
||||
model_weights_filename,
|
||||
" not found, you may need to manually download the model weights."
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user