Spaces:
Sleeping
Sleeping
import streamlit as st | |
from hub_name import LORA_HUB_NAMES | |
from random import shuffle | |
import pandas as pd | |
import streamlit as st | |
import contextlib | |
from functools import wraps | |
from io import StringIO | |
import contextlib | |
import redirect as rd | |
import torch | |
import shutil | |
import os | |
css = """ | |
<style> | |
.stDataFrame { width: 100% !important; } | |
</style> | |
""" | |
st.markdown(css, unsafe_allow_html=True) | |
def main(): | |
st.title("LoraHub") | |
st.markdown("Low-rank adaptations (LoRA) are techniques for fine-tuning large language models on new tasks. We propose LoraHub, a framework that allows composing multiple LoRA modules trained on different tasks. The goal is to achieve good performance on unseen tasks using just a few examples, without needing extra parameters or training. And we want to build a marketplace where users can share their trained LoRA modules, thereby facilitating the application of these modules to new tasks.") | |
st.markdown("In this demo, you will use avaiable lora modules selected in the left sidebar to tackle your few-shot examples. When the LoraHub learning is done, you can download the final LoRA module and use it for your new task. You can check out more details in our [paper](https://huggingface.co/papers/2307.13269).") | |
with st.sidebar: | |
st.title("LoRA Module Pool") | |
st.markdown( | |
"The following modules are available for you to compose for your new task. Every module name is a peft repository in Huggingface Hub, and you can find them [here](https://huggingface.co/models?search=lorahub).") | |
df = pd.DataFrame({ | |
"Index": list(range(len(LORA_HUB_NAMES))), | |
"Module Name": LORA_HUB_NAMES, | |
}) | |
st.data_editor(df, | |
disabled=["LoRA Module", "Index"], | |
hide_index=True) | |
st.multiselect( | |
'Select your favorite modules as the candidate for LoRA composition', | |
list(range(len(LORA_HUB_NAMES))), | |
[], | |
key="select_names") | |
def set_lucky_modules(): | |
names = list(range(len(LORA_HUB_NAMES))) | |
shuffle(names) | |
names = names[:20] | |
st.session_state["select_names"] = names | |
st.button(":game_die: Give 20 Lucky Modules", | |
on_click=set_lucky_modules) | |
st.write('We will use the following modules', [ | |
LORA_HUB_NAMES[i] for i in st.session_state["select_names"]]) | |
st.subheader("Prepare your few-shot examples") | |
txt_input = st.text_area('Examples Inputs (One Line One Input)', | |
''' | |
Infer the date from context. Q: Today, 8/3/1997, is a day that we will never forget. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 03/27/1998 (B) 09/02/1997 (C) 07/27/1997 (D) 06/29/1997 (E) 07/27/1973 (F) 12/27/1997 A: | |
Infer the date from context. Q: May 6, 1992 is like yesterday to Jane, but that is actually ten years ago. What is the date tomorrow in MM/DD/YYYY? Options: (A) 04/16/2002 (B) 04/07/2003 (C) 05/07/2036 (D) 05/28/2002 (E) 05/07/2002 A: | |
Infer the date from context. Q: Today is the second day of the third month of 1966. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 02/26/1966 (B) 01/13/1966 (C) 02/02/1966 (D) 10/23/1966 (E) 02/23/1968 (F) 02/23/1966 A: | |
'''.strip()) | |
txt_output = st.text_area('Examples Outputs (One Line One Output)', ''' | |
(C) | |
(E) | |
(F) | |
'''.strip()) | |
max_step = st.slider('Maximum iteration step', 10, 1000, step=10) | |
# st.subheader("Watch the logs below") | |
buffer = st.expander("Learning Logs") | |
if st.button(':rocket: Start!'): | |
if len(st.session_state["select_names"]) == 0: | |
st.error("Please select at least 1 module!") | |
elif max_step < len(st.session_state["select_names"]): | |
st.error( | |
"Please specify a larger maximum iteration step than the number of selected modules!") | |
else: | |
buffer.text("* begin to perform lorahub learning *") | |
from util import lorahub_learning | |
with rd.stderr(to=buffer): | |
recommendation, final_lora = lorahub_learning([LORA_HUB_NAMES[i] for i in st.session_state["select_names"]], | |
txt_input, txt_output, max_inference_step=max_step) | |
st.success("Lorahub learning finished! You got the following recommendation:") | |
df = { | |
"modules": [LORA_HUB_NAMES[i] for i in st.session_state["select_names"]], | |
"weights": recommendation.value, | |
} | |
st.table(df) | |
# zip the final lora module | |
torch.save(final_lora, "lora/adapter_model.bin") | |
# create a zip file | |
shutil.make_archive("lora_module", 'zip', "lora") | |
with open("lora_module.zip", "rb") as fp: | |
btn = st.download_button( | |
label="Download ZIP", | |
data=fp, | |
file_name="lora_module.zip", | |
mime="application/zip" | |
) | |
if __name__ == "__main__": | |
main() | |