Hot_reload_debug (#4519)

* Add hot reload function decorator, allow faster debugging without reloading models and data.
This commit is contained in:
Xuan-Phi Nguyen 2022-06-28 12:06:52 -07:00 committed by GitHub
parent 58c8041c17
commit fe56de410c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -840,3 +840,104 @@ def safe_getattr(obj, k, default=None):
def safe_hasattr(obj, k):
"""Returns True if the given key exists and is not None."""
return getattr(obj, k, None) is not None
def hotreload_function(name=None):
"""
Decorator to function to enable hot-reload for debugging.
It allows you to debug a function without having reloading all heavy models, dataset loading and
preprocessing, allow faster debugging.
If you want to change model or dataset loading, consider relaunching your code
-----------------------------------
This will run the decorated function func:
if func run successful:
It will pause, allow user to edit code, and prompt user to:
Press enter to re-run the function with updated code
Type "done" to finish the function, return output
Type "disable" to stop pausing this function and let code continue without pause
Ctril + C to terminal
if func raise error:
it will prompt user to
1. Edit code, and press enter to retry
2. Ctrl + C to terminate
3. Type "raise" to raise that exception
* Requirements:
0. Fairseq was installed with `pip install --editable .`
1. pip install jurigged[develoop]
2. set environment HOTRELOAD_PAUSE=1 CUDA_LAUNCH_BLOCKING=1
3. Run on only 1 GPU (no distributed)
* How to use:
1. in python, import and decorate the top-level function to be re-run after code edits:
```python
from fairseq.utils import hotreload_function
....
@hotreload_function("train_step")
def train_step(self, sample ....):
....
....
```
2. in bash run scripts:
```bash
watch_dir=<home>/fairseq-py/fairseq/tasks # directory to watch for file changes
export CUDA_VISIBLE_DEVICES=0 # single-gpu
HOTRELOAD_PAUSE=1 CUDA_LAUNCH_BLOCKING=1 python -m jurigged -w ${watch_dir} --poll 2 -v train.py ......
```
* NOTE:
1. -w ${watch_dir} specify all the files to be watched for changes
once functions, class, ... code are changed, all instances in the process will get updated (hot-reload)
* Limitation:
* Currently distributed debugging not working
* Need to launch train.py locally (cannot submit jobs)
"""
try:
import jurigged
except ImportError as e:
logger.warning(f'Please install jurigged: pip install jurigged[develoop]')
raise e
from fairseq.distributed import utils as distributed_utils
import traceback
def hotreload_decorator(func):
assert callable(func), f'not callable: {func}'
jname = name or func.__name__
logger.info(f'jurigged-hotreload:Apply jurigged on {jname}:{func.__name__}')
HOTRELOAD_PAUSE = bool(os.environ.get("HOTRELOAD_PAUSE", 0))
cublk = bool(os.environ.get("CUDA_LAUNCH_BLOCKING", 0))
prefix = f"HOTRELOAD:{jname}:[cublk={cublk}]"
hot_reload_state = {"disable": False}
def func_wrapper(*args, **kwargs):
if not HOTRELOAD_PAUSE or hot_reload_state['disable']:
return func(*args, **kwargs)
world_size = distributed_utils.get_global_world_size()
assert world_size <= 1, f'HOTRELOAD_PAUSE:{jname} currently cannot do distributed training'
success = False
while not success:
try:
output = func(*args, **kwargs)
# success = True
end_action = input(f'{prefix}: PAUSE, you may edit code now. Enter to re-run, ctrl+C to terminate, '
f'type "done" to continue (function still being watched), or type "disable" to stop pausing this function :')
if end_action.strip().lower() in ["disable", "done"]:
success = True
else:
logger.warning(f'{prefix}: action={end_action} function will re-run now.')
except Exception as e:
action = input(
f'{prefix}:ERROR: \n{traceback.format_exc()}\n'
f'Edit code to try again: enter to continue, ctrl+C to terminate, or type "raise" to raise the exception: '
)
if action.strip().lower() == "raise":
raise e
if end_action.strip().lower() == "disable":
logger.warning(
f'{prefix}: Stop pausing {jname}. The function is still being watched and newly editted code will take effect '
f'if the {jname} is called again later.'
f' "unset HOTRELOAD_PAUSE" before relaunch to disable hotreload and'
f' remove @hotreload_function decorator in the code.'
)
hot_reload_state['disable'] = True
return output
return func_wrapper
return hotreload_decorator