From 5f1dfbbc959855fd90ba80c0c76301d2063772fa Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 24 Dec 2022 18:02:22 -0500 Subject: [PATCH] implement train api --- modules/api/api.py | 94 ++++++++++++++++++++++++++- modules/api/models.py | 9 +++ modules/hypernetworks/hypernetwork.py | 26 ++++++++ modules/hypernetworks/ui.py | 31 ++------- 4 files changed, 132 insertions(+), 28 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index b43dd16b..1ceba75d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -10,13 +10,17 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru +from modules import sd_samplers, deepbooru, sd_hijack from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.extras import run_extras, run_pnginfo +from modules.textual_inversion.textual_inversion import create_embedding, train_embedding +from modules.textual_inversion.preprocess import preprocess +from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image from modules.sd_models import checkpoints_list from modules.realesrgan_model import get_realesrgan_models +from modules import devices from typing import List def upscaler_to_index(name: str): @@ -97,6 +101,11 @@ class Api: self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) + self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) + self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse) + self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse) + self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse) + self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse) def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: @@ -326,6 +335,89 @@ class Api: def refresh_checkpoints(self): shared.refresh_checkpoints() + def create_embedding(self, args: dict): + try: + shared.state.begin() + filename = create_embedding(**args) # create empty embedding + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used + shared.state.end() + return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename)) + except AssertionError as e: + shared.state.end() + return TrainResponse(info = "create embedding error: {error}".format(error = e)) + + def create_hypernetwork(self, args: dict): + try: + shared.state.begin() + filename = create_hypernetwork(**args) # create empty embedding + shared.state.end() + return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename)) + except AssertionError as e: + shared.state.end() + return TrainResponse(info = "create hypernetwork error: {error}".format(error = e)) + + def preprocess(self, args: dict): + try: + shared.state.begin() + preprocess(**args) # quick operation unless blip/booru interrogation is enabled + shared.state.end() + return PreprocessResponse(info = 'preprocess complete') + except KeyError as e: + shared.state.end() + return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e)) + except AssertionError as e: + shared.state.end() + return PreprocessResponse(info = "preprocess error: {error}".format(error = e)) + except FileNotFoundError as e: + shared.state.end() + return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e)) + + def train_embedding(self, args: dict): + try: + shared.state.begin() + apply_optimizations = shared.opts.training_xattention_optimizations + error = None + filename = '' + if not apply_optimizations: + sd_hijack.undo_optimizations() + try: + embedding, filename = train_embedding(**args) # can take a long time to complete + except Exception as e: + error = e + finally: + if not apply_optimizations: + sd_hijack.apply_optimizations() + shared.state.end() + return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) + except AssertionError as msg: + shared.state.end() + return TrainResponse(info = "train embedding error: {msg}".format(msg = msg)) + + def train_hypernetwork(self, args: dict): + try: + shared.state.begin() + initial_hypernetwork = shared.loaded_hypernetwork + apply_optimizations = shared.opts.training_xattention_optimizations + error = None + filename = '' + if not apply_optimizations: + sd_hijack.undo_optimizations() + try: + hypernetwork, filename = train_hypernetwork(*args) + except Exception as e: + error = e + finally: + shared.loaded_hypernetwork = initial_hypernetwork + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + if not apply_optimizations: + sd_hijack.apply_optimizations() + shared.state.end() + return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) + except AssertionError as msg: + shared.state.end() + return TrainResponse(info = "train embedding error: {error}".format(error = error)) + def launch(self, server_name, port): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port) diff --git a/modules/api/models.py b/modules/api/models.py index a22bc6b3..c446ce7a 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -175,6 +175,15 @@ class InterrogateRequest(BaseModel): class InterrogateResponse(BaseModel): caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") +class TrainResponse(BaseModel): + info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.") + +class CreateResponse(BaseModel): + info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.") + +class PreprocessResponse(BaseModel): + info: str = Field(title="Preprocess info", description="Response string from preprocessing task.") + fields = {} for key, metadata in opts.data_labels.items(): value = opts.data.get(key) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c406ffb3..3182ff03 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -378,6 +378,32 @@ def report_statistics(loss_info:dict): print(e) +def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): + # Remove illegal characters from name. + name = "".join( x for x in name if (x.isalnum() or x in "._- ")) + + fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") + if not overwrite_old: + assert not os.path.exists(fn), f"file {fn} already exists" + + if type(layer_structure) == str: + layer_structure = [float(x.strip()) for x in layer_structure.split(",")] + + hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( + name=name, + enable_sizes=[int(x) for x in enable_sizes], + layer_structure=layer_structure, + activation_func=activation_func, + weight_init=weight_init, + add_layer_norm=add_layer_norm, + use_dropout=use_dropout, + ) + hypernet.save(fn) + + shared.reload_hypernetworks() + + return fn + def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index c2d4b51c..e7f9e593 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -3,39 +3,16 @@ import os import re import gradio as gr -import modules.textual_inversion.preprocess -import modules.textual_inversion.textual_inversion +import modules.hypernetworks.hypernetwork from modules import devices, sd_hijack, shared -from modules.hypernetworks import hypernetwork not_available = ["hardswish", "multiheadattention"] -keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) +keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): - # Remove illegal characters from name. - name = "".join( x for x in name if (x.isalnum() or x in "._- ")) + filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout) - fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") - if not overwrite_old: - assert not os.path.exists(fn), f"file {fn} already exists" - - if type(layer_structure) == str: - layer_structure = [float(x.strip()) for x in layer_structure.split(",")] - - hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( - name=name, - enable_sizes=[int(x) for x in enable_sizes], - layer_structure=layer_structure, - activation_func=activation_func, - weight_init=weight_init, - add_layer_norm=add_layer_norm, - use_dropout=use_dropout, - ) - hypernet.save(fn) - - shared.reload_hypernetworks() - - return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" + return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", "" def train_hypernetwork(*args):