Spaces:
Sleeping
Sleeping
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import gradio as gr | |
from threading import Thread | |
import os | |
import json | |
import uuid | |
from datasets import Dataset | |
from huggingface_hub import HfApi, login | |
import time | |
# Install required packages if not present | |
from gradio_modal import Modal | |
import huggingface_hub | |
import datasets | |
# Model setup | |
checkpoint = "WillHeld/soft-raccoon" | |
device = "cuda" | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device) | |
# Constants for dataset | |
DATASET_REPO = "WillHeld/model-feedback" # Replace with your username | |
DATASET_PATH = "./feedback_data" # Local path to store feedback | |
DATASET_FILENAME = "feedback.jsonl" # Filename for feedback data | |
# Ensure feedback directory exists | |
os.makedirs(DATASET_PATH, exist_ok=True) | |
# Feedback storage functions | |
def save_feedback_locally(conversation, satisfaction, feedback_text): | |
"""Save feedback to a local JSONL file""" | |
# Create a unique ID for this feedback entry | |
feedback_id = str(uuid.uuid4()) | |
# Create a timestamp | |
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |
# Prepare the feedback data | |
feedback_data = { | |
"id": feedback_id, | |
"timestamp": timestamp, | |
"conversation": conversation, | |
"satisfaction": satisfaction, | |
"feedback": feedback_text | |
} | |
# Save to local file | |
feedback_file = os.path.join(DATASET_PATH, DATASET_FILENAME) | |
with open(feedback_file, "a") as f: | |
f.write(json.dumps(feedback_data) + "\n") | |
return feedback_id | |
def push_feedback_to_hub(hf_token=None): | |
"""Push the local feedback data to HuggingFace as a dataset""" | |
# Check if we have a token | |
if hf_token is None: | |
# Try to get token from environment variable | |
hf_token = os.environ.get("HF_TOKEN") | |
if hf_token is None: | |
print("No HuggingFace token provided. Cannot push to Hub.") | |
return False | |
try: | |
# Login to HuggingFace | |
login(token=hf_token) | |
# Check if we have data to push | |
feedback_file = os.path.join(DATASET_PATH, DATASET_FILENAME) | |
if not os.path.exists(feedback_file): | |
print("No feedback data to push.") | |
return False | |
# Load data from the JSONL file | |
with open(feedback_file, "r") as f: | |
feedback_data = [json.loads(line) for line in f] | |
# Create a dataset from the feedback data | |
dataset = Dataset.from_list(feedback_data) | |
# Push to Hub | |
dataset.push_to_hub( | |
DATASET_REPO, | |
private=True # Set to False if you want the dataset to be public | |
) | |
print(f"Feedback data pushed to {DATASET_REPO} successfully.") | |
return True | |
except Exception as e: | |
print(f"Error pushing feedback data to Hub: {e}") | |
return False | |
# Modified predict function to update conversation state | |
def predict(message, history, temperature, top_p): | |
# Update history with user message | |
history.append({"role": "user", "content": message}) | |
input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True) | |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
# Create a streamer | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Set up generation parameters | |
generation_kwargs = { | |
"input_ids": inputs, | |
"max_new_tokens": 1024, | |
"temperature": float(temperature), | |
"top_p": float(top_p), | |
"do_sample": True, | |
"streamer": streamer, | |
} | |
# Run generation in a separate thread | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield from the streamer as tokens are generated | |
partial_text = "" | |
for new_text in streamer: | |
partial_text += new_text | |
yield partial_text | |
# After full generation, update state with assistant's response | |
history.append({"role": "assistant", "content": partial_text}) | |
return partial_text | |
# Function to handle the research feedback submission | |
def submit_research_feedback(conversation_state, satisfaction, feedback_text): | |
"""Save user feedback both locally and to HuggingFace Hub""" | |
# Save locally first | |
feedback_id = save_feedback_locally(conversation_state, satisfaction, feedback_text) | |
# Get token from environment variable | |
env_token = os.environ.get("HF_TOKEN") | |
# Use environment token | |
push_success = push_feedback_to_hub(env_token) | |
if push_success: | |
status_msg = "Thank you for your valuable feedback! Your insights have been saved to the dataset." | |
else: | |
status_msg = "Thank you for your feedback! It has been saved locally, but couldn't be pushed to the dataset. Please check server logs." | |
return status_msg | |
# Create the Gradio blocks interface | |
with gr.Blocks() as demo: | |
# State to track conversation history | |
conversation_state = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
# Custom chat function wrapper to update state | |
def chat_with_state(message, history, state, temperature, top_p): | |
for partial_response in predict(message, history, temperature, top_p): | |
# Update our state with each yield | |
state = history.copy() | |
yield partial_response, state | |
state = history.copy() | |
print(state) | |
return partial_response, state | |
# Create ChatInterface | |
chatbot = gr.ChatInterface( | |
chat_with_state, | |
additional_inputs=[ | |
conversation_state, | |
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P") | |
], | |
additional_outputs=[conversation_state], | |
type="messages" | |
) | |
with gr.Column(scale=1): | |
report_button = gr.Button("Share Feedback", variant="primary") | |
# Create the modal with feedback form components | |
with Modal(visible=False) as feedback_modal: | |
with gr.Column(): | |
gr.Markdown("## Research Preview Feedback") | |
gr.Markdown("Thank you for testing our research model. Your feedback (positive or negative) helps us improve!") | |
satisfaction = gr.Radio( | |
["Very satisfied", "Satisfied", "Neutral", "Unsatisfied", "Very unsatisfied"], | |
label="How would you rate your experience with this research model?", | |
value="Neutral" | |
) | |
feedback_text = gr.Textbox( | |
lines=5, | |
label="Share your observations (strengths, weaknesses, suggestions):", | |
placeholder="We welcome both positive feedback and constructive criticism to help improve this research prototype..." | |
) | |
submit_button = gr.Button("Submit Research Feedback", variant="primary") | |
response_text = gr.Textbox(label="Status", interactive=False) | |
# Connect the "Share Feedback" button to show the modal | |
report_button.click( | |
lambda: Modal(visible=True), | |
None, | |
feedback_modal | |
) | |
# Connect the submit button to the submit_research_feedback function with the current conversation state | |
submit_button.click( | |
submit_research_feedback, | |
inputs=[conversation_state, satisfaction, feedback_text], | |
outputs=response_text | |
) | |
# Launch the demo | |
demo.launch() |