Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
big_text = """ | |
<div style='text-align: center;'> | |
<h1 style='font-size: 30x;'>Knowledge Extraction B</h1> | |
</div> | |
""" | |
st.markdown(big_text, unsafe_allow_html=True) | |
st.markdown( | |
f'<a href="https://ikmtechnology.github.io/ikmtechnology/questions_answers.json" target="_blank">question and answer used to fine tune the LLM</a>', | |
unsafe_allow_html=True) | |
st.markdown("sample queries for above file: <br/> What is a pretty amazing thing to say about your life? What is one of the best teachers in all of life? What does a wise person say?",unsafe_allow_html=True) | |
if 'is_initialized' not in st.session_state: | |
st.session_state['is_initialized'] = True | |
model_name = "EleutherAI/gpt-neo-125M" | |
st.session_state.model_name = "EleutherAI/gpt-neo-125M" | |
st.session_state.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
st.session_state.model = AutoModelForCausalLM.from_pretrained("zmbfeng/gpt-neo-125M_untethered_100_epochs_multiple_paragraph") | |
if torch.cuda.is_available(): | |
st.session_state.device = torch.device("cuda") | |
print("Using GPU:", torch.cuda.get_device_name(0)) | |
else: | |
st.session_state.device = torch.device("cpu") | |
print("GPU is not available, using CPU instead.") | |
st.session_state.model.to(st.session_state.device) | |
#prompt = "Discuss the impact of artificial intelligence on modern society." | |
#prompt = "What is one of the best teachers in all of life?" | |
#prompt = "What is the necessary awareness for deep and meaningful relationships?" | |
#prompt = "What would happen if you knew you were going to die within a week or month?" | |
#prompt = "question: What is one of the best teachers in all of life? " | |
#prompt = "question: What would happen if death were to happen in an hour, week, or year?" | |
#============= | |
#prompt = "question: What if you live life fully?" | |
#prompt = "question: What does death do to you?" | |
#============ | |
#prompt = "question: Do you understand that every minute you're on the verge of death?" | |
#most recent: | |
#prompt = "question: Are you going to wait until the last moment to let death be your teacher?" | |
temperature = st.slider("Select Temperature", min_value=0.01, max_value=2.0, value=0.01, step=0.01) | |
query = st.text_input("Enter your query") | |
if query: | |
prompt = "question: "+query | |
with st.spinner('Generating text...'): | |
input_ids = st.session_state.tokenizer(prompt, return_tensors="pt").input_ids.to(st.session_state.device) | |
# Generate a response | |
#output = st.session_state.model.generate(input_ids, max_length=2048, do_sample=True,temperature=0.01, pad_token_id=st.session_state.tokenizer.eos_token_id) #exact result for single paragraph | |
output = st.session_state.model.generate(input_ids, max_length=2048, do_sample=True, temperature=temperature, | |
pad_token_id=st.session_state.tokenizer.eos_token_id) # exact result for single paragraph | |
# Decode the output | |
response = st.session_state.tokenizer.decode(output[0], skip_special_tokens=True) | |
if response.startswith(prompt): | |
response=response[len(prompt):] | |
if response.startswith(" answer: "): | |
response = response[len(" answer: "):] | |
response=response.replace("<P>", "\n\n") | |
st.write(response) |