#include <iostream>
#include <string>
#include "ort_genai.h"
#include "common.h"
void CXX_API(
GeneratorParamsArgs& generator_params_args,
const std::string& model_path,
const std::string& ep,
const std::string& system_prompt,
bool verbose,
bool interactive,
bool rewind) {
RegisterEP(ep, ep_path);
// Create model and tokenizer
std::unordered_map<std::string, std::string> ep_options;
auto config = GetConfig(model_path, ep, ep_options, generator_params_args);
auto model = OgaModel::Create(*config);
auto tokenizer = OgaTokenizer::Create(*model);
auto stream = OgaTokenizerStream::Create(*tokenizer);
// Set search options for generator params
auto params = OgaGeneratorParams::Create(*model);
SetSearchOptions(*params, generator_params_args, verbose);
// Create system message
nlohmann::ordered_json message = nlohmann::ordered_json::array();
message.push_back({{"role", "system"}, {"content", system_prompt}});
// Create generator
auto generator = OgaGenerator::Create(*model, *params);
// Apply chat template and encode system prompt
std::string prompt = ApplyChatTemplate(
model_path, *tokenizer, message.dump(), false
);
auto sequences = OgaSequences::Create();
tokenizer->Encode(prompt.c_str(), *sequences);
generator->AppendTokenSequences(*sequences);
const int prompt_tokens_length = generator->TokenCount();
// Interactive conversation loop
while (true) {
// Get user input
std::string text;
std::cout << "Prompt (Use quit() to exit):" << std::endl;
std::getline(std::cin, text);
if (text == "quit()") {
break;
}
// Create user message
message = nlohmann::ordered_json::array();
message.push_back({{"role", "user"}, {"content", text}});
// Apply chat template and encode
prompt = ApplyChatTemplate(model_path, *tokenizer, message.dump(), true);
sequences = OgaSequences::Create();
tokenizer->Encode(prompt.c_str(), *sequences);
generator->AppendTokenSequences(*sequences);
// Generate response
std::cout << "Output: ";
const int current_token_count = generator->TokenCount();
try {
while (!generator->IsDone()) {
generator->GenerateNextToken();
const auto new_token = generator->GetNextTokens()[0];
std::cout << stream->Decode(new_token) << std::flush;
}
} catch (const std::exception& e) {
std::cout << "\nTerminating generation: " << e.what() << std::endl;
// Rewind to the last valid state
generator->RewindTo(current_token_count);
}
std::cout << "\n\n" << std::endl;
// Optionally rewind to system prompt (clears chat history)
if (rewind) {
generator->RewindTo(prompt_tokens_length);
}
}
}