diff --git a/scripts/img2txt.py b/scripts/img2txt.py index 25b6378..778ffe7 100644 --- a/scripts/img2txt.py +++ b/scripts/img2txt.py @@ -329,6 +329,24 @@ def img2txt(): models.append('ViT-H-14') if st.session_state["ViT-g-14"]: models.append('ViT-g-14') + + if st.session_state["ViTB32"]: + models.append('ViT-B/32') + if st.session_state['ViTB16']: + models.append('ViT-B/16') + + if st.session_state["ViTL14_336px"]: + models.append('ViT-L/14@336px') + if st.session_state["RN101"]: + models.append('RN101') + if st.session_state["RN50"]: + models.append('RN50') + if st.session_state["RN50x4"]: + models.append('RN50x4') + if st.session_state["RN50x16"]: + models.append('RN50x16') + if st.session_state["RN50x64"]: + models.append('RN50x64') # if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'): #image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB') @@ -371,6 +389,20 @@ def layout(): st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.") st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.") + + + with st.expander("Others"): + st.info("For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images.") + + st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.") + st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.") + st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.") + st.session_state["RN50"] = st.checkbox("RN50", value=False, help="RN50 model.") + st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.") + st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.") + st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.") + st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.") + # # st.subheader("Logs:")