Autokeras中標籤編碼、隨機種子對模型性能的影響及復現性策略
Autokeras中的标签处理机制
在机器学习分类任务中,标签编码是数据预处理的关键一步。常见的编码方式包括One-Hot编码和整数编码。对于Autokeras的StructuredDataClassifier,它被设计为处理分类任务,通常期望接收整数形式的类别标签。即使您提供One-Hot编码的标签,Autokeras在内部处理时也会将其视为分类问题,并在其内部管道中进行相应的转换和处理。
实际上,autokeras在接收到整数标签后,会自行将其转换为One-Hot编码形式,以便与通常用于多分类任务的损失函数(如CategoricalCrossentropy)兼容。您可以通过检查clf.outputs[0].in_blocks[0].get_hyper_preprocessors()来验证其预处理器链中是否存在OneHotEncoder对象,以及通过clf.outputs[0].in_blocks[0].loss来确认所使用的损失函数。这意味着,无论您是提供原始的One-Hot编码还是转换后的整数标签,最终模型训练使用的内部标签表示和损失函数通常是一致的。因此,当观察到两者之间存在巨大性能差异(例如从0.40到0.97)时,问题往往不在于标签编码的“正确性”,而在于其他因素。
随机种子与模型复现性
Autokeras作为一种自动化机器学习(AutoML)工具,在寻找最佳模型架构和超参数时,会执行大量的随机操作,例如:
- 超参数搜索空间探索: 不同的随机初始化可能导致搜索算法探索不同的超参数组合。
- 模型权重初始化: 神经网络的初始权重通常是随机的。
- 数据洗牌: 训练数据在每个epoch开始前通常会被随机洗牌。
- Dropout层: Dropout操作本身具有随机性。
这些随机性在每次运行代码时都可能产生不同的结果,尤其是在max_trials(最大尝试次数)参数较小的情况下。当随机性导致模型在超参数搜索阶段选择了一个次优架构或初始化了一个不利的权重集时,即使输入数据和标签处理方式看似正确,也可能导致性能急剧下降。这正是本案例中观察到One-Hot编码直接输入导致低准确率(0.40)而整数编码导致高准确率(0.97)的根本原因——不同的随机种子导致了不同的超参数搜索路径和最终模型。
确保Autokeras模型复现性的策略
为了解决随机性带来的性能波动问题,并确保实验结果的可复现性,我们需要显式地设置随机种子。仅仅在StructuredDataClassifier构造函数中设置seed参数可能不足以完全控制所有随机源。更全面的方法是使用Keras提供的工具来设置全局随机种子。
以下是确保Autokeras模型复现性的推荐步骤:
-
全局设置随机种子: 在脚本的开头,使用keras.utils.set_random_seed()来设置所有涉及Keras和TensorFlow操作的随机种子。
import numpy as np import tensorflow as tf import os import autokeras as ak import keras # 导入keras # 设置随机种子以确保复现性 random_seed = 42 # 选择一个你喜欢的整数 keras.utils.set_random_seed(random_seed) tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True) # 如果使用GPU,可选
-
初始化Autokeras分类器时指定种子和覆盖模式: 在初始化StructuredDataClassifier时,除了设置seed参数外,还建议设置overwrite=True。overwrite=True可以确保每次运行时都会从头开始进行超参数搜索,而不会加载之前运行的结果,从而避免潜在的干扰。
# 初始化结构化数据分类器 # overwrite=True 确保每次运行都重新开始搜索,不加载之前的结果 # seed 参数进一步确保 autokeras 内部的随机性可控 clf = ak.StructuredDataClassifier(overwrite=True, max_trials=10, seed=random_seed)
增加max_trials以稳定结果(可选但推荐):max_trials参数决定了Autokeras尝试的不同模型架构和超参数组合的数量。当max_trials较小(例如默认的10)时,超参数搜索可能不够充分,导致结果对随机种子非常敏感。增加max_trials(例如设置为50或100)可以使搜索过程更全面,从而提高找到稳定且高性能模型的概率,减少不同随机种子带来的结果波动。
优化标签编码实践
尽管Autokeras能够内部处理One-Hot编码,但为了代码的清晰性和与大多数分类API的约定保持一致,建议在将数据传递给StructuredDataClassifier之前,将One-Hot编码的标签转换为整数标签。这简化了tf.data.Dataset.from_generator的output_signature定义,并使标签的含义更加直观。
以下是转换为整数标签的示例代码片段:
import numpy as np import tensorflow as tf import os import autokeras as ak import keras # 设置随机种子 random_seed = 42 keras.utils.set_random_seed(random_seed) N_FEATURES = 8 N_CLASSES = 3 BATCH_SIZE = 100 def get_data_generator(folder_path, batch_size, n_features): """ 获取一个数据生成器,从指定文件夹的.npy文件中分批返回数据。 特征的形状为 (batch_size, n_features)。 标签的形状为 (batch_size,),为整数形式。 """ def data_generator(): files = os.listdir(folder_path) npy_files = [f for f in files if f.endswith('.npy')] for npy_file in npy_files: data = np.load(os.path.join(folder_path, npy_file)) x = data[:, :n_features] y_ohe = data[:, n_features:] y_int = np.argmax(y_ohe, axis=1) # 将One-Hot编码转换为整数标签 for i in range(0, len(x), batch_size): yield x[i:i batch_size], y_int[i:i batch_size] return data_generator train_data_folder = '/home/my_user_name/original_data/train_data_npy' validation_data_folder = '/home/my_user_name/original_data/valid_data_npy' # 创建训练数据集,标签为1D整数 train_dataset = tf.data.Dataset.from_generator( get_data_generator(train_data_folder, BATCH_SIZE, N_FEATURES), output_signature=( tf.TensorSpec(shape=(None, N_FEATURES), dtype=tf.float32), tf.TensorSpec(shape=(None,), dtype=tf.int32) # 标签现在是1D整数 ) ) # 创建验证数据集,标签为1D整数 validation_dataset = tf.data.Dataset.from_generator( get_data_generator(validation_data_folder, BATCH_SIZE, N_FEATURES), output_signature=( tf.TensorSpec(shape=(None, N_FEATURES), dtype=tf.float32), tf.TensorSpec(shape=(None,), dtype=tf.int32) # 标签现在是1D整数 ) ) # 初始化分类器,并设置随机种子和覆盖模式 clf = ak.StructuredDataClassifier(overwrite=True, max_trials=10, seed=random_seed) # 训练分类器 clf.fit(train_dataset, epochs=100) # 评估模型 print("Model evaluation results:", clf.evaluate(validation_dataset)) # 导出并保存模型 (可选) model = clf.export_model() model.save("heca_v2_model_reproducible", save_format='tf')
总结
当Autokeras模型在不同运行中表现出显著性能差异时,即使标签编码方式看似合理,其根本原因也往往是随机种子未被妥善管理。Autokeras的StructuredDataClassifier能够内部处理整数标签并进行One-Hot转换,因此直接提供One-Hot编码的标签通常不是性能低下的直接原因。通过在脚本开头全局设置随机种子、在分类器初始化时指定种子并设置overwrite=True,可以有效地提高模型训练的复现性。此外,适当地增加max_trials参数,以及始终将One-Hot编码的标签转换为整数形式再输入模型,是构建稳定、可信赖的AutoML工作流的最佳实践。
以上是Autokeras中標籤編碼、隨機種子對模型性能的影響及復現性策略的詳細內容。更多資訊請關注PHP中文網其他相關文章!

熱AI工具

Undress AI Tool
免費脫衣圖片

Undresser.AI Undress
人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover
用於從照片中去除衣服的線上人工智慧工具。

Stock Market GPT
人工智慧支援投資研究,做出更明智的決策

熱門文章

熱工具

記事本++7.3.1
好用且免費的程式碼編輯器

SublimeText3漢化版
中文版,非常好用

禪工作室 13.0.1
強大的PHP整合開發環境

Dreamweaver CS6
視覺化網頁開發工具

SublimeText3 Mac版
神級程式碼編輯軟體(SublimeText3)

運行pipinstall-rrequirements.txt可安裝依賴包,建議先創建並激活虛擬環境以避免衝突,確保文件路徑正確且pip已更新,必要時使用--no-deps或--user等選項調整安裝行為。

本教程詳細介紹瞭如何將PEFT LoRA適配器與基礎模型高效合併,生成一個完全獨立的模型。文章指出直接使用transformers.AutoModel加載適配器並手動合併權重是錯誤的,並提供了使用peft庫中merge_and_unload方法的正確流程。此外,教程還強調了處理分詞器的重要性,並討論了PEFT版本兼容性問題及解決方案。

Pytest是Python中簡單強大的測試工具,安裝後按命名規則自動發現測試文件。編寫以test_開頭的函數進行斷言測試,使用@pytest.fixture創建可複用的測試數據,通過pytest.raises驗證異常,支持運行指定測試和多種命令行選項,提升測試效率。

theargparsemodulestherecommondedwaywaytohandlecommand-lineargumentsInpython,提供式刺激,typeValidation,helpmessages anderrornhandling; useSudys.argvforsimplecasesRequeRequeRingminimalSetup。

本文旨在探討Python及NumPy中浮點數計算精度不足的常見問題,解釋其根源在於標準64位浮點數的表示限制。針對需要更高精度的計算場景,文章將詳細介紹並對比mpmath、SymPy和gmpy等高精度數學庫的使用方法、特點及適用場景,幫助讀者選擇合適的工具來解決複雜的精度需求。

獲取當前時間在Python中可通過datetime模塊實現,1.使用datetime.now()獲取本地當前時間,2.用strftime("%Y-%m-%d%H:%M:%S")格式化輸出年月日時分秒,3.通過datetime.now().time()獲取僅時間部分,4.推薦使用datetime.now(timezone.utc)獲取UTC時間,避免使用已棄用的utcnow(),日常操作以datetime.now()結合格式化字符串即可滿足需求。

PyPDF2、pdfplumber和FPDF是Python處理PDF的核心庫。使用PyPDF2可進行文本提取、合併、拆分及加密,如通過PdfReader讀取頁面並調用extract_text()獲取內容;pdfplumber更適合保留佈局的文本提取和表格識別,支持extract_tables()精準抓取表格數據;FPDF(推薦fpdf2)用於生成PDF,通過add_page()、set_font()和cell()構建文檔並輸出。合併PDF時,PdfWriter的append()方法可集成多個文件

Import@contextmanagerfromcontextlibanddefineageneratorfunctionthatyieldsexactlyonce,wherecodebeforeyieldactsasenterandcodeafteryield(preferablyinfinally)actsas__exit__.2.Usethefunctioninawithstatement,wheretheyieldedvalueisaccessibleviaas,andthesetup
