Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration | |
model_name = 'jian1114/jian_KoBART_title' | |
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name) | |
model = BartForConditionalGeneration.from_pretrained(model_name) | |
def process_paragraph(paragraph): | |
# Return a list from tokenizer.encode instead of tensor | |
input_ids_list = tokenizer.encode(paragraph, max_length=1024) | |
# Convert the list to tensor when needed | |
input_ids = torch.tensor([input_ids_list]) | |
output = model.generate(input_ids, max_length=32, num_beams=10, early_stopping=True) | |
subheading = tokenizer.decode(output[0], skip_special_tokens=True) | |
subheading_final = "" # μ€μ λ°νν μμ λͺ© | |
check_list = ["em class", "violet_text", "green_text", "red_text", "blue_text"] | |
if subheading=="O" or "OO" in subheading: | |
subheading_final = "π’μμ λͺ© μμ± μ€ν¨: λ μμΈν λ΄μ©μ΄ νμν©λλ€." | |
elif any(x in subheading for x in check_list): | |
subheading_final = "π’μμ λͺ© μμ± μ€ν¨: λ¬Έλ² κ΅μ ν λ€μ μλν΄ λ³΄μΈμ." | |
else: | |
subheading_final = subheading | |
return subheading_final | |
def main(): | |
css = """ | |
<style> | |
textarea { | |
height: 300px; | |
} | |
</style> | |
""" | |
st.markdown(css, unsafe_allow_html=True) | |
st.title("Subheading Generator") | |
user_input = st.text_area("Enter a paragraph: ") | |
if st.button("Generate"): | |
if user_input: | |
with st.spinner('Generating...'): | |
result = process_paragraph(user_input) | |
st.write(f'Subheading: {result}') | |
else: | |
st.warning('Please enter a paragraph.') | |
if __name__ == "__main__": | |
main() |