mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-18 18:08:07 +03:00
Add server methods for creating chat domain objects
Also, consolidate all sql into a `db` module
This commit is contained in:
parent
2b9b9b8f1f
commit
109d8271e0
@ -1,7 +1,6 @@
|
||||
use crate::{auth::RequestExt as _, AppState, DbPool, LayoutData, Request, RequestExt as _};
|
||||
use crate::{auth::RequestExt as _, db, AppState, LayoutData, Request, RequestExt as _};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Executor, FromRow};
|
||||
use std::sync::Arc;
|
||||
use surf::http::mime;
|
||||
|
||||
@ -41,23 +40,8 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>) {
|
||||
struct AdminData {
|
||||
#[serde(flatten)]
|
||||
layout: Arc<LayoutData>,
|
||||
users: Vec<User>,
|
||||
signups: Vec<Signup>,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow, Serialize)]
|
||||
pub struct User {
|
||||
pub id: i32,
|
||||
pub github_login: String,
|
||||
pub admin: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow, Serialize)]
|
||||
pub struct Signup {
|
||||
pub id: i32,
|
||||
pub github_login: String,
|
||||
pub email_address: String,
|
||||
pub about: String,
|
||||
users: Vec<db::User>,
|
||||
signups: Vec<db::Signup>,
|
||||
}
|
||||
|
||||
async fn get_admin_page(mut request: Request) -> tide::Result {
|
||||
@ -65,12 +49,8 @@ async fn get_admin_page(mut request: Request) -> tide::Result {
|
||||
|
||||
let data = AdminData {
|
||||
layout: request.layout_data().await?,
|
||||
users: sqlx::query_as("SELECT * FROM users ORDER BY github_login ASC")
|
||||
.fetch_all(request.db())
|
||||
.await?,
|
||||
signups: sqlx::query_as("SELECT * FROM signups ORDER BY id DESC")
|
||||
.fetch_all(request.db())
|
||||
.await?,
|
||||
users: request.db().get_all_users().await?,
|
||||
signups: request.db().get_all_signups().await?,
|
||||
};
|
||||
|
||||
Ok(tide::Response::builder(200)
|
||||
@ -96,7 +76,7 @@ async fn post_user(mut request: Request) -> tide::Result {
|
||||
.unwrap_or(&form.github_login);
|
||||
|
||||
if !github_login.is_empty() {
|
||||
create_user(request.db(), github_login, form.admin).await?;
|
||||
request.db().create_user(github_login, form.admin).await?;
|
||||
}
|
||||
|
||||
Ok(tide::Redirect::new("/admin").into())
|
||||
@ -116,11 +96,7 @@ async fn put_user(mut request: Request) -> tide::Result {
|
||||
|
||||
request
|
||||
.db()
|
||||
.execute(
|
||||
sqlx::query("UPDATE users SET admin = $1 WHERE id = $2;")
|
||||
.bind(body.admin)
|
||||
.bind(user_id),
|
||||
)
|
||||
.set_user_is_admin(db::UserId(user_id), body.admin)
|
||||
.await?;
|
||||
|
||||
Ok(tide::Response::builder(200).build())
|
||||
@ -128,33 +104,14 @@ async fn put_user(mut request: Request) -> tide::Result {
|
||||
|
||||
async fn delete_user(request: Request) -> tide::Result {
|
||||
request.require_admin().await?;
|
||||
|
||||
let user_id = request.param("id")?.parse::<i32>()?;
|
||||
request
|
||||
.db()
|
||||
.execute(sqlx::query("DELETE FROM users WHERE id = $1;").bind(user_id))
|
||||
.await?;
|
||||
|
||||
let user_id = db::UserId(request.param("id")?.parse::<i32>()?);
|
||||
request.db().delete_user(user_id).await?;
|
||||
Ok(tide::Redirect::new("/admin").into())
|
||||
}
|
||||
|
||||
pub async fn create_user(db: &DbPool, github_login: &str, admin: bool) -> tide::Result<i32> {
|
||||
let id: i32 =
|
||||
sqlx::query_scalar("INSERT INTO users (github_login, admin) VALUES ($1, $2) RETURNING id;")
|
||||
.bind(github_login)
|
||||
.bind(admin)
|
||||
.fetch_one(db)
|
||||
.await?;
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
async fn delete_signup(request: Request) -> tide::Result {
|
||||
request.require_admin().await?;
|
||||
let signup_id = request.param("id")?.parse::<i32>()?;
|
||||
request
|
||||
.db()
|
||||
.execute(sqlx::query("DELETE FROM signups WHERE id = $1;").bind(signup_id))
|
||||
.await?;
|
||||
|
||||
let signup_id = db::SignupId(request.param("id")?.parse::<i32>()?);
|
||||
request.db().delete_signup(signup_id).await?;
|
||||
Ok(tide::Redirect::new("/admin").into())
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
use super::errors::TideResultExt;
|
||||
use crate::{github, rpc, AppState, DbPool, Request, RequestExt as _};
|
||||
use super::{
|
||||
db::{self, UserId},
|
||||
errors::TideResultExt,
|
||||
};
|
||||
use crate::{github, rpc, AppState, Request, RequestExt as _};
|
||||
use anyhow::{anyhow, Context};
|
||||
use async_std::stream::StreamExt;
|
||||
use async_trait::async_trait;
|
||||
pub use oauth2::basic::BasicClient as Client;
|
||||
use oauth2::{
|
||||
@ -14,7 +16,6 @@ use scrypt::{
|
||||
Scrypt,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
use std::{borrow::Cow, convert::TryFrom, sync::Arc};
|
||||
use surf::Url;
|
||||
use tide::Server;
|
||||
@ -34,9 +35,6 @@ pub struct User {
|
||||
|
||||
pub struct VerifyToken;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct UserId(pub i32);
|
||||
|
||||
#[async_trait]
|
||||
impl tide::Middleware<Arc<AppState>> for VerifyToken {
|
||||
async fn handle(
|
||||
@ -51,33 +49,28 @@ impl tide::Middleware<Arc<AppState>> for VerifyToken {
|
||||
.as_str()
|
||||
.split_whitespace();
|
||||
|
||||
let user_id: i32 = auth_header
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("missing user id in authorization header"))?
|
||||
.parse()?;
|
||||
let user_id = UserId(
|
||||
auth_header
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("missing user id in authorization header"))?
|
||||
.parse()?,
|
||||
);
|
||||
let access_token = auth_header
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("missing access token in authorization header"))?;
|
||||
|
||||
let state = request.state().clone();
|
||||
|
||||
let mut password_hashes =
|
||||
sqlx::query_scalar::<_, String>("SELECT hash FROM access_tokens WHERE user_id = $1")
|
||||
.bind(&user_id)
|
||||
.fetch_many(&state.db);
|
||||
|
||||
let mut credentials_valid = false;
|
||||
while let Some(password_hash) = password_hashes.next().await {
|
||||
if let either::Either::Right(password_hash) = password_hash? {
|
||||
if verify_access_token(&access_token, &password_hash)? {
|
||||
credentials_valid = true;
|
||||
break;
|
||||
}
|
||||
for password_hash in state.db.get_access_token_hashes(user_id).await? {
|
||||
if verify_access_token(&access_token, &password_hash)? {
|
||||
credentials_valid = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if credentials_valid {
|
||||
request.set_ext(UserId(user_id));
|
||||
request.set_ext(user_id);
|
||||
Ok(next.run(request).await)
|
||||
} else {
|
||||
Err(anyhow!("invalid credentials").into())
|
||||
@ -94,25 +87,12 @@ pub trait RequestExt {
|
||||
impl RequestExt for Request {
|
||||
async fn current_user(&self) -> tide::Result<Option<User>> {
|
||||
if let Some(details) = self.session().get::<github::User>(CURRENT_GITHUB_USER) {
|
||||
#[derive(FromRow)]
|
||||
struct UserRow {
|
||||
admin: bool,
|
||||
}
|
||||
|
||||
let user_row: Option<UserRow> =
|
||||
sqlx::query_as("SELECT admin FROM users WHERE github_login = $1")
|
||||
.bind(&details.login)
|
||||
.fetch_optional(self.db())
|
||||
.await?;
|
||||
|
||||
let is_insider = user_row.is_some();
|
||||
let is_admin = user_row.map_or(false, |row| row.admin);
|
||||
|
||||
let user = self.db().get_user_by_github_login(&details.login).await?;
|
||||
Ok(Some(User {
|
||||
github_login: details.login,
|
||||
avatar_url: details.avatar_url,
|
||||
is_insider,
|
||||
is_admin,
|
||||
is_insider: user.is_some(),
|
||||
is_admin: user.map_or(false, |user| user.admin),
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
@ -265,9 +245,9 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
|
||||
.await
|
||||
.context("failed to fetch user")?;
|
||||
|
||||
let user_id: Option<i32> = sqlx::query_scalar("SELECT id from users where github_login = $1")
|
||||
.bind(&user_details.login)
|
||||
.fetch_optional(request.db())
|
||||
let user = request
|
||||
.db()
|
||||
.get_user_by_github_login(&user_details.login)
|
||||
.await?;
|
||||
|
||||
request
|
||||
@ -276,8 +256,8 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
|
||||
|
||||
// When signing in from the native app, generate a new access token for the current user. Return
|
||||
// a redirect so that the user's browser sends this access token to the locally-running app.
|
||||
if let Some((user_id, app_sign_in_params)) = user_id.zip(query.native_app_sign_in_params) {
|
||||
let access_token = create_access_token(request.db(), user_id).await?;
|
||||
if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
|
||||
let access_token = create_access_token(request.db(), user.id()).await?;
|
||||
let native_app_public_key =
|
||||
zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
|
||||
.context("failed to parse app public key")?;
|
||||
@ -287,7 +267,9 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
|
||||
|
||||
return Ok(tide::Redirect::new(&format!(
|
||||
"http://127.0.0.1:{}?user_id={}&access_token={}",
|
||||
app_sign_in_params.native_app_port, user_id, encrypted_access_token,
|
||||
app_sign_in_params.native_app_port,
|
||||
user.id().0,
|
||||
encrypted_access_token,
|
||||
))
|
||||
.into());
|
||||
}
|
||||
@ -300,14 +282,11 @@ async fn post_sign_out(mut request: Request) -> tide::Result {
|
||||
Ok(tide::Redirect::new("/").into())
|
||||
}
|
||||
|
||||
pub async fn create_access_token(db: &DbPool, user_id: i32) -> tide::Result<String> {
|
||||
pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
|
||||
let access_token = zed_auth::random_token();
|
||||
let access_token_hash =
|
||||
hash_access_token(&access_token).context("failed to hash access token")?;
|
||||
sqlx::query("INSERT INTO access_tokens (user_id, hash) values ($1, $2)")
|
||||
.bind(user_id)
|
||||
.bind(access_token_hash)
|
||||
.fetch_optional(db)
|
||||
db.create_access_token_hash(user_id, access_token_hash)
|
||||
.await?;
|
||||
Ok(access_token)
|
||||
}
|
||||
|
276
server/src/db.rs
Normal file
276
server/src/db.rs
Normal file
@ -0,0 +1,276 @@
|
||||
use serde::Serialize;
|
||||
use sqlx::{FromRow, Result};
|
||||
|
||||
pub use async_sqlx_session::PostgresSessionStore as SessionStore;
|
||||
pub use sqlx::postgres::PgPoolOptions as DbOptions;
|
||||
|
||||
pub struct Db(pub sqlx::PgPool);
|
||||
|
||||
#[derive(Debug, FromRow, Serialize)]
|
||||
pub struct User {
|
||||
id: i32,
|
||||
pub github_login: String,
|
||||
pub admin: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow, Serialize)]
|
||||
pub struct Signup {
|
||||
id: i32,
|
||||
pub github_login: String,
|
||||
pub email_address: String,
|
||||
pub about: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct ChannelMessage {
|
||||
id: i32,
|
||||
sender_id: i32,
|
||||
body: String,
|
||||
sent_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct UserId(pub i32);
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct OrgId(pub i32);
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct ChannelId(pub i32);
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct SignupId(pub i32);
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct MessageId(pub i32);
|
||||
|
||||
impl Db {
|
||||
// signups
|
||||
|
||||
pub async fn create_signup(
|
||||
&self,
|
||||
github_login: &str,
|
||||
email_address: &str,
|
||||
about: &str,
|
||||
) -> Result<SignupId> {
|
||||
let query = "
|
||||
INSERT INTO signups (github_login, email_address, about)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(github_login)
|
||||
.bind(email_address)
|
||||
.bind(about)
|
||||
.fetch_one(&self.0)
|
||||
.await
|
||||
.map(SignupId)
|
||||
}
|
||||
|
||||
pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
|
||||
let query = "SELECT * FROM users ORDER BY github_login ASC";
|
||||
sqlx::query_as(query).fetch_all(&self.0).await
|
||||
}
|
||||
|
||||
pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
|
||||
let query = "DELETE FROM signups WHERE id = $1";
|
||||
sqlx::query(query)
|
||||
.bind(id.0)
|
||||
.execute(&self.0)
|
||||
.await
|
||||
.map(drop)
|
||||
}
|
||||
|
||||
// users
|
||||
|
||||
pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
|
||||
let query = "
|
||||
INSERT INTO users (github_login, admin)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(github_login)
|
||||
.bind(admin)
|
||||
.fetch_one(&self.0)
|
||||
.await
|
||||
.map(UserId)
|
||||
}
|
||||
|
||||
pub async fn get_all_users(&self) -> Result<Vec<User>> {
|
||||
let query = "SELECT * FROM users ORDER BY github_login ASC";
|
||||
sqlx::query_as(query).fetch_all(&self.0).await
|
||||
}
|
||||
|
||||
pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
|
||||
let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
|
||||
sqlx::query_as(query)
|
||||
.bind(github_login)
|
||||
.fetch_optional(&self.0)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
|
||||
let query = "UPDATE users SET admin = $1 WHERE id = $2";
|
||||
sqlx::query(query)
|
||||
.bind(is_admin)
|
||||
.bind(id.0)
|
||||
.execute(&self.0)
|
||||
.await
|
||||
.map(drop)
|
||||
}
|
||||
|
||||
pub async fn delete_user(&self, id: UserId) -> Result<()> {
|
||||
let query = "DELETE FROM users WHERE id = $1;";
|
||||
sqlx::query(query)
|
||||
.bind(id.0)
|
||||
.execute(&self.0)
|
||||
.await
|
||||
.map(drop)
|
||||
}
|
||||
|
||||
// access tokens
|
||||
|
||||
pub async fn create_access_token_hash(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
access_token_hash: String,
|
||||
) -> Result<()> {
|
||||
let query = "
|
||||
INSERT INTO access_tokens (user_id, hash)
|
||||
VALUES ($1, $2)
|
||||
";
|
||||
sqlx::query(query)
|
||||
.bind(user_id.0 as i32)
|
||||
.bind(access_token_hash)
|
||||
.execute(&self.0)
|
||||
.await
|
||||
.map(drop)
|
||||
}
|
||||
|
||||
pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
|
||||
let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
|
||||
sqlx::query_scalar::<_, String>(query)
|
||||
.bind(user_id.0 as i32)
|
||||
.fetch_all(&self.0)
|
||||
.await
|
||||
}
|
||||
|
||||
// orgs
|
||||
|
||||
pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
|
||||
let query = "
|
||||
INSERT INTO orgs (name, slug)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(name)
|
||||
.bind(slug)
|
||||
.fetch_one(&self.0)
|
||||
.await
|
||||
.map(OrgId)
|
||||
}
|
||||
|
||||
pub async fn add_org_member(&self, org_id: OrgId, user_id: UserId) -> Result<()> {
|
||||
let query = "
|
||||
INSERT INTO org_memberships (org_id, user_id)
|
||||
VALUES ($1, $2)
|
||||
";
|
||||
sqlx::query(query)
|
||||
.bind(org_id.0)
|
||||
.bind(user_id.0)
|
||||
.execute(&self.0)
|
||||
.await
|
||||
.map(drop)
|
||||
}
|
||||
|
||||
// channels
|
||||
|
||||
pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
|
||||
let query = "
|
||||
INSERT INTO channels (owner_id, owner_is_user, name)
|
||||
VALUES ($1, false, $2)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(org_id.0)
|
||||
.bind(name)
|
||||
.fetch_one(&self.0)
|
||||
.await
|
||||
.map(ChannelId)
|
||||
}
|
||||
|
||||
pub async fn add_channel_member(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
user_id: UserId,
|
||||
is_admin: bool,
|
||||
) -> Result<()> {
|
||||
let query = "
|
||||
INSERT INTO channel_memberships (channel_id, user_id, admin)
|
||||
VALUES ($1, $2, $3)
|
||||
";
|
||||
sqlx::query(query)
|
||||
.bind(channel_id.0)
|
||||
.bind(user_id.0)
|
||||
.bind(is_admin)
|
||||
.execute(&self.0)
|
||||
.await
|
||||
.map(drop)
|
||||
}
|
||||
|
||||
// messages
|
||||
|
||||
pub async fn create_channel_message(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
sender_id: UserId,
|
||||
body: &str,
|
||||
) -> Result<MessageId> {
|
||||
let query = "
|
||||
INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
|
||||
VALUES ($1, $2, $3, NOW()::timestamp)
|
||||
RETURNING id
|
||||
";
|
||||
sqlx::query_scalar(query)
|
||||
.bind(channel_id.0)
|
||||
.bind(sender_id.0)
|
||||
.bind(body)
|
||||
.fetch_one(&self.0)
|
||||
.await
|
||||
.map(MessageId)
|
||||
}
|
||||
|
||||
pub async fn get_recent_channel_messages(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
count: usize,
|
||||
) -> Result<Vec<ChannelMessage>> {
|
||||
let query = "
|
||||
SELECT id, sender_id, body, sent_at
|
||||
FROM channel_messages
|
||||
WHERE channel_id = $1
|
||||
LIMIT $2
|
||||
";
|
||||
sqlx::query_as(query)
|
||||
.bind(channel_id.0)
|
||||
.bind(count as i64)
|
||||
.fetch_all(&self.0)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for Db {
|
||||
type Target = sqlx::PgPool;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub fn id(&self) -> UserId {
|
||||
UserId(self.id)
|
||||
}
|
||||
}
|
@ -3,7 +3,6 @@ use crate::{
|
||||
};
|
||||
use comrak::ComrakOptions;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::Executor as _;
|
||||
use std::sync::Arc;
|
||||
use tide::{http::mime, log, Server};
|
||||
|
||||
@ -76,14 +75,7 @@ async fn post_signup(mut request: Request) -> tide::Result {
|
||||
// Save signup in the database
|
||||
request
|
||||
.db()
|
||||
.execute(
|
||||
sqlx::query(
|
||||
"INSERT INTO signups (github_login, email_address, about) VALUES ($1, $2, $3);",
|
||||
)
|
||||
.bind(&form.github_login)
|
||||
.bind(&form.email_address)
|
||||
.bind(&form.about),
|
||||
)
|
||||
.create_signup(&form.github_login, &form.email_address, &form.about)
|
||||
.await?;
|
||||
|
||||
let layout_data = request.layout_data().await?;
|
||||
|
@ -1,6 +1,7 @@
|
||||
mod admin;
|
||||
mod assets;
|
||||
mod auth;
|
||||
mod db;
|
||||
mod env;
|
||||
mod errors;
|
||||
mod expiring;
|
||||
@ -13,15 +14,14 @@ mod tests;
|
||||
|
||||
use self::errors::TideResultExt as _;
|
||||
use anyhow::{Context, Result};
|
||||
use async_sqlx_session::PostgresSessionStore;
|
||||
use async_std::{net::TcpListener, sync::RwLock as AsyncRwLock};
|
||||
use async_trait::async_trait;
|
||||
use auth::RequestExt as _;
|
||||
use db::{Db, DbOptions};
|
||||
use handlebars::{Handlebars, TemplateRenderError};
|
||||
use parking_lot::RwLock;
|
||||
use rust_embed::RustEmbed;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::postgres::{PgPool, PgPoolOptions};
|
||||
use std::sync::Arc;
|
||||
use surf::http::cookies::SameSite;
|
||||
use tide::{log, sessions::SessionMiddleware};
|
||||
@ -29,7 +29,6 @@ use tide_compress::CompressMiddleware;
|
||||
use zrpc::Peer;
|
||||
|
||||
type Request = tide::Request<Arc<AppState>>;
|
||||
type DbPool = PgPool;
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "templates"]
|
||||
@ -47,7 +46,7 @@ pub struct Config {
|
||||
}
|
||||
|
||||
pub struct AppState {
|
||||
db: sqlx::PgPool,
|
||||
db: Db,
|
||||
handlebars: RwLock<Handlebars<'static>>,
|
||||
auth_client: auth::Client,
|
||||
github_client: Arc<github::AppClient>,
|
||||
@ -58,11 +57,11 @@ pub struct AppState {
|
||||
|
||||
impl AppState {
|
||||
async fn new(config: Config) -> tide::Result<Arc<Self>> {
|
||||
let db = PgPoolOptions::new()
|
||||
let db = Db(DbOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&config.database_url)
|
||||
.await
|
||||
.context("failed to connect to postgres database")?;
|
||||
.context("failed to connect to postgres database")?);
|
||||
|
||||
let github_client =
|
||||
github::AppClient::new(config.github_app_id, config.github_private_key.clone());
|
||||
@ -117,7 +116,7 @@ impl AppState {
|
||||
#[async_trait]
|
||||
trait RequestExt {
|
||||
async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>>;
|
||||
fn db(&self) -> &DbPool;
|
||||
fn db(&self) -> &Db;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@ -131,7 +130,7 @@ impl RequestExt for Request {
|
||||
Ok(self.ext::<Arc<LayoutData>>().unwrap().clone())
|
||||
}
|
||||
|
||||
fn db(&self) -> &DbPool {
|
||||
fn db(&self) -> &Db {
|
||||
&self.state().db
|
||||
}
|
||||
}
|
||||
@ -173,7 +172,7 @@ pub async fn run_server(
|
||||
web.with(CompressMiddleware::new());
|
||||
web.with(
|
||||
SessionMiddleware::new(
|
||||
PostgresSessionStore::new_with_table_name(&state.config.database_url, "sessions")
|
||||
db::SessionStore::new_with_table_name(&state.config.database_url, "sessions")
|
||||
.await
|
||||
.unwrap(),
|
||||
state.config.session_secret.as_bytes(),
|
||||
|
@ -1,6 +1,8 @@
|
||||
use crate::auth::{self, UserId};
|
||||
|
||||
use super::{auth::PeerExt as _, AppState};
|
||||
use super::{
|
||||
auth::{self, PeerExt as _},
|
||||
db::UserId,
|
||||
AppState,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use async_std::task;
|
||||
use async_tungstenite::{
|
||||
@ -37,7 +39,7 @@ pub struct State {
|
||||
}
|
||||
|
||||
struct ConnectionState {
|
||||
_user_id: i32,
|
||||
_user_id: UserId,
|
||||
worktrees: HashSet<u64>,
|
||||
}
|
||||
|
||||
@ -68,7 +70,7 @@ impl WorktreeState {
|
||||
|
||||
impl State {
|
||||
// Add a new connection associated with a given user.
|
||||
pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: i32) {
|
||||
pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: UserId) {
|
||||
self.connections.insert(
|
||||
connection_id,
|
||||
ConnectionState {
|
||||
@ -291,7 +293,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
|
||||
let upgrade_receiver = http_res.recv_upgrade().await;
|
||||
let addr = request.remote().unwrap_or("unknown").to_string();
|
||||
let state = request.state().clone();
|
||||
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?.0;
|
||||
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
|
||||
task::spawn(async move {
|
||||
if let Some(stream) = upgrade_receiver.await {
|
||||
let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
|
||||
@ -310,7 +312,7 @@ pub async fn handle_connection<Conn>(
|
||||
state: Arc<AppState>,
|
||||
addr: String,
|
||||
stream: Conn,
|
||||
user_id: i32,
|
||||
user_id: UserId,
|
||||
) where
|
||||
Conn: 'static
|
||||
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
admin, auth, github,
|
||||
auth, db, github,
|
||||
rpc::{self, add_rpc_routes},
|
||||
AppState, Config,
|
||||
};
|
||||
@ -9,7 +9,6 @@ use rand::prelude::*;
|
||||
use serde_json::json;
|
||||
use sqlx::{
|
||||
migrate::{MigrateDatabase, Migrator},
|
||||
postgres::PgPoolOptions,
|
||||
Executor as _, Postgres,
|
||||
};
|
||||
use std::{path::Path, sync::Arc};
|
||||
@ -499,9 +498,7 @@ impl TestServer {
|
||||
}
|
||||
|
||||
async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> Client {
|
||||
let user_id = admin::create_user(&self.app_state.db, name, false)
|
||||
.await
|
||||
.unwrap();
|
||||
let user_id = self.app_state.db.create_user(name, false).await.unwrap();
|
||||
let lang_registry = Arc::new(LanguageRegistry::new());
|
||||
let client = Client::new(lang_registry.clone());
|
||||
let mut client_router = ForegroundRouter::new();
|
||||
@ -532,18 +529,20 @@ impl TestServer {
|
||||
config.database_url = format!("postgres://postgres@localhost/{}", db_name);
|
||||
|
||||
Self::create_db(&config.database_url).await;
|
||||
let db = PgPoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&config.database_url)
|
||||
.await
|
||||
.expect("failed to connect to postgres database");
|
||||
let db = db::Db(
|
||||
db::DbOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&config.database_url)
|
||||
.await
|
||||
.expect("failed to connect to postgres database"),
|
||||
);
|
||||
let migrator = Migrator::new(Path::new(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/migrations"
|
||||
)))
|
||||
.await
|
||||
.unwrap();
|
||||
migrator.run(&db).await.unwrap();
|
||||
migrator.run(&db.0).await.unwrap();
|
||||
|
||||
let github_client = github::AppClient::test();
|
||||
Arc::new(AppState {
|
||||
|
Loading…
Reference in New Issue
Block a user