File size: 1,778 Bytes
885b434
 
 
9e25427
885b434
9e25427
cb34ab7
9e25427
5f17867
9e25427
 
 
885b434
 
 
9e25427
 
885b434
9e25427
40c4a66
b9bec37
9e25427
885b434
 
 
 
9e25427
 
885b434
9e25427
 
 
885b434
091f6c9
943a7c3
885b434
9e25427
885b434
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gradio as gr

from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch

MODEL_URL = "kingabzpro/Llama-3.1-8B-Instruct-Mental-Health-Classification"

tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
tokenizer.pad_token_id = tokenizer.eos_token_id

model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL, low_cpu_mem_usage=True, return_dict=True,torch_dtype=torch.float16,
    device_map="cpu")

def prediction(news):
    # create pipeline
    clasifer = pipeline("text-generation", tokenizer=tokenizer, model=model, torch_dtype=torch.float16,
    device_map="cpu",)

    outputs = pipe(prompt, max_new_tokens=2, do_sample=True, temperature=0.1)
    preds = outputs[0]["generated_text"].split("label: ")[-1].strip()

    return preds


gradio_ui = gr.Interface(
    fn=prediction,
    title="Mental Health Disorder Classification",
    description=f"Input the text to generate a Mental Health Disorder.\n For this classification, the {MODEL_URL} model was used.",
    examples=[
        ['trouble sleeping, confused mind, restless heart. All out of tune'],
        ["In the quiet hours, even the shadows seem too heavy to bear."],
        ["Riding a tempest of emotions, where ecstatic highs crash into desolate lows without warning."]
    ],
    inputs=gr.Textbox(lines=10, label="Write the text here"),
    outputs=gr.Label(num_top_classes=4, label="Mental Health Disorder Category"),
    theme="huggingface",
    article="<p style='text-align: center'>Please read the tutorial to fine-tune the Llama 3.1 model on Mental Health Classification <a href='https://www.datacamp.com/tutorial/fine-tuning-llama-3-1' target='_blank'>https://www.datacamp.com/tutorial/fine-tuning-llama-3-1</a></p>",
)

gradio_ui.launch()