call-center-asr / app.py
aidoskanapyanov
Don't cache examples
d43abcf
raw
history blame
4.14 kB
# type: ignore
from typing import List, Tuple
import gradio as gr
import pandas as pd
import torch
from langchain_community.llms import CTransformers
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from loguru import logger
from transformers import pipeline
logger.add("logs/file_{time}.log")
# asr model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.info(f"Device: {device}")
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-medium",
chunk_length_s=30,
# device=device,
generate_kwargs={"language": "russian"},
)
# qa model
class Result(BaseModel):
"""Извлечь вопрос и ответ из аудио записи колл-центра"""
question: str = Field(..., description="Вопрос клиента")
answer: str = Field(..., description="Ответ оператора")
class Results(BaseModel):
results: List[Result] = Field(..., description="Пары вопрос-ответ")
config = {
"max_new_tokens": 1000,
"context_length": 3000,
"temperature": 0,
# "gpu_layers": 50,
}
llm = CTransformers(
model="TheBloke/saiga_mistral_7b-GGUF",
config=config,
)
# accelerator = Accelerator()
# llm, config = accelerator.prepare(llm, config)
def asr(audio_file) -> str:
transcribed_text = pipe(audio_file, batch_size=16)
logger.info(f"Transcribed text: {transcribed_text}")
return transcribed_text["text"]
# return "Здравствуйте, меня зовут Александр, чем могу помочь? До скольки вы работаете? До 20:00. Спасибо, до свидания!"
def qa(transcribed_text: str) -> Tuple[str, str]:
parser = PydanticOutputParser(pydantic_object=Results)
prompt = PromptTemplate(
template="На основе транскрипции звонка из колл-центра определите пары вопросов и ответов, выделив конкретные вопросы, которые задал клиент, и ответы, которые предоставил оператор.\n{format_instructions}\nТекст аудио записи: {transcribed_text}\n",
# template="Какой вопрос задал клиент? Какой ответ дал оператор?\n{format_instructions}\nТекст аудио записи: {transcribed_text}\n",
input_variables=["transcribed_text"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
prompt_and_model = prompt | llm
output = prompt_and_model.invoke({"transcribed_text": transcribed_text})
logger.info(f"Output: {output}")
results = parser.invoke(output)
logger.info(f"Result: {results}")
logger.info(f"Dict: {results.dict()}")
results = (
pd.DataFrame(results.dict())
.results.apply(pd.Series)
.rename({"question": "Вопрос", "answer": "Ответ"}, axis=1)
)
return transcribed_text, results
@logger.catch
def inference(audio_file):
transcribed_text = asr(audio_file)
return qa(transcribed_text)
demo = gr.Interface(
fn=inference,
inputs=[
gr.Audio(
label="Аудио запись для обработки",
sources="upload",
type="filepath",
)
],
outputs=[
gr.components.Textbox(label="Транскрибированный текст"),
gr.DataFrame(headers=["Вопрос", "Ответ"], label="Вопросы и ответы"),
],
submit_btn="Обработать",
clear_btn="Очистить",
allow_flagging="never",
title="Обработчик аудиозаписей колл-центра",
description="Распознавание речи и определение вопроса клиента и ответа оператора.",
css="footer {visibility: hidden}",
examples=["samples/out_olga2.mp3"],
cache_examples=False,
)
demo.launch()