import streamlit as st import base64 from transformers import AutoModel, AutoTokenizer from graphviz import Digraph import json def display_tree(output): size = str(int(len(output))) + ',5' dpi = '300' format = 'svg' print(size, dpi) # Initialize Digraph object dot = Digraph(engine='dot', format=format) dot.attr('graph', rankdir='LR', rank='same', size=size, dpi=dpi) # Add nodes and edges for i,word_info in enumerate(output): word = word_info['word'] # Prepare word for RTL display head_idx = word_info['dep_head_idx'] dep_func = word_info['dep_func'] dot.node(str(i), word) # Create an invisible edge from the previous word to this one to enforce order if i > 0: dot.edge(str(i), str(i - 1), style='invis') if head_idx != -1: dot.edge(str(i), str(head_idx), label=dep_func, constraint='False') # Render the Digraph object dot.render('syntax_tree', format=format, cleanup=True) # Display the image in a scrollable container st.markdown( f"""
""", unsafe_allow_html=True) #st.image('syntax_tree.' + format, use_column_width=True) # Streamlit app title st.title('BERT Syntax Dependency Tree Visualizer') # Load Hugging Face token hf_token = st.secrets["HF_TOKEN"] # Assuming you've set up the token in Streamlit secrets # Authenticate and load model tokenizer = AutoTokenizer.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token) model = AutoModel.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token, trust_remote_code=True) model.eval() # Checkbox for the compute_mst parameter compute_mst = st.checkbox('Compute Maximum Spanning Tree', value=True) output_style = st.selectbox( 'Output Style: ', ('JSON', 'UD', 'IAHLT_UD'), index=1).lower() # User input sentence = st.text_input('Enter a sentence to analyze:') if sentence: # Display the input sentence st.text(sentence) # Model prediction output = model.predict([sentence], tokenizer, compute_syntax_mst=compute_mst, output_style=output_style)[0] if output_style == 'ud' or output_style == 'iahlt_ud': ud_output = output # convert to tree format of [dict(word, dep_head_idx, dep_func)] tree = [] for l in ud_output[2:]: parts = l.split('\t') if '-' in parts[0]: continue tree.append(dict(word=parts[1], dep_head_idx=int(parts[6]) - 1, dep_func=parts[7])) display_tree(tree) # Construct the table as a Markdown string table_md = "
\n\n" # Start with RTL div # Add the UD header lines table_md += "##" + ud_output[0] + "\n" table_md += "##" + ud_output[1] + "\n" # Table header table_md += "| " + " | ".join(["ID", "FORM", "LEMMA", "UPOS", "XPOS", "FEATS", "HEAD", "DEPREL", "DEPS", "MISC"]) + " |\n" # Table alignment table_md += "| " + " | ".join(["---"]*10) + " |\n" for line in ud_output[2:]: # Each UD line as a table row cells = line.replace('_', '\\_').replace('|', '|').split('\t') table_md += "| " + " | ".join(cells) + " |\n" table_md += "
" # Close the RTL div # Display the table using a single markdown call st.markdown(table_md, unsafe_allow_html=True) else: # display the tree tree = [w['syntax'] for w in output['tokens']] display_tree(tree) # and the full json st.markdown("```json\n" + json.dumps(output, ensure_ascii=False, indent=2) + "\n```")