CFCF2181L.LLM Training

省选/NOI-

通过率:0%

AC君温馨提醒

该题目为【codeforces】题库的题目,您提交的代码将被提交至codeforces进行远程评测,并由ACGO抓取测评结果后进行展示。由于远程测评的测评机由其他平台提供,我们无法保证该服务的稳定性,若提交后无反应,请等待一段时间后再进行重试。

题目描述

给定一个文本数据集。你的任务是训练一个大型语言模型(LLM),并找到最小可能损失。不开玩笑。

一个文本数据集由若干文本 t1,t2,,tnt_1, t_2, \ldots, t_n 组成。每个文本 tit_i 是一个由多个 token 组成的序列。我们定义 token 集合 TT 为至少在一个 tit_i 中出现过的所有 token 的集合。此外,对于每个文本 tit_i,还有一个位置集合 Li{1,2,,ti}L_i \subseteq \{1, 2, \ldots, |t_i|\}。当 jLij \in L_i 时,token ti[j]t_i[j] 由 LLM 生成。当 jLij \notin L_i 时,该 token 由用户输入。

我们定义上下文长度为 kk 的 LLM 为概率模型 PkP_k,它定义了下一个 token 的概率分布,并依赖于上下文 ww——一个长度在 00kk(含)之间的、元素来自 TT 的序列。因此概率模型 PkP_k 本质上是一个巨大的概率表,对每个上下文 wTw \in T^{*}0wk0 \leq |w| \leq k 和每个 token nextT\text{next} \in T,都给定 Pk(nextw)P_k(\text{next} | w)。这些概率应满足 0Pk(nextw)10 \leq P_k(\text{next} | w) \leq 1,且 nextTPk(nextw)=1\sum\limits_{\text{next} \in T} P_k(\text{next} | w) = 1

LLM 损失函数如下定义,对于 PkP_k

Lk(Pk)=i=1njLilog2Pk ⁣(ti[j]下一个 token | ti[max(1,jk)j1]上下文)\mathcal{L}_k(P_k) = \sum_{i=1}^{n} \sum_{j\in L_i} -\log_2 P_k\!\left( \underbrace{t_i[j]}_{\text{下一个 token}} \ \middle|\ \underbrace{t_i[\max(1, j-k)\ldots j-1]}_{\text{上下文}} \right)

这里 ti[l..r]=ti[l]ti[l+1]ti[r]t_i[l\,..\,r] = t_i[l] t_i[l+1] \ldots t_i[r] 是从第 ll 到第 rr 个 token 的子串,ti[1..0]t_i[1..0] 表示空串。所以对于每个文本、每个由 LLM 生成的位置,我们将根据前 kk 个 token(或整个前缀,如果长度不足 kk)的子串,加上当前 token 的概率的负对数(以 2 为底)到损失中。如果概率为 0,则负对数视为 ++\infty。该损失函数称为(以 2 为底)的交叉熵损失(Cross Entropy Loss),只针对 LLM 生成的位置。Lk(Pk)\mathcal{L}_k(P_k) 越小,LLM PkP_k 越好。

对于每个 0k<maxi=1..nti0 \leq k < \max\limits_{i=1..n} |t_i|,请计算某个具有上下文长度 kk 的 LLM 所能达到的最小可能的损失 Lk(Pk)\mathcal{L}_k(P_k)。可以证明,这个最小值是可达的且不是无穷大。

输入格式

第一行包含一个整数 nn1n1051 \leq n \leq 10^5),表示数据集中文本的数量。接下来是每个文本的描述。

ii 个文本的描述第一行为一个整数 mim_i1mi31051 \leq m_i \leq 3 \cdot 10^5),表示 tit_i 的长度(mi=tim_i = |t_i|)。

下一行包含 mim_i 个字符串 ti[1]t_{i}[1]ti[2]t_{i}[2]\ldotsti[mi]t_{i}[m_i]1ti[j]51 \leq |t_{i}[j]| \leq 5),为该文本的每个 token。每个 token 由 ASCII 码 33 至 126(可打印字符)的字符构成。

下一行包含一个长度为 mim_i 的字符串 i\ell_i,由字母 U 和 L 组成,编码了 LiL_i。所有位置中,字母 L 处对应由 LLM 生成,U 则由用户输入。所以 Li={ji[j]=L}L_i = \{j\,|\,\ell_i[j] = \texttt{L}\}。保证每个文本最后一个 token 都是 LLM 生成的,即 i[mi]=L\ell_i[m_i] = \texttt{L}

保证所有 mim_i 之和不超过 31053 \cdot 10^5

输出格式

输出 M=maxi=1..nmiM = \max\limits_{i=1..n} m_i 个实数:对于每个 k=0,1,,M1k = 0, 1, \ldots, M-1,输出最小可能损失 Lk(Pk)\mathcal{L}_k(P_k),即所有可能的上下文长度为 kk 的 LLM 的最小损失。

如果你的答案的绝对误差或相对误差不超过 10610^{-6},即对于你的答案 pp 和标准答案 qq,有 pqmax{1,q}106\frac{|p - q|}{\max\{1, |q|\}} \leq 10^{-6},即视为正确。

输入输出样例

  • 输入#1

    4
    5
    1 + 1 = 2
    UUUUL
    5
    1 + 2 = 3
    UUUUL
    5
    2 + 1 = 3
    UUUUL
    5
    2 + 2 = 4
    UUUUL

    输出#1

    6.000000000000
    6.000000000000
    4.000000000000
    4.000000000000
    0.000000000000
首页