// src/services/ModelService.tsx
import React, { createContext, useContext, useState, useCallback } from 'react';
import { RunAnywhere, ModelCategory } from '@runanywhere/core';
import { LlamaCPP } from '@runanywhere/llamacpp';
import { ONNX, ModelArtifactType } from '@runanywhere/onnx';
const MODEL_IDS = {
llm: 'lfm2-350m-q8_0',
stt: 'sherpa-onnx-whisper-tiny.en',
tts: 'vits-piper-en_US-lessac-medium',
};
interface ModelState {
isLLMLoaded: boolean;
isSTTLoaded: boolean;
isTTSLoaded: boolean;
isLLMDownloading: boolean;
isSTTDownloading: boolean;
isTTSDownloading: boolean;
llmDownloadProgress: number;
sttDownloadProgress: number;
ttsDownloadProgress: number;
downloadAndLoadLLM: () => Promise<void>;
downloadAndLoadSTT: () => Promise<void>;
downloadAndLoadTTS: () => Promise<void>;
}
const ModelContext = createContext<ModelState | null>(null);
export const useModelService = () => {
const ctx = useContext(ModelContext);
if (!ctx) throw new Error('useModelService must be inside ModelServiceProvider');
return ctx;
};
export const ModelServiceProvider: React.FC<{children: React.ReactNode}> = ({ children }) => {
const [isLLMLoaded, setIsLLMLoaded] = useState(false);
const [isSTTLoaded, setIsSTTLoaded] = useState(false);
const [isTTSLoaded, setIsTTSLoaded] = useState(false);
const [isLLMDownloading, setIsLLMDownloading] = useState(false);
const [isSTTDownloading, setIsSTTDownloading] = useState(false);
const [isTTSDownloading, setIsTTSDownloading] = useState(false);
const [llmDownloadProgress, setLLMDownloadProgress] = useState(0);
const [sttDownloadProgress, setSTTDownloadProgress] = useState(0);
const [ttsDownloadProgress, setTTSDownloadProgress] = useState(0);
const downloadAndLoadLLM = useCallback(async () => {
const info = await RunAnywhere.getModelInfo(MODEL_IDS.llm);
if (!info?.localPath) {
setIsLLMDownloading(true);
await RunAnywhere.downloadModel(MODEL_IDS.llm, (p) => setLLMDownloadProgress(p.progress * 100));
setIsLLMDownloading(false);
}
const updated = await RunAnywhere.getModelInfo(MODEL_IDS.llm);
await RunAnywhere.loadModel(updated!.localPath!);
setIsLLMLoaded(true);
}, []);
const downloadAndLoadSTT = useCallback(async () => {
const info = await RunAnywhere.getModelInfo(MODEL_IDS.stt);
if (!info?.localPath) {
setIsSTTDownloading(true);
await RunAnywhere.downloadModel(MODEL_IDS.stt, (p) => setSTTDownloadProgress(p.progress * 100));
setIsSTTDownloading(false);
}
const updated = await RunAnywhere.getModelInfo(MODEL_IDS.stt);
await RunAnywhere.loadSTTModel(updated!.localPath!, 'whisper');
setIsSTTLoaded(true);
}, []);
const downloadAndLoadTTS = useCallback(async () => {
const info = await RunAnywhere.getModelInfo(MODEL_IDS.tts);
if (!info?.localPath) {
setIsTTSDownloading(true);
await RunAnywhere.downloadModel(MODEL_IDS.tts, (p) => setTTSDownloadProgress(p.progress * 100));
setIsTTSDownloading(false);
}
const updated = await RunAnywhere.getModelInfo(MODEL_IDS.tts);
await RunAnywhere.loadTTSModel(updated!.localPath!, 'piper');
setIsTTSLoaded(true);
}, []);
return (
<ModelContext.Provider value={{
isLLMLoaded, isSTTLoaded, isTTSLoaded,
isLLMDownloading, isSTTDownloading, isTTSDownloading,
llmDownloadProgress, sttDownloadProgress, ttsDownloadProgress,
downloadAndLoadLLM, downloadAndLoadSTT, downloadAndLoadTTS,
}}>
{children}
</ModelContext.Provider>
);
};
export async function registerModels() {
await LlamaCPP.addModel({
id: MODEL_IDS.llm,
name: 'LiquidAI LFM2 350M',
url: 'https://huggingface.co/LiquidAI/LFM2-350M-GGUF/resolve/main/LFM2-350M-Q8_0.gguf',
memoryRequirement: 400_000_000,
});
await ONNX.addModel({
id: MODEL_IDS.stt,
name: 'Whisper Tiny EN',
url: 'https://github.com/RunanywhereAI/sherpa-onnx/releases/download/runanywhere-models-v1/sherpa-onnx-whisper-tiny.en.tar.gz',
modality: ModelCategory.SpeechRecognition,
artifactType: ModelArtifactType.TarGzArchive,
memoryRequirement: 75_000_000,
});
await ONNX.addModel({
id: MODEL_IDS.tts,
name: 'Piper TTS EN-US',
url: 'https://github.com/RunanywhereAI/sherpa-onnx/releases/download/runanywhere-models-v1/vits-piper-en_US-lessac-medium.tar.gz',
modality: ModelCategory.SpeechSynthesis,
artifactType: ModelArtifactType.TarGzArchive,
memoryRequirement: 65_000_000,
});
}