|
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
import os |
|
|
|
|
|
model_name = "kalyani2599/emotional_support_bot" |
|
|
|
|
|
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface") |
|
if os.path.exists(cache_dir): |
|
for file in os.listdir(cache_dir): |
|
file_path = os.path.join(cache_dir, file) |
|
os.remove(file_path) |
|
|
|
|
|
try: |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
|
except Exception as e: |
|
print(f"Error loading model or tokenizer: {e}") |
|
|
|
model_name = "facebook/blenderbot-3B" |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
|
|
|
def chatbot_response(input_text): |
|
try: |
|
inputs = tokenizer(input_text, return_tensors="pt") |
|
outputs = model.generate(**inputs, max_length=100) |
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return response |
|
except Exception as e: |
|
return f"Error in generating response: {e}" |
|
|
|
|
|
input_text = "Hello, how are you?" |
|
response = chatbot_response(input_text) |
|
print(response) |
|
|