Home > Technology peripherals > AI > Machine Learning with Rust's Linfa and Polars Libraries: Linear Regression

Machine Learning with Rust's Linfa and Polars Libraries: Linear Regression

WBOY
Release: 2024-03-01 17:25:02
forward
1160 people have browsed it

Machine Learning with Rusts Linfa and Polars Libraries: Linear Regression

In this article, we will use Rust’s Linfa library and Polars library to implement the linear regression algorithm in machine learning.

The Linfa crate aims to provide a comprehensive toolkit for building machine learning applications using Rust.

Polars is a Rust DataFrame library based on the Apache Arrow memory model. Apache Arrow provides efficient column data structures and has gradually become the de facto standard.

In the example below, we use a diabetes data set to train a linear regression algorithm.

Create a new Rust project using the following command:

cargo new machine_learning_linfa
Copy after login

Add the following dependencies to the Cargo.toml file:

[dependencies]linfa = "0.7.0"linfa-linear = "0.7.0"ndarray = "0.15.6"polars = { version = "0.35.4", features = ["ndarray"]}
Copy after login

Create a diabetes_file.csv file in the project root directory and write the data set to the file.

AGESEX BMI BPS1S2S3S4S5S6Y592 32.1101 157 93.2384 4.859887151481 21.687183 103.2 703 3.89186975722 30.593156 93.6414 4.672885141241 25.384198 131.4 405 4.890389206501 23101 192 125.4 524 4.290580135231 22.689139 64.8612 4.18976897362 2290160 99.6503 3.951282138662 26.2114 255 185 564.554.24859263602 32.183179 119.4 424 4.477394110.............
Copy after login

The data set can be downloaded from here: https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt

Write the following code in the src/main.rs file:

use linfa::prelude::*;use linfa::traits::Fit;use linfa_linear::LinearRegression;use ndarray::{ArrayBase, OwnedRepr};use polars::prelude::*; // Import polarsfn main() -> Result> {// 将制表符定义为分隔符let separator = b'\t';let df = polars::prelude::CsvReader::from_path("./diabetes_file.csv")?.infer_schema(None).with_separator(separator).has_header(true).finish()?;println!("{:?}", df);// 提取并转换目标列let age_series = df.column("AGE")?.cast(&DataType::Float64)?;let target = age_series.f64()?;println!("Creating features dataset");let mut features = df.drop("AGE")?;// 遍历列并将每个列强制转换为Float64for col_name in features.get_column_names_owned() {let casted_col = df.column(&col_name)?.cast(&DataType::Float64).expect("Failed to cast column");features.with_column(casted_col)?;}println!("{:?}", df);let features_ndarray: ArrayBase<ownedrepr>, _> =features.to_ndarray::<float64type>(IndexOrder::C)?;let target_ndarray = target.to_ndarray()?.to_owned();let (dataset_training, dataset_validation) =Dataset::new(features_ndarray, target_ndarray).split_with_ratio(0.80);// 训练模型let model = LinearRegression::default().fit(&dataset_training)?;// 预测let pred = model.predict(&dataset_validation);// 评价模型let r2 = pred.r2(&dataset_validation)?;println!("r2 from prediction: {}", r2);Ok(())}</float64type></ownedrepr>
Copy after login

  • Use polar's CSV reader to read the CSV file.
  • Print the dataframe to the console for inspection.
  • Extract the "AGE" column from the DataFrame as the target variable for linear regression. Cast the target column to Float64 (double precision floating point number), which is a common format for numeric data in machine learning.
  • Convert the features DataFrame to narray::ArrayBase (a multidimensional array) for compatibility with linfa. Convert target sequences into arrays that are compatible with the linfa library for machine learning.
  • Split the data set into training and validation sets using an 80-20 ratio, which is a common practice in machine learning for evaluating models on unknown data.
  • Use linfa's linear regression algorithm to train a linear regression model on the training data set.
  • Use the trained model to predict the validation data set.
  • Calculate the R² (coefficient of determination) measure on the validation data set to evaluate the performance of the model. The R² value indicates how closely the regression predictions approximate the actual data points.

Execute cargo run, the results are as follows:

shape: (442, 11)┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐│ AGE ┆ SEX ┆ BMI┆ BP┆ … ┆ S4 ┆ S5 ┆ S6┆ Y ││ --- ┆ --- ┆ ---┆ --- ┆ ┆ ---┆ ---┆ --- ┆ --- ││ i64 ┆ i64 ┆ f64┆ f64 ┆ ┆ f64┆ f64┆ i64 ┆ i64 │╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡│ 59┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0┆ 4.8598 ┆ 87┆ 151 ││ 48┆ 1 ┆ 21.6 ┆ 87.0┆ … ┆ 3.0┆ 3.8918 ┆ 69┆ 75││ 72┆ 2 ┆ 30.5 ┆ 93.0┆ … ┆ 4.0┆ 4.6728 ┆ 85┆ 141 ││ 24┆ 1 ┆ 25.3 ┆ 84.0┆ … ┆ 5.0┆ 4.8903 ┆ 89┆ 206 ││ … ┆ … ┆ …┆ … ┆ … ┆ …┆ …┆ … ┆ … ││ 47┆ 2 ┆ 24.9 ┆ 75.0┆ … ┆ 5.0┆ 4.4427 ┆ 102 ┆ 104 ││ 60┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95┆ 132 ││ 36┆ 1 ┆ 30.0 ┆ 95.0┆ … ┆ 4.79 ┆ 5.1299 ┆ 85┆ 220 ││ 36┆ 1 ┆ 19.6 ┆ 71.0┆ … ┆ 3.0┆ 4.5951 ┆ 92┆ 57│└─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘Creating features datasetshape: (442, 11)┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐│ AGE ┆ SEX ┆ BMI┆ BP┆ … ┆ S4 ┆ S5 ┆ S6┆ Y ││ --- ┆ --- ┆ ---┆ --- ┆ ┆ ---┆ ---┆ --- ┆ --- ││ i64 ┆ i64 ┆ f64┆ f64 ┆ ┆ f64┆ f64┆ i64 ┆ i64 │╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡│ 59┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0┆ 4.8598 ┆ 87┆ 151 ││ 48┆ 1 ┆ 21.6 ┆ 87.0┆ … ┆ 3.0┆ 3.8918 ┆ 69┆ 75││ 72┆ 2 ┆ 30.5 ┆ 93.0┆ … ┆ 4.0┆ 4.6728 ┆ 85┆ 141 ││ 24┆ 1 ┆ 25.3 ┆ 84.0┆ … ┆ 5.0┆ 4.8903 ┆ 89┆ 206 ││ … ┆ … ┆ …┆ … ┆ … ┆ …┆ …┆ … ┆ … ││ 47┆ 2 ┆ 24.9 ┆ 75.0┆ … ┆ 5.0┆ 4.4427 ┆ 102 ┆ 104 ││ 60┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95┆ 132 ││ 36┆ 1 ┆ 30.0 ┆ 95.0┆ … ┆ 4.79 ┆ 5.1299 ┆ 85┆ 220 ││ 36┆ 1 ┆ 19.6 ┆ 71.0┆ … ┆ 3.0┆ 4.5951 ┆ 92┆ 57│└─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘r2 from prediction: 0.15937814745521017
Copy after login

For data scientists who prioritize rapid iteration and rapid prototyping, Rust's compilation time can be a headache. Rust's strong static type system, while good for ensuring type safety and reducing runtime errors, also adds a layer of complexity to the coding process.

The above is the detailed content of Machine Learning with Rust's Linfa and Polars Libraries: Linear Regression. For more information, please follow other related articles on the PHP Chinese website!

Related labels:
source:51cto.com
Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Popular Tutorials
More>
Latest Downloads
More>
Web Effects
Website Source Code
Website Materials
Front End Template