import torch
from transformers import pipeline
from PIL import Image
import soundfile as sf
from datasets import load_dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class AutomaticSpeechRecognition():
    QuestionAudio = None
    ##
    QuestionText = None

    def funcAutomaticSpeechRecognition(self, input):
        '''
        Verify the inference
        '''
        speech_reco = pipeline(
            "automatic-speech-recognition", model="openai/whisper-base", device=device
        )
        res = speech_reco(input)
        return res["text"]

    def run(self):
        self.QuestionText = self.funcAutomaticSpeechRecognition(self.QuestionAudio)

class TextandImageQuery():
    QuestionText = None
    RawImage = None

    ##
    AnswerText = None

    def funcTextandImageQuery(self, raw_image_path,question):
        '''
        Apply an NN to answer the question
        '''
        raw_image=Image.open(raw_image_path).convert("RGB")
        pipe = pipeline("image-text-to-text", model="Salesforce/blip-vqa-base",device=device)

        output = pipe(raw_image, question, top_k=1)[0]
        return output['generated_text']

    def run(self):
        self.AnswerText = self.funcTextandImageQuery(self.RawImage, self.QuestionText)

class TextToSpeech():
    AnswerText = None
    AnswerAudio = None

    def funcTextToSpeech(self, input):
        synthesiser = pipeline("text-to-speech", "microsoft/speecht5_tts",device=device)

        embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
        speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
        # You can replace this embedding with your own as well.

        speech = synthesiser(input,
                                forward_params={"speaker_embeddings": speaker_embedding})


        path_output = "AudioAnswer.wav"
        sf.write(path_output, speech["audio"], samplerate=speech["sampling_rate"])
        return path_output

    def run(self):
        self.AnswerAudio = self.funcTextToSpeech(self.AnswerText)

if __name__ == '__main__':
    AIM_ASR = AutomaticSpeechRecognition()
    AIM_ASR.QuestionAudio = "path/to/audio/question"
    AIM_ASR.run()
    #print(AIM_ASR.QuestionText)
    AIM_TIQ = TextandImageQuery()
    AIM_TIQ.QuestionText=AIM_ASR.QuestionText
    ##AIM_TIQ.QuestionText="question as a string"
    AIM_TIQ.RawImage="path/to/context/image"
    AIM_TIQ.run()
    print(AIM_TIQ.AnswerText)
    AIM_TTS = TextToSpeech()
    AIM_TTS.AnswerText=AIM_TIQ.AnswerText
    AIM_TTS.run()
