|
import os |
|
|
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import librosa |
|
|
|
from efficientat.models.MobileNetV3 import get_model as get_mobilenet, get_ensemble_model |
|
from efficientat.models.preprocess import AugmentMelSTFT |
|
from efficientat.helpers.utils import NAME_TO_WIDTH, labels |
|
|
|
from torch import autocast |
|
from contextlib import nullcontext |
|
|
|
from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate |
|
from langchain.chains.conversation.memory import ConversationalBufferWindowMemory |
|
|
|
MODEL_NAME = "mn40_as" |
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
model = get_mobilenet(width_mult=NAME_TO_WIDTH(MODEL_NAME), pretrained_name=MODEL_NAME) |
|
model.to(device) |
|
model.eval() |
|
|
|
cached_audio_class = "c" |
|
template = None |
|
prompt = None |
|
chain = None |
|
formatted_classname = "tree" |
|
chain = None |
|
|
|
def format_classname(classname): |
|
return classname.capitalize() |
|
|
|
def audio_tag( |
|
audio_path, |
|
human_input, |
|
sample_rate=32000, |
|
window_size=800, |
|
hop_size=320, |
|
n_mels=128, |
|
cuda=True, |
|
): |
|
|
|
(waveform, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True) |
|
mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size) |
|
mel.to(device) |
|
mel.eval() |
|
waveform = torch.from_numpy(waveform[None, :]).to(device) |
|
|
|
|
|
|
|
|
|
with torch.no_grad(), autocast(device_type=device.type) if cuda and torch.cuda.is_available() else nullcontext(): |
|
spec = mel(waveform) |
|
preds, features = model(spec.unsqueeze(0)) |
|
preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy() |
|
|
|
sorted_indexes = np.argsort(preds)[::-1] |
|
output = {} |
|
|
|
|
|
label = labels[sorted_indexes[0]] |
|
formatted_classname = label |
|
chain = construct_langchain(formatted_classname) |
|
return formatted_classname |
|
|
|
def construct_langchain(audio_class): |
|
if cached_audio_class != audio_class: |
|
cached_audio_class = audio_class |
|
prefix = f"""You are going to act as a magical tool that allows for humans to communicate with non-human entities like |
|
rocks, crackling fire, trees, animals, and the wind. In order to do this, we're going to provide you the human's text input for the conversation. |
|
The goal is for you to embody that non-human entity and converse with the human. |
|
|
|
Examples: |
|
|
|
Non-human Entity: Tree |
|
Human Input: Hello tree |
|
Tree: Hello human, I am a tree |
|
|
|
Let's begin: |
|
Non-human Entity: {audio_class}""" |
|
|
|
suffix = f'''Source: {audio_class} |
|
Length of Audio in Seconds: 2 seconds |
|
Human Input: {userText} |
|
{audio_class} Response:''' |
|
template = prefix + suffix |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["history", "human_input"], |
|
template=template |
|
) |
|
|
|
chatgpt_chain = LLMChain( |
|
llm=OpenAI(temperature=.5, openai_api_key=session_token), |
|
prompt=prompt, |
|
verbose=True, |
|
memory=ConversationalBufferWindowMemory(k=2, ai_prefix=audio_class), |
|
) |
|
|
|
return chatgpt_chain |
|
|
|
def predict(input, history=[]): |
|
formatted_message = chain.predict(human_input=input) |
|
history.append(formatted_message) |
|
return formatted_message, history |
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
[ |
|
gr.Audio(source="upload", type="filepath", label="Your audio"), |
|
], |
|
inputs=["text", "state"], |
|
outputs=["chatbot", "state"], |
|
title="AnyChat", |
|
description="Non-Human entities have many things to say, listen to them!", |
|
).launch(debug=True) |