2022-09-21 06:50:29 +03:00
|
|
|
import torch
|
|
|
|
import tkinter as tk
|
|
|
|
|
|
|
|
|
|
|
|
window = tk.Tk()
|
|
|
|
window.title(string="Model Merger")
|
2023-06-23 05:58:20 +03:00
|
|
|
tk.Label(text="Model Merger", font=("Arial", 25)).pack()
|
|
|
|
tk.Label(text="GUI by antrobot1234").pack()
|
2022-09-21 06:50:29 +03:00
|
|
|
|
|
|
|
frame1 = tk.Frame()
|
|
|
|
frame2 = tk.Frame()
|
|
|
|
frame3 = tk.Frame()
|
|
|
|
|
|
|
|
frameSlider = tk.Frame()
|
|
|
|
frameButton = tk.Frame()
|
|
|
|
|
2023-06-23 05:58:20 +03:00
|
|
|
tk.Label(frame1, text="File 1:").pack(side="left")
|
|
|
|
file1text = tk.Entry(frame1, width=40)
|
2022-09-21 06:50:29 +03:00
|
|
|
file1text.pack(side="left")
|
|
|
|
|
2023-06-23 05:58:20 +03:00
|
|
|
tk.Label(frame2, text="File 2:").pack(side="left")
|
|
|
|
file2text = tk.Entry(frame2, width=40)
|
2022-09-21 06:50:29 +03:00
|
|
|
file2text.pack(side="left")
|
|
|
|
|
2023-06-23 05:58:20 +03:00
|
|
|
tk.Label(frame3, text="File Out:").pack(side="left")
|
|
|
|
fileOtext = tk.Entry(frame3, width=38)
|
2022-09-21 06:50:29 +03:00
|
|
|
fileOtext.pack(side="left")
|
|
|
|
|
2023-06-23 05:58:20 +03:00
|
|
|
tk.Label(frameSlider, text="Weight of file 1").pack(side="left")
|
|
|
|
scale = tk.Scale(
|
|
|
|
frameSlider, from_=0, to=100, orient="horizontal", tickinterval=10, length=450
|
|
|
|
)
|
2022-09-21 06:50:29 +03:00
|
|
|
scale.pack(side="left")
|
|
|
|
|
|
|
|
|
2023-06-23 05:58:20 +03:00
|
|
|
goButton = tk.Button(frameButton, text="RUN", height=2, width=20, bg="green")
|
2022-09-21 06:50:29 +03:00
|
|
|
|
2023-06-23 05:58:20 +03:00
|
|
|
|
|
|
|
def merge(file1, file2, out, a):
|
|
|
|
alpha = (a) / 100
|
|
|
|
if not (file1.endswith(".ckpt")):
|
2022-09-21 06:50:29 +03:00
|
|
|
file1 += ".ckpt"
|
2023-06-23 05:58:20 +03:00
|
|
|
if not (file2.endswith(".ckpt")):
|
2022-09-21 06:50:29 +03:00
|
|
|
file2 += ".ckpt"
|
2023-06-23 05:58:20 +03:00
|
|
|
if not (out.endswith(".ckpt")):
|
2022-09-21 06:50:29 +03:00
|
|
|
out += ".ckpt"
|
2023-06-23 05:58:20 +03:00
|
|
|
# Load Models
|
2022-09-21 06:50:29 +03:00
|
|
|
model_0 = torch.load(file1)
|
|
|
|
model_1 = torch.load(file2)
|
2023-06-23 05:58:20 +03:00
|
|
|
theta_0 = model_0["state_dict"]
|
|
|
|
theta_1 = model_1["state_dict"]
|
2022-09-21 06:50:29 +03:00
|
|
|
|
|
|
|
for key in theta_0.keys():
|
2023-06-23 05:58:20 +03:00
|
|
|
if "model" in key and key in theta_1:
|
|
|
|
theta_0[key] = (alpha) * theta_0[key] + (1 - alpha) * theta_1[key]
|
2022-09-21 06:50:29 +03:00
|
|
|
|
2023-06-23 05:58:20 +03:00
|
|
|
goButton.config(bg="red", text="RUNNING...\n(STAGE 2)")
|
2022-09-21 06:50:29 +03:00
|
|
|
window.update()
|
|
|
|
|
|
|
|
for key in theta_1.keys():
|
2023-06-23 05:58:20 +03:00
|
|
|
if "model" in key and key not in theta_0:
|
2022-09-21 06:50:29 +03:00
|
|
|
theta_0[key] = theta_1[key]
|
|
|
|
torch.save(model_0, out)
|
2023-06-23 05:58:20 +03:00
|
|
|
|
2022-09-21 06:50:29 +03:00
|
|
|
|
|
|
|
def handleClick(event):
|
2023-06-23 05:58:20 +03:00
|
|
|
goButton.config(bg="red", text="RUNNING...\n(STAGE 1)")
|
2022-09-21 06:50:29 +03:00
|
|
|
window.update()
|
2023-06-23 05:58:20 +03:00
|
|
|
merge(file1text.get(), file2text.get(), fileOtext.get(), scale.get())
|
|
|
|
goButton.config(bg="green", text="RUN")
|
|
|
|
|
|
|
|
|
2022-09-21 06:50:29 +03:00
|
|
|
goButton.pack()
|
2023-06-23 05:58:20 +03:00
|
|
|
goButton.bind("<Button-1>", handleClick)
|
2022-09-21 06:50:29 +03:00
|
|
|
|
|
|
|
|
|
|
|
frame1.pack()
|
|
|
|
frame2.pack()
|
|
|
|
frame3.pack()
|
|
|
|
frameSlider.pack()
|
|
|
|
frameButton.pack()
|
|
|
|
|
|
|
|
window.mainloop()
|