Spaces:
Running
Running
| import json | |
| from huggingnft.lightweight_gan.train import timestamped_filename | |
| from streamlit_option_menu import option_menu | |
| from huggingface_hub import hf_hub_download, file_download | |
| from huggingface_hub.hf_api import HfApi | |
| import streamlit as st | |
| from huggingnft.lightweight_gan.lightweight_gan import Generator, LightweightGAN, evaluate_in_chunks, Trainer | |
| from accelerate import Accelerator | |
| hfapi = HfApi() | |
| model_names = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")] | |
| # streamlit-option-menu | |
| # st.set_page_config(page_title="Sharone's Streamlit App Gallery", page_icon="", layout="wide") | |
| # sysmenu = ''' | |
| # <style> | |
| # #MainMenu {visibility:hidden;} | |
| # footer {visibility:hidden;} | |
| # ''' | |
| # st.markdown(sysmenu,unsafe_allow_html=True) | |
| # # Add a logo (optional) in the sidebar | |
| # logo = Image.open(r'C:\Users\13525\Desktop\Insights_Bees_logo.png') | |
| # profile = Image.open(r'C:\Users\13525\Desktop\medium_profile.png') | |
| ABOUT_TEXT = "🤗 Hugging NFT - Generate NFT by OpenSea collection name." | |
| CONTACT_TEXT = "Here is some contact info" | |
| GENERATE_IMAGE_TEXT = "Text about generation" | |
| INTERPOLATION_TEXT = "Text about Interpolation" | |
| COLLECTION2COLLECTION_TEXT = "Text about Collection2Collection" | |
| STOPWORDS = ["-old"] | |
| COLLECTION2COLLECTION_KEYS = ["2"] | |
| def load_lightweight_model(model_name): | |
| file_path = file_download.hf_hub_download( | |
| repo_id=model_name, | |
| filename="config.json" | |
| ) | |
| config = json.loads(open(file_path).read()) | |
| organization_name, name = model_name.split("/") | |
| model = Trainer(**config, organization_name=organization_name, name=name) | |
| model.load(use_cpu=True) | |
| model.accelerator = Accelerator() | |
| return model | |
| def clean_models(model_names, stopwords): | |
| cleaned_model_names = [] | |
| for model_name in model_names: | |
| clear = True | |
| for stopword in stopwords: | |
| if stopword in model_name: | |
| clear = False | |
| break | |
| if clear: | |
| cleaned_model_names.append(model_name) | |
| return cleaned_model_names | |
| model_names = clean_models(model_names, STOPWORDS) | |
| with st.sidebar: | |
| choose = option_menu("Hugging NFT", | |
| ["About", "Generate image", "Interpolation", "Collection2Collection", "Contact"], | |
| icons=['house', 'camera fill', 'bi bi-youtube', 'book', 'person lines fill'], | |
| menu_icon="app-indicator", default_index=0, | |
| ) | |
| st.sidebar.markdown( | |
| """ | |
| <style> | |
| .aligncenter { | |
| text-align: center; | |
| } | |
| </style> | |
| <p style='text-align: center'> | |
| <a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank">Project Repository</a> | |
| </p> | |
| <p class="aligncenter"> | |
| <a href="https://github.com/AlekseyKorshuk/huggingnft" target="_blank"> | |
| <img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingnft?style=social"/> | |
| </a> | |
| </p> | |
| <p class="aligncenter"> | |
| <a href="https://twitter.com/alekseykorshuk" target="_blank"> | |
| <img src="https://img.shields.io/twitter/follow/alekseykorshuk?style=social"/> | |
| </a> | |
| </p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| if choose == "About": | |
| st.title(choose) | |
| st.markdown(ABOUT_TEXT) | |
| if choose == "Contact": | |
| st.title(choose) | |
| st.markdown(CONTACT_TEXT) | |
| if choose == "Generate image": | |
| st.title(choose) | |
| st.markdown(GENERATE_IMAGE_TEXT) | |
| model_name = st.selectbox( | |
| 'Choose model:', | |
| clean_models(model_names, COLLECTION2COLLECTION_KEYS) | |
| ) | |
| generation_type = st.selectbox( | |
| 'Select generation type:', | |
| ["default", "ema"] | |
| ) | |
| nrows = st.number_input("Number of rows:", | |
| min_value=1, | |
| max_value=10, | |
| step=1, | |
| value=8, | |
| ) | |
| generate_image_button = st.button("Generate") | |
| if generate_image_button: | |
| with st.spinner(text=f"Downloading selected model..."): | |
| model = load_lightweight_model(f"huggingnft/{model_name}") | |
| with st.spinner(text=f"Generating..."): | |
| st.image( | |
| model.generate_app( | |
| num=timestamped_filename(), | |
| nrow=nrows, | |
| checkpoint=-1, | |
| types=generation_type | |
| ) | |
| ) | |
| if choose == "Interpolation": | |
| st.title(choose) | |
| st.markdown(INTERPOLATION_TEXT) | |
| model_name = st.selectbox( | |
| 'Choose model:', | |
| clean_models(model_names, COLLECTION2COLLECTION_KEYS) | |
| ) | |
| nrows = st.number_input("Number of rows:", | |
| min_value=1, | |
| max_value=10, | |
| step=1, | |
| value=1, | |
| ) | |
| num_steps = st.number_input("Number of steps:", | |
| min_value=1, | |
| max_value=1000, | |
| step=1, | |
| value=100, | |
| ) | |
| generate_image_button = st.button("Generate") | |
| if generate_image_button: | |
| with st.spinner(text=f"Downloading selected model..."): | |
| model = load_lightweight_model(f"huggingnft/{model_name}") | |
| my_bar = st.progress(0) | |
| result = model.generate_interpolation( | |
| num=timestamped_filename(), | |
| num_image_tiles=nrows, | |
| num_steps=num_steps, | |
| save_frames=False, | |
| progress_bar=my_bar | |
| ) | |
| my_bar.empty() | |
| with st.spinner(text=f"Uploading result..."): | |
| st.image(result) | |
| if choose == "Collection2Collection": | |
| st.title(choose) | |
| st.markdown(INTERPOLATION_TEXT) | |
| model_name = st.selectbox( | |
| 'Choose model:', | |
| set(model_names) - set(clean_models(model_names, COLLECTION2COLLECTION_KEYS)) | |
| ) | |
| generate_image_button = st.button("Generate") | |
| if generate_image_button: | |
| st.markdown("generating Collection2Collection") | |