Add server methods for creating chat domain objects

Also, consolidate all sql into a `db` module
This commit is contained in:
Max Brunsfeld 2021-08-05 19:06:50 -07:00
parent 2b9b9b8f1f
commit 109d8271e0
7 changed files with 344 additions and 140 deletions

View File

@ -1,7 +1,6 @@
use crate::{auth::RequestExt as _, AppState, DbPool, LayoutData, Request, RequestExt as _};
use crate::{auth::RequestExt as _, db, AppState, LayoutData, Request, RequestExt as _};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use sqlx::{Executor, FromRow};
use std::sync::Arc;
use surf::http::mime;
@ -41,23 +40,8 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>) {
struct AdminData {
#[serde(flatten)]
layout: Arc<LayoutData>,
users: Vec<User>,
signups: Vec<Signup>,
}
#[derive(Debug, FromRow, Serialize)]
pub struct User {
pub id: i32,
pub github_login: String,
pub admin: bool,
}
#[derive(Debug, FromRow, Serialize)]
pub struct Signup {
pub id: i32,
pub github_login: String,
pub email_address: String,
pub about: String,
users: Vec<db::User>,
signups: Vec<db::Signup>,
}
async fn get_admin_page(mut request: Request) -> tide::Result {
@ -65,12 +49,8 @@ async fn get_admin_page(mut request: Request) -> tide::Result {
let data = AdminData {
layout: request.layout_data().await?,
users: sqlx::query_as("SELECT * FROM users ORDER BY github_login ASC")
.fetch_all(request.db())
.await?,
signups: sqlx::query_as("SELECT * FROM signups ORDER BY id DESC")
.fetch_all(request.db())
.await?,
users: request.db().get_all_users().await?,
signups: request.db().get_all_signups().await?,
};
Ok(tide::Response::builder(200)
@ -96,7 +76,7 @@ async fn post_user(mut request: Request) -> tide::Result {
.unwrap_or(&form.github_login);
if !github_login.is_empty() {
create_user(request.db(), github_login, form.admin).await?;
request.db().create_user(github_login, form.admin).await?;
}
Ok(tide::Redirect::new("/admin").into())
@ -116,11 +96,7 @@ async fn put_user(mut request: Request) -> tide::Result {
request
.db()
.execute(
sqlx::query("UPDATE users SET admin = $1 WHERE id = $2;")
.bind(body.admin)
.bind(user_id),
)
.set_user_is_admin(db::UserId(user_id), body.admin)
.await?;
Ok(tide::Response::builder(200).build())
@ -128,33 +104,14 @@ async fn put_user(mut request: Request) -> tide::Result {
async fn delete_user(request: Request) -> tide::Result {
request.require_admin().await?;
let user_id = request.param("id")?.parse::<i32>()?;
request
.db()
.execute(sqlx::query("DELETE FROM users WHERE id = $1;").bind(user_id))
.await?;
let user_id = db::UserId(request.param("id")?.parse::<i32>()?);
request.db().delete_user(user_id).await?;
Ok(tide::Redirect::new("/admin").into())
}
pub async fn create_user(db: &DbPool, github_login: &str, admin: bool) -> tide::Result<i32> {
let id: i32 =
sqlx::query_scalar("INSERT INTO users (github_login, admin) VALUES ($1, $2) RETURNING id;")
.bind(github_login)
.bind(admin)
.fetch_one(db)
.await?;
Ok(id)
}
async fn delete_signup(request: Request) -> tide::Result {
request.require_admin().await?;
let signup_id = request.param("id")?.parse::<i32>()?;
request
.db()
.execute(sqlx::query("DELETE FROM signups WHERE id = $1;").bind(signup_id))
.await?;
let signup_id = db::SignupId(request.param("id")?.parse::<i32>()?);
request.db().delete_signup(signup_id).await?;
Ok(tide::Redirect::new("/admin").into())
}

View File

@ -1,7 +1,9 @@
use super::errors::TideResultExt;
use crate::{github, rpc, AppState, DbPool, Request, RequestExt as _};
use super::{
db::{self, UserId},
errors::TideResultExt,
};
use crate::{github, rpc, AppState, Request, RequestExt as _};
use anyhow::{anyhow, Context};
use async_std::stream::StreamExt;
use async_trait::async_trait;
pub use oauth2::basic::BasicClient as Client;
use oauth2::{
@ -14,7 +16,6 @@ use scrypt::{
Scrypt,
};
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use std::{borrow::Cow, convert::TryFrom, sync::Arc};
use surf::Url;
use tide::Server;
@ -34,9 +35,6 @@ pub struct User {
pub struct VerifyToken;
#[derive(Clone, Copy)]
pub struct UserId(pub i32);
#[async_trait]
impl tide::Middleware<Arc<AppState>> for VerifyToken {
async fn handle(
@ -51,33 +49,28 @@ impl tide::Middleware<Arc<AppState>> for VerifyToken {
.as_str()
.split_whitespace();
let user_id: i32 = auth_header
.next()
.ok_or_else(|| anyhow!("missing user id in authorization header"))?
.parse()?;
let user_id = UserId(
auth_header
.next()
.ok_or_else(|| anyhow!("missing user id in authorization header"))?
.parse()?,
);
let access_token = auth_header
.next()
.ok_or_else(|| anyhow!("missing access token in authorization header"))?;
let state = request.state().clone();
let mut password_hashes =
sqlx::query_scalar::<_, String>("SELECT hash FROM access_tokens WHERE user_id = $1")
.bind(&user_id)
.fetch_many(&state.db);
let mut credentials_valid = false;
while let Some(password_hash) = password_hashes.next().await {
if let either::Either::Right(password_hash) = password_hash? {
if verify_access_token(&access_token, &password_hash)? {
credentials_valid = true;
break;
}
for password_hash in state.db.get_access_token_hashes(user_id).await? {
if verify_access_token(&access_token, &password_hash)? {
credentials_valid = true;
break;
}
}
if credentials_valid {
request.set_ext(UserId(user_id));
request.set_ext(user_id);
Ok(next.run(request).await)
} else {
Err(anyhow!("invalid credentials").into())
@ -94,25 +87,12 @@ pub trait RequestExt {
impl RequestExt for Request {
async fn current_user(&self) -> tide::Result<Option<User>> {
if let Some(details) = self.session().get::<github::User>(CURRENT_GITHUB_USER) {
#[derive(FromRow)]
struct UserRow {
admin: bool,
}
let user_row: Option<UserRow> =
sqlx::query_as("SELECT admin FROM users WHERE github_login = $1")
.bind(&details.login)
.fetch_optional(self.db())
.await?;
let is_insider = user_row.is_some();
let is_admin = user_row.map_or(false, |row| row.admin);
let user = self.db().get_user_by_github_login(&details.login).await?;
Ok(Some(User {
github_login: details.login,
avatar_url: details.avatar_url,
is_insider,
is_admin,
is_insider: user.is_some(),
is_admin: user.map_or(false, |user| user.admin),
}))
} else {
Ok(None)
@ -265,9 +245,9 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
.await
.context("failed to fetch user")?;
let user_id: Option<i32> = sqlx::query_scalar("SELECT id from users where github_login = $1")
.bind(&user_details.login)
.fetch_optional(request.db())
let user = request
.db()
.get_user_by_github_login(&user_details.login)
.await?;
request
@ -276,8 +256,8 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
// When signing in from the native app, generate a new access token for the current user. Return
// a redirect so that the user's browser sends this access token to the locally-running app.
if let Some((user_id, app_sign_in_params)) = user_id.zip(query.native_app_sign_in_params) {
let access_token = create_access_token(request.db(), user_id).await?;
if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
let access_token = create_access_token(request.db(), user.id()).await?;
let native_app_public_key =
zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
.context("failed to parse app public key")?;
@ -287,7 +267,9 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
return Ok(tide::Redirect::new(&format!(
"http://127.0.0.1:{}?user_id={}&access_token={}",
app_sign_in_params.native_app_port, user_id, encrypted_access_token,
app_sign_in_params.native_app_port,
user.id().0,
encrypted_access_token,
))
.into());
}
@ -300,14 +282,11 @@ async fn post_sign_out(mut request: Request) -> tide::Result {
Ok(tide::Redirect::new("/").into())
}
pub async fn create_access_token(db: &DbPool, user_id: i32) -> tide::Result<String> {
pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
let access_token = zed_auth::random_token();
let access_token_hash =
hash_access_token(&access_token).context("failed to hash access token")?;
sqlx::query("INSERT INTO access_tokens (user_id, hash) values ($1, $2)")
.bind(user_id)
.bind(access_token_hash)
.fetch_optional(db)
db.create_access_token_hash(user_id, access_token_hash)
.await?;
Ok(access_token)
}

276
server/src/db.rs Normal file
View File

@ -0,0 +1,276 @@
use serde::Serialize;
use sqlx::{FromRow, Result};
pub use async_sqlx_session::PostgresSessionStore as SessionStore;
pub use sqlx::postgres::PgPoolOptions as DbOptions;
pub struct Db(pub sqlx::PgPool);
#[derive(Debug, FromRow, Serialize)]
pub struct User {
id: i32,
pub github_login: String,
pub admin: bool,
}
#[derive(Debug, FromRow, Serialize)]
pub struct Signup {
id: i32,
pub github_login: String,
pub email_address: String,
pub about: String,
}
#[derive(Debug, FromRow)]
pub struct ChannelMessage {
id: i32,
sender_id: i32,
body: String,
sent_at: i64,
}
#[derive(Clone, Copy)]
pub struct UserId(pub i32);
#[derive(Clone, Copy)]
pub struct OrgId(pub i32);
#[derive(Clone, Copy)]
pub struct ChannelId(pub i32);
#[derive(Clone, Copy)]
pub struct SignupId(pub i32);
#[derive(Clone, Copy)]
pub struct MessageId(pub i32);
impl Db {
// signups
pub async fn create_signup(
&self,
github_login: &str,
email_address: &str,
about: &str,
) -> Result<SignupId> {
let query = "
INSERT INTO signups (github_login, email_address, about)
VALUES ($1, $2, $3)
RETURNING id
";
sqlx::query_scalar(query)
.bind(github_login)
.bind(email_address)
.bind(about)
.fetch_one(&self.0)
.await
.map(SignupId)
}
pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
let query = "SELECT * FROM users ORDER BY github_login ASC";
sqlx::query_as(query).fetch_all(&self.0).await
}
pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
let query = "DELETE FROM signups WHERE id = $1";
sqlx::query(query)
.bind(id.0)
.execute(&self.0)
.await
.map(drop)
}
// users
pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
let query = "
INSERT INTO users (github_login, admin)
VALUES ($1, $2)
RETURNING id
";
sqlx::query_scalar(query)
.bind(github_login)
.bind(admin)
.fetch_one(&self.0)
.await
.map(UserId)
}
pub async fn get_all_users(&self) -> Result<Vec<User>> {
let query = "SELECT * FROM users ORDER BY github_login ASC";
sqlx::query_as(query).fetch_all(&self.0).await
}
pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
sqlx::query_as(query)
.bind(github_login)
.fetch_optional(&self.0)
.await
}
pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
let query = "UPDATE users SET admin = $1 WHERE id = $2";
sqlx::query(query)
.bind(is_admin)
.bind(id.0)
.execute(&self.0)
.await
.map(drop)
}
pub async fn delete_user(&self, id: UserId) -> Result<()> {
let query = "DELETE FROM users WHERE id = $1;";
sqlx::query(query)
.bind(id.0)
.execute(&self.0)
.await
.map(drop)
}
// access tokens
pub async fn create_access_token_hash(
&self,
user_id: UserId,
access_token_hash: String,
) -> Result<()> {
let query = "
INSERT INTO access_tokens (user_id, hash)
VALUES ($1, $2)
";
sqlx::query(query)
.bind(user_id.0 as i32)
.bind(access_token_hash)
.execute(&self.0)
.await
.map(drop)
}
pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
sqlx::query_scalar::<_, String>(query)
.bind(user_id.0 as i32)
.fetch_all(&self.0)
.await
}
// orgs
pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
let query = "
INSERT INTO orgs (name, slug)
VALUES ($1, $2)
RETURNING id
";
sqlx::query_scalar(query)
.bind(name)
.bind(slug)
.fetch_one(&self.0)
.await
.map(OrgId)
}
pub async fn add_org_member(&self, org_id: OrgId, user_id: UserId) -> Result<()> {
let query = "
INSERT INTO org_memberships (org_id, user_id)
VALUES ($1, $2)
";
sqlx::query(query)
.bind(org_id.0)
.bind(user_id.0)
.execute(&self.0)
.await
.map(drop)
}
// channels
pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
let query = "
INSERT INTO channels (owner_id, owner_is_user, name)
VALUES ($1, false, $2)
RETURNING id
";
sqlx::query_scalar(query)
.bind(org_id.0)
.bind(name)
.fetch_one(&self.0)
.await
.map(ChannelId)
}
pub async fn add_channel_member(
&self,
channel_id: ChannelId,
user_id: UserId,
is_admin: bool,
) -> Result<()> {
let query = "
INSERT INTO channel_memberships (channel_id, user_id, admin)
VALUES ($1, $2, $3)
";
sqlx::query(query)
.bind(channel_id.0)
.bind(user_id.0)
.bind(is_admin)
.execute(&self.0)
.await
.map(drop)
}
// messages
pub async fn create_channel_message(
&self,
channel_id: ChannelId,
sender_id: UserId,
body: &str,
) -> Result<MessageId> {
let query = "
INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
VALUES ($1, $2, $3, NOW()::timestamp)
RETURNING id
";
sqlx::query_scalar(query)
.bind(channel_id.0)
.bind(sender_id.0)
.bind(body)
.fetch_one(&self.0)
.await
.map(MessageId)
}
pub async fn get_recent_channel_messages(
&self,
channel_id: ChannelId,
count: usize,
) -> Result<Vec<ChannelMessage>> {
let query = "
SELECT id, sender_id, body, sent_at
FROM channel_messages
WHERE channel_id = $1
LIMIT $2
";
sqlx::query_as(query)
.bind(channel_id.0)
.bind(count as i64)
.fetch_all(&self.0)
.await
}
}
impl std::ops::Deref for Db {
type Target = sqlx::PgPool;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl User {
pub fn id(&self) -> UserId {
UserId(self.id)
}
}

View File

@ -3,7 +3,6 @@ use crate::{
};
use comrak::ComrakOptions;
use serde::{Deserialize, Serialize};
use sqlx::Executor as _;
use std::sync::Arc;
use tide::{http::mime, log, Server};
@ -76,14 +75,7 @@ async fn post_signup(mut request: Request) -> tide::Result {
// Save signup in the database
request
.db()
.execute(
sqlx::query(
"INSERT INTO signups (github_login, email_address, about) VALUES ($1, $2, $3);",
)
.bind(&form.github_login)
.bind(&form.email_address)
.bind(&form.about),
)
.create_signup(&form.github_login, &form.email_address, &form.about)
.await?;
let layout_data = request.layout_data().await?;

View File

@ -1,6 +1,7 @@
mod admin;
mod assets;
mod auth;
mod db;
mod env;
mod errors;
mod expiring;
@ -13,15 +14,14 @@ mod tests;
use self::errors::TideResultExt as _;
use anyhow::{Context, Result};
use async_sqlx_session::PostgresSessionStore;
use async_std::{net::TcpListener, sync::RwLock as AsyncRwLock};
use async_trait::async_trait;
use auth::RequestExt as _;
use db::{Db, DbOptions};
use handlebars::{Handlebars, TemplateRenderError};
use parking_lot::RwLock;
use rust_embed::RustEmbed;
use serde::{Deserialize, Serialize};
use sqlx::postgres::{PgPool, PgPoolOptions};
use std::sync::Arc;
use surf::http::cookies::SameSite;
use tide::{log, sessions::SessionMiddleware};
@ -29,7 +29,6 @@ use tide_compress::CompressMiddleware;
use zrpc::Peer;
type Request = tide::Request<Arc<AppState>>;
type DbPool = PgPool;
#[derive(RustEmbed)]
#[folder = "templates"]
@ -47,7 +46,7 @@ pub struct Config {
}
pub struct AppState {
db: sqlx::PgPool,
db: Db,
handlebars: RwLock<Handlebars<'static>>,
auth_client: auth::Client,
github_client: Arc<github::AppClient>,
@ -58,11 +57,11 @@ pub struct AppState {
impl AppState {
async fn new(config: Config) -> tide::Result<Arc<Self>> {
let db = PgPoolOptions::new()
let db = Db(DbOptions::new()
.max_connections(5)
.connect(&config.database_url)
.await
.context("failed to connect to postgres database")?;
.context("failed to connect to postgres database")?);
let github_client =
github::AppClient::new(config.github_app_id, config.github_private_key.clone());
@ -117,7 +116,7 @@ impl AppState {
#[async_trait]
trait RequestExt {
async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>>;
fn db(&self) -> &DbPool;
fn db(&self) -> &Db;
}
#[async_trait]
@ -131,7 +130,7 @@ impl RequestExt for Request {
Ok(self.ext::<Arc<LayoutData>>().unwrap().clone())
}
fn db(&self) -> &DbPool {
fn db(&self) -> &Db {
&self.state().db
}
}
@ -173,7 +172,7 @@ pub async fn run_server(
web.with(CompressMiddleware::new());
web.with(
SessionMiddleware::new(
PostgresSessionStore::new_with_table_name(&state.config.database_url, "sessions")
db::SessionStore::new_with_table_name(&state.config.database_url, "sessions")
.await
.unwrap(),
state.config.session_secret.as_bytes(),

View File

@ -1,6 +1,8 @@
use crate::auth::{self, UserId};
use super::{auth::PeerExt as _, AppState};
use super::{
auth::{self, PeerExt as _},
db::UserId,
AppState,
};
use anyhow::anyhow;
use async_std::task;
use async_tungstenite::{
@ -37,7 +39,7 @@ pub struct State {
}
struct ConnectionState {
_user_id: i32,
_user_id: UserId,
worktrees: HashSet<u64>,
}
@ -68,7 +70,7 @@ impl WorktreeState {
impl State {
// Add a new connection associated with a given user.
pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: i32) {
pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: UserId) {
self.connections.insert(
connection_id,
ConnectionState {
@ -291,7 +293,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let upgrade_receiver = http_res.recv_upgrade().await;
let addr = request.remote().unwrap_or("unknown").to_string();
let state = request.state().clone();
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?.0;
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
task::spawn(async move {
if let Some(stream) = upgrade_receiver.await {
let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
@ -310,7 +312,7 @@ pub async fn handle_connection<Conn>(
state: Arc<AppState>,
addr: String,
stream: Conn,
user_id: i32,
user_id: UserId,
) where
Conn: 'static
+ futures::Sink<WebSocketMessage, Error = WebSocketError>

View File

@ -1,5 +1,5 @@
use crate::{
admin, auth, github,
auth, db, github,
rpc::{self, add_rpc_routes},
AppState, Config,
};
@ -9,7 +9,6 @@ use rand::prelude::*;
use serde_json::json;
use sqlx::{
migrate::{MigrateDatabase, Migrator},
postgres::PgPoolOptions,
Executor as _, Postgres,
};
use std::{path::Path, sync::Arc};
@ -499,9 +498,7 @@ impl TestServer {
}
async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> Client {
let user_id = admin::create_user(&self.app_state.db, name, false)
.await
.unwrap();
let user_id = self.app_state.db.create_user(name, false).await.unwrap();
let lang_registry = Arc::new(LanguageRegistry::new());
let client = Client::new(lang_registry.clone());
let mut client_router = ForegroundRouter::new();
@ -532,18 +529,20 @@ impl TestServer {
config.database_url = format!("postgres://postgres@localhost/{}", db_name);
Self::create_db(&config.database_url).await;
let db = PgPoolOptions::new()
.max_connections(5)
.connect(&config.database_url)
.await
.expect("failed to connect to postgres database");
let db = db::Db(
db::DbOptions::new()
.max_connections(5)
.connect(&config.database_url)
.await
.expect("failed to connect to postgres database"),
);
let migrator = Migrator::new(Path::new(concat!(
env!("CARGO_MANIFEST_DIR"),
"/migrations"
)))
.await
.unwrap();
migrator.run(&db).await.unwrap();
migrator.run(&db.0).await.unwrap();
let github_client = github::AppClient::test();
Arc::new(AppState {