improve interrogate

This commit is contained in:
Vladimir Mandic 2023-01-23 09:03:17 -05:00 committed by GitHub
parent 59146621e2
commit 925dd09c91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 12 deletions

View File

@ -20,6 +20,7 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.") re_topn = re.compile(r"\.top(\d+)\.")
category_types = ["artists", "flavors", "mediums", "movements"]
def download_default_clip_interrogate_categories(content_dir): def download_default_clip_interrogate_categories(content_dir):
print("Downloading CLIP categories...") print("Downloading CLIP categories...")
@ -27,12 +28,8 @@ def download_default_clip_interrogate_categories(content_dir):
tmpdir = content_dir + "_tmp" tmpdir = content_dir + "_tmp"
try: try:
os.makedirs(tmpdir) os.makedirs(tmpdir)
for category_type in category_types:
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt")) torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
os.rename(tmpdir, content_dir) os.rename(tmpdir, content_dir)
except Exception as e: except Exception as e:
@ -51,12 +48,13 @@ class InterrogateModels:
def __init__(self, content_dir): def __init__(self, content_dir):
self.loaded_categories = None self.loaded_categories = None
self.selected_categories = []
self.content_dir = content_dir self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu") self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
def categories(self): def categories(self):
if self.loaded_categories is not None: if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories:
return self.loaded_categories return self.loaded_categories
self.loaded_categories = [] self.loaded_categories = []
@ -64,14 +62,19 @@ class InterrogateModels:
download_default_clip_interrogate_categories(self.content_dir) download_default_clip_interrogate_categories(self.content_dir)
if os.path.exists(self.content_dir): if os.path.exists(self.content_dir):
for filename in os.listdir(self.content_dir): self.selected_categories = shared.opts.interrogate_clip_categories
for category_type in category_types:
if 'all' not in self.selected_categories and category_type not in self.selected_categories:
continue
filename = os.path.join(self.content_dir, f"{category_type}.txt")
if not os.path.isfile(filename):
continue
m = re_topn.search(filename) m = re_topn.search(filename)
topn = 1 if m is None else int(m.group(1)) topn = 1 if m is None else int(m.group(1))
with open(filename, "r", encoding="utf8") as file:
with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()] lines = [x.strip() for x in file.readlines()]
self.loaded_categories.append(Category(name=filename, topn=topn, items=lines)) self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines))
return self.loaded_categories return self.loaded_categories
@ -139,6 +142,8 @@ class InterrogateModels:
def rank(self, image_features, text_array, top_count=1): def rank(self, image_features, text_array, top_count=1):
import clip import clip
devices.torch_gc()
if shared.opts.interrogate_clip_dict_limit != 0: if shared.opts.interrogate_clip_dict_limit != 0:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]

View File

@ -424,6 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
"interrogate_clip_categories": OptionInfo(modules.interrogate.category_types, "CLIP: select which categories to inquire", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types}),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),