test / app.py
ngrigg's picture
updated
2d0dc09
raw
history blame
No virus
2.59 kB
import streamlit as st
import pandas as pd
import asyncio
from llama_models import process_text
from dotenv import load_dotenv
import os
# Load environment variables from .env file
load_dotenv()
# Ensure API key is loaded correctly
api_key = os.getenv("HUGGINGFACE_API_KEY")
print(f"Hugging Face API Key: {api_key}")
async def process_csv(file):
print("Reading CSV file...")
df = pd.read_csv(file, header=None) # Read the CSV file without a header
print("CSV file read successfully.")
descriptions = df[0].tolist() # Access the first column directly
SAMPLE_SIZE = min(5, len(descriptions)) # Adjust sample size as needed
descriptions_subset = descriptions[:SAMPLE_SIZE]
model_name = "instruction-pretrain/finance-Llama3-8B" # or any other model you want to use
print(f"Model name: {model_name}")
print(f"Processing {SAMPLE_SIZE} descriptions out of {len(descriptions)} total descriptions.")
results = []
for i, desc in enumerate(descriptions_subset):
print(f"Processing description {i+1}/{SAMPLE_SIZE}...")
result = await process_text(model_name, desc)
print(f"Description {i+1} processed. Result: {result[:50]}...") # Print first 50 characters of the result
results.append(result)
# Fill the rest of the results with empty strings to match the length of the DataFrame
results.extend([''] * (len(descriptions) - SAMPLE_SIZE))
print("Assigning results to DataFrame...")
df['predictions'] = results
df.columns = df.columns.astype(str) # Convert column names to strings to avoid warnings
print("Results assigned to DataFrame successfully.")
print(df.head()) # Print first few rows of the DataFrame to verify
return df
st.title("Finance Model Deployment")
st.write("""
### Upload a CSV file with company descriptions to extract key products, geographies, and important keywords:
""")
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file is not None:
if st.button("Predict"):
with st.spinner("Processing..."):
print("Starting CSV processing...")
df = asyncio.run(process_csv(uploaded_file))
print("CSV processing completed. Displaying results.")
st.write(df)
st.download_button(
label="Download Predictions as CSV",
data=df.to_csv(index=False).encode('utf-8'),
file_name='predictions.csv',
mime='text/csv'
)
print("Results displayed and download button created.")