Spaces:
Sleeping
Sleeping
# type: ignore | |
from typing import List, Tuple | |
import gradio as gr | |
import pandas as pd | |
import torch | |
from langchain_community.llms import CTransformers | |
from langchain_core.output_parsers import PydanticOutputParser | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from loguru import logger | |
from transformers import pipeline | |
logger.add("logs/file_{time}.log") | |
# asr model | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Device: {device}") | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-medium", | |
chunk_length_s=30, | |
# device=device, | |
generate_kwargs={"language": "russian"}, | |
) | |
# qa model | |
class Result(BaseModel): | |
"""Извлечь вопрос и ответ из аудио записи колл-центра""" | |
question: str = Field(..., description="Вопрос клиента") | |
answer: str = Field(..., description="Ответ оператора") | |
class Results(BaseModel): | |
results: List[Result] = Field(..., description="Пары вопрос-ответ") | |
config = { | |
"max_new_tokens": 1000, | |
"context_length": 3000, | |
"temperature": 0, | |
# "gpu_layers": 50, | |
} | |
llm = CTransformers( | |
model="TheBloke/saiga_mistral_7b-GGUF", | |
config=config, | |
) | |
# accelerator = Accelerator() | |
# llm, config = accelerator.prepare(llm, config) | |
def asr(audio_file) -> str: | |
transcribed_text = pipe(audio_file, batch_size=16) | |
logger.info(f"Transcribed text: {transcribed_text}") | |
return transcribed_text["text"] | |
# return "Здравствуйте, меня зовут Александр, чем могу помочь? До скольки вы работаете? До 20:00. Спасибо, до свидания!" | |
def qa(transcribed_text: str) -> Tuple[str, str]: | |
parser = PydanticOutputParser(pydantic_object=Results) | |
prompt = PromptTemplate( | |
template="На основе транскрипции звонка из колл-центра определите пары вопросов и ответов, выделив конкретные вопросы, которые задал клиент, и ответы, которые предоставил оператор.\n{format_instructions}\nТекст аудио записи: {transcribed_text}\n", | |
# template="Какой вопрос задал клиент? Какой ответ дал оператор?\n{format_instructions}\nТекст аудио записи: {transcribed_text}\n", | |
input_variables=["transcribed_text"], | |
partial_variables={"format_instructions": parser.get_format_instructions()}, | |
) | |
prompt_and_model = prompt | llm | |
output = prompt_and_model.invoke({"transcribed_text": transcribed_text}) | |
logger.info(f"Output: {output}") | |
results = parser.invoke(output) | |
logger.info(f"Result: {results}") | |
logger.info(f"Dict: {results.dict()}") | |
results = ( | |
pd.DataFrame(results.dict()) | |
.results.apply(pd.Series) | |
.rename({"question": "Вопрос", "answer": "Ответ"}, axis=1) | |
) | |
return transcribed_text, results | |
def inference(audio_file): | |
transcribed_text = asr(audio_file) | |
return qa(transcribed_text) | |
demo = gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.Audio( | |
label="Аудио запись для обработки", | |
sources="upload", | |
type="filepath", | |
) | |
], | |
outputs=[ | |
gr.components.Textbox(label="Транскрибированный текст"), | |
gr.DataFrame(headers=["Вопрос", "Ответ"], label="Вопросы и ответы"), | |
], | |
submit_btn="Обработать", | |
clear_btn="Очистить", | |
allow_flagging="never", | |
title="Обработчик аудиозаписей колл-центра", | |
description="Распознавание речи и определение вопроса клиента и ответа оператора.", | |
css="footer {visibility: hidden}", | |
examples=["samples/out_olga2.mp3"], | |
cache_examples=False, | |
) | |
demo.launch() | |