File size: 2,122 Bytes
885b434 b8774c1 9e25427 885b434 9e25427 cb34ab7 9e25427 5f17867 9e25427 b8774c1 7be036b 885b434 d48bb43 885b434 d48bb43 9e25427 885b434 d48bb43 9e25427 40c4a66 b9bec37 9e25427 885b434 9e25427 885b434 9e25427 885b434 091f6c9 943a7c3 1025e47 9e25427 885b434 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
MODEL_URL = "kingabzpro/Llama-3.1-8B-Instruct-Mental-Health-Classification"
tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(MODEL_URL,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float16,
device_map="cpu")
def prediction(text):
# create pipeline
pipe = pipeline("text-generation", tokenizer=tokenizer, model=model, torch_dtype=torch.float16,
device_map="cpu",)
prompt = f"""Classify the text into Normal, Depression, Anxiety, Bipolar, and return the answer as the corresponding mental health disorder label.
text: {text}
label: """.strip()
outputs = pipe(prompt, max_new_tokens=2, do_sample=True, temperature=0.1)
preds = outputs[0]["generated_text"].split("label: ")[-1].strip()
return preds
gradio_ui = gr.Interface(
fn=prediction,
title="Mental Health Disorder Classification",
description=f"Input the text to generate a Mental Health Disorder.\n For this classification, the {MODEL_URL} model was used.",
examples=[
['trouble sleeping, confused mind, restless heart. All out of tune'],
["In the quiet hours, even the shadows seem too heavy to bear."],
["Riding a tempest of emotions, where ecstatic highs crash into desolate lows without warning."]
],
inputs=gr.Textbox(lines=10, label="Write the text here"),
outputs=gr.Label(num_top_classes=4, label="Mental Health Disorder Category"),
theme= gr.themes.Soft(),
article="<p style='text-align: center'>Please read the tutorial to fine-tune the Llama 3.1 model on Mental Health Classification <a href='https://www.datacamp.com/tutorial/fine-tuning-llama-3-1' target='_blank'>https://www.datacamp.com/tutorial/fine-tuning-llama-3-1</a></p>",
)
gradio_ui.launch() |